184 lines
6.8 KiB
Python
184 lines
6.8 KiB
Python
from datetime import datetime, timedelta
|
|
import secrets
|
|
import requests
|
|
import logging
|
|
from jose import jwt
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import RedirectResponse
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.auth import create_access_token, create_refresh_token, hash_password, verify_password, decode_token
|
|
from app.core.config import settings
|
|
from app.deps import get_db, get_current_user
|
|
from app.models import User, UserRole, LoginEvent
|
|
from app.schemas import TokenPair, UserCreate, UserOut, TokenRefresh
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
@router.post("/register", response_model=UserOut)
|
|
def register(payload: UserCreate, db: Session = Depends(get_db)):
|
|
existing = db.query(User).filter(User.email == payload.email).first()
|
|
if existing:
|
|
raise HTTPException(status_code=400, detail="Email already registered")
|
|
user = User(email=payload.email, password_hash=hash_password(payload.password), role=UserRole.learner)
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
return user
|
|
|
|
|
|
@router.post("/login", response_model=TokenPair)
|
|
def login(form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
|
user = db.query(User).filter(User.email == form.username).first()
|
|
if not user or not verify_password(form.password, user.password_hash):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User inactive")
|
|
|
|
# Track login event
|
|
db.add(LoginEvent(user_id=user.id))
|
|
db.commit()
|
|
|
|
return TokenPair(
|
|
access_token=create_access_token(str(user.id)),
|
|
refresh_token=create_refresh_token(str(user.id)),
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=TokenPair)
|
|
def refresh(payload: TokenRefresh):
|
|
decoded = decode_token(payload.refresh_token, "refresh")
|
|
if decoded.get("type") != "refresh":
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|
|
user_id = decoded.get("sub")
|
|
return TokenPair(
|
|
access_token=create_access_token(str(user_id)),
|
|
refresh_token=create_refresh_token(str(user_id)),
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserOut)
|
|
def me(user: User = Depends(get_current_user)):
|
|
return user
|
|
|
|
|
|
def build_google_oauth_url(state: str) -> str:
|
|
params = {
|
|
"client_id": settings.google_client_id,
|
|
"redirect_uri": settings.google_redirect_uri,
|
|
"response_type": "code",
|
|
"scope": "openid email profile",
|
|
"access_type": "online",
|
|
"include_granted_scopes": "true",
|
|
"state": state,
|
|
"prompt": "select_account",
|
|
}
|
|
base = "https://accounts.google.com/o/oauth2/v2/auth"
|
|
query = "&".join([f"{k}={requests.utils.quote(str(v))}" for k, v in params.items()])
|
|
return f"{base}?{query}"
|
|
|
|
|
|
@router.get("/google/login")
|
|
def google_login():
|
|
if not settings.google_client_id or not settings.google_client_secret:
|
|
raise HTTPException(status_code=400, detail="Google OAuth not configured")
|
|
nonce = secrets.token_urlsafe(16)
|
|
state_payload = {
|
|
"nonce": nonce,
|
|
"exp": int((datetime.utcnow() + timedelta(minutes=10)).timestamp()),
|
|
"type": "oauth_state",
|
|
}
|
|
state = jwt.encode(state_payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
|
auth_url = build_google_oauth_url(state)
|
|
logger.info(f"Redirecting to Google OAuth: {auth_url[:100]}...")
|
|
return RedirectResponse(auth_url, status_code=302)
|
|
|
|
|
|
@router.get("/google/callback")
|
|
def google_callback(code: str | None = None, state: str | None = None, db: Session = Depends(get_db)):
|
|
logger.info("Google OAuth callback received")
|
|
if not code or not state:
|
|
logger.error(f"Missing code or state: code={bool(code)}, state={bool(state)}")
|
|
raise HTTPException(status_code=400, detail="Missing code or state")
|
|
try:
|
|
decoded = jwt.decode(state, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
|
|
except Exception as e:
|
|
logger.error(f"Failed to decode state: {e}")
|
|
raise HTTPException(status_code=400, detail="Invalid state")
|
|
if decoded.get("type") != "oauth_state":
|
|
logger.error("State type mismatch")
|
|
raise HTTPException(status_code=400, detail="Invalid state")
|
|
|
|
token_res = requests.post(
|
|
"https://oauth2.googleapis.com/token",
|
|
data={
|
|
"code": code,
|
|
"client_id": settings.google_client_id,
|
|
"client_secret": settings.google_client_secret,
|
|
"redirect_uri": settings.google_redirect_uri,
|
|
"grant_type": "authorization_code",
|
|
},
|
|
timeout=10,
|
|
)
|
|
if token_res.status_code != 200:
|
|
raise HTTPException(status_code=400, detail="Failed to exchange code")
|
|
token_data = token_res.json()
|
|
access_token = token_data.get("access_token")
|
|
if not access_token:
|
|
raise HTTPException(status_code=400, detail="Missing access token")
|
|
|
|
userinfo_res = requests.get(
|
|
"https://www.googleapis.com/oauth2/v3/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
timeout=10,
|
|
)
|
|
if userinfo_res.status_code != 200:
|
|
raise HTTPException(status_code=400, detail="Failed to fetch user info")
|
|
profile = userinfo_res.json()
|
|
email = profile.get("email")
|
|
sub = profile.get("sub")
|
|
picture = profile.get("picture")
|
|
if not email:
|
|
raise HTTPException(status_code=400, detail="Email not available")
|
|
|
|
user = db.query(User).filter(User.email == email).first()
|
|
if not user:
|
|
user = User(
|
|
email=email,
|
|
password_hash=hash_password(secrets.token_urlsafe(24)),
|
|
role=UserRole.learner,
|
|
oauth_provider="google",
|
|
oauth_subject=sub,
|
|
picture_url=picture,
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
else:
|
|
if user.oauth_provider != "google":
|
|
user.oauth_provider = "google"
|
|
if not user.oauth_subject:
|
|
user.oauth_subject = sub
|
|
if picture and not user.picture_url:
|
|
user.picture_url = picture
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
# Track login event
|
|
db.add(LoginEvent(user_id=user.id))
|
|
db.commit()
|
|
|
|
token_pair = TokenPair(
|
|
access_token=create_access_token(str(user.id)),
|
|
refresh_token=create_refresh_token(str(user.id)),
|
|
)
|
|
redirect_url = f"{settings.frontend_url}/auth/callback?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
|
logger.info(f"OAuth successful for user {user.email}, redirecting to frontend")
|
|
return RedirectResponse(redirect_url, status_code=302)
|
|
|