89 lines
2.5 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from typing import List, Optional
from auth_utils import get_current_user
from chat_db_utils import (
create_conversation,
get_conversation,
get_user_conversations,
send_message,
get_messages,
)
router = APIRouter(prefix="/conversations", tags=["chat"])
class ConversationCreate(BaseModel):
user_ids: List[int]
is_group: bool = False
name: Optional[str] = None
class MessageCreate(BaseModel):
content: str
@router.post("")
def create_conversation_endpoint(
data: ConversationCreate,
current_user: dict = Depends(get_current_user)
):
"""Create a new conversation (private or group chat)"""
user_ids = data.user_ids
if current_user["user_id"] not in user_ids:
user_ids.append(current_user["user_id"])
return create_conversation(
user_ids=user_ids,
is_group=data.is_group,
name=data.name,
created_by=current_user["user_id"]
)
@router.get("")
def get_conversations_endpoint(
current_user: dict = Depends(get_current_user)
):
"""Get all conversations for current user"""
return get_user_conversations(current_user["user_id"])
@router.get("/{conversation_id}")
def get_conversation_endpoint(
conversation_id: int,
current_user: dict = Depends(get_current_user)
):
"""Get conversation details"""
conversation = get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
return conversation
@router.post("/{conversation_id}/messages")
def send_message_endpoint(
conversation_id: int,
data: MessageCreate,
current_user: dict = Depends(get_current_user)
):
"""Send a message in a conversation"""
result = send_message(conversation_id, current_user["user_id"], data.content)
if "error" in result:
raise HTTPException(status_code=403, detail=result["error"])
return result
@router.get("/{conversation_id}/messages")
def get_messages_endpoint(
conversation_id: int,
limit: int = Query(50, le=100),
before_id: Optional[int] = None,
current_user: dict = Depends(get_current_user)
):
"""Get messages from a conversation"""
result = get_messages(conversation_id, current_user["user_id"], limit, before_id)
if isinstance(result, dict) and "error" in result:
raise HTTPException(status_code=403, detail=result["error"])
return result