89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
from auth_utils import get_current_user
|
|
from chat_db_utils import (
|
|
create_conversation,
|
|
get_conversation,
|
|
get_user_conversations,
|
|
send_message,
|
|
get_messages,
|
|
)
|
|
|
|
router = APIRouter(prefix="/conversations", tags=["chat"])
|
|
|
|
|
|
class ConversationCreate(BaseModel):
|
|
user_ids: List[int]
|
|
is_group: bool = False
|
|
name: Optional[str] = None
|
|
|
|
|
|
class MessageCreate(BaseModel):
|
|
content: str
|
|
|
|
|
|
@router.post("")
|
|
def create_conversation_endpoint(
|
|
data: ConversationCreate,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""Create a new conversation (private or group chat)"""
|
|
user_ids = data.user_ids
|
|
if current_user["user_id"] not in user_ids:
|
|
user_ids.append(current_user["user_id"])
|
|
|
|
return create_conversation(
|
|
user_ids=user_ids,
|
|
is_group=data.is_group,
|
|
name=data.name,
|
|
created_by=current_user["user_id"]
|
|
)
|
|
|
|
|
|
@router.get("")
|
|
def get_conversations_endpoint(
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""Get all conversations for current user"""
|
|
return get_user_conversations(current_user["user_id"])
|
|
|
|
|
|
@router.get("/{conversation_id}")
|
|
def get_conversation_endpoint(
|
|
conversation_id: int,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""Get conversation details"""
|
|
conversation = get_conversation(conversation_id)
|
|
if not conversation:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
return conversation
|
|
|
|
|
|
@router.post("/{conversation_id}/messages")
|
|
def send_message_endpoint(
|
|
conversation_id: int,
|
|
data: MessageCreate,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""Send a message in a conversation"""
|
|
result = send_message(conversation_id, current_user["user_id"], data.content)
|
|
if "error" in result:
|
|
raise HTTPException(status_code=403, detail=result["error"])
|
|
return result
|
|
|
|
|
|
@router.get("/{conversation_id}/messages")
|
|
def get_messages_endpoint(
|
|
conversation_id: int,
|
|
limit: int = Query(50, le=100),
|
|
before_id: Optional[int] = None,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""Get messages from a conversation"""
|
|
result = get_messages(conversation_id, current_user["user_id"], limit, before_id)
|
|
if isinstance(result, dict) and "error" in result:
|
|
raise HTTPException(status_code=403, detail=result["error"])
|
|
return result
|