my-recipes/backend/chat_db_utils.py

240 lines
8.6 KiB
Python

import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Optional
def get_db_connection():
"""Get database connection"""
return psycopg2.connect(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", "5432")),
database=os.getenv("DB_NAME", "recipes_db"),
user=os.getenv("DB_USER", "recipes_user"),
password=os.getenv("DB_PASSWORD", "recipes_password"),
)
# ============= Conversations & Messages =============
def create_conversation(user_ids: List[int], is_group: bool = False, name: Optional[str] = None, created_by: int = None):
"""Create a new conversation"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
# For private chats, check if conversation already exists
if not is_group and len(user_ids) == 2:
cur.execute(
"""
SELECT c.id FROM conversations c
JOIN conversation_members cm1 ON c.id = cm1.conversation_id
JOIN conversation_members cm2 ON c.id = cm2.conversation_id
WHERE c.is_group = FALSE
AND cm1.user_id = %s AND cm2.user_id = %s
""",
(user_ids[0], user_ids[1])
)
existing = cur.fetchone()
if existing:
return get_conversation(existing["id"])
# Create conversation
cur.execute(
"""
INSERT INTO conversations (name, is_group, created_by)
VALUES (%s, %s, %s)
RETURNING id, name, is_group, created_by, created_at
""",
(name, is_group, created_by)
)
conversation = dict(cur.fetchone())
conversation_id = conversation["id"]
# Add members
for user_id in user_ids:
cur.execute(
"INSERT INTO conversation_members (conversation_id, user_id) VALUES (%s, %s)",
(conversation_id, user_id)
)
conn.commit()
# Return conversation with conversation_id field
conversation["conversation_id"] = conversation["id"]
return conversation
finally:
cur.close()
conn.close()
def get_conversation(conversation_id: int):
"""Get conversation details"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute(
"SELECT id, name, is_group, created_by, created_at FROM conversations WHERE id = %s",
(conversation_id,)
)
conversation = cur.fetchone()
if not conversation:
return None
# Get members
cur.execute(
"""
SELECT u.id, u.username, u.display_name
FROM conversation_members cm
JOIN users u ON u.id = cm.user_id
WHERE cm.conversation_id = %s
""",
(conversation_id,)
)
members = [dict(row) for row in cur.fetchall()]
result = dict(conversation)
result["members"] = members
return result
finally:
cur.close()
conn.close()
def get_user_conversations(user_id: int):
"""Get all conversations for a user"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute(
"""
SELECT c.id AS conversation_id, c.name, c.is_group, c.created_at,
(SELECT COUNT(*) FROM messages m WHERE m.conversation_id = c.id AND m.created_at > cm.last_read_at) AS unread_count,
(SELECT m.content FROM messages m WHERE m.conversation_id = c.id ORDER BY m.created_at DESC LIMIT 1) AS last_message,
(SELECT m.created_at FROM messages m WHERE m.conversation_id = c.id ORDER BY m.created_at DESC LIMIT 1) AS last_message_at
FROM conversations c
JOIN conversation_members cm ON c.id = cm.conversation_id
WHERE cm.user_id = %s
ORDER BY last_message_at DESC NULLS LAST, c.created_at DESC
""",
(user_id,)
)
conversations = [dict(row) for row in cur.fetchall()]
# Get members for each conversation and add other_member_name for private chats
for conv in conversations:
cur.execute(
"""
SELECT u.id, u.username, u.display_name, u.email
FROM conversation_members cm
JOIN users u ON u.id = cm.user_id
WHERE cm.conversation_id = %s AND u.id != %s
""",
(conv["conversation_id"], user_id)
)
members = [dict(row) for row in cur.fetchall()]
conv["members"] = members
# For private chats, add other_member_name
if not conv["is_group"] and len(members) > 0:
conv["other_member_name"] = members[0].get("display_name") or members[0].get("username") or members[0].get("email")
return conversations
finally:
cur.close()
conn.close()
def send_message(conversation_id: int, sender_id: int, content: str):
"""Send a message in a conversation"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
# Verify user is member of conversation
cur.execute(
"SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
(conversation_id, sender_id)
)
if not cur.fetchone():
return {"error": "Not a member of this conversation"}
cur.execute(
"""
INSERT INTO messages (conversation_id, sender_id, content)
VALUES (%s, %s, %s)
RETURNING id, conversation_id, sender_id, content, created_at
""",
(conversation_id, sender_id, content)
)
message = cur.fetchone()
# Update conversation updated_at
cur.execute(
"UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(conversation_id,)
)
conn.commit()
return dict(message)
finally:
cur.close()
conn.close()
def get_messages(conversation_id: int, user_id: int, limit: int = 50, before_id: Optional[int] = None):
"""Get messages from a conversation"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
# Verify user is member
cur.execute(
"SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id)
)
if not cur.fetchone():
return {"error": "Not a member of this conversation"}
# Get messages
if before_id:
cur.execute(
"""
SELECT m.id AS message_id, m.sender_id, m.content, m.created_at, m.edited_at,
u.username AS sender_username, u.display_name AS sender_display_name, u.email AS sender_email,
CASE WHEN m.sender_id = %s THEN TRUE ELSE FALSE END AS is_mine
FROM messages m
JOIN users u ON u.id = m.sender_id
WHERE m.conversation_id = %s AND m.is_deleted = FALSE AND m.id < %s
ORDER BY m.created_at DESC
LIMIT %s
""",
(user_id, conversation_id, before_id, limit)
)
else:
cur.execute(
"""
SELECT m.id AS message_id, m.sender_id, m.content, m.created_at, m.edited_at,
u.username AS sender_username, u.display_name AS sender_display_name, u.email AS sender_email,
CASE WHEN m.sender_id = %s THEN TRUE ELSE FALSE END AS is_mine
FROM messages m
JOIN users u ON u.id = m.sender_id
WHERE m.conversation_id = %s AND m.is_deleted = FALSE
ORDER BY m.created_at DESC
LIMIT %s
""",
(user_id, conversation_id, limit)
)
messages = [dict(row) for row in cur.fetchall()]
messages.reverse() # Return in chronological order
# Mark as read
cur.execute(
"UPDATE conversation_members SET last_read_at = CURRENT_TIMESTAMP WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id)
)
conn.commit()
return messages
finally:
cur.close()
conn.close()