Refactor main_db to databases

This commit is contained in:
Labrys of Knossos 2023-01-03 17:27:27 -05:00
commit 2bf7c64da9
8 changed files with 235 additions and 249 deletions

View file

@ -4,8 +4,8 @@ import os
import sys import sys
import nzb2media import nzb2media
import nzb2media.databases
import nzb2media.torrent import nzb2media.torrent
from nzb2media import main_db
from nzb2media.auto_process import comics, games, movies, music, tv, books from nzb2media.auto_process import comics, games, movies, music, tv, books
from nzb2media.auto_process.common import ProcessResult from nzb2media.auto_process.common import ProcessResult
from nzb2media.plex import plex_update from nzb2media.plex import plex_update
@ -25,7 +25,7 @@ def process_torrent(input_directory, input_name, input_category, input_hash, inp
if client_agent != 'manual' and not nzb2media.DOWNLOAD_INFO: if client_agent != 'manual' and not nzb2media.DOWNLOAD_INFO:
log.debug(f'Adding TORRENT download info for directory {input_directory} to database') log.debug(f'Adding TORRENT download info for directory {input_directory} to database')
my_db = main_db.DBConnection() my_db = nzb2media.databases.DBConnection()
input_directory1 = input_directory input_directory1 = input_directory
input_name1 = input_name input_name1 = input_name

View file

@ -11,13 +11,12 @@ from typing import Any
import setuptools_scm import setuptools_scm
import nzb2media.databases
import nzb2media.fork.medusa import nzb2media.fork.medusa
import nzb2media.fork.sickbeard import nzb2media.fork.sickbeard
import nzb2media.fork.sickchill import nzb2media.fork.sickchill
import nzb2media.fork.sickgear import nzb2media.fork.sickgear
import nzb2media.tool import nzb2media.tool
from nzb2media import databases
from nzb2media import main_db
from nzb2media.configuration import Config from nzb2media.configuration import Config
from nzb2media.transcoder import configure_transcoder from nzb2media.transcoder import configure_transcoder
from nzb2media.utils.network import wake_up from nzb2media.utils.network import wake_up
@ -184,7 +183,6 @@ def initialize(section=None):
return False return False
configure_migration() configure_migration()
# initialize the main SB database # initialize the main SB database
main_db.upgrade_database(main_db.DBConnection(), databases.InitialSchema)
configure_general() configure_general()
configure_wake_on_lan() configure_wake_on_lan()
configure_remote_paths() configure_remote_paths()

View file

@ -1,9 +1,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import re
import sqlite3
import sys import sys
import time
from nzb2media import main_db import nzb2media
from nzb2media.utils.files import backup_versioned_file from nzb2media.utils.files import backup_versioned_file
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -14,7 +17,7 @@ MAX_DB_VERSION = 2
def backup_database(version): def backup_database(version):
log.info('Backing up database before upgrade') log.info('Backing up database before upgrade')
if not backup_versioned_file(main_db.db_filename(), version): if not backup_versioned_file(db_filename(), version):
logging.critical('Database backup failed, abort upgrading database') logging.critical('Database backup failed, abort upgrading database')
sys.exit(1) sys.exit(1)
else: else:
@ -25,7 +28,33 @@ def backup_database(version):
# = Main DB Migrations = # = Main DB Migrations =
# ====================== # ======================
# Add new migrations at the bottom of the list; subclass the previous migration. # Add new migrations at the bottom of the list; subclass the previous migration.
class InitialSchema(main_db.SchemaUpgrade): class SchemaUpgrade:
def __init__(self, connection):
self.connection = connection
def has_table(self, table_name):
return len(self.connection.action('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)).fetchall()) > 0
def has_column(self, table_name, column):
return column in self.connection.table_info(table_name)
def add_column(self, table, column, data_type='NUMERIC', default=0):
self.connection.action(f'ALTER TABLE {table} ADD {column} {data_type}')
self.connection.action(f'UPDATE {table} SET {column} = ?', (default,))
def check_db_version(self):
result = self.connection.select('SELECT db_version FROM db_version')
if result:
return int(result[-1]['db_version'])
return 0
def inc_db_version(self):
new_version = self.check_db_version() + 1
self.connection.action('UPDATE db_version SET db_version = ?', [new_version])
return new_version
class InitialSchema(SchemaUpgrade):
def test(self): def test(self):
no_update = False no_update = False
if self.has_table('db_version'): if self.has_table('db_version'):
@ -84,3 +113,199 @@ class InitialSchema(main_db.SchemaUpgrade):
] ]
for query in queries: for query in queries:
self.connection.action(query) self.connection.action(query)
def db_filename(filename: str = 'nzbtomedia.db', suffix: str | None = None):
"""Return the correct location of the database file.
@param filename: The sqlite database filename to use. If not specified, will be made to be nzbtomedia.db
@param suffix: The suffix to append to the filename. A '.' will be added
automatically, i.e. suffix='v0' will make dbfile.db.v0
@return: the correct location of the database file.
"""
if suffix:
filename = f'{filename}.{suffix}'
return nzb2media.os.path.join(nzb2media.APP_ROOT, filename)
class DBConnection:
def __init__(self, filename='nzbtomedia.db'):
self.filename = filename
self.connection = sqlite3.connect(db_filename(filename), 20)
self.connection.row_factory = sqlite3.Row
def check_db_version(self):
result = None
try:
result = self.select('SELECT db_version FROM db_version')
except sqlite3.OperationalError as error:
if 'no such table: db_version' in error.args[0]:
return 0
if result:
return int(result[0]['db_version'])
return 0
def fetch(self, query, args=None):
if query is None:
return
sql_result = None
attempt = 0
while attempt < 5:
try:
if args is None:
log.debug(f'{self.filename}: {query}')
cursor = self.connection.cursor()
cursor.execute(query)
sql_result = cursor.fetchone()[0]
else:
log.debug(f'{self.filename}: {query} with args {args}')
cursor = self.connection.cursor()
cursor.execute(query, args)
sql_result = cursor.fetchone()[0]
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as error:
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
log.warning(f'DB error: {error}')
attempt += 1
time.sleep(1)
else:
log.error(f'DB error: {error}')
raise
except sqlite3.DatabaseError as error:
log.error(f'Fatal error executing query: {error}')
raise
return sql_result
def mass_action(self, querylist, log_transaction=False):
if querylist is None:
return
sql_result = []
attempt = 0
while attempt < 5:
try:
for query in querylist:
if len(query) == 1:
if log_transaction:
log.debug(query[0])
sql_result.append(self.connection.execute(query[0]))
elif len(query) > 1:
if log_transaction:
log.debug(f'{query[0]} with args {query[1]}')
sql_result.append(self.connection.execute(query[0], query[1]))
self.connection.commit()
log.debug(f'Transaction with {len(querylist)} query\'s executed')
return sql_result
except sqlite3.OperationalError as error:
sql_result = []
if self.connection:
self.connection.rollback()
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
log.warning(f'DB error: {error}')
attempt += 1
time.sleep(1)
else:
log.error(f'DB error: {error}')
raise
except sqlite3.DatabaseError as error:
if self.connection:
self.connection.rollback()
log.error(f'Fatal error executing query: {error}')
raise
return sql_result
def action(self, query, args=None):
if query is None:
return
sql_result = None
attempt = 0
while attempt < 5:
try:
if args is None:
log.debug(f'{self.filename}: {query}')
sql_result = self.connection.execute(query)
else:
log.debug(f'{self.filename}: {query} with args {args}')
sql_result = self.connection.execute(query, args)
self.connection.commit()
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as error:
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
log.warning(f'DB error: {error}')
attempt += 1
time.sleep(1)
else:
log.error(f'DB error: {error}')
raise
except sqlite3.DatabaseError as error:
log.error(f'Fatal error executing query: {error}')
raise
return sql_result
def select(self, query, args=None):
sql_results = self.action(query, args).fetchall()
if sql_results is None:
return []
return sql_results
def upsert(self, table_name, value_dict, key_dict):
def gen_params(my_dict):
return [f'{k} = ?' for k in my_dict.keys()]
changes_before = self.connection.total_changes
items = list(value_dict.values()) + list(key_dict.values())
_params = ', '.join(gen_params(value_dict))
_conditions = ' AND '.join(gen_params(key_dict))
self.action(f'UPDATE {table_name} SET {_params} WHERE {_conditions}', items)
if self.connection.total_changes == changes_before:
_cols = ', '.join(map(str, value_dict.keys()))
values = list(value_dict.values())
_vals = ', '.join(['?'] * len(values))
self.action(f'INSERT OR IGNORE INTO {table_name} ({_cols}) VALUES ({_vals})', values)
def table_info(self, table_name):
# FIXME ? binding is not supported here, but I cannot find a way to escape a string manually
cursor = self.connection.execute(f'PRAGMA table_info({table_name})')
return {column['name']: {'type': column['type']} for column in cursor}
def sanity_check_database(connection, sanity_check):
sanity_check(connection).check()
class DBSanityCheck:
def __init__(self, connection):
self.connection = connection
def check(self):
pass
def upgrade_database(connection, schema):
log.info('Checking database structure...')
_process_upgrade(connection, schema)
def pretty_name(class_name):
return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)])
def _process_upgrade(connection, upgrade_class):
instance = upgrade_class(connection)
log.debug(f'Checking {pretty_name(upgrade_class.__name__)} database upgrade')
if not instance.test():
log.info(f'Database upgrade required: {pretty_name(upgrade_class.__name__)}')
try:
instance.execute()
except sqlite3.DatabaseError as error:
print(f'Error in {upgrade_class.__name__}: {error}')
raise
log.debug(f'{upgrade_class.__name__} upgrade completed')
else:
log.debug(f'{upgrade_class.__name__} upgrade not required')
for upgrade_sub_class in upgrade_class.__subclasses__():
_process_upgrade(connection, upgrade_sub_class)
upgrade_database(nzb2media.databases.DBConnection(), nzb2media.databases.InitialSchema)

0
nzb2media/exceptions.py Normal file
View file

View file

@ -1,234 +0,0 @@
from __future__ import annotations
import logging
import re
import sqlite3
import time
import nzb2media
log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())
def db_filename(filename: str = 'nzbtomedia.db', suffix: str | None = None):
"""Return the correct location of the database file.
@param filename: The sqlite database filename to use. If not specified, will be made to be nzbtomedia.db
@param suffix: The suffix to append to the filename. A '.' will be added
automatically, i.e. suffix='v0' will make dbfile.db.v0
@return: the correct location of the database file.
"""
if suffix:
filename = f'{filename}.{suffix}'
return nzb2media.os.path.join(nzb2media.APP_ROOT, filename)
class DBConnection:
def __init__(self, filename='nzbtomedia.db'):
self.filename = filename
self.connection = sqlite3.connect(db_filename(filename), 20)
self.connection.row_factory = sqlite3.Row
def check_db_version(self):
result = None
try:
result = self.select('SELECT db_version FROM db_version')
except sqlite3.OperationalError as error:
if 'no such table: db_version' in error.args[0]:
return 0
if result:
return int(result[0]['db_version'])
return 0
def fetch(self, query, args=None):
if query is None:
return
sql_result = None
attempt = 0
while attempt < 5:
try:
if args is None:
log.debug(f'{self.filename}: {query}')
cursor = self.connection.cursor()
cursor.execute(query)
sql_result = cursor.fetchone()[0]
else:
log.debug(f'{self.filename}: {query} with args {args}')
cursor = self.connection.cursor()
cursor.execute(query, args)
sql_result = cursor.fetchone()[0]
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as error:
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
log.warning(f'DB error: {error}')
attempt += 1
time.sleep(1)
else:
log.error(f'DB error: {error}')
raise
except sqlite3.DatabaseError as error:
log.error(f'Fatal error executing query: {error}')
raise
return sql_result
def mass_action(self, querylist, log_transaction=False):
if querylist is None:
return
sql_result = []
attempt = 0
while attempt < 5:
try:
for query in querylist:
if len(query) == 1:
if log_transaction:
log.debug(query[0])
sql_result.append(self.connection.execute(query[0]))
elif len(query) > 1:
if log_transaction:
log.debug(f'{query[0]} with args {query[1]}')
sql_result.append(self.connection.execute(query[0], query[1]))
self.connection.commit()
log.debug(f'Transaction with {len(querylist)} query\'s executed')
return sql_result
except sqlite3.OperationalError as error:
sql_result = []
if self.connection:
self.connection.rollback()
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
log.warning(f'DB error: {error}')
attempt += 1
time.sleep(1)
else:
log.error(f'DB error: {error}')
raise
except sqlite3.DatabaseError as error:
if self.connection:
self.connection.rollback()
log.error(f'Fatal error executing query: {error}')
raise
return sql_result
def action(self, query, args=None):
if query is None:
return
sql_result = None
attempt = 0
while attempt < 5:
try:
if args is None:
log.debug(f'{self.filename}: {query}')
sql_result = self.connection.execute(query)
else:
log.debug(f'{self.filename}: {query} with args {args}')
sql_result = self.connection.execute(query, args)
self.connection.commit()
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as error:
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
log.warning(f'DB error: {error}')
attempt += 1
time.sleep(1)
else:
log.error(f'DB error: {error}')
raise
except sqlite3.DatabaseError as error:
log.error(f'Fatal error executing query: {error}')
raise
return sql_result
def select(self, query, args=None):
sql_results = self.action(query, args).fetchall()
if sql_results is None:
return []
return sql_results
def upsert(self, table_name, value_dict, key_dict):
def gen_params(my_dict):
return [f'{k} = ?' for k in my_dict.keys()]
changes_before = self.connection.total_changes
items = list(value_dict.values()) + list(key_dict.values())
_params = ', '.join(gen_params(value_dict))
_conditions = ' AND '.join(gen_params(key_dict))
self.action(f'UPDATE {table_name} SET {_params} WHERE {_conditions}', items)
if self.connection.total_changes == changes_before:
_cols = ', '.join(map(str, value_dict.keys()))
values = list(value_dict.values())
_vals = ', '.join(['?'] * len(values))
self.action(f'INSERT OR IGNORE INTO {table_name} ({_cols}) VALUES ({_vals})', values)
def table_info(self, table_name):
# FIXME ? binding is not supported here, but I cannot find a way to escape a string manually
cursor = self.connection.execute(f'PRAGMA table_info({table_name})')
return {column['name']: {'type': column['type']} for column in cursor}
def sanity_check_database(connection, sanity_check):
sanity_check(connection).check()
class DBSanityCheck:
def __init__(self, connection):
self.connection = connection
def check(self):
pass
# ===============
# = Upgrade API =
# ===============
def upgrade_database(connection, schema):
log.info('Checking database structure...')
_process_upgrade(connection, schema)
def pretty_name(class_name):
return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)])
def _process_upgrade(connection, upgrade_class):
instance = upgrade_class(connection)
log.debug(f'Checking {pretty_name(upgrade_class.__name__)} database upgrade')
if not instance.test():
log.info(f'Database upgrade required: {pretty_name(upgrade_class.__name__)}')
try:
instance.execute()
except sqlite3.DatabaseError as error:
print(f'Error in {upgrade_class.__name__}: {error}')
raise
log.debug(f'{upgrade_class.__name__} upgrade completed')
else:
log.debug(f'{upgrade_class.__name__} upgrade not required')
for upgrade_sub_class in upgrade_class.__subclasses__():
_process_upgrade(connection, upgrade_sub_class)
# Base migration class. All future DB changes should be subclassed from this class
class SchemaUpgrade:
def __init__(self, connection):
self.connection = connection
def has_table(self, table_name):
return len(self.connection.action('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)).fetchall()) > 0
def has_column(self, table_name, column):
return column in self.connection.table_info(table_name)
def add_column(self, table, column, data_type='NUMERIC', default=0):
self.connection.action(f'ALTER TABLE {table} ADD {column} {data_type}')
self.connection.action(f'UPDATE {table} SET {column} = ?', (default,))
def check_db_version(self):
result = self.connection.select('SELECT db_version FROM db_version')
if result:
return int(result[-1]['db_version'])
return 0
def inc_db_version(self):
new_version = self.check_db_version() + 1
self.connection.action('UPDATE db_version SET db_version = ?', [new_version])
return new_version

View file

@ -4,8 +4,8 @@ import datetime
import logging import logging
import nzb2media import nzb2media
import nzb2media.databases
import nzb2media.nzb import nzb2media.nzb
from nzb2media import main_db
from nzb2media.auto_process import books from nzb2media.auto_process import books
from nzb2media.auto_process import comics from nzb2media.auto_process import comics
from nzb2media.auto_process import games from nzb2media.auto_process import games
@ -34,7 +34,7 @@ def process(*, input_directory, input_name=None, status=0, client_agent='manual'
download_id = get_nzoid(input_name) download_id = get_nzoid(input_name)
if client_agent != 'manual' and not nzb2media.DOWNLOAD_INFO: if client_agent != 'manual' and not nzb2media.DOWNLOAD_INFO:
log.debug(f'Adding NZB download info for directory {input_directory} to database') log.debug(f'Adding NZB download info for directory {input_directory} to database')
my_db = main_db.DBConnection() my_db = nzb2media.databases.DBConnection()
input_directory1 = input_directory input_directory1 = input_directory
input_name1 = input_name input_name1 = input_name
try: try:

View file

@ -3,11 +3,11 @@ from __future__ import annotations
import datetime import datetime
import logging import logging
from nzb2media import main_db import nzb2media.databases
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler()) log.addHandler(logging.NullHandler())
database = main_db.DBConnection() database = nzb2media.databases.DBConnection()
def update_download_info_status(input_name, status): def update_download_info_status(input_name, status):

View file

@ -135,9 +135,6 @@ def test_import_nzb2media():
import nzb2media.databases import nzb2media.databases
assert nzb2media.databases assert nzb2media.databases
import nzb2media.main_db
assert nzb2media.main_db
import nzb2media.scene_exceptions import nzb2media.scene_exceptions
assert nzb2media.scene_exceptions assert nzb2media.scene_exceptions