nzbToMedia/core/main_db.py
2019-01-06 12:09:07 -05:00

291 lines
10 KiB
Python

# coding=utf-8
from __future__ import print_function
import re
import sqlite3
import time
from six import text_type
import core
from core import logger
def db_filename(filename='nzbtomedia.db', suffix=None):
"""
@param filename: The sqlite database filename to use. If not specified,
will be made to be nzbtomedia.db
@param suffix: The suffix to append to the filename. A '.' will be added
automatically, i.e. suffix='v0' will make dbfile.db.v0
@return: the correct location of the database file.
"""
if suffix:
filename = '{0}.{1}'.format(filename, suffix)
return core.os.path.join(core.APP_ROOT, filename)
class DBConnection(object):
def __init__(self, filename='nzbtomedia.db', suffix=None, row_type=None):
self.filename = filename
self.connection = sqlite3.connect(db_filename(filename), 20)
if row_type == 'dict':
self.connection.row_factory = self._dict_factory
else:
self.connection.row_factory = sqlite3.Row
def check_db_version(self):
result = None
try:
result = self.select('SELECT db_version FROM db_version')
except sqlite3.OperationalError as e:
if 'no such table: db_version' in e.args[0]:
return 0
if result:
return int(result[0]['db_version'])
else:
return 0
def fetch(self, query, args=None):
if query is None:
return
sql_result = None
attempt = 0
while attempt < 5:
try:
if args is None:
logger.log('{name}: {query}'.format(name=self.filename, query=query), logger.DB)
cursor = self.connection.cursor()
cursor.execute(query)
sql_result = cursor.fetchone()[0]
else:
logger.log('{name}: {query} with args {args}'.format
(name=self.filename, query=query, args=args), logger.DB)
cursor = self.connection.cursor()
cursor.execute(query, args)
sql_result = cursor.fetchone()[0]
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as error:
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
logger.log(u'DB error: {msg}'.format(msg=error), logger.WARNING)
attempt += 1
time.sleep(1)
else:
logger.log(u'DB error: {msg}'.format(msg=error), logger.ERROR)
raise
except sqlite3.DatabaseError as error:
logger.log(u'Fatal error executing query: {msg}'.format(msg=error), logger.ERROR)
raise
return sql_result
def mass_action(self, querylist, log_transaction=False):
if querylist is None:
return
sql_result = []
attempt = 0
while attempt < 5:
try:
for qu in querylist:
if len(qu) == 1:
if log_transaction:
logger.log(qu[0], logger.DEBUG)
sql_result.append(self.connection.execute(qu[0]))
elif len(qu) > 1:
if log_transaction:
logger.log(u'{query} with args {args}'.format(query=qu[0], args=qu[1]), logger.DEBUG)
sql_result.append(self.connection.execute(qu[0], qu[1]))
self.connection.commit()
logger.log(u'Transaction with {x} query\'s executed'.format(x=len(querylist)), logger.DEBUG)
return sql_result
except sqlite3.OperationalError as error:
sql_result = []
if self.connection:
self.connection.rollback()
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
logger.log(u'DB error: {msg}'.format(msg=error), logger.WARNING)
attempt += 1
time.sleep(1)
else:
logger.log(u'DB error: {msg}'.format(msg=error), logger.ERROR)
raise
except sqlite3.DatabaseError as error:
if self.connection:
self.connection.rollback()
logger.log(u'Fatal error executing query: {msg}'.format(msg=error), logger.ERROR)
raise
return sql_result
def action(self, query, args=None):
if query is None:
return
sql_result = None
attempt = 0
while attempt < 5:
try:
if args is None:
logger.log(u'{name}: {query}'.format(name=self.filename, query=query), logger.DB)
sql_result = self.connection.execute(query)
else:
logger.log(u'{name}: {query} with args {args}'.format
(name=self.filename, query=query, args=args), logger.DB)
sql_result = self.connection.execute(query, args)
self.connection.commit()
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as error:
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
logger.log(u'DB error: {msg}'.format(msg=error), logger.WARNING)
attempt += 1
time.sleep(1)
else:
logger.log(u'DB error: {msg}'.format(msg=error), logger.ERROR)
raise
except sqlite3.DatabaseError as error:
logger.log(u'Fatal error executing query: {msg}'.format(msg=error), logger.ERROR)
raise
return sql_result
def select(self, query, args=None):
sql_results = self.action(query, args).fetchall()
if sql_results is None:
return []
return sql_results
def upsert(self, table_name, value_dict, key_dict):
def gen_params(my_dict):
return [
'{key} = ?'.format(key=k)
for k in my_dict.keys()
]
changes_before = self.connection.total_changes
items = list(value_dict.values()) + list(key_dict.values())
self.action(
'UPDATE {table} '
'SET {params} '
'WHERE {conditions}'.format(
table=table_name,
params=', '.join(gen_params(value_dict)),
conditions=' AND '.join(gen_params(key_dict))
),
items
)
if self.connection.total_changes == changes_before:
self.action(
'INSERT OR IGNORE INTO {table} ({columns}) '
'VALUES ({values})'.format(
table=table_name,
columns=', '.join(map(text_type, value_dict.keys())),
values=', '.join(['?'] * len(value_dict.values()))
),
list(value_dict.values())
)
def table_info(self, table_name):
# FIXME ? binding is not supported here, but I cannot find a way to escape a string manually
cursor = self.connection.execute('PRAGMA table_info({0})'.format(table_name))
return {
column['name']: {'type': column['type']}
for column in cursor
}
# http://stackoverflow.com/questions/3300464/how-can-i-get-dict-from-sqlite-query
def _dict_factory(self, cursor, row):
return {
col[0]: row[idx]
for idx, col in enumerate(cursor.description)
}
def sanity_check_database(connection, sanity_check):
sanity_check(connection).check()
class DBSanityCheck(object):
def __init__(self, connection):
self.connection = connection
def check(self):
pass
# ===============
# = Upgrade API =
# ===============
def upgrade_database(connection, schema):
logger.log(u'Checking database structure...', logger.MESSAGE)
_process_upgrade(connection, schema)
def pretty_name(class_name):
return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)])
def _process_upgrade(connection, upgrade_class):
instance = upgrade_class(connection)
logger.log(u'Checking {name} database upgrade'.format
(name=pretty_name(upgrade_class.__name__)), logger.DEBUG)
if not instance.test():
logger.log(u'Database upgrade required: {name}'.format
(name=pretty_name(upgrade_class.__name__)), logger.MESSAGE)
try:
instance.execute()
except sqlite3.DatabaseError as error:
print(u'Error in {name}: {msg}'.format
(name=upgrade_class.__name__, msg=error))
raise
logger.log(u'{name} upgrade completed'.format
(name=upgrade_class.__name__), logger.DEBUG)
else:
logger.log(u'{name} upgrade not required'.format
(name=upgrade_class.__name__), logger.DEBUG)
for upgradeSubClass in upgrade_class.__subclasses__():
_process_upgrade(connection, upgradeSubClass)
# Base migration class. All future DB changes should be subclassed from this class
class SchemaUpgrade(object):
def __init__(self, connection):
self.connection = connection
def has_table(self, table_name):
return len(self.connection.action('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)).fetchall()) > 0
def has_column(self, table_name, column):
return column in self.connection.table_info(table_name)
def add_column(self, table, column, data_type='NUMERIC', default=0):
self.connection.action('ALTER TABLE {0} ADD {1} {2}'.format(table, column, data_type))
self.connection.action('UPDATE {0} SET {1} = ?'.format(table, column), (default,))
def check_db_version(self):
result = self.connection.select('SELECT db_version FROM db_version')
if result:
return int(result[-1]['db_version'])
else:
return 0
def inc_db_version(self):
new_version = self.check_db_version() + 1
self.connection.action('UPDATE db_version SET db_version = ?', [new_version])
return new_version