refactor: implement database connection and migration handling

This commit is contained in:
Sam Chau
2025-02-02 02:51:02 +10:30
parent f1d162a129
commit df82ddaa73
12 changed files with 206 additions and 100 deletions

0
backend/__init__.py Normal file
View File

0
backend/app/__init__.py Normal file
View File

View File

@@ -0,0 +1,10 @@
# backend/app/db/__init__.py
from .connection import get_db
from .queries.settings import get_settings, get_secret_key, save_settings
from .queries.arr import get_unique_arrs
from .migrations.runner import run_migrations
__all__ = [
'get_db', 'get_settings', 'get_secret_key', 'save_settings',
'get_unique_arrs', 'run_migrations'
]

View File

@@ -0,0 +1,12 @@
# backend/app/db/connection.py
import sqlite3
from ..config import config
DB_PATH = config.DB_PATH
def get_db():
"""Create and return a database connection with Row factory."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn

View File

View File

@@ -0,0 +1,64 @@
# backend/app/db/migrations/runner.py
import os
import importlib
from pathlib import Path
from ..connection import get_db
def init_migrations():
"""Create migrations table if it doesn't exist."""
with get_db() as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version INTEGER NOT NULL,
name TEXT NOT NULL,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
def get_applied_migrations():
"""Get list of already applied migrations."""
with get_db() as conn:
result = conn.execute(
'SELECT version FROM migrations ORDER BY version')
return [row[0] for row in result.fetchall()]
def get_available_migrations():
"""Get all migration files from versions directory."""
versions_dir = Path(__file__).parent / 'versions'
migrations = []
for file in versions_dir.glob('[0-9]*.py'):
if file.stem != '__init__':
# Import the migration module
module = importlib.import_module(f'.versions.{file.stem}',
package='app.db.migrations')
migrations.append((module.version, module.name, module))
return sorted(migrations, key=lambda x: x[0])
def run_migrations():
"""Run all pending migrations in order."""
init_migrations()
applied = set(get_applied_migrations())
available = get_available_migrations()
for version, name, module in available:
if version not in applied:
print(f"Applying migration {version}: {name}")
try:
module.up()
with get_db() as conn:
conn.execute(
'INSERT INTO migrations (version, name) VALUES (?, ?)',
(version, name))
conn.commit()
print(f"Successfully applied migration {version}")
except Exception as e:
print(f"Error applying migration {version}: {str(e)}")
raise

View File

@@ -1,20 +1,16 @@
# db.py
import sqlite3
# backend/app/db/migrations/versions/001_initial_schema.py
import os
import secrets
from .config import config
from ...connection import get_db
DB_PATH = config.DB_PATH
version = 1
name = "initial_schema"
def get_db():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def init_db():
def up():
"""Apply the initial database schema."""
with get_db() as conn:
# Create backups table
conn.execute('''
CREATE TABLE IF NOT EXISTS backups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -24,6 +20,7 @@ def init_db():
)
''')
# Create arr_config table
conn.execute('''
CREATE TABLE IF NOT EXISTS arr_config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -44,6 +41,7 @@ def init_db():
)
''')
# Create scheduled_tasks table
conn.execute('''
CREATE TABLE IF NOT EXISTS scheduled_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -56,7 +54,36 @@ def init_db():
)
''')
# Insert required tasks if missing
# Create settings table
conn.execute('''
CREATE TABLE IF NOT EXISTS settings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
value TEXT,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Create auth table
conn.execute('''
CREATE TABLE IF NOT EXISTS auth (
username TEXT NOT NULL,
password_hash TEXT NOT NULL,
api_key TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Create failed_attempts table
conn.execute('''
CREATE TABLE IF NOT EXISTS failed_attempts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip_address TEXT NOT NULL,
attempt_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Insert initial required data
required_tasks = [
('Repository Sync', 'Sync', 2),
('Backup', 'Backup', 1440),
@@ -72,21 +99,13 @@ def init_db():
VALUES (?, ?, ?)
''', (task_name, task_type, interval))
conn.execute('''
CREATE TABLE IF NOT EXISTS settings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
value TEXT,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Insert or ignore example
# Insert initial settings
conn.execute('''
INSERT OR IGNORE INTO settings (key, value, updated_at)
VALUES ('auto_pull_enabled', 0, CURRENT_TIMESTAMP)
VALUES ('auto_pull_enabled', '0', CURRENT_TIMESTAMP)
''')
# Handle profilarr_pat setting
profilarr_pat = os.environ.get('PROFILARR_PAT')
conn.execute(
'''
@@ -97,8 +116,8 @@ def init_db():
updated_at = CURRENT_TIMESTAMP
''', (str(bool(profilarr_pat)).lower(), str(
bool(profilarr_pat)).lower()))
conn.commit()
# Handle secret_key setting
secret_key = conn.execute(
'SELECT value FROM settings WHERE key = "secret_key"').fetchone()
if not secret_key:
@@ -108,82 +127,18 @@ def init_db():
INSERT INTO settings (key, value, updated_at)
VALUES ('secret_key', ?, CURRENT_TIMESTAMP)
''', (new_secret_key, ))
conn.commit()
conn.execute('''
CREATE TABLE IF NOT EXISTS auth (
username TEXT NOT NULL,
password_hash TEXT NOT NULL,
api_key TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS failed_attempts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip_address TEXT NOT NULL,
attempt_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
def get_settings():
with get_db() as conn:
result = conn.execute(
'SELECT key, value FROM settings WHERE key NOT IN ("secret_key")'
).fetchall()
settings = {row['key']: row['value'] for row in result}
return settings if 'gitRepo' in settings else None
def get_secret_key():
with get_db() as conn:
result = conn.execute(
'SELECT value FROM settings WHERE key = "secret_key"').fetchone()
return result['value'] if result else None
def save_settings(settings_dict):
with get_db() as conn:
for key, value in settings_dict.items():
conn.execute(
'''
INSERT INTO settings (key, value, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(key) DO UPDATE SET
value = excluded.value,
updated_at = CURRENT_TIMESTAMP
''', (key, value))
conn.commit()
def get_unique_arrs(arr_ids):
"""
Get import_as_unique settings for a list of arr IDs.
Args:
arr_ids (list): List of arr configuration IDs
Returns:
dict: Dictionary mapping arr IDs to their import_as_unique settings and names
"""
if not arr_ids:
return {}
def down():
"""Revert the initial schema migration."""
with get_db() as conn:
placeholders = ','.join('?' * len(arr_ids))
query = f'''
SELECT id, name, import_as_unique
FROM arr_config
WHERE id IN ({placeholders})
'''
results = conn.execute(query, arr_ids).fetchall()
return {
row['id']: {
'import_as_unique': bool(row['import_as_unique']),
'name': row['name']
}
for row in results
}
# Drop all tables in reverse order of creation
tables = [
'failed_attempts', 'auth', 'settings', 'scheduled_tasks',
'arr_config', 'backups'
]
for table in tables:
conn.execute(f'DROP TABLE IF EXISTS {table}')
conn.commit()

View File

View File

@@ -0,0 +1,33 @@
# backend/app/db/queries/arr.py
from ..connection import get_db
def get_unique_arrs(arr_ids):
"""
Get import_as_unique settings for a list of arr IDs.
Args:
arr_ids (list): List of arr configuration IDs
Returns:
dict: Dictionary mapping arr IDs to their import_as_unique settings and names
"""
if not arr_ids:
return {}
with get_db() as conn:
placeholders = ','.join('?' * len(arr_ids))
query = f'''
SELECT id, name, import_as_unique
FROM arr_config
WHERE id IN ({placeholders})
'''
results = conn.execute(query, arr_ids).fetchall()
return {
row['id']: {
'import_as_unique': bool(row['import_as_unique']),
'name': row['name']
}
for row in results
}

View File

@@ -0,0 +1,32 @@
# backend/app/db/queries/settings.py
from ..connection import get_db
def get_settings():
with get_db() as conn:
result = conn.execute(
'SELECT key, value FROM settings WHERE key NOT IN ("secret_key")'
).fetchall()
settings = {row['key']: row['value'] for row in result}
return settings if 'gitRepo' in settings else None
def get_secret_key():
with get_db() as conn:
result = conn.execute(
'SELECT value FROM settings WHERE key = "secret_key"').fetchone()
return result['value'] if result else None
def save_settings(settings_dict):
with get_db() as conn:
for key, value in settings_dict.items():
conn.execute(
'''
INSERT INTO settings (key, value, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(key) DO UPDATE SET
value = excluded.value,
updated_at = CURRENT_TIMESTAMP
''', (key, value))
conn.commit()

View File

@@ -7,7 +7,7 @@ from .data import bp as data_bp
from .importarr import bp as importarr_bp
from .task import bp as tasks_bp, TaskScheduler
from .backup import bp as backup_bp
from .db import init_db, get_settings
from .db import run_migrations, get_settings
from .auth import bp as auth_bp
from .logs import bp as logs_bp
from .middleware import init_middleware
@@ -27,7 +27,7 @@ def create_app():
config.ensure_directories()
logger.info("Initializing database")
init_db()
run_migrations()
# Initialize Git user configuration
logger.info("Initializing Git user")