from datetime import datetime, timedelta import secrets import requests import logging from jose import jwt from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session from app.auth import create_access_token, create_refresh_token, hash_password, verify_password, decode_token from app.core.config import settings from app.deps import get_db, get_current_user from app.models import User, UserRole, LoginEvent from app.schemas import TokenPair, UserCreate, UserOut, TokenRefresh logger = logging.getLogger(__name__) router = APIRouter(prefix="/auth", tags=["auth"]) @router.post("/register", response_model=UserOut) def register(payload: UserCreate, db: Session = Depends(get_db)): existing = db.query(User).filter(User.email == payload.email).first() if existing: raise HTTPException(status_code=400, detail="Email already registered") user = User(email=payload.email, password_hash=hash_password(payload.password), role=UserRole.learner) db.add(user) db.commit() db.refresh(user) return user @router.post("/login", response_model=TokenPair) def login(form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): user = db.query(User).filter(User.email == form.username).first() if not user or not verify_password(form.password, user.password_hash): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") if not user.is_active: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User inactive") # Track login event db.add(LoginEvent(user_id=user.id)) db.commit() return TokenPair( access_token=create_access_token(str(user.id)), refresh_token=create_refresh_token(str(user.id)), ) @router.post("/refresh", response_model=TokenPair) def refresh(payload: TokenRefresh): decoded = decode_token(payload.refresh_token, "refresh") if decoded.get("type") != "refresh": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") user_id = decoded.get("sub") return TokenPair( access_token=create_access_token(str(user_id)), refresh_token=create_refresh_token(str(user_id)), ) @router.get("/me", response_model=UserOut) def me(user: User = Depends(get_current_user)): return user def build_google_oauth_url(state: str) -> str: params = { "client_id": settings.google_client_id, "redirect_uri": settings.google_redirect_uri, "response_type": "code", "scope": "openid email profile", "access_type": "online", "include_granted_scopes": "true", "state": state, "prompt": "select_account", } base = "https://accounts.google.com/o/oauth2/v2/auth" query = "&".join([f"{k}={requests.utils.quote(str(v))}" for k, v in params.items()]) return f"{base}?{query}" @router.get("/google/login") def google_login(): if not settings.google_client_id or not settings.google_client_secret: raise HTTPException(status_code=400, detail="Google OAuth not configured") nonce = secrets.token_urlsafe(16) state_payload = { "nonce": nonce, "exp": int((datetime.utcnow() + timedelta(minutes=10)).timestamp()), "type": "oauth_state", } state = jwt.encode(state_payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) auth_url = build_google_oauth_url(state) logger.info(f"Redirecting to Google OAuth: {auth_url[:100]}...") return RedirectResponse(auth_url, status_code=302) @router.get("/google/callback") def google_callback(code: str | None = None, state: str | None = None, db: Session = Depends(get_db)): logger.info("Google OAuth callback received") if not code or not state: logger.error(f"Missing code or state: code={bool(code)}, state={bool(state)}") raise HTTPException(status_code=400, detail="Missing code or state") try: decoded = jwt.decode(state, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]) except Exception as e: logger.error(f"Failed to decode state: {e}") raise HTTPException(status_code=400, detail="Invalid state") if decoded.get("type") != "oauth_state": logger.error("State type mismatch") raise HTTPException(status_code=400, detail="Invalid state") token_res = requests.post( "https://oauth2.googleapis.com/token", data={ "code": code, "client_id": settings.google_client_id, "client_secret": settings.google_client_secret, "redirect_uri": settings.google_redirect_uri, "grant_type": "authorization_code", }, timeout=10, ) if token_res.status_code != 200: raise HTTPException(status_code=400, detail="Failed to exchange code") token_data = token_res.json() access_token = token_data.get("access_token") if not access_token: raise HTTPException(status_code=400, detail="Missing access token") userinfo_res = requests.get( "https://www.googleapis.com/oauth2/v3/userinfo", headers={"Authorization": f"Bearer {access_token}"}, timeout=10, ) if userinfo_res.status_code != 200: raise HTTPException(status_code=400, detail="Failed to fetch user info") profile = userinfo_res.json() email = profile.get("email") sub = profile.get("sub") picture = profile.get("picture") if not email: raise HTTPException(status_code=400, detail="Email not available") user = db.query(User).filter(User.email == email).first() if not user: user = User( email=email, password_hash=hash_password(secrets.token_urlsafe(24)), role=UserRole.learner, oauth_provider="google", oauth_subject=sub, picture_url=picture, ) db.add(user) db.commit() db.refresh(user) else: if user.oauth_provider != "google": user.oauth_provider = "google" if not user.oauth_subject: user.oauth_subject = sub if picture and not user.picture_url: user.picture_url = picture db.commit() db.refresh(user) # Track login event db.add(LoginEvent(user_id=user.id)) db.commit() token_pair = TokenPair( access_token=create_access_token(str(user.id)), refresh_token=create_refresh_token(str(user.id)), ) redirect_url = f"{settings.frontend_url}/auth/callback?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" logger.info(f"OAuth successful for user {user.email}, redirecting to frontend") return RedirectResponse(redirect_url, status_code=302)