172 lines
6.7 KiB
Python
172 lines
6.7 KiB
Python
from app.db import get_db_connection
|
|
from app.schemas import MessageResponse, ConversationResponse
|
|
from datetime import datetime
|
|
|
|
class ChatService:
|
|
"""Handle chat messages and conversations"""
|
|
|
|
@staticmethod
|
|
def send_message(sender_id: int, conversation_id: int, content: str) -> MessageResponse:
|
|
"""Send a message in a conversation"""
|
|
with get_db_connection() as conn:
|
|
cur = conn.cursor()
|
|
|
|
# Verify user is in this conversation
|
|
cur.execute(
|
|
"SELECT user_id_1, user_id_2 FROM conversations WHERE id = %s",
|
|
(conversation_id,)
|
|
)
|
|
row = cur.fetchone()
|
|
if not row or (sender_id != row[0] and sender_id != row[1]):
|
|
raise ValueError("Not authorized to message in this conversation")
|
|
|
|
# Insert message
|
|
cur.execute(
|
|
"INSERT INTO messages (conversation_id, sender_id, content) VALUES (%s, %s, %s) RETURNING id, created_at",
|
|
(conversation_id, sender_id, content)
|
|
)
|
|
result = cur.fetchone()
|
|
message_id = result[0]
|
|
created_at = result[1]
|
|
|
|
# Update conversation updated_at
|
|
cur.execute(
|
|
"UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
|
(conversation_id,)
|
|
)
|
|
conn.commit()
|
|
|
|
return MessageResponse(
|
|
id=message_id,
|
|
conversation_id=conversation_id,
|
|
sender_id=sender_id,
|
|
content=content,
|
|
created_at=created_at.isoformat()
|
|
)
|
|
|
|
@staticmethod
|
|
def get_conversations(user_id: int) -> list:
|
|
"""Get all conversations for a user"""
|
|
with get_db_connection() as conn:
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""SELECT id, user_id_1, user_id_2, created_at, updated_at
|
|
FROM conversations
|
|
WHERE user_id_1 = %s OR user_id_2 = %s
|
|
ORDER BY updated_at DESC""",
|
|
(user_id, user_id)
|
|
)
|
|
|
|
conversations = []
|
|
for row in cur.fetchall():
|
|
conv_id, user_1, user_2, created_at, updated_at = row
|
|
other_user_id = user_2 if user_1 == user_id else user_1
|
|
|
|
# Get other user's display name and photo
|
|
cur.execute(
|
|
"""SELECT p.display_name, ph.file_path
|
|
FROM profiles p
|
|
LEFT JOIN photos ph ON p.id = ph.profile_id
|
|
WHERE p.user_id = %s
|
|
ORDER BY ph.display_order
|
|
LIMIT 1""",
|
|
(other_user_id,)
|
|
)
|
|
profile_row = cur.fetchone()
|
|
other_user_name = profile_row[0] if profile_row else "Unknown"
|
|
other_user_photo = profile_row[1] if profile_row and profile_row[1] else None
|
|
|
|
# Get latest message
|
|
cur.execute(
|
|
"SELECT content FROM messages WHERE conversation_id = %s ORDER BY created_at DESC LIMIT 1",
|
|
(conv_id,)
|
|
)
|
|
msg_row = cur.fetchone()
|
|
latest_msg = msg_row[0] if msg_row else ""
|
|
|
|
# Get unread message count (messages not read by current user and not sent by them)
|
|
cur.execute(
|
|
"SELECT COUNT(*) FROM messages WHERE conversation_id = %s AND sender_id != %s AND read_at IS NULL",
|
|
(conv_id, user_id)
|
|
)
|
|
unread_count = cur.fetchone()[0]
|
|
|
|
conversations.append(ConversationResponse(
|
|
id=conv_id,
|
|
user_id_1=user_1,
|
|
user_id_2=user_2,
|
|
other_user_id=other_user_id,
|
|
other_user_display_name=other_user_name,
|
|
other_user_photo=other_user_photo,
|
|
latest_message=latest_msg,
|
|
unread_count=unread_count,
|
|
created_at=created_at.isoformat()
|
|
))
|
|
|
|
return conversations
|
|
|
|
@staticmethod
|
|
def get_messages(user_id: int, conversation_id: int, limit: int = 50) -> list:
|
|
"""Get messages from a conversation"""
|
|
with get_db_connection() as conn:
|
|
cur = conn.cursor()
|
|
|
|
# Verify user is in this conversation
|
|
cur.execute(
|
|
"SELECT user_id_1, user_id_2 FROM conversations WHERE id = %s",
|
|
(conversation_id,)
|
|
)
|
|
row = cur.fetchone()
|
|
if not row or (user_id != row[0] and user_id != row[1]):
|
|
raise ValueError("Not authorized to view this conversation")
|
|
|
|
# Fetch messages
|
|
cur.execute(
|
|
"""SELECT id, conversation_id, sender_id, content, created_at
|
|
FROM messages
|
|
WHERE conversation_id = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT %s""",
|
|
(conversation_id, limit)
|
|
)
|
|
|
|
messages = []
|
|
for row in cur.fetchall():
|
|
messages.append(MessageResponse(
|
|
id=row[0],
|
|
conversation_id=row[1],
|
|
sender_id=row[2],
|
|
content=row[3],
|
|
created_at=row[4].isoformat()
|
|
))
|
|
|
|
return list(reversed(messages)) # Return in chronological order
|
|
|
|
@staticmethod
|
|
def mark_messages_as_read(user_id: int, conversation_id: int) -> dict:
|
|
"""Mark all unread messages in a conversation as read"""
|
|
with get_db_connection() as conn:
|
|
cur = conn.cursor()
|
|
|
|
# Verify user is in this conversation
|
|
cur.execute(
|
|
"SELECT user_id_1, user_id_2 FROM conversations WHERE id = %s",
|
|
(conversation_id,)
|
|
)
|
|
row = cur.fetchone()
|
|
if not row or (user_id != row[0] and user_id != row[1]):
|
|
raise ValueError("Not authorized to view this conversation")
|
|
|
|
# Mark all unread messages from other user as read
|
|
cur.execute(
|
|
"""UPDATE messages
|
|
SET read_at = CURRENT_TIMESTAMP
|
|
WHERE conversation_id = %s
|
|
AND sender_id != %s
|
|
AND read_at IS NULL""",
|
|
(conversation_id, user_id)
|
|
)
|
|
conn.commit()
|
|
|
|
return {"message": "Messages marked as read"}
|