diff --git a/core/main_db.py b/core/main_db.py index 4be8aa3c..d9d7c1b9 100644 --- a/core/main_db.py +++ b/core/main_db.py @@ -11,11 +11,37 @@ import re import sqlite3 import time -from six import text_type +from six import text_type, PY2 import core from core import logger +if PY2: + class Row(sqlite3.Row, object): + """ + Row factory that uses Byte Strings for keys. + + The sqlite3.Row in Python 2 does not support unicode keys. + This overrides __getitem__ to attempt to encode the key to bytes first. + """ + + def __getitem__(self, item): + """ + Get an item from the row by index or key. + + :param item: Index or Key of item to return. + :return: An item from the sqlite3.Row. + """ + try: + # sqlite3.Row column names should be Bytes in Python 2 + item = item.encode() + except AttributeError: + pass # assume item is a numeric index + + return super(Row, self).__getitem__(item) +else: + from sqlite3 import Row + def db_filename(filename='nzbtomedia.db', suffix=None): """ @@ -37,10 +63,7 @@ class DBConnection(object): self.filename = filename self.connection = sqlite3.connect(db_filename(filename), 20) - if row_type == 'dict': - self.connection.row_factory = self._dict_factory - else: - self.connection.row_factory = sqlite3.Row + self.connection.row_factory = Row def check_db_version(self): result = None @@ -214,13 +237,6 @@ class DBConnection(object): for column in cursor } - # http://stackoverflow.com/questions/3300464/how-can-i-get-dict-from-sqlite-query - def _dict_factory(self, cursor, row): - return { - col[0]: row[idx] - for idx, col in enumerate(cursor.description) - } - def sanity_check_database(connection, sanity_check): sanity_check(connection).check()