240 lines
8.6 KiB
Python
240 lines
8.6 KiB
Python
import os
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
from typing import List, Optional
|
|
|
|
|
|
def get_db_connection():
|
|
"""Get database connection"""
|
|
return psycopg2.connect(
|
|
host=os.getenv("DB_HOST", "localhost"),
|
|
port=int(os.getenv("DB_PORT", "5432")),
|
|
database=os.getenv("DB_NAME", "recipes_db"),
|
|
user=os.getenv("DB_USER", "recipes_user"),
|
|
password=os.getenv("DB_PASSWORD", "recipes_password"),
|
|
)
|
|
|
|
|
|
# ============= Conversations & Messages =============
|
|
|
|
def create_conversation(user_ids: List[int], is_group: bool = False, name: Optional[str] = None, created_by: int = None):
|
|
"""Create a new conversation"""
|
|
conn = get_db_connection()
|
|
cur = conn.cursor(cursor_factory=RealDictCursor)
|
|
try:
|
|
# For private chats, check if conversation already exists
|
|
if not is_group and len(user_ids) == 2:
|
|
cur.execute(
|
|
"""
|
|
SELECT c.id FROM conversations c
|
|
JOIN conversation_members cm1 ON c.id = cm1.conversation_id
|
|
JOIN conversation_members cm2 ON c.id = cm2.conversation_id
|
|
WHERE c.is_group = FALSE
|
|
AND cm1.user_id = %s AND cm2.user_id = %s
|
|
""",
|
|
(user_ids[0], user_ids[1])
|
|
)
|
|
existing = cur.fetchone()
|
|
if existing:
|
|
return get_conversation(existing["id"])
|
|
|
|
# Create conversation
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO conversations (name, is_group, created_by)
|
|
VALUES (%s, %s, %s)
|
|
RETURNING id, name, is_group, created_by, created_at
|
|
""",
|
|
(name, is_group, created_by)
|
|
)
|
|
conversation = dict(cur.fetchone())
|
|
conversation_id = conversation["id"]
|
|
|
|
# Add members
|
|
for user_id in user_ids:
|
|
cur.execute(
|
|
"INSERT INTO conversation_members (conversation_id, user_id) VALUES (%s, %s)",
|
|
(conversation_id, user_id)
|
|
)
|
|
|
|
conn.commit()
|
|
|
|
# Return conversation with conversation_id field
|
|
conversation["conversation_id"] = conversation["id"]
|
|
return conversation
|
|
finally:
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def get_conversation(conversation_id: int):
|
|
"""Get conversation details"""
|
|
conn = get_db_connection()
|
|
cur = conn.cursor(cursor_factory=RealDictCursor)
|
|
try:
|
|
cur.execute(
|
|
"SELECT id, name, is_group, created_by, created_at FROM conversations WHERE id = %s",
|
|
(conversation_id,)
|
|
)
|
|
conversation = cur.fetchone()
|
|
if not conversation:
|
|
return None
|
|
|
|
# Get members
|
|
cur.execute(
|
|
"""
|
|
SELECT u.id, u.username, u.display_name
|
|
FROM conversation_members cm
|
|
JOIN users u ON u.id = cm.user_id
|
|
WHERE cm.conversation_id = %s
|
|
""",
|
|
(conversation_id,)
|
|
)
|
|
members = [dict(row) for row in cur.fetchall()]
|
|
|
|
result = dict(conversation)
|
|
result["members"] = members
|
|
return result
|
|
finally:
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def get_user_conversations(user_id: int):
|
|
"""Get all conversations for a user"""
|
|
conn = get_db_connection()
|
|
cur = conn.cursor(cursor_factory=RealDictCursor)
|
|
try:
|
|
cur.execute(
|
|
"""
|
|
SELECT c.id AS conversation_id, c.name, c.is_group, c.created_at,
|
|
(SELECT COUNT(*) FROM messages m WHERE m.conversation_id = c.id AND m.created_at > cm.last_read_at) AS unread_count,
|
|
(SELECT m.content FROM messages m WHERE m.conversation_id = c.id ORDER BY m.created_at DESC LIMIT 1) AS last_message,
|
|
(SELECT m.created_at FROM messages m WHERE m.conversation_id = c.id ORDER BY m.created_at DESC LIMIT 1) AS last_message_at
|
|
FROM conversations c
|
|
JOIN conversation_members cm ON c.id = cm.conversation_id
|
|
WHERE cm.user_id = %s
|
|
ORDER BY last_message_at DESC NULLS LAST, c.created_at DESC
|
|
""",
|
|
(user_id,)
|
|
)
|
|
conversations = [dict(row) for row in cur.fetchall()]
|
|
|
|
# Get members for each conversation and add other_member_name for private chats
|
|
for conv in conversations:
|
|
cur.execute(
|
|
"""
|
|
SELECT u.id, u.username, u.display_name, u.email
|
|
FROM conversation_members cm
|
|
JOIN users u ON u.id = cm.user_id
|
|
WHERE cm.conversation_id = %s AND u.id != %s
|
|
""",
|
|
(conv["conversation_id"], user_id)
|
|
)
|
|
members = [dict(row) for row in cur.fetchall()]
|
|
conv["members"] = members
|
|
|
|
# For private chats, add other_member_name
|
|
if not conv["is_group"] and len(members) > 0:
|
|
conv["other_member_name"] = members[0].get("display_name") or members[0].get("username") or members[0].get("email")
|
|
|
|
return conversations
|
|
finally:
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def send_message(conversation_id: int, sender_id: int, content: str):
|
|
"""Send a message in a conversation"""
|
|
conn = get_db_connection()
|
|
cur = conn.cursor(cursor_factory=RealDictCursor)
|
|
try:
|
|
# Verify user is member of conversation
|
|
cur.execute(
|
|
"SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, sender_id)
|
|
)
|
|
if not cur.fetchone():
|
|
return {"error": "Not a member of this conversation"}
|
|
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO messages (conversation_id, sender_id, content)
|
|
VALUES (%s, %s, %s)
|
|
RETURNING id, conversation_id, sender_id, content, created_at
|
|
""",
|
|
(conversation_id, sender_id, content)
|
|
)
|
|
message = cur.fetchone()
|
|
|
|
# Update conversation updated_at
|
|
cur.execute(
|
|
"UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
|
(conversation_id,)
|
|
)
|
|
|
|
conn.commit()
|
|
return dict(message)
|
|
finally:
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def get_messages(conversation_id: int, user_id: int, limit: int = 50, before_id: Optional[int] = None):
|
|
"""Get messages from a conversation"""
|
|
conn = get_db_connection()
|
|
cur = conn.cursor(cursor_factory=RealDictCursor)
|
|
try:
|
|
# Verify user is member
|
|
cur.execute(
|
|
"SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id)
|
|
)
|
|
if not cur.fetchone():
|
|
return {"error": "Not a member of this conversation"}
|
|
|
|
# Get messages
|
|
if before_id:
|
|
cur.execute(
|
|
"""
|
|
SELECT m.id AS message_id, m.sender_id, m.content, m.created_at, m.edited_at,
|
|
u.username AS sender_username, u.display_name AS sender_display_name, u.email AS sender_email,
|
|
CASE WHEN m.sender_id = %s THEN TRUE ELSE FALSE END AS is_mine
|
|
FROM messages m
|
|
JOIN users u ON u.id = m.sender_id
|
|
WHERE m.conversation_id = %s AND m.is_deleted = FALSE AND m.id < %s
|
|
ORDER BY m.created_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(user_id, conversation_id, before_id, limit)
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""
|
|
SELECT m.id AS message_id, m.sender_id, m.content, m.created_at, m.edited_at,
|
|
u.username AS sender_username, u.display_name AS sender_display_name, u.email AS sender_email,
|
|
CASE WHEN m.sender_id = %s THEN TRUE ELSE FALSE END AS is_mine
|
|
FROM messages m
|
|
JOIN users u ON u.id = m.sender_id
|
|
WHERE m.conversation_id = %s AND m.is_deleted = FALSE
|
|
ORDER BY m.created_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(user_id, conversation_id, limit)
|
|
)
|
|
|
|
messages = [dict(row) for row in cur.fetchall()]
|
|
messages.reverse() # Return in chronological order
|
|
|
|
# Mark as read
|
|
cur.execute(
|
|
"UPDATE conversation_members SET last_read_at = CURRENT_TIMESTAMP WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id)
|
|
)
|
|
conn.commit()
|
|
|
|
return messages
|
|
finally:
|
|
cur.close()
|
|
conn.close()
|