diff --git a/TorrentToMedia.py b/TorrentToMedia.py index f8fc3493..3940235f 100755 --- a/TorrentToMedia.py +++ b/TorrentToMedia.py @@ -4,8 +4,8 @@ import os import sys import nzb2media +import nzb2media.databases import nzb2media.torrent -from nzb2media import main_db from nzb2media.auto_process import comics, games, movies, music, tv, books from nzb2media.auto_process.common import ProcessResult from nzb2media.plex import plex_update @@ -25,7 +25,7 @@ def process_torrent(input_directory, input_name, input_category, input_hash, inp if client_agent != 'manual' and not nzb2media.DOWNLOAD_INFO: log.debug(f'Adding TORRENT download info for directory {input_directory} to database') - my_db = main_db.DBConnection() + my_db = nzb2media.databases.DBConnection() input_directory1 = input_directory input_name1 = input_name diff --git a/nzb2media/__init__.py b/nzb2media/__init__.py index c5e76eea..dd68a112 100644 --- a/nzb2media/__init__.py +++ b/nzb2media/__init__.py @@ -11,13 +11,12 @@ from typing import Any import setuptools_scm +import nzb2media.databases import nzb2media.fork.medusa import nzb2media.fork.sickbeard import nzb2media.fork.sickchill import nzb2media.fork.sickgear import nzb2media.tool -from nzb2media import databases -from nzb2media import main_db from nzb2media.configuration import Config from nzb2media.transcoder import configure_transcoder from nzb2media.utils.network import wake_up @@ -184,7 +183,6 @@ def initialize(section=None): return False configure_migration() # initialize the main SB database - main_db.upgrade_database(main_db.DBConnection(), databases.InitialSchema) configure_general() configure_wake_on_lan() configure_remote_paths() diff --git a/nzb2media/databases.py b/nzb2media/databases.py index 88004403..9b2f876d 100644 --- a/nzb2media/databases.py +++ b/nzb2media/databases.py @@ -1,9 +1,12 @@ from __future__ import annotations import logging +import re +import sqlite3 import sys +import time -from nzb2media import main_db +import nzb2media from nzb2media.utils.files import backup_versioned_file log = logging.getLogger(__name__) @@ -14,7 +17,7 @@ MAX_DB_VERSION = 2 def backup_database(version): log.info('Backing up database before upgrade') - if not backup_versioned_file(main_db.db_filename(), version): + if not backup_versioned_file(db_filename(), version): logging.critical('Database backup failed, abort upgrading database') sys.exit(1) else: @@ -25,7 +28,33 @@ def backup_database(version): # = Main DB Migrations = # ====================== # Add new migrations at the bottom of the list; subclass the previous migration. -class InitialSchema(main_db.SchemaUpgrade): +class SchemaUpgrade: + def __init__(self, connection): + self.connection = connection + + def has_table(self, table_name): + return len(self.connection.action('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)).fetchall()) > 0 + + def has_column(self, table_name, column): + return column in self.connection.table_info(table_name) + + def add_column(self, table, column, data_type='NUMERIC', default=0): + self.connection.action(f'ALTER TABLE {table} ADD {column} {data_type}') + self.connection.action(f'UPDATE {table} SET {column} = ?', (default,)) + + def check_db_version(self): + result = self.connection.select('SELECT db_version FROM db_version') + if result: + return int(result[-1]['db_version']) + return 0 + + def inc_db_version(self): + new_version = self.check_db_version() + 1 + self.connection.action('UPDATE db_version SET db_version = ?', [new_version]) + return new_version + + +class InitialSchema(SchemaUpgrade): def test(self): no_update = False if self.has_table('db_version'): @@ -84,3 +113,199 @@ class InitialSchema(main_db.SchemaUpgrade): ] for query in queries: self.connection.action(query) + + +def db_filename(filename: str = 'nzbtomedia.db', suffix: str | None = None): + """Return the correct location of the database file. + + @param filename: The sqlite database filename to use. If not specified, will be made to be nzbtomedia.db + @param suffix: The suffix to append to the filename. A '.' will be added + automatically, i.e. suffix='v0' will make dbfile.db.v0 + @return: the correct location of the database file. + """ + if suffix: + filename = f'{filename}.{suffix}' + return nzb2media.os.path.join(nzb2media.APP_ROOT, filename) + + +class DBConnection: + def __init__(self, filename='nzbtomedia.db'): + self.filename = filename + self.connection = sqlite3.connect(db_filename(filename), 20) + self.connection.row_factory = sqlite3.Row + + def check_db_version(self): + result = None + try: + result = self.select('SELECT db_version FROM db_version') + except sqlite3.OperationalError as error: + if 'no such table: db_version' in error.args[0]: + return 0 + if result: + return int(result[0]['db_version']) + return 0 + + def fetch(self, query, args=None): + if query is None: + return + sql_result = None + attempt = 0 + while attempt < 5: + try: + if args is None: + log.debug(f'{self.filename}: {query}') + cursor = self.connection.cursor() + cursor.execute(query) + sql_result = cursor.fetchone()[0] + else: + log.debug(f'{self.filename}: {query} with args {args}') + cursor = self.connection.cursor() + cursor.execute(query, args) + sql_result = cursor.fetchone()[0] + # get out of the connection attempt loop since we were successful + break + except sqlite3.OperationalError as error: + if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]: + log.warning(f'DB error: {error}') + attempt += 1 + time.sleep(1) + else: + log.error(f'DB error: {error}') + raise + except sqlite3.DatabaseError as error: + log.error(f'Fatal error executing query: {error}') + raise + return sql_result + + def mass_action(self, querylist, log_transaction=False): + if querylist is None: + return + sql_result = [] + attempt = 0 + while attempt < 5: + try: + for query in querylist: + if len(query) == 1: + if log_transaction: + log.debug(query[0]) + sql_result.append(self.connection.execute(query[0])) + elif len(query) > 1: + if log_transaction: + log.debug(f'{query[0]} with args {query[1]}') + sql_result.append(self.connection.execute(query[0], query[1])) + self.connection.commit() + log.debug(f'Transaction with {len(querylist)} query\'s executed') + return sql_result + except sqlite3.OperationalError as error: + sql_result = [] + if self.connection: + self.connection.rollback() + if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]: + log.warning(f'DB error: {error}') + attempt += 1 + time.sleep(1) + else: + log.error(f'DB error: {error}') + raise + except sqlite3.DatabaseError as error: + if self.connection: + self.connection.rollback() + log.error(f'Fatal error executing query: {error}') + raise + return sql_result + + def action(self, query, args=None): + if query is None: + return + sql_result = None + attempt = 0 + while attempt < 5: + try: + if args is None: + log.debug(f'{self.filename}: {query}') + sql_result = self.connection.execute(query) + else: + log.debug(f'{self.filename}: {query} with args {args}') + sql_result = self.connection.execute(query, args) + self.connection.commit() + # get out of the connection attempt loop since we were successful + break + except sqlite3.OperationalError as error: + if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]: + log.warning(f'DB error: {error}') + attempt += 1 + time.sleep(1) + else: + log.error(f'DB error: {error}') + raise + except sqlite3.DatabaseError as error: + log.error(f'Fatal error executing query: {error}') + raise + return sql_result + + def select(self, query, args=None): + sql_results = self.action(query, args).fetchall() + if sql_results is None: + return [] + return sql_results + + def upsert(self, table_name, value_dict, key_dict): + def gen_params(my_dict): + return [f'{k} = ?' for k in my_dict.keys()] + + changes_before = self.connection.total_changes + items = list(value_dict.values()) + list(key_dict.values()) + _params = ', '.join(gen_params(value_dict)) + _conditions = ' AND '.join(gen_params(key_dict)) + self.action(f'UPDATE {table_name} SET {_params} WHERE {_conditions}', items) + if self.connection.total_changes == changes_before: + _cols = ', '.join(map(str, value_dict.keys())) + values = list(value_dict.values()) + _vals = ', '.join(['?'] * len(values)) + self.action(f'INSERT OR IGNORE INTO {table_name} ({_cols}) VALUES ({_vals})', values) + + def table_info(self, table_name): + # FIXME ? binding is not supported here, but I cannot find a way to escape a string manually + cursor = self.connection.execute(f'PRAGMA table_info({table_name})') + return {column['name']: {'type': column['type']} for column in cursor} + + +def sanity_check_database(connection, sanity_check): + sanity_check(connection).check() + + +class DBSanityCheck: + def __init__(self, connection): + self.connection = connection + + def check(self): + pass + + +def upgrade_database(connection, schema): + log.info('Checking database structure...') + _process_upgrade(connection, schema) + + +def pretty_name(class_name): + return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)]) + + +def _process_upgrade(connection, upgrade_class): + instance = upgrade_class(connection) + log.debug(f'Checking {pretty_name(upgrade_class.__name__)} database upgrade') + if not instance.test(): + log.info(f'Database upgrade required: {pretty_name(upgrade_class.__name__)}') + try: + instance.execute() + except sqlite3.DatabaseError as error: + print(f'Error in {upgrade_class.__name__}: {error}') + raise + log.debug(f'{upgrade_class.__name__} upgrade completed') + else: + log.debug(f'{upgrade_class.__name__} upgrade not required') + for upgrade_sub_class in upgrade_class.__subclasses__(): + _process_upgrade(connection, upgrade_sub_class) + + +upgrade_database(nzb2media.databases.DBConnection(), nzb2media.databases.InitialSchema) diff --git a/nzb2media/exceptions.py b/nzb2media/exceptions.py new file mode 100644 index 00000000..e69de29b diff --git a/nzb2media/main_db.py b/nzb2media/main_db.py deleted file mode 100644 index e842e988..00000000 --- a/nzb2media/main_db.py +++ /dev/null @@ -1,234 +0,0 @@ -from __future__ import annotations - -import logging -import re -import sqlite3 -import time - -import nzb2media - -log = logging.getLogger(__name__) -log.addHandler(logging.NullHandler()) - - -def db_filename(filename: str = 'nzbtomedia.db', suffix: str | None = None): - """Return the correct location of the database file. - - @param filename: The sqlite database filename to use. If not specified, will be made to be nzbtomedia.db - @param suffix: The suffix to append to the filename. A '.' will be added - automatically, i.e. suffix='v0' will make dbfile.db.v0 - @return: the correct location of the database file. - """ - if suffix: - filename = f'{filename}.{suffix}' - return nzb2media.os.path.join(nzb2media.APP_ROOT, filename) - - -class DBConnection: - def __init__(self, filename='nzbtomedia.db'): - self.filename = filename - self.connection = sqlite3.connect(db_filename(filename), 20) - self.connection.row_factory = sqlite3.Row - - def check_db_version(self): - result = None - try: - result = self.select('SELECT db_version FROM db_version') - except sqlite3.OperationalError as error: - if 'no such table: db_version' in error.args[0]: - return 0 - if result: - return int(result[0]['db_version']) - return 0 - - def fetch(self, query, args=None): - if query is None: - return - sql_result = None - attempt = 0 - while attempt < 5: - try: - if args is None: - log.debug(f'{self.filename}: {query}') - cursor = self.connection.cursor() - cursor.execute(query) - sql_result = cursor.fetchone()[0] - else: - log.debug(f'{self.filename}: {query} with args {args}') - cursor = self.connection.cursor() - cursor.execute(query, args) - sql_result = cursor.fetchone()[0] - # get out of the connection attempt loop since we were successful - break - except sqlite3.OperationalError as error: - if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]: - log.warning(f'DB error: {error}') - attempt += 1 - time.sleep(1) - else: - log.error(f'DB error: {error}') - raise - except sqlite3.DatabaseError as error: - log.error(f'Fatal error executing query: {error}') - raise - return sql_result - - def mass_action(self, querylist, log_transaction=False): - if querylist is None: - return - sql_result = [] - attempt = 0 - while attempt < 5: - try: - for query in querylist: - if len(query) == 1: - if log_transaction: - log.debug(query[0]) - sql_result.append(self.connection.execute(query[0])) - elif len(query) > 1: - if log_transaction: - log.debug(f'{query[0]} with args {query[1]}') - sql_result.append(self.connection.execute(query[0], query[1])) - self.connection.commit() - log.debug(f'Transaction with {len(querylist)} query\'s executed') - return sql_result - except sqlite3.OperationalError as error: - sql_result = [] - if self.connection: - self.connection.rollback() - if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]: - log.warning(f'DB error: {error}') - attempt += 1 - time.sleep(1) - else: - log.error(f'DB error: {error}') - raise - except sqlite3.DatabaseError as error: - if self.connection: - self.connection.rollback() - log.error(f'Fatal error executing query: {error}') - raise - return sql_result - - def action(self, query, args=None): - if query is None: - return - sql_result = None - attempt = 0 - while attempt < 5: - try: - if args is None: - log.debug(f'{self.filename}: {query}') - sql_result = self.connection.execute(query) - else: - log.debug(f'{self.filename}: {query} with args {args}') - sql_result = self.connection.execute(query, args) - self.connection.commit() - # get out of the connection attempt loop since we were successful - break - except sqlite3.OperationalError as error: - if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]: - log.warning(f'DB error: {error}') - attempt += 1 - time.sleep(1) - else: - log.error(f'DB error: {error}') - raise - except sqlite3.DatabaseError as error: - log.error(f'Fatal error executing query: {error}') - raise - return sql_result - - def select(self, query, args=None): - sql_results = self.action(query, args).fetchall() - if sql_results is None: - return [] - return sql_results - - def upsert(self, table_name, value_dict, key_dict): - def gen_params(my_dict): - return [f'{k} = ?' for k in my_dict.keys()] - - changes_before = self.connection.total_changes - items = list(value_dict.values()) + list(key_dict.values()) - _params = ', '.join(gen_params(value_dict)) - _conditions = ' AND '.join(gen_params(key_dict)) - self.action(f'UPDATE {table_name} SET {_params} WHERE {_conditions}', items) - if self.connection.total_changes == changes_before: - _cols = ', '.join(map(str, value_dict.keys())) - values = list(value_dict.values()) - _vals = ', '.join(['?'] * len(values)) - self.action(f'INSERT OR IGNORE INTO {table_name} ({_cols}) VALUES ({_vals})', values) - - def table_info(self, table_name): - # FIXME ? binding is not supported here, but I cannot find a way to escape a string manually - cursor = self.connection.execute(f'PRAGMA table_info({table_name})') - return {column['name']: {'type': column['type']} for column in cursor} - - -def sanity_check_database(connection, sanity_check): - sanity_check(connection).check() - - -class DBSanityCheck: - def __init__(self, connection): - self.connection = connection - - def check(self): - pass - - -# =============== -# = Upgrade API = -# =============== -def upgrade_database(connection, schema): - log.info('Checking database structure...') - _process_upgrade(connection, schema) - - -def pretty_name(class_name): - return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)]) - - -def _process_upgrade(connection, upgrade_class): - instance = upgrade_class(connection) - log.debug(f'Checking {pretty_name(upgrade_class.__name__)} database upgrade') - if not instance.test(): - log.info(f'Database upgrade required: {pretty_name(upgrade_class.__name__)}') - try: - instance.execute() - except sqlite3.DatabaseError as error: - print(f'Error in {upgrade_class.__name__}: {error}') - raise - log.debug(f'{upgrade_class.__name__} upgrade completed') - else: - log.debug(f'{upgrade_class.__name__} upgrade not required') - for upgrade_sub_class in upgrade_class.__subclasses__(): - _process_upgrade(connection, upgrade_sub_class) - - -# Base migration class. All future DB changes should be subclassed from this class -class SchemaUpgrade: - def __init__(self, connection): - self.connection = connection - - def has_table(self, table_name): - return len(self.connection.action('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)).fetchall()) > 0 - - def has_column(self, table_name, column): - return column in self.connection.table_info(table_name) - - def add_column(self, table, column, data_type='NUMERIC', default=0): - self.connection.action(f'ALTER TABLE {table} ADD {column} {data_type}') - self.connection.action(f'UPDATE {table} SET {column} = ?', (default,)) - - def check_db_version(self): - result = self.connection.select('SELECT db_version FROM db_version') - if result: - return int(result[-1]['db_version']) - return 0 - - def inc_db_version(self): - new_version = self.check_db_version() + 1 - self.connection.action('UPDATE db_version SET db_version = ?', [new_version]) - return new_version diff --git a/nzb2media/processor/nzb.py b/nzb2media/processor/nzb.py index 396cfcae..9bac24ba 100644 --- a/nzb2media/processor/nzb.py +++ b/nzb2media/processor/nzb.py @@ -4,8 +4,8 @@ import datetime import logging import nzb2media +import nzb2media.databases import nzb2media.nzb -from nzb2media import main_db from nzb2media.auto_process import books from nzb2media.auto_process import comics from nzb2media.auto_process import games @@ -34,7 +34,7 @@ def process(*, input_directory, input_name=None, status=0, client_agent='manual' download_id = get_nzoid(input_name) if client_agent != 'manual' and not nzb2media.DOWNLOAD_INFO: log.debug(f'Adding NZB download info for directory {input_directory} to database') - my_db = main_db.DBConnection() + my_db = nzb2media.databases.DBConnection() input_directory1 = input_directory input_name1 = input_name try: diff --git a/nzb2media/utils/download_info.py b/nzb2media/utils/download_info.py index f05ed676..cc428fba 100644 --- a/nzb2media/utils/download_info.py +++ b/nzb2media/utils/download_info.py @@ -3,11 +3,11 @@ from __future__ import annotations import datetime import logging -from nzb2media import main_db +import nzb2media.databases log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) -database = main_db.DBConnection() +database = nzb2media.databases.DBConnection() def update_download_info_status(input_name, status): diff --git a/tests/import_test.py b/tests/import_test.py index 800bd4ef..2a487861 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -135,9 +135,6 @@ def test_import_nzb2media(): import nzb2media.databases assert nzb2media.databases - import nzb2media.main_db - assert nzb2media.main_db - import nzb2media.scene_exceptions assert nzb2media.scene_exceptions