225 lines
6.4 KiB
Python
225 lines
6.4 KiB
Python
"""
|
|
Database restore script from R2 storage
|
|
Downloads compressed backup from R2 and restores to PostgreSQL
|
|
"""
|
|
import os
|
|
import subprocess
|
|
import gzip
|
|
from pathlib import Path
|
|
import boto3
|
|
from botocore.config import Config
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# R2 Configuration
|
|
R2_ENDPOINT = os.getenv("R2_ENDPOINT")
|
|
R2_ACCESS_KEY = os.getenv("R2_ACCESS_KEY")
|
|
R2_SECRET_KEY = os.getenv("R2_SECRET_KEY")
|
|
R2_BUCKET = os.getenv("R2_BUCKET")
|
|
|
|
# Database Configuration
|
|
DB_HOST = os.getenv("DB_HOST", "localhost")
|
|
DB_PORT = os.getenv("DB_PORT", "5432")
|
|
DB_NAME = os.getenv("DB_NAME", "recipes_db")
|
|
DB_USER = os.getenv("DB_USER", "recipes_user")
|
|
DB_PASSWORD = os.getenv("DB_PASSWORD", "recipes_password")
|
|
|
|
# Restore directory
|
|
RESTORE_DIR = Path(__file__).parent / "restores"
|
|
RESTORE_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
def list_r2_backups():
|
|
"""List all available backups in R2"""
|
|
s3_client = boto3.client(
|
|
's3',
|
|
endpoint_url=R2_ENDPOINT,
|
|
aws_access_key_id=R2_ACCESS_KEY,
|
|
aws_secret_access_key=R2_SECRET_KEY,
|
|
config=Config(
|
|
signature_version='s3v4',
|
|
s3={'addressing_style': 'path'}
|
|
),
|
|
region_name='auto'
|
|
)
|
|
|
|
try:
|
|
response = s3_client.list_objects_v2(Bucket=R2_BUCKET)
|
|
|
|
if 'Contents' not in response:
|
|
return []
|
|
|
|
backups = sorted(response['Contents'], key=lambda x: x['LastModified'], reverse=True)
|
|
return backups
|
|
|
|
except Exception as e:
|
|
print(f"✗ Error listing backups: {e}")
|
|
return []
|
|
|
|
|
|
def download_from_r2(backup_name):
|
|
"""Download backup file from R2"""
|
|
local_file = RESTORE_DIR / backup_name
|
|
|
|
print(f"Downloading {backup_name} from R2...")
|
|
|
|
s3_client = boto3.client(
|
|
's3',
|
|
endpoint_url=R2_ENDPOINT,
|
|
aws_access_key_id=R2_ACCESS_KEY,
|
|
aws_secret_access_key=R2_SECRET_KEY,
|
|
config=Config(
|
|
signature_version='s3v4',
|
|
s3={'addressing_style': 'path'}
|
|
),
|
|
region_name='auto'
|
|
)
|
|
|
|
try:
|
|
s3_client.download_file(R2_BUCKET, backup_name, str(local_file))
|
|
size_mb = local_file.stat().st_size / (1024 * 1024)
|
|
print(f"✓ Downloaded: {local_file.name} ({size_mb:.2f} MB)")
|
|
return local_file
|
|
except Exception as e:
|
|
print(f"✗ Error downloading from R2: {e}")
|
|
raise
|
|
|
|
|
|
def decompress_file(compressed_file):
|
|
"""Decompress gzip file"""
|
|
decompressed_file = Path(str(compressed_file).replace('.gz', ''))
|
|
|
|
print(f"Decompressing {compressed_file.name}...")
|
|
|
|
with gzip.open(compressed_file, 'rb') as f_in:
|
|
with open(decompressed_file, 'wb') as f_out:
|
|
f_out.write(f_in.read())
|
|
|
|
compressed_size = compressed_file.stat().st_size
|
|
decompressed_size = decompressed_file.stat().st_size
|
|
|
|
print(f"✓ Decompressed to {decompressed_file.name}")
|
|
print(f" Compressed: {compressed_size / 1024:.2f} KB")
|
|
print(f" Decompressed: {decompressed_size / 1024:.2f} KB")
|
|
|
|
return decompressed_file
|
|
|
|
|
|
def restore_database(sql_file):
|
|
"""Restore PostgreSQL database from SQL file"""
|
|
print(f"\nRestoring database from {sql_file.name}...")
|
|
print("WARNING: This will overwrite the current database!")
|
|
|
|
response = input("Are you sure you want to continue? (yes/no): ")
|
|
if response.lower() != 'yes':
|
|
print("Restore cancelled")
|
|
return False
|
|
|
|
# Set PGPASSWORD environment variable
|
|
env = os.environ.copy()
|
|
env['PGPASSWORD'] = DB_PASSWORD
|
|
|
|
# Drop and recreate database (optional, comment out if you want to merge)
|
|
print("Dropping existing tables...")
|
|
drop_cmd = [
|
|
"psql",
|
|
"-h", DB_HOST,
|
|
"-p", DB_PORT,
|
|
"-U", DB_USER,
|
|
"-d", DB_NAME,
|
|
"-c", "DROP SCHEMA public CASCADE; CREATE SCHEMA public;"
|
|
]
|
|
|
|
try:
|
|
subprocess.run(drop_cmd, env=env, check=True, capture_output=True, text=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Warning: Could not drop schema: {e.stderr}")
|
|
|
|
# Restore from backup
|
|
print("Restoring database...")
|
|
restore_cmd = [
|
|
"psql",
|
|
"-h", DB_HOST,
|
|
"-p", DB_PORT,
|
|
"-U", DB_USER,
|
|
"-d", DB_NAME,
|
|
"-f", str(sql_file)
|
|
]
|
|
|
|
try:
|
|
subprocess.run(restore_cmd, env=env, check=True, capture_output=True, text=True)
|
|
print("✓ Database restored successfully!")
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"✗ Error restoring database: {e.stderr}")
|
|
raise
|
|
|
|
|
|
def main():
|
|
"""Main restore process"""
|
|
print("=" * 60)
|
|
print("Database Restore from Cloudflare R2")
|
|
print("=" * 60)
|
|
print()
|
|
|
|
try:
|
|
# Verify R2 credentials
|
|
if not all([R2_ENDPOINT, R2_ACCESS_KEY, R2_SECRET_KEY, R2_BUCKET]):
|
|
raise ValueError("Missing R2 credentials in environment variables")
|
|
|
|
# List available backups
|
|
print("Available backups:")
|
|
backups = list_r2_backups()
|
|
|
|
if not backups:
|
|
print("No backups found in R2")
|
|
return
|
|
|
|
for i, backup in enumerate(backups, 1):
|
|
size_mb = backup['Size'] / (1024 * 1024)
|
|
print(f"{i}. {backup['Key']}")
|
|
print(f" Size: {size_mb:.2f} MB, Date: {backup['LastModified']}")
|
|
print()
|
|
|
|
# Select backup
|
|
choice = input(f"Select backup to restore (1-{len(backups)}) or 'q' to quit: ")
|
|
|
|
if choice.lower() == 'q':
|
|
print("Restore cancelled")
|
|
return
|
|
|
|
try:
|
|
backup_index = int(choice) - 1
|
|
if backup_index < 0 or backup_index >= len(backups):
|
|
raise ValueError()
|
|
except ValueError:
|
|
print("Invalid selection")
|
|
return
|
|
|
|
selected_backup = backups[backup_index]['Key']
|
|
|
|
# Download backup
|
|
compressed_file = download_from_r2(selected_backup)
|
|
|
|
# Decompress backup
|
|
sql_file = decompress_file(compressed_file)
|
|
|
|
# Restore database
|
|
restore_database(sql_file)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("✓ Restore completed successfully!")
|
|
print("=" * 60)
|
|
|
|
except Exception as e:
|
|
print("\n" + "=" * 60)
|
|
print(f"✗ Restore failed: {e}")
|
|
print("=" * 60)
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|