mirror of
https://github.com/clinton-hall/nzbToMedia.git
synced 2025-08-14 10:36:52 -07:00
Update vendored beets to 1.6.0
Updates colorama to 0.4.6 Adds confuse version 1.7.0 Updates jellyfish to 0.9.0 Adds mediafile 0.10.1 Updates munkres to 1.1.4 Updates musicbrainzngs to 0.7.1 Updates mutagen to 1.46.0 Updates pyyaml to 6.0 Updates unidecode to 1.3.6
This commit is contained in:
parent
5073ec0c6f
commit
56c6773c6b
385 changed files with 25143 additions and 18080 deletions
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This file is part of beets.
|
||||
# Copyright 2016, Adrian Sampson.
|
||||
#
|
||||
|
@ -16,7 +15,6 @@
|
|||
"""DBCore is an abstract database package that forms the basis for beets'
|
||||
Library.
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
from .db import Model, Database
|
||||
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This file is part of beets.
|
||||
# Copyright 2016, Adrian Sampson.
|
||||
#
|
||||
|
@ -15,22 +14,21 @@
|
|||
|
||||
"""The central Model and Database constructs for DBCore.
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
import sqlite3
|
||||
import contextlib
|
||||
import collections
|
||||
|
||||
import beets
|
||||
from beets.util.functemplate import Template
|
||||
from beets.util import functemplate
|
||||
from beets.util import py3_path
|
||||
from beets.dbcore import types
|
||||
from .query import MatchQuery, NullSort, TrueQuery
|
||||
import six
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
class DBAccessError(Exception):
|
||||
|
@ -42,20 +40,30 @@ class DBAccessError(Exception):
|
|||
"""
|
||||
|
||||
|
||||
class FormattedMapping(collections.Mapping):
|
||||
class FormattedMapping(Mapping):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
The accessor `mapping[key]` returns the formatted version of
|
||||
`model[key]` as a unicode string.
|
||||
|
||||
The `included_keys` parameter allows filtering the fields that are
|
||||
returned. By default all fields are returned. Limiting to specific keys can
|
||||
avoid expensive per-item database queries.
|
||||
|
||||
If `for_path` is true, all path separators in the formatted values
|
||||
are replaced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, for_path=False):
|
||||
ALL_KEYS = '*'
|
||||
|
||||
def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
self.model_keys = model.keys(True)
|
||||
if included_keys == self.ALL_KEYS:
|
||||
# Performance note: this triggers a database query.
|
||||
self.model_keys = self.model.keys(True)
|
||||
else:
|
||||
self.model_keys = included_keys
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.model_keys:
|
||||
|
@ -72,7 +80,7 @@ class FormattedMapping(collections.Mapping):
|
|||
def get(self, key, default=None):
|
||||
if default is None:
|
||||
default = self.model._type(key).format(None)
|
||||
return super(FormattedMapping, self).get(key, default)
|
||||
return super().get(key, default)
|
||||
|
||||
def _get_formatted(self, model, key):
|
||||
value = model._type(key).format(model.get(key))
|
||||
|
@ -81,6 +89,11 @@ class FormattedMapping(collections.Mapping):
|
|||
|
||||
if self.for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].as_str()
|
||||
sep_drive = beets.config['drive_sep_replace'].as_str()
|
||||
|
||||
if re.match(r'^\w:', value):
|
||||
value = re.sub(r'(?<=^\w):', sep_drive, value)
|
||||
|
||||
for sep in (os.path.sep, os.path.altsep):
|
||||
if sep:
|
||||
value = value.replace(sep, sep_repl)
|
||||
|
@ -88,11 +101,105 @@ class FormattedMapping(collections.Mapping):
|
|||
return value
|
||||
|
||||
|
||||
class LazyConvertDict:
|
||||
"""Lazily convert types for attributes fetched from the database
|
||||
"""
|
||||
|
||||
def __init__(self, model_cls):
|
||||
"""Initialize the object empty
|
||||
"""
|
||||
self.data = {}
|
||||
self.model_cls = model_cls
|
||||
self._converted = {}
|
||||
|
||||
def init(self, data):
|
||||
"""Set the base data that should be lazily converted
|
||||
"""
|
||||
self.data = data
|
||||
|
||||
def _convert(self, key, value):
|
||||
"""Convert the attribute type according the the SQL type
|
||||
"""
|
||||
return self.model_cls._type(key).from_sql(value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Set an attribute value, assume it's already converted
|
||||
"""
|
||||
self._converted[key] = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get an attribute value, converting the type on demand
|
||||
if needed
|
||||
"""
|
||||
if key in self._converted:
|
||||
return self._converted[key]
|
||||
elif key in self.data:
|
||||
value = self._convert(key, self.data[key])
|
||||
self._converted[key] = value
|
||||
return value
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Delete both converted and base data
|
||||
"""
|
||||
if key in self._converted:
|
||||
del self._converted[key]
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
|
||||
def keys(self):
|
||||
"""Get a list of available field names for this object.
|
||||
"""
|
||||
return list(self._converted.keys()) + list(self.data.keys())
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of the object.
|
||||
"""
|
||||
new = self.__class__(self.model_cls)
|
||||
new.data = self.data.copy()
|
||||
new._converted = self._converted.copy()
|
||||
return new
|
||||
|
||||
# Act like a dictionary.
|
||||
|
||||
def update(self, values):
|
||||
"""Assign all values in the given dict.
|
||||
"""
|
||||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist.
|
||||
"""
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys()
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
return iter(self.keys())
|
||||
|
||||
|
||||
# Abstract base for model classes.
|
||||
|
||||
class Model(object):
|
||||
class Model:
|
||||
"""An abstract object representing an object in the database. Model
|
||||
objects act like dictionaries (i.e., the allow subscript access like
|
||||
objects act like dictionaries (i.e., they allow subscript access like
|
||||
``obj['field']``). The same field set is available via attribute
|
||||
access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
|
||||
available:
|
||||
|
@ -143,12 +250,22 @@ class Model(object):
|
|||
are subclasses of `Sort`.
|
||||
"""
|
||||
|
||||
_queries = {}
|
||||
"""Named queries that use a field-like `name:value` syntax but which
|
||||
do not relate to any specific field.
|
||||
"""
|
||||
|
||||
_always_dirty = False
|
||||
"""By default, fields only become "dirty" when their value actually
|
||||
changes. Enabling this flag marks fields as dirty even when the new
|
||||
value is the same as the old value (e.g., `o.f = o.f`).
|
||||
"""
|
||||
|
||||
_revision = -1
|
||||
"""A revision number from when the model was loaded from or written
|
||||
to the database.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
|
@ -172,8 +289,8 @@ class Model(object):
|
|||
"""
|
||||
self._db = db
|
||||
self._dirty = set()
|
||||
self._values_fixed = {}
|
||||
self._values_flex = {}
|
||||
self._values_fixed = LazyConvertDict(self)
|
||||
self._values_flex = LazyConvertDict(self)
|
||||
|
||||
# Initial contents.
|
||||
self.update(values)
|
||||
|
@ -187,23 +304,25 @@ class Model(object):
|
|||
ordinary construction are bypassed.
|
||||
"""
|
||||
obj = cls(db)
|
||||
for key, value in fixed_values.items():
|
||||
obj._values_fixed[key] = cls._type(key).from_sql(value)
|
||||
for key, value in flex_values.items():
|
||||
obj._values_flex[key] = cls._type(key).from_sql(value)
|
||||
|
||||
obj._values_fixed.init(fixed_values)
|
||||
obj._values_flex.init(flex_values)
|
||||
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return '{0}({1})'.format(
|
||||
return '{}({})'.format(
|
||||
type(self).__name__,
|
||||
', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
|
||||
', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
|
||||
)
|
||||
|
||||
def clear_dirty(self):
|
||||
"""Mark all fields as *clean* (i.e., not needing to be stored to
|
||||
the database).
|
||||
the database). Also update the revision.
|
||||
"""
|
||||
self._dirty = set()
|
||||
if self._db:
|
||||
self._revision = self._db.revision
|
||||
|
||||
def _check_db(self, need_id=True):
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
|
@ -212,10 +331,10 @@ class Model(object):
|
|||
"""
|
||||
if not self._db:
|
||||
raise ValueError(
|
||||
u'{0} has no database'.format(type(self).__name__)
|
||||
'{} has no database'.format(type(self).__name__)
|
||||
)
|
||||
if need_id and not self.id:
|
||||
raise ValueError(u'{0} has no id'.format(type(self).__name__))
|
||||
raise ValueError('{} has no id'.format(type(self).__name__))
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of the model object.
|
||||
|
@ -243,19 +362,32 @@ class Model(object):
|
|||
"""
|
||||
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
def _get(self, key, default=None, raise_=False):
|
||||
"""Get the value for a field, or `default`. Alternatively,
|
||||
raise a KeyError if the field is not available.
|
||||
"""
|
||||
getters = self._getters()
|
||||
if key in getters: # Computed.
|
||||
return getters[key](self)
|
||||
elif key in self._fields: # Fixed.
|
||||
return self._values_fixed.get(key, self._type(key).null)
|
||||
if key in self._values_fixed:
|
||||
return self._values_fixed[key]
|
||||
else:
|
||||
return self._type(key).null
|
||||
elif key in self._values_flex: # Flexible.
|
||||
return self._values_flex[key]
|
||||
else:
|
||||
elif raise_:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return default
|
||||
|
||||
get = _get
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
"""
|
||||
return self._get(key, raise_=True)
|
||||
|
||||
def _setitem(self, key, value):
|
||||
"""Assign the value for a field, return whether new and old value
|
||||
|
@ -290,12 +422,12 @@ class Model(object):
|
|||
if key in self._values_flex: # Flexible.
|
||||
del self._values_flex[key]
|
||||
self._dirty.add(key) # Mark for dropping on store.
|
||||
elif key in self._fields: # Fixed
|
||||
setattr(self, key, self._type(key).null)
|
||||
elif key in self._getters(): # Computed.
|
||||
raise KeyError(u'computed field {0} cannot be deleted'.format(key))
|
||||
elif key in self._fields: # Fixed.
|
||||
raise KeyError(u'fixed field {0} cannot be deleted'.format(key))
|
||||
raise KeyError(f'computed field {key} cannot be deleted')
|
||||
else:
|
||||
raise KeyError(u'no such field {0}'.format(key))
|
||||
raise KeyError(f'no such field {key}')
|
||||
|
||||
def keys(self, computed=False):
|
||||
"""Get a list of available field names for this object. The
|
||||
|
@ -330,19 +462,10 @@ class Model(object):
|
|||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist.
|
||||
"""
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys(True)
|
||||
return key in self.keys(computed=True)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the available field names (excluding computed
|
||||
|
@ -354,22 +477,22 @@ class Model(object):
|
|||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError(u'model has no attribute {0!r}'.format(key))
|
||||
raise AttributeError(f'model has no attribute {key!r}')
|
||||
else:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(u'no such field {0!r}'.format(key))
|
||||
raise AttributeError(f'no such field {key!r}')
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super(Model, self).__setattr__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
super(Model, self).__delattr__(key)
|
||||
super().__delattr__(key)
|
||||
else:
|
||||
del self[key]
|
||||
|
||||
|
@ -398,7 +521,7 @@ class Model(object):
|
|||
with self._db.transaction() as tx:
|
||||
# Main table update.
|
||||
if assignments:
|
||||
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
|
||||
query = 'UPDATE {} SET {} WHERE id=?'.format(
|
||||
self._table, assignments
|
||||
)
|
||||
subvars.append(self.id)
|
||||
|
@ -409,7 +532,7 @@ class Model(object):
|
|||
if key in self._dirty:
|
||||
self._dirty.remove(key)
|
||||
tx.mutate(
|
||||
'INSERT INTO {0} '
|
||||
'INSERT INTO {} '
|
||||
'(entity_id, key, value) '
|
||||
'VALUES (?, ?, ?);'.format(self._flex_table),
|
||||
(self.id, key, value),
|
||||
|
@ -418,7 +541,7 @@ class Model(object):
|
|||
# Deleted flexible attributes.
|
||||
for key in self._dirty:
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} '
|
||||
'DELETE FROM {} '
|
||||
'WHERE entity_id=? AND key=?'.format(self._flex_table),
|
||||
(self.id, key)
|
||||
)
|
||||
|
@ -427,12 +550,18 @@ class Model(object):
|
|||
|
||||
def load(self):
|
||||
"""Refresh the object's metadata from the library database.
|
||||
|
||||
If check_revision is true, the database is only queried loaded when a
|
||||
transaction has been committed since the item was last loaded.
|
||||
"""
|
||||
self._check_db()
|
||||
if not self._dirty and self._db.revision == self._revision:
|
||||
# Exit early
|
||||
return
|
||||
stored_obj = self._db._get(type(self), self.id)
|
||||
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
|
||||
self._values_fixed = {}
|
||||
self._values_flex = {}
|
||||
assert stored_obj is not None, f"object {self.id} not in DB"
|
||||
self._values_fixed = LazyConvertDict(self)
|
||||
self._values_flex = LazyConvertDict(self)
|
||||
self.update(dict(stored_obj))
|
||||
self.clear_dirty()
|
||||
|
||||
|
@ -442,11 +571,11 @@ class Model(object):
|
|||
self._check_db()
|
||||
with self._db.transaction() as tx:
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} WHERE id=?'.format(self._table),
|
||||
f'DELETE FROM {self._table} WHERE id=?',
|
||||
(self.id,)
|
||||
)
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
|
||||
f'DELETE FROM {self._flex_table} WHERE entity_id=?',
|
||||
(self.id,)
|
||||
)
|
||||
|
||||
|
@ -464,7 +593,7 @@ class Model(object):
|
|||
|
||||
with self._db.transaction() as tx:
|
||||
new_id = tx.mutate(
|
||||
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
|
||||
f'INSERT INTO {self._table} DEFAULT VALUES'
|
||||
)
|
||||
self.id = new_id
|
||||
self.added = time.time()
|
||||
|
@ -479,11 +608,11 @@ class Model(object):
|
|||
|
||||
_formatter = FormattedMapping
|
||||
|
||||
def formatted(self, for_path=False):
|
||||
def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
|
||||
"""Get a mapping containing all values on this object formatted
|
||||
as human-readable unicode strings.
|
||||
"""
|
||||
return self._formatter(self, for_path)
|
||||
return self._formatter(self, included_keys, for_path)
|
||||
|
||||
def evaluate_template(self, template, for_path=False):
|
||||
"""Evaluate a template (a string or a `Template` object) using
|
||||
|
@ -491,9 +620,9 @@ class Model(object):
|
|||
separators will be added to the template.
|
||||
"""
|
||||
# Perform substitution.
|
||||
if isinstance(template, six.string_types):
|
||||
template = Template(template)
|
||||
return template.substitute(self.formatted(for_path),
|
||||
if isinstance(template, str):
|
||||
template = functemplate.template(template)
|
||||
return template.substitute(self.formatted(for_path=for_path),
|
||||
self._template_funcs())
|
||||
|
||||
# Parsing.
|
||||
|
@ -502,8 +631,8 @@ class Model(object):
|
|||
def _parse(cls, key, string):
|
||||
"""Parse a string as a value for the given key.
|
||||
"""
|
||||
if not isinstance(string, six.string_types):
|
||||
raise TypeError(u"_parse() argument must be a string")
|
||||
if not isinstance(string, str):
|
||||
raise TypeError("_parse() argument must be a string")
|
||||
|
||||
return cls._type(key).parse(string)
|
||||
|
||||
|
@ -515,11 +644,13 @@ class Model(object):
|
|||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
class Results(object):
|
||||
class Results:
|
||||
"""An item query result set. Iterating over the collection lazily
|
||||
constructs LibModel objects that reflect database rows.
|
||||
"""
|
||||
def __init__(self, model_class, rows, db, query=None, sort=None):
|
||||
|
||||
def __init__(self, model_class, rows, db, flex_rows,
|
||||
query=None, sort=None):
|
||||
"""Create a result set that will construct objects of type
|
||||
`model_class`.
|
||||
|
||||
|
@ -539,6 +670,7 @@ class Results(object):
|
|||
self.db = db
|
||||
self.query = query
|
||||
self.sort = sort
|
||||
self.flex_rows = flex_rows
|
||||
|
||||
# We keep a queue of rows we haven't yet consumed for
|
||||
# materialization. We preserve the original total number of
|
||||
|
@ -560,6 +692,10 @@ class Results(object):
|
|||
a `Results` object a second time should be much faster than the
|
||||
first.
|
||||
"""
|
||||
|
||||
# Index flexible attributes by the item ID, so we have easier access
|
||||
flex_attrs = self._get_indexed_flex_attrs()
|
||||
|
||||
index = 0 # Position in the materialized objects.
|
||||
while index < len(self._objects) or self._rows:
|
||||
# Are there previously-materialized objects to produce?
|
||||
|
@ -572,7 +708,7 @@ class Results(object):
|
|||
else:
|
||||
while self._rows:
|
||||
row = self._rows.pop(0)
|
||||
obj = self._make_model(row)
|
||||
obj = self._make_model(row, flex_attrs.get(row['id'], {}))
|
||||
# If there is a slow-query predicate, ensurer that the
|
||||
# object passes it.
|
||||
if not self.query or self.query.match(obj):
|
||||
|
@ -594,20 +730,24 @@ class Results(object):
|
|||
# Objects are pre-sorted (i.e., by the database).
|
||||
return self._get_objects()
|
||||
|
||||
def _make_model(self, row):
|
||||
# Get the flexible attributes for the object.
|
||||
with self.db.transaction() as tx:
|
||||
flex_rows = tx.query(
|
||||
'SELECT * FROM {0} WHERE entity_id=?'.format(
|
||||
self.model_class._flex_table
|
||||
),
|
||||
(row['id'],)
|
||||
)
|
||||
def _get_indexed_flex_attrs(self):
|
||||
""" Index flexible attributes by the entity id they belong to
|
||||
"""
|
||||
flex_values = {}
|
||||
for row in self.flex_rows:
|
||||
if row['entity_id'] not in flex_values:
|
||||
flex_values[row['entity_id']] = {}
|
||||
|
||||
flex_values[row['entity_id']][row['key']] = row['value']
|
||||
|
||||
return flex_values
|
||||
|
||||
def _make_model(self, row, flex_values={}):
|
||||
""" Create a Model object for the given row
|
||||
"""
|
||||
cols = dict(row)
|
||||
values = dict((k, v) for (k, v) in cols.items()
|
||||
if not k[:4] == 'flex')
|
||||
flex_values = dict((row['key'], row['value']) for row in flex_rows)
|
||||
values = {k: v for (k, v) in cols.items()
|
||||
if not k[:4] == 'flex'}
|
||||
|
||||
# Construct the Python object
|
||||
obj = self.model_class._awaken(self.db, values, flex_values)
|
||||
|
@ -656,7 +796,7 @@ class Results(object):
|
|||
next(it)
|
||||
return next(it)
|
||||
except StopIteration:
|
||||
raise IndexError(u'result index {0} out of range'.format(n))
|
||||
raise IndexError(f'result index {n} out of range')
|
||||
|
||||
def get(self):
|
||||
"""Return the first matching object, or None if no objects
|
||||
|
@ -669,10 +809,16 @@ class Results(object):
|
|||
return None
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
class Transaction:
|
||||
"""A context manager for safe, concurrent access to the database.
|
||||
All SQL commands should be executed through a transaction.
|
||||
"""
|
||||
|
||||
_mutated = False
|
||||
"""A flag storing whether a mutation has been executed in the
|
||||
current transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
|
@ -694,12 +840,15 @@ class Transaction(object):
|
|||
entered but not yet exited transaction. If it is the last active
|
||||
transaction, the database updates are committed.
|
||||
"""
|
||||
# Beware of races; currently secured by db._db_lock
|
||||
self.db.revision += self._mutated
|
||||
with self.db._tx_stack() as stack:
|
||||
assert stack.pop() is self
|
||||
empty = not stack
|
||||
if empty:
|
||||
# Ending a "root" transaction. End the SQLite transaction.
|
||||
self.db._connection().commit()
|
||||
self._mutated = False
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement, subvals=()):
|
||||
|
@ -715,7 +864,6 @@ class Transaction(object):
|
|||
"""
|
||||
try:
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.lastrowid
|
||||
except sqlite3.OperationalError as e:
|
||||
# In two specific cases, SQLite reports an error while accessing
|
||||
# the underlying database file. We surface these exceptions as
|
||||
|
@ -725,26 +873,41 @@ class Transaction(object):
|
|||
raise DBAccessError(e.args[0])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
self._mutated = True
|
||||
return cursor.lastrowid
|
||||
|
||||
def script(self, statements):
|
||||
"""Execute a string containing multiple SQL statements."""
|
||||
# We don't know whether this mutates, but quite likely it does.
|
||||
self._mutated = True
|
||||
self.db._connection().executescript(statements)
|
||||
|
||||
|
||||
class Database(object):
|
||||
class Database:
|
||||
"""A container for Model objects that wraps an SQLite database as
|
||||
the backend.
|
||||
"""
|
||||
|
||||
_models = ()
|
||||
"""The Model subclasses representing tables in this database.
|
||||
"""
|
||||
|
||||
supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
|
||||
"""Whether or not the current version of SQLite supports extensions"""
|
||||
|
||||
revision = 0
|
||||
"""The current revision of the database. To be increased whenever
|
||||
data is written in a transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, path, timeout=5.0):
|
||||
self.path = path
|
||||
self.timeout = timeout
|
||||
|
||||
self._connections = {}
|
||||
self._tx_stacks = defaultdict(list)
|
||||
self._extensions = []
|
||||
|
||||
# A lock to protect the _connections and _tx_stacks maps, which
|
||||
# both map thread IDs to private resources.
|
||||
|
@ -794,6 +957,13 @@ class Database(object):
|
|||
py3_path(self.path), timeout=self.timeout
|
||||
)
|
||||
|
||||
if self.supports_extensions:
|
||||
conn.enable_load_extension(True)
|
||||
|
||||
# Load any extension that are already loaded for other connections.
|
||||
for path in self._extensions:
|
||||
conn.load_extension(path)
|
||||
|
||||
# Access SELECT results like dictionaries.
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
@ -822,6 +992,18 @@ class Database(object):
|
|||
"""
|
||||
return Transaction(self)
|
||||
|
||||
def load_extension(self, path):
|
||||
"""Load an SQLite extension into all open connections."""
|
||||
if not self.supports_extensions:
|
||||
raise ValueError(
|
||||
'this sqlite3 installation does not support extensions')
|
||||
|
||||
self._extensions.append(path)
|
||||
|
||||
# Load the extension into every open connection.
|
||||
for conn in self._connections.values():
|
||||
conn.load_extension(path)
|
||||
|
||||
# Schema setup and migration.
|
||||
|
||||
def _make_table(self, table, fields):
|
||||
|
@ -831,7 +1013,7 @@ class Database(object):
|
|||
# Get current schema.
|
||||
with self.transaction() as tx:
|
||||
rows = tx.query('PRAGMA table_info(%s)' % table)
|
||||
current_fields = set([row[1] for row in rows])
|
||||
current_fields = {row[1] for row in rows}
|
||||
|
||||
field_names = set(fields.keys())
|
||||
if current_fields.issuperset(field_names):
|
||||
|
@ -842,9 +1024,9 @@ class Database(object):
|
|||
# No table exists.
|
||||
columns = []
|
||||
for name, typ in fields.items():
|
||||
columns.append('{0} {1}'.format(name, typ.sql))
|
||||
setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
|
||||
', '.join(columns))
|
||||
columns.append(f'{name} {typ.sql}')
|
||||
setup_sql = 'CREATE TABLE {} ({});\n'.format(table,
|
||||
', '.join(columns))
|
||||
|
||||
else:
|
||||
# Table exists does not match the field set.
|
||||
|
@ -852,7 +1034,7 @@ class Database(object):
|
|||
for name, typ in fields.items():
|
||||
if name in current_fields:
|
||||
continue
|
||||
setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
|
||||
setup_sql += 'ALTER TABLE {} ADD COLUMN {} {};\n'.format(
|
||||
table, name, typ.sql
|
||||
)
|
||||
|
||||
|
@ -888,17 +1070,31 @@ class Database(object):
|
|||
where, subvals = query.clause()
|
||||
order_by = sort.order_clause()
|
||||
|
||||
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
|
||||
sql = ("SELECT * FROM {} WHERE {} {}").format(
|
||||
model_cls._table,
|
||||
where or '1',
|
||||
"ORDER BY {0}".format(order_by) if order_by else '',
|
||||
f"ORDER BY {order_by}" if order_by else '',
|
||||
)
|
||||
|
||||
# Fetch flexible attributes for items matching the main query.
|
||||
# Doing the per-item filtering in python is faster than issuing
|
||||
# one query per item to sqlite.
|
||||
flex_sql = ("""
|
||||
SELECT * FROM {} WHERE entity_id IN
|
||||
(SELECT id FROM {} WHERE {});
|
||||
""".format(
|
||||
model_cls._flex_table,
|
||||
model_cls._table,
|
||||
where or '1',
|
||||
)
|
||||
)
|
||||
|
||||
with self.transaction() as tx:
|
||||
rows = tx.query(sql, subvals)
|
||||
flex_rows = tx.query(flex_sql, subvals)
|
||||
|
||||
return Results(
|
||||
model_cls, rows, self,
|
||||
model_cls, rows, self, flex_rows,
|
||||
None if where else query, # Slow query component.
|
||||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This file is part of beets.
|
||||
# Copyright 2016, Adrian Sampson.
|
||||
#
|
||||
|
@ -15,7 +14,6 @@
|
|||
|
||||
"""The Query type hierarchy for DBCore.
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
import re
|
||||
from operator import mul
|
||||
|
@ -23,10 +21,6 @@ from beets import util
|
|||
from datetime import datetime, timedelta
|
||||
import unicodedata
|
||||
from functools import reduce
|
||||
import six
|
||||
|
||||
if not six.PY2:
|
||||
buffer = memoryview # sqlite won't accept memoryview in python 2
|
||||
|
||||
|
||||
class ParsingError(ValueError):
|
||||
|
@ -44,8 +38,8 @@ class InvalidQueryError(ParsingError):
|
|||
def __init__(self, query, explanation):
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
message = u"'{0}': {1}".format(query, explanation)
|
||||
super(InvalidQueryError, self).__init__(message)
|
||||
message = f"'{query}': {explanation}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvalidQueryArgumentValueError(ParsingError):
|
||||
|
@ -56,13 +50,13 @@ class InvalidQueryArgumentValueError(ParsingError):
|
|||
"""
|
||||
|
||||
def __init__(self, what, expected, detail=None):
|
||||
message = u"'{0}' is not {1}".format(what, expected)
|
||||
message = f"'{what}' is not {expected}"
|
||||
if detail:
|
||||
message = u"{0}: {1}".format(message, detail)
|
||||
super(InvalidQueryArgumentValueError, self).__init__(message)
|
||||
message = f"{message}: {detail}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Query(object):
|
||||
class Query:
|
||||
"""An abstract class representing a query into the item database.
|
||||
"""
|
||||
|
||||
|
@ -82,7 +76,7 @@ class Query(object):
|
|||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{0.__class__.__name__}()".format(self)
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(self) == type(other)
|
||||
|
@ -129,7 +123,7 @@ class FieldQuery(Query):
|
|||
"{0.fast})".format(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return super(FieldQuery, self).__eq__(other) and \
|
||||
return super().__eq__(other) and \
|
||||
self.field == other.field and self.pattern == other.pattern
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -151,17 +145,13 @@ class NoneQuery(FieldQuery):
|
|||
"""A query that checks whether a field is null."""
|
||||
|
||||
def __init__(self, field, fast=True):
|
||||
super(NoneQuery, self).__init__(field, None, fast)
|
||||
super().__init__(field, None, fast)
|
||||
|
||||
def col_clause(self):
|
||||
return self.field + " IS NULL", ()
|
||||
|
||||
@classmethod
|
||||
def match(cls, item):
|
||||
try:
|
||||
return item[cls.field] is None
|
||||
except KeyError:
|
||||
return True
|
||||
def match(self, item):
|
||||
return item.get(self.field) is None
|
||||
|
||||
def __repr__(self):
|
||||
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
|
||||
|
@ -214,14 +204,14 @@ class RegexpQuery(StringFieldQuery):
|
|||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(RegexpQuery, self).__init__(field, pattern, fast)
|
||||
super().__init__(field, pattern, fast)
|
||||
pattern = self._normalize(pattern)
|
||||
try:
|
||||
self.pattern = re.compile(self.pattern)
|
||||
except re.error as exc:
|
||||
# Invalid regular expression.
|
||||
raise InvalidQueryArgumentValueError(pattern,
|
||||
u"a regular expression",
|
||||
"a regular expression",
|
||||
format(exc))
|
||||
|
||||
@staticmethod
|
||||
|
@ -242,8 +232,8 @@ class BooleanQuery(MatchQuery):
|
|||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(BooleanQuery, self).__init__(field, pattern, fast)
|
||||
if isinstance(pattern, six.string_types):
|
||||
super().__init__(field, pattern, fast)
|
||||
if isinstance(pattern, str):
|
||||
self.pattern = util.str2bool(pattern)
|
||||
self.pattern = int(self.pattern)
|
||||
|
||||
|
@ -256,16 +246,16 @@ class BytesQuery(MatchQuery):
|
|||
"""
|
||||
|
||||
def __init__(self, field, pattern):
|
||||
super(BytesQuery, self).__init__(field, pattern)
|
||||
super().__init__(field, pattern)
|
||||
|
||||
# Use a buffer/memoryview representation of the pattern for SQLite
|
||||
# matching. This instructs SQLite to treat the blob as binary
|
||||
# rather than encoded Unicode.
|
||||
if isinstance(self.pattern, (six.text_type, bytes)):
|
||||
if isinstance(self.pattern, six.text_type):
|
||||
if isinstance(self.pattern, (str, bytes)):
|
||||
if isinstance(self.pattern, str):
|
||||
self.pattern = self.pattern.encode('utf-8')
|
||||
self.buf_pattern = buffer(self.pattern)
|
||||
elif isinstance(self.pattern, buffer):
|
||||
self.buf_pattern = memoryview(self.pattern)
|
||||
elif isinstance(self.pattern, memoryview):
|
||||
self.buf_pattern = self.pattern
|
||||
self.pattern = bytes(self.pattern)
|
||||
|
||||
|
@ -297,10 +287,10 @@ class NumericQuery(FieldQuery):
|
|||
try:
|
||||
return float(s)
|
||||
except ValueError:
|
||||
raise InvalidQueryArgumentValueError(s, u"an int or a float")
|
||||
raise InvalidQueryArgumentValueError(s, "an int or a float")
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(NumericQuery, self).__init__(field, pattern, fast)
|
||||
super().__init__(field, pattern, fast)
|
||||
|
||||
parts = pattern.split('..', 1)
|
||||
if len(parts) == 1:
|
||||
|
@ -318,7 +308,7 @@ class NumericQuery(FieldQuery):
|
|||
if self.field not in item:
|
||||
return False
|
||||
value = item[self.field]
|
||||
if isinstance(value, six.string_types):
|
||||
if isinstance(value, str):
|
||||
value = self._convert(value)
|
||||
|
||||
if self.point is not None:
|
||||
|
@ -335,14 +325,14 @@ class NumericQuery(FieldQuery):
|
|||
return self.field + '=?', (self.point,)
|
||||
else:
|
||||
if self.rangemin is not None and self.rangemax is not None:
|
||||
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
|
||||
return ('{0} >= ? AND {0} <= ?'.format(self.field),
|
||||
(self.rangemin, self.rangemax))
|
||||
elif self.rangemin is not None:
|
||||
return u'{0} >= ?'.format(self.field), (self.rangemin,)
|
||||
return f'{self.field} >= ?', (self.rangemin,)
|
||||
elif self.rangemax is not None:
|
||||
return u'{0} <= ?'.format(self.field), (self.rangemax,)
|
||||
return f'{self.field} <= ?', (self.rangemax,)
|
||||
else:
|
||||
return u'1', ()
|
||||
return '1', ()
|
||||
|
||||
|
||||
class CollectionQuery(Query):
|
||||
|
@ -387,7 +377,7 @@ class CollectionQuery(Query):
|
|||
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return super(CollectionQuery, self).__eq__(other) and \
|
||||
return super().__eq__(other) and \
|
||||
self.subqueries == other.subqueries
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -411,7 +401,7 @@ class AnyFieldQuery(CollectionQuery):
|
|||
subqueries = []
|
||||
for field in self.fields:
|
||||
subqueries.append(cls(field, pattern, True))
|
||||
super(AnyFieldQuery, self).__init__(subqueries)
|
||||
super().__init__(subqueries)
|
||||
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('or')
|
||||
|
@ -427,7 +417,7 @@ class AnyFieldQuery(CollectionQuery):
|
|||
"{0.query_class.__name__})".format(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return super(AnyFieldQuery, self).__eq__(other) and \
|
||||
return super().__eq__(other) and \
|
||||
self.query_class == other.query_class
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -453,7 +443,7 @@ class AndQuery(MutableCollectionQuery):
|
|||
return self.clause_with_joiner('and')
|
||||
|
||||
def match(self, item):
|
||||
return all([q.match(item) for q in self.subqueries])
|
||||
return all(q.match(item) for q in self.subqueries)
|
||||
|
||||
|
||||
class OrQuery(MutableCollectionQuery):
|
||||
|
@ -463,7 +453,7 @@ class OrQuery(MutableCollectionQuery):
|
|||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
return any([q.match(item) for q in self.subqueries])
|
||||
return any(q.match(item) for q in self.subqueries)
|
||||
|
||||
|
||||
class NotQuery(Query):
|
||||
|
@ -477,7 +467,7 @@ class NotQuery(Query):
|
|||
def clause(self):
|
||||
clause, subvals = self.subquery.clause()
|
||||
if clause:
|
||||
return 'not ({0})'.format(clause), subvals
|
||||
return f'not ({clause})', subvals
|
||||
else:
|
||||
# If there is no clause, there is nothing to negate. All the logic
|
||||
# is handled by match() for slow queries.
|
||||
|
@ -490,7 +480,7 @@ class NotQuery(Query):
|
|||
return "{0.__class__.__name__}({0.subquery!r})".format(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return super(NotQuery, self).__eq__(other) and \
|
||||
return super().__eq__(other) and \
|
||||
self.subquery == other.subquery
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -546,7 +536,7 @@ def _parse_periods(pattern):
|
|||
return (start, end)
|
||||
|
||||
|
||||
class Period(object):
|
||||
class Period:
|
||||
"""A period of time given by a date, time and precision.
|
||||
|
||||
Example: 2014-01-01 10:50:30 with precision 'month' represents all
|
||||
|
@ -572,7 +562,7 @@ class Period(object):
|
|||
or "second").
|
||||
"""
|
||||
if precision not in Period.precisions:
|
||||
raise ValueError(u'Invalid precision {0}'.format(precision))
|
||||
raise ValueError(f'Invalid precision {precision}')
|
||||
self.date = date
|
||||
self.precision = precision
|
||||
|
||||
|
@ -653,10 +643,10 @@ class Period(object):
|
|||
elif 'second' == precision:
|
||||
return date + timedelta(seconds=1)
|
||||
else:
|
||||
raise ValueError(u'unhandled precision {0}'.format(precision))
|
||||
raise ValueError(f'unhandled precision {precision}')
|
||||
|
||||
|
||||
class DateInterval(object):
|
||||
class DateInterval:
|
||||
"""A closed-open interval of dates.
|
||||
|
||||
A left endpoint of None means since the beginning of time.
|
||||
|
@ -665,7 +655,7 @@ class DateInterval(object):
|
|||
|
||||
def __init__(self, start, end):
|
||||
if start is not None and end is not None and not start < end:
|
||||
raise ValueError(u"start date {0} is not before end date {1}"
|
||||
raise ValueError("start date {} is not before end date {}"
|
||||
.format(start, end))
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
@ -686,7 +676,7 @@ class DateInterval(object):
|
|||
return True
|
||||
|
||||
def __str__(self):
|
||||
return '[{0}, {1})'.format(self.start, self.end)
|
||||
return f'[{self.start}, {self.end})'
|
||||
|
||||
|
||||
class DateQuery(FieldQuery):
|
||||
|
@ -700,7 +690,7 @@ class DateQuery(FieldQuery):
|
|||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(DateQuery, self).__init__(field, pattern, fast)
|
||||
super().__init__(field, pattern, fast)
|
||||
start, end = _parse_periods(pattern)
|
||||
self.interval = DateInterval.from_periods(start, end)
|
||||
|
||||
|
@ -759,12 +749,12 @@ class DurationQuery(NumericQuery):
|
|||
except ValueError:
|
||||
raise InvalidQueryArgumentValueError(
|
||||
s,
|
||||
u"a M:SS string or a float")
|
||||
"a M:SS string or a float")
|
||||
|
||||
|
||||
# Sorting.
|
||||
|
||||
class Sort(object):
|
||||
class Sort:
|
||||
"""An abstract class representing a sort operation for a query into
|
||||
the item database.
|
||||
"""
|
||||
|
@ -851,13 +841,13 @@ class MultipleSort(Sort):
|
|||
return items
|
||||
|
||||
def __repr__(self):
|
||||
return 'MultipleSort({!r})'.format(self.sorts)
|
||||
return f'MultipleSort({self.sorts!r})'
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.sorts))
|
||||
|
||||
def __eq__(self, other):
|
||||
return super(MultipleSort, self).__eq__(other) and \
|
||||
return super().__eq__(other) and \
|
||||
self.sorts == other.sorts
|
||||
|
||||
|
||||
|
@ -878,14 +868,14 @@ class FieldSort(Sort):
|
|||
|
||||
def key(item):
|
||||
field_val = item.get(self.field, '')
|
||||
if self.case_insensitive and isinstance(field_val, six.text_type):
|
||||
if self.case_insensitive and isinstance(field_val, str):
|
||||
field_val = field_val.lower()
|
||||
return field_val
|
||||
|
||||
return sorted(objs, key=key, reverse=not self.ascending)
|
||||
|
||||
def __repr__(self):
|
||||
return '<{0}: {1}{2}>'.format(
|
||||
return '<{}: {}{}>'.format(
|
||||
type(self).__name__,
|
||||
self.field,
|
||||
'+' if self.ascending else '-',
|
||||
|
@ -895,7 +885,7 @@ class FieldSort(Sort):
|
|||
return hash((self.field, self.ascending))
|
||||
|
||||
def __eq__(self, other):
|
||||
return super(FieldSort, self).__eq__(other) and \
|
||||
return super().__eq__(other) and \
|
||||
self.field == other.field and \
|
||||
self.ascending == other.ascending
|
||||
|
||||
|
@ -913,7 +903,7 @@ class FixedFieldSort(FieldSort):
|
|||
'ELSE {0} END)'.format(self.field)
|
||||
else:
|
||||
field = self.field
|
||||
return "{0} {1}".format(field, order)
|
||||
return f"{field} {order}"
|
||||
|
||||
|
||||
class SlowFieldSort(FieldSort):
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This file is part of beets.
|
||||
# Copyright 2016, Adrian Sampson.
|
||||
#
|
||||
|
@ -15,12 +14,10 @@
|
|||
|
||||
"""Parsing of strings into DBCore queries.
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
import re
|
||||
import itertools
|
||||
from . import query
|
||||
import beets
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
|
@ -89,7 +86,7 @@ def parse_query_part(part, query_classes={}, prefixes={},
|
|||
assert match # Regex should always match
|
||||
negate = bool(match.group(1))
|
||||
key = match.group(2)
|
||||
term = match.group(3).replace('\:', ':')
|
||||
term = match.group(3).replace('\\:', ':')
|
||||
|
||||
# Check whether there's a prefix in the query and use the
|
||||
# corresponding query type.
|
||||
|
@ -119,12 +116,13 @@ def construct_query_part(model_cls, prefixes, query_part):
|
|||
if not query_part:
|
||||
return query.TrueQuery()
|
||||
|
||||
# Use `model_cls` to build up a map from field names to `Query`
|
||||
# classes.
|
||||
# Use `model_cls` to build up a map from field (or query) names to
|
||||
# `Query` classes.
|
||||
query_classes = {}
|
||||
for k, t in itertools.chain(model_cls._fields.items(),
|
||||
model_cls._types.items()):
|
||||
query_classes[k] = t.query
|
||||
query_classes.update(model_cls._queries) # Non-field queries.
|
||||
|
||||
# Parse the string.
|
||||
key, pattern, query_class, negate = \
|
||||
|
@ -137,26 +135,27 @@ def construct_query_part(model_cls, prefixes, query_part):
|
|||
# The query type matches a specific field, but none was
|
||||
# specified. So we use a version of the query that matches
|
||||
# any field.
|
||||
q = query.AnyFieldQuery(pattern, model_cls._search_fields,
|
||||
query_class)
|
||||
if negate:
|
||||
return query.NotQuery(q)
|
||||
else:
|
||||
return q
|
||||
out_query = query.AnyFieldQuery(pattern, model_cls._search_fields,
|
||||
query_class)
|
||||
else:
|
||||
# Non-field query type.
|
||||
if negate:
|
||||
return query.NotQuery(query_class(pattern))
|
||||
else:
|
||||
return query_class(pattern)
|
||||
out_query = query_class(pattern)
|
||||
|
||||
# Otherwise, this must be a `FieldQuery`. Use the field name to
|
||||
# construct the query object.
|
||||
key = key.lower()
|
||||
q = query_class(key.lower(), pattern, key in model_cls._fields)
|
||||
# Field queries get constructed according to the name of the field
|
||||
# they are querying.
|
||||
elif issubclass(query_class, query.FieldQuery):
|
||||
key = key.lower()
|
||||
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
|
||||
|
||||
# Non-field (named) query.
|
||||
else:
|
||||
out_query = query_class(pattern)
|
||||
|
||||
# Apply negation.
|
||||
if negate:
|
||||
return query.NotQuery(q)
|
||||
return q
|
||||
return query.NotQuery(out_query)
|
||||
else:
|
||||
return out_query
|
||||
|
||||
|
||||
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
||||
|
@ -172,11 +171,13 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
|||
return query_cls(subqueries)
|
||||
|
||||
|
||||
def construct_sort_part(model_cls, part):
|
||||
def construct_sort_part(model_cls, part, case_insensitive=True):
|
||||
"""Create a `Sort` from a single string criterion.
|
||||
|
||||
`model_cls` is the `Model` being queried. `part` is a single string
|
||||
ending in ``+`` or ``-`` indicating the sort.
|
||||
ending in ``+`` or ``-`` indicating the sort. `case_insensitive`
|
||||
indicates whether or not the sort should be performed in a case
|
||||
sensitive manner.
|
||||
"""
|
||||
assert part, "part must be a field name and + or -"
|
||||
field = part[:-1]
|
||||
|
@ -185,7 +186,6 @@ def construct_sort_part(model_cls, part):
|
|||
assert direction in ('+', '-'), "part must end with + or -"
|
||||
is_ascending = direction == '+'
|
||||
|
||||
case_insensitive = beets.config['sort_case_insensitive'].get(bool)
|
||||
if field in model_cls._sorts:
|
||||
sort = model_cls._sorts[field](model_cls, is_ascending,
|
||||
case_insensitive)
|
||||
|
@ -197,21 +197,23 @@ def construct_sort_part(model_cls, part):
|
|||
return sort
|
||||
|
||||
|
||||
def sort_from_strings(model_cls, sort_parts):
|
||||
def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
|
||||
"""Create a `Sort` from a list of sort criteria (strings).
|
||||
"""
|
||||
if not sort_parts:
|
||||
sort = query.NullSort()
|
||||
elif len(sort_parts) == 1:
|
||||
sort = construct_sort_part(model_cls, sort_parts[0])
|
||||
sort = construct_sort_part(model_cls, sort_parts[0], case_insensitive)
|
||||
else:
|
||||
sort = query.MultipleSort()
|
||||
for part in sort_parts:
|
||||
sort.add_sort(construct_sort_part(model_cls, part))
|
||||
sort.add_sort(construct_sort_part(model_cls, part,
|
||||
case_insensitive))
|
||||
return sort
|
||||
|
||||
|
||||
def parse_sorted_query(model_cls, parts, prefixes={}):
|
||||
def parse_sorted_query(model_cls, parts, prefixes={},
|
||||
case_insensitive=True):
|
||||
"""Given a list of strings, create the `Query` and `Sort` that they
|
||||
represent.
|
||||
"""
|
||||
|
@ -222,8 +224,8 @@ def parse_sorted_query(model_cls, parts, prefixes={}):
|
|||
# Split up query in to comma-separated subqueries, each representing
|
||||
# an AndQuery, which need to be joined together in one OrQuery
|
||||
subquery_parts = []
|
||||
for part in parts + [u',']:
|
||||
if part.endswith(u','):
|
||||
for part in parts + [',']:
|
||||
if part.endswith(','):
|
||||
# Ensure we can catch "foo, bar" as well as "foo , bar"
|
||||
last_subquery_part = part[:-1]
|
||||
if last_subquery_part:
|
||||
|
@ -237,8 +239,8 @@ def parse_sorted_query(model_cls, parts, prefixes={}):
|
|||
else:
|
||||
# Sort parts (1) end in + or -, (2) don't have a field, and
|
||||
# (3) consist of more than just the + or -.
|
||||
if part.endswith((u'+', u'-')) \
|
||||
and u':' not in part \
|
||||
if part.endswith(('+', '-')) \
|
||||
and ':' not in part \
|
||||
and len(part) > 1:
|
||||
sort_parts.append(part)
|
||||
else:
|
||||
|
@ -246,5 +248,5 @@ def parse_sorted_query(model_cls, parts, prefixes={}):
|
|||
|
||||
# Avoid needlessly wrapping single statements in an OR
|
||||
q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0]
|
||||
s = sort_from_strings(model_cls, sort_parts)
|
||||
s = sort_from_strings(model_cls, sort_parts, case_insensitive)
|
||||
return q, s
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This file is part of beets.
|
||||
# Copyright 2016, Adrian Sampson.
|
||||
#
|
||||
|
@ -15,25 +14,20 @@
|
|||
|
||||
"""Representation of type information for DBCore model fields.
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
from . import query
|
||||
from beets.util import str2bool
|
||||
import six
|
||||
|
||||
if not six.PY2:
|
||||
buffer = memoryview # sqlite won't accept memoryview in python 2
|
||||
|
||||
|
||||
# Abstract base.
|
||||
|
||||
class Type(object):
|
||||
class Type:
|
||||
"""An object encapsulating the type of a model field. Includes
|
||||
information about how to store, query, format, and parse a given
|
||||
field.
|
||||
"""
|
||||
|
||||
sql = u'TEXT'
|
||||
sql = 'TEXT'
|
||||
"""The SQLite column type for the value.
|
||||
"""
|
||||
|
||||
|
@ -41,7 +35,7 @@ class Type(object):
|
|||
"""The `Query` subclass to be used when querying the field.
|
||||
"""
|
||||
|
||||
model_type = six.text_type
|
||||
model_type = str
|
||||
"""The Python type that is used to represent the value in the model.
|
||||
|
||||
The model is guaranteed to return a value of this type if the field
|
||||
|
@ -63,11 +57,11 @@ class Type(object):
|
|||
value = self.null
|
||||
# `self.null` might be `None`
|
||||
if value is None:
|
||||
value = u''
|
||||
value = ''
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8', 'ignore')
|
||||
|
||||
return six.text_type(value)
|
||||
return str(value)
|
||||
|
||||
def parse(self, string):
|
||||
"""Parse a (possibly human-written) string and return the
|
||||
|
@ -97,16 +91,16 @@ class Type(object):
|
|||
For fixed fields the type of `value` is determined by the column
|
||||
type affinity given in the `sql` property and the SQL to Python
|
||||
mapping of the database adapter. For more information see:
|
||||
http://www.sqlite.org/datatype3.html
|
||||
https://www.sqlite.org/datatype3.html
|
||||
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
|
||||
|
||||
Flexible fields have the type affinity `TEXT`. This means the
|
||||
`sql_value` is either a `buffer`/`memoryview` or a `unicode` object`
|
||||
`sql_value` is either a `memoryview` or a `unicode` object`
|
||||
and the method must handle these in addition.
|
||||
"""
|
||||
if isinstance(sql_value, buffer):
|
||||
if isinstance(sql_value, memoryview):
|
||||
sql_value = bytes(sql_value).decode('utf-8', 'ignore')
|
||||
if isinstance(sql_value, six.text_type):
|
||||
if isinstance(sql_value, str):
|
||||
return self.parse(sql_value)
|
||||
else:
|
||||
return self.normalize(sql_value)
|
||||
|
@ -127,10 +121,18 @@ class Default(Type):
|
|||
class Integer(Type):
|
||||
"""A basic integer type.
|
||||
"""
|
||||
sql = u'INTEGER'
|
||||
sql = 'INTEGER'
|
||||
query = query.NumericQuery
|
||||
model_type = int
|
||||
|
||||
def normalize(self, value):
|
||||
try:
|
||||
return self.model_type(round(float(value)))
|
||||
except ValueError:
|
||||
return self.null
|
||||
except TypeError:
|
||||
return self.null
|
||||
|
||||
|
||||
class PaddedInt(Integer):
|
||||
"""An integer field that is formatted with a given number of digits,
|
||||
|
@ -140,19 +142,25 @@ class PaddedInt(Integer):
|
|||
self.digits = digits
|
||||
|
||||
def format(self, value):
|
||||
return u'{0:0{1}d}'.format(value or 0, self.digits)
|
||||
return '{0:0{1}d}'.format(value or 0, self.digits)
|
||||
|
||||
|
||||
class NullPaddedInt(PaddedInt):
|
||||
"""Same as `PaddedInt`, but does not normalize `None` to `0.0`.
|
||||
"""
|
||||
null = None
|
||||
|
||||
|
||||
class ScaledInt(Integer):
|
||||
"""An integer whose formatting operation scales the number by a
|
||||
constant and adds a suffix. Good for units with large magnitudes.
|
||||
"""
|
||||
def __init__(self, unit, suffix=u''):
|
||||
def __init__(self, unit, suffix=''):
|
||||
self.unit = unit
|
||||
self.suffix = suffix
|
||||
|
||||
def format(self, value):
|
||||
return u'{0}{1}'.format((value or 0) // self.unit, self.suffix)
|
||||
return '{}{}'.format((value or 0) // self.unit, self.suffix)
|
||||
|
||||
|
||||
class Id(Integer):
|
||||
|
@ -163,18 +171,22 @@ class Id(Integer):
|
|||
|
||||
def __init__(self, primary=True):
|
||||
if primary:
|
||||
self.sql = u'INTEGER PRIMARY KEY'
|
||||
self.sql = 'INTEGER PRIMARY KEY'
|
||||
|
||||
|
||||
class Float(Type):
|
||||
"""A basic floating-point type.
|
||||
"""A basic floating-point type. The `digits` parameter specifies how
|
||||
many decimal places to use in the human-readable representation.
|
||||
"""
|
||||
sql = u'REAL'
|
||||
sql = 'REAL'
|
||||
query = query.NumericQuery
|
||||
model_type = float
|
||||
|
||||
def __init__(self, digits=1):
|
||||
self.digits = digits
|
||||
|
||||
def format(self, value):
|
||||
return u'{0:.1f}'.format(value or 0.0)
|
||||
return '{0:.{1}f}'.format(value or 0, self.digits)
|
||||
|
||||
|
||||
class NullFloat(Float):
|
||||
|
@ -186,19 +198,25 @@ class NullFloat(Float):
|
|||
class String(Type):
|
||||
"""A Unicode string type.
|
||||
"""
|
||||
sql = u'TEXT'
|
||||
sql = 'TEXT'
|
||||
query = query.SubstringQuery
|
||||
|
||||
def normalize(self, value):
|
||||
if value is None:
|
||||
return self.null
|
||||
else:
|
||||
return self.model_type(value)
|
||||
|
||||
|
||||
class Boolean(Type):
|
||||
"""A boolean type.
|
||||
"""
|
||||
sql = u'INTEGER'
|
||||
sql = 'INTEGER'
|
||||
query = query.BooleanQuery
|
||||
model_type = bool
|
||||
|
||||
def format(self, value):
|
||||
return six.text_type(bool(value))
|
||||
return str(bool(value))
|
||||
|
||||
def parse(self, string):
|
||||
return str2bool(string)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue