Move common libs to libs/common

This commit is contained in:
Labrys of Knossos 2018-12-16 13:30:24 -05:00
commit 1f4bd41bcc
1612 changed files with 962 additions and 10 deletions

View file

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
from __future__ import division, absolute_import, print_function
import os
from beets.util import confit
__version__ = u'1.4.7'
__author__ = u'Adrian Sampson <adrian@radbox.org>'
class IncludeLazyConfig(confit.LazyConfig):
"""A version of Confit's LazyConfig that also merges in data from
YAML files specified in an `include` setting.
"""
def read(self, user=True, defaults=True):
super(IncludeLazyConfig, self).read(user, defaults)
try:
for view in self['include']:
filename = view.as_filename()
if os.path.isfile(filename):
self.set_file(filename)
except confit.NotFoundError:
pass
config = IncludeLazyConfig('beets', __name__)

View file

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2017, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The __main__ module lets you run the beets CLI interface by typing
`python -m beets`.
"""
from __future__ import division, absolute_import, print_function
import sys
from .ui import main
if __name__ == "__main__":
main(sys.argv[1:])

222
libs/common/beets/art.py Normal file
View file

@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""High-level utilities for manipulating image files associated with
music and items' embedded album art.
"""
from __future__ import division, absolute_import, print_function
import subprocess
import platform
from tempfile import NamedTemporaryFile
import os
from beets.util import displayable_path, syspath, bytestring_path
from beets.util.artresizer import ArtResizer
from beets import mediafile
def mediafile_image(image_path, maxwidth=None):
"""Return a `mediafile.Image` object for the path.
"""
with open(syspath(image_path), 'rb') as f:
data = f.read()
return mediafile.Image(data, type=mediafile.ImageType.front)
def get_art(log, item):
# Extract the art.
try:
mf = mediafile.MediaFile(syspath(item.path))
except mediafile.UnreadableFileError as exc:
log.warning(u'Could not extract art from {0}: {1}',
displayable_path(item.path), exc)
return
return mf.art
def embed_item(log, item, imagepath, maxwidth=None, itempath=None,
compare_threshold=0, ifempty=False, as_album=False):
"""Embed an image into the item's media file.
"""
# Conditions and filters.
if compare_threshold:
if not check_art_similarity(log, item, imagepath, compare_threshold):
log.info(u'Image not similar; skipping.')
return
if ifempty and get_art(log, item):
log.info(u'media file already contained art')
return
if maxwidth and not as_album:
imagepath = resize_image(log, imagepath, maxwidth)
# Get the `Image` object from the file.
try:
log.debug(u'embedding {0}', displayable_path(imagepath))
image = mediafile_image(imagepath, maxwidth)
except IOError as exc:
log.warning(u'could not read image file: {0}', exc)
return
# Make sure the image kind is safe (some formats only support PNG
# and JPEG).
if image.mime_type not in ('image/jpeg', 'image/png'):
log.info('not embedding image of unsupported type: {}',
image.mime_type)
return
item.try_write(path=itempath, tags={'images': [image]})
def embed_album(log, album, maxwidth=None, quiet=False,
compare_threshold=0, ifempty=False):
"""Embed album art into all of the album's items.
"""
imagepath = album.artpath
if not imagepath:
log.info(u'No album art present for {0}', album)
return
if not os.path.isfile(syspath(imagepath)):
log.info(u'Album art not found at {0} for {1}',
displayable_path(imagepath), album)
return
if maxwidth:
imagepath = resize_image(log, imagepath, maxwidth)
log.info(u'Embedding album art into {0}', album)
for item in album.items():
embed_item(log, item, imagepath, maxwidth, None,
compare_threshold, ifempty, as_album=True)
def resize_image(log, imagepath, maxwidth):
"""Returns path to an image resized to maxwidth.
"""
log.debug(u'Resizing album art to {0} pixels wide', maxwidth)
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath))
return imagepath
def check_art_similarity(log, item, imagepath, compare_threshold):
"""A boolean indicating if an image is similar to embedded item art.
"""
with NamedTemporaryFile(delete=True) as f:
art = extract(log, f.name, item)
if art:
is_windows = platform.system() == "Windows"
# Converting images to grayscale tends to minimize the weight
# of colors in the diff score. So we first convert both images
# to grayscale and then pipe them into the `compare` command.
# On Windows, ImageMagick doesn't support the magic \\?\ prefix
# on paths, so we pass `prefix=False` to `syspath`.
convert_cmd = ['convert', syspath(imagepath, prefix=False),
syspath(art, prefix=False),
'-colorspace', 'gray', 'MIFF:-']
compare_cmd = ['compare', '-metric', 'PHASH', '-', 'null:']
log.debug(u'comparing images with pipeline {} | {}',
convert_cmd, compare_cmd)
convert_proc = subprocess.Popen(
convert_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=not is_windows,
)
compare_proc = subprocess.Popen(
compare_cmd,
stdin=convert_proc.stdout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=not is_windows,
)
# Check the convert output. We're not interested in the
# standard output; that gets piped to the next stage.
convert_proc.stdout.close()
convert_stderr = convert_proc.stderr.read()
convert_proc.stderr.close()
convert_proc.wait()
if convert_proc.returncode:
log.debug(
u'ImageMagick convert failed with status {}: {!r}',
convert_proc.returncode,
convert_stderr,
)
return
# Check the compare output.
stdout, stderr = compare_proc.communicate()
if compare_proc.returncode:
if compare_proc.returncode != 1:
log.debug(u'ImageMagick compare failed: {0}, {1}',
displayable_path(imagepath),
displayable_path(art))
return
out_str = stderr
else:
out_str = stdout
try:
phash_diff = float(out_str)
except ValueError:
log.debug(u'IM output is not a number: {0!r}', out_str)
return
log.debug(u'ImageMagick compare score: {0}', phash_diff)
return phash_diff <= compare_threshold
return True
def extract(log, outpath, item):
art = get_art(log, item)
outpath = bytestring_path(outpath)
if not art:
log.info(u'No album art present in {0}, skipping.', item)
return
# Add an extension to the filename.
ext = mediafile.image_extension(art)
if not ext:
log.warning(u'Unknown image type in {0}.',
displayable_path(item.path))
return
outpath += bytestring_path('.' + ext)
log.info(u'Extracting album art from: {0} to: {1}',
item, displayable_path(outpath))
with open(syspath(outpath), 'wb') as f:
f.write(art)
return outpath
def extract_first(log, outpath, items):
for item in items:
real_path = extract(log, outpath, item)
if real_path:
return real_path
def clear(log, lib, query):
items = lib.items(query)
log.info(u'Clearing album art from {0} items', len(items))
for item in items:
log.debug(u'Clearing art for {0}', item)
item.try_write(tags={'images': None})

View file

@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Facilities for automatically determining files' correct metadata.
"""
from __future__ import division, absolute_import, print_function
from beets import logging
from beets import config
# Parts of external interface.
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa
from .match import tag_item, tag_album, Proposal # noqa
from .match import Recommendation # noqa
# Global logger.
log = logging.getLogger('beets')
# Additional utilities for the main interface.
def apply_item_metadata(item, track_info):
"""Set an item's metadata from its matched TrackInfo object.
"""
item.artist = track_info.artist
item.artist_sort = track_info.artist_sort
item.artist_credit = track_info.artist_credit
item.title = track_info.title
item.mb_trackid = track_info.track_id
item.mb_releasetrackid = track_info.release_track_id
if track_info.artist_id:
item.mb_artistid = track_info.artist_id
if track_info.data_source:
item.data_source = track_info.data_source
if track_info.lyricist is not None:
item.lyricist = track_info.lyricist
if track_info.composer is not None:
item.composer = track_info.composer
if track_info.composer_sort is not None:
item.composer_sort = track_info.composer_sort
if track_info.arranger is not None:
item.arranger = track_info.arranger
# At the moment, the other metadata is left intact (including album
# and track number). Perhaps these should be emptied?
def apply_metadata(album_info, mapping):
"""Set the items' metadata to match an AlbumInfo object using a
mapping from Items to TrackInfo objects.
"""
for item, track_info in mapping.items():
# Artist or artist credit.
if config['artist_credit']:
item.artist = (track_info.artist_credit or
track_info.artist or
album_info.artist_credit or
album_info.artist)
item.albumartist = (album_info.artist_credit or
album_info.artist)
else:
item.artist = (track_info.artist or album_info.artist)
item.albumartist = album_info.artist
# Album.
item.album = album_info.album
# Artist sort and credit names.
item.artist_sort = track_info.artist_sort or album_info.artist_sort
item.artist_credit = (track_info.artist_credit or
album_info.artist_credit)
item.albumartist_sort = album_info.artist_sort
item.albumartist_credit = album_info.artist_credit
# Release date.
for prefix in '', 'original_':
if config['original_date'] and not prefix:
# Ignore specific release date.
continue
for suffix in 'year', 'month', 'day':
key = prefix + suffix
value = getattr(album_info, key) or 0
# If we don't even have a year, apply nothing.
if suffix == 'year' and not value:
break
# Otherwise, set the fetched value (or 0 for the month
# and day if not available).
item[key] = value
# If we're using original release date for both fields,
# also set item.year = info.original_year, etc.
if config['original_date']:
item[suffix] = value
# Title.
item.title = track_info.title
if config['per_disc_numbering']:
# We want to let the track number be zero, but if the medium index
# is not provided we need to fall back to the overall index.
if track_info.medium_index is not None:
item.track = track_info.medium_index
else:
item.track = track_info.index
item.tracktotal = track_info.medium_total or len(album_info.tracks)
else:
item.track = track_info.index
item.tracktotal = len(album_info.tracks)
# Disc and disc count.
item.disc = track_info.medium
item.disctotal = album_info.mediums
# MusicBrainz IDs.
item.mb_trackid = track_info.track_id
item.mb_releasetrackid = track_info.release_track_id
item.mb_albumid = album_info.album_id
if track_info.artist_id:
item.mb_artistid = track_info.artist_id
else:
item.mb_artistid = album_info.artist_id
item.mb_albumartistid = album_info.artist_id
item.mb_releasegroupid = album_info.releasegroup_id
# Compilation flag.
item.comp = album_info.va
# Miscellaneous metadata.
for field in ('albumtype',
'label',
'asin',
'catalognum',
'script',
'language',
'country',
'albumstatus',
'albumdisambig',
'data_source',):
value = getattr(album_info, field)
if value is not None:
item[field] = value
if track_info.disctitle is not None:
item.disctitle = track_info.disctitle
if track_info.media is not None:
item.media = track_info.media
if track_info.lyricist is not None:
item.lyricist = track_info.lyricist
if track_info.composer is not None:
item.composer = track_info.composer
if track_info.composer_sort is not None:
item.composer_sort = track_info.composer_sort
if track_info.arranger is not None:
item.arranger = track_info.arranger
item.track_alt = track_info.track_alt

View file

@ -0,0 +1,631 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Glue between metadata sources and the matching logic."""
from __future__ import division, absolute_import, print_function
from collections import namedtuple
from functools import total_ordering
import re
from beets import logging
from beets import plugins
from beets import config
from beets.util import as_string
from beets.autotag import mb
from jellyfish import levenshtein_distance
from unidecode import unidecode
import six
log = logging.getLogger('beets')
# Classes used to represent candidate options.
class AlbumInfo(object):
"""Describes a canonical release that may be used to match a release
in the library. Consists of these data members:
- ``album``: the release title
- ``album_id``: MusicBrainz ID; UUID fragment only
- ``artist``: name of the release's primary artist
- ``artist_id``
- ``tracks``: list of TrackInfo objects making up the release
- ``asin``: Amazon ASIN
- ``albumtype``: string describing the kind of release
- ``va``: boolean: whether the release has "various artists"
- ``year``: release year
- ``month``: release month
- ``day``: release day
- ``label``: music label responsible for the release
- ``mediums``: the number of discs in this release
- ``artist_sort``: name of the release's artist for sorting
- ``releasegroup_id``: MBID for the album's release group
- ``catalognum``: the label's catalog number for the release
- ``script``: character set used for metadata
- ``language``: human language of the metadata
- ``country``: the release country
- ``albumstatus``: MusicBrainz release status (Official, etc.)
- ``media``: delivery mechanism (Vinyl, etc.)
- ``albumdisambig``: MusicBrainz release disambiguation comment
- ``artist_credit``: Release-specific artist name
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
- ``data_url``: The data source release URL.
The fields up through ``tracks`` are required. The others are
optional and may be None.
"""
def __init__(self, album, album_id, artist, artist_id, tracks, asin=None,
albumtype=None, va=False, year=None, month=None, day=None,
label=None, mediums=None, artist_sort=None,
releasegroup_id=None, catalognum=None, script=None,
language=None, country=None, albumstatus=None, media=None,
albumdisambig=None, artist_credit=None, original_year=None,
original_month=None, original_day=None, data_source=None,
data_url=None):
self.album = album
self.album_id = album_id
self.artist = artist
self.artist_id = artist_id
self.tracks = tracks
self.asin = asin
self.albumtype = albumtype
self.va = va
self.year = year
self.month = month
self.day = day
self.label = label
self.mediums = mediums
self.artist_sort = artist_sort
self.releasegroup_id = releasegroup_id
self.catalognum = catalognum
self.script = script
self.language = language
self.country = country
self.albumstatus = albumstatus
self.media = media
self.albumdisambig = albumdisambig
self.artist_credit = artist_credit
self.original_year = original_year
self.original_month = original_month
self.original_day = original_day
self.data_source = data_source
self.data_url = data_url
# Work around a bug in python-musicbrainz-ngs that causes some
# strings to be bytes rather than Unicode.
# https://github.com/alastair/python-musicbrainz-ngs/issues/85
def decode(self, codec='utf-8'):
"""Ensure that all string attributes on this object, and the
constituent `TrackInfo` objects, are decoded to Unicode.
"""
for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
'catalognum', 'script', 'language', 'country',
'albumstatus', 'albumdisambig', 'artist_credit', 'media']:
value = getattr(self, fld)
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, 'ignore'))
if self.tracks:
for track in self.tracks:
track.decode(codec)
class TrackInfo(object):
"""Describes a canonical track present on a release. Appears as part
of an AlbumInfo's ``tracks`` list. Consists of these data members:
- ``title``: name of the track
- ``track_id``: MusicBrainz ID; UUID fragment only
- ``release_track_id``: MusicBrainz ID respective to a track on a
particular release; UUID fragment only
- ``artist``: individual track artist name
- ``artist_id``
- ``length``: float: duration of the track in seconds
- ``index``: position on the entire release
- ``media``: delivery mechanism (Vinyl, etc.)
- ``medium``: the disc number this track appears on in the album
- ``medium_index``: the track's position on the disc
- ``medium_total``: the number of tracks on the item's disc
- ``artist_sort``: name of the track artist for sorting
- ``disctitle``: name of the individual medium (subtitle)
- ``artist_credit``: Recording-specific artist name
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
- ``data_url``: The data source release URL.
- ``lyricist``: individual track lyricist name
- ``composer``: individual track composer name
- ``composer_sort``: individual track composer sort name
- ``arranger`: individual track arranger name
- ``track_alt``: alternative track number (tape, vinyl, etc.)
Only ``title`` and ``track_id`` are required. The rest of the fields
may be None. The indices ``index``, ``medium``, and ``medium_index``
are all 1-based.
"""
def __init__(self, title, track_id, release_track_id=None, artist=None,
artist_id=None, length=None, index=None, medium=None,
medium_index=None, medium_total=None, artist_sort=None,
disctitle=None, artist_credit=None, data_source=None,
data_url=None, media=None, lyricist=None, composer=None,
composer_sort=None, arranger=None, track_alt=None):
self.title = title
self.track_id = track_id
self.release_track_id = release_track_id
self.artist = artist
self.artist_id = artist_id
self.length = length
self.index = index
self.media = media
self.medium = medium
self.medium_index = medium_index
self.medium_total = medium_total
self.artist_sort = artist_sort
self.disctitle = disctitle
self.artist_credit = artist_credit
self.data_source = data_source
self.data_url = data_url
self.lyricist = lyricist
self.composer = composer
self.composer_sort = composer_sort
self.arranger = arranger
self.track_alt = track_alt
# As above, work around a bug in python-musicbrainz-ngs.
def decode(self, codec='utf-8'):
"""Ensure that all string attributes on this object are decoded
to Unicode.
"""
for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
'artist_credit', 'media']:
value = getattr(self, fld)
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, 'ignore'))
# Candidate distance scoring.
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ['the', 'a', 'an']
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r'^the ', 0.1),
(r'[\[\(]?(ep|single)[\]\)]?', 0.0),
(r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
(r'\(.*?\)', 0.3),
(r'\[.*?\]', 0.3),
(r'(, )?(pt\.|part) .+', 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r'&', 'and'),
]
def _string_dist_basic(str1, str2):
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, six.text_type)
assert isinstance(str2, six.text_type)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(', %s' % word):
str1 = '%s %s' % (word, str1[:-len(word) - 2])
if str2.endswith(', %s' % word):
str2 = '%s %s' % (word, str2[:-len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, '', str1)
case_str2 = re.sub(pat, '', str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
class LazyClassProperty(object):
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.computed = False
def __get__(self, obj, owner):
if not self.computed:
self.value = self.getter(owner)
self.computed = True
return self.value
@total_ordering
@six.python_2_unicode_compatible
class Distance(object):
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self):
self._penalties = {}
@LazyClassProperty
def _weights(cls): # noqa
"""A dictionary from keys to floating-point weights.
"""
weights_view = config['match']['distance_weights']
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self):
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self):
"""Return the maximum distance penalty (normalization factor).
"""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self):
"""Return the raw (denormalized) distance.
"""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self):
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_,
key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self):
return id(self)
def __eq__(self, other):
return self.distance == other
# Behave like a float.
def __lt__(self, other):
return self.distance < other
def __float__(self):
return self.distance
def __sub__(self, other):
return self.distance - other
def __rsub__(self, other):
return other - self.distance
def __str__(self):
return "{0:.2f}".format(self.distance)
# Behave like a dict.
def __getitem__(self, key):
"""Returns the weighted distance for a named penalty.
"""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self):
return iter(self.items())
def __len__(self):
return len(self.items())
def keys(self):
return [key for key, _ in self.items()]
def update(self, dist):
"""Adds all the distance penalties from `dist`.
"""
if not isinstance(dist, Distance):
raise ValueError(
u'`dist` must be a Distance object, not {0}'.format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1, value2):
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re._pattern_type):
return bool(value1.match(value2))
return value1 == value2
def add(self, key, dist):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(
u'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key, expr):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key, number1, number2):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(self, key, value, options):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(self, key, number1, number2):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key, str1, str2):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
# Structures that compose all the information for a candidate match.
AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
'extra_items', 'extra_tracks'])
TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
# Aggregation of sources.
def album_for_mbid(release_id):
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
if the ID is not found.
"""
try:
album = mb.album_for_id(release_id)
if album:
plugins.send(u'albuminfo_received', info=album)
return album
except mb.MusicBrainzAPIError as exc:
exc.log(log)
def track_for_mbid(recording_id):
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
if the ID is not found.
"""
try:
track = mb.track_for_id(recording_id)
if track:
plugins.send(u'trackinfo_received', info=track)
return track
except mb.MusicBrainzAPIError as exc:
exc.log(log)
def albums_for_id(album_id):
"""Get a list of albums for an ID."""
a = album_for_mbid(album_id)
if a:
yield a
for a in plugins.album_for_id(album_id):
if a:
plugins.send(u'albuminfo_received', info=a)
yield a
def tracks_for_id(track_id):
"""Get a list of tracks for an ID."""
t = track_for_mbid(track_id)
if t:
yield t
for t in plugins.track_for_id(track_id):
if t:
plugins.send(u'trackinfo_received', info=t)
yield t
@plugins.notify_info_yielded(u'albuminfo_received')
def album_candidates(items, artist, album, va_likely):
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
names (strings), which may be derived from the item list or may be
entered by the user. ``va_likely`` is a boolean indicating whether
the album is likely to be a "various artists" release.
"""
# Base candidates if we have album and artist to match.
if artist and album:
try:
for candidate in mb.match_album(artist, album, len(items)):
yield candidate
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Also add VA matches from MusicBrainz where appropriate.
if va_likely and album:
try:
for candidate in mb.match_album(None, album, len(items)):
yield candidate
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Candidates from plugins.
for candidate in plugins.candidates(items, artist, album, va_likely):
yield candidate
@plugins.notify_info_yielded(u'trackinfo_received')
def item_candidates(item, artist, title):
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or
are specified by the user.
"""
# MusicBrainz candidates.
if artist and title:
try:
for candidate in mb.match_track(artist, title):
yield candidate
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Plugin candidates.
for candidate in plugins.item_candidates(item, artist, title):
yield candidate

View file

@ -0,0 +1,521 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Matches existing metadata with canonical information to identify
releases and tracks.
"""
from __future__ import division, absolute_import, print_function
import datetime
import re
from munkres import Munkres
from collections import namedtuple
from beets import logging
from beets import plugins
from beets import config
from beets.util import plurality
from beets.autotag import hooks
from beets.util.enumeration import OrderedEnum
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown')
# Global logger.
log = logging.getLogger('beets')
# Recommendation enumeration.
class Recommendation(OrderedEnum):
"""Indicates a qualitative suggestion to the user about what should
be done with a given match.
"""
none = 0
low = 1
medium = 2
strong = 3
# A structure for holding a set of possible matches to choose between. This
# consists of a list of possible candidates (i.e., AlbumInfo or TrackInfo
# objects) and a recommendation value.
Proposal = namedtuple('Proposal', ('candidates', 'recommendation'))
# Primary matching functionality.
def current_metadata(items):
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
- Whether each field's value was unanimous (values are booleans).
"""
assert items # Must be nonempty.
likelies = {}
consensus = {}
fields = ['artist', 'album', 'albumartist', 'year', 'disctotal',
'mb_albumid', 'label', 'catalognum', 'country', 'media',
'albumdisambig']
for field in fields:
values = [item[field] for item in items if item]
likelies[field], freq = plurality(values)
consensus[field] = (freq == len(values))
# If there's an album artist consensus, use this for the artist.
if consensus['albumartist'] and likelies['albumartist']:
likelies['artist'] = likelies['albumartist']
return likelies, consensus
def assign_items(items, tracks):
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
objects, a set of extra Items, and a set of extra TrackInfo
objects. These "extra" objects occur when there is an unequal number
of objects of the two types.
"""
# Construct the cost matrix.
costs = []
for item in items:
row = []
for i, track in enumerate(tracks):
row.append(track_distance(item, track))
costs.append(row)
# Find a minimum-cost bipartite matching.
log.debug('Computing track assignment...')
matching = Munkres().compute(costs)
log.debug('...done.')
# Produce the output matching.
mapping = dict((items[i], tracks[j]) for (i, j) in matching)
extra_items = list(set(items) - set(mapping.keys()))
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
extra_tracks.sort(key=lambda t: (t.index, t.title))
return mapping, extra_items, extra_tracks
def track_index_changed(item, track_info):
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(item, track_info, incl_artist=False):
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
"""
dist = hooks.Distance()
# Length.
if track_info.length:
diff = abs(item.length - track_info.length) - \
config['match']['track_length_grace'].as_number()
dist.add_ratio('track_length', diff,
config['match']['track_length_max'].as_number())
# Title.
dist.add_string('track_title', item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if incl_artist and track_info.artist and \
item.artist.lower() not in VA_ARTISTS:
dist.add_string('track_artist', item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr('track_index', track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr('track_id', item.mb_trackid != track_info.track_id)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(items, album_info, mapping):
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = current_metadata(items)
dist = hooks.Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string('artist', likelies['artist'], album_info.artist)
# Album.
dist.add_string('album', likelies['album'], album_info.album)
# Current or preferred media.
if album_info.media:
# Preferred media options.
patterns = config['match']['preferred']['media'].as_str_seq()
options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns]
if options:
dist.add_priority('media', album_info.media, options)
# Current media.
elif likelies['media']:
dist.add_equality('media', album_info.media, likelies['media'])
# Mediums.
if likelies['disctotal'] and album_info.mediums:
dist.add_number('mediums', likelies['disctotal'], album_info.mediums)
# Prefer earliest release.
if album_info.year and config['match']['preferred']['original_year']:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio('year', diff, diff_max)
# Year.
elif likelies['year'] and album_info.year:
if likelies['year'] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add('year', 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies['year'] - album_info.year)
diff_max = abs(datetime.date.today().year -
album_info.original_year)
dist.add_ratio('year', diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add('year', 1.0)
# Preferred countries.
patterns = config['match']['preferred']['countries'].as_str_seq()
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options:
dist.add_priority('country', album_info.country, options)
# Country.
elif likelies['country'] and album_info.country:
dist.add_string('country', likelies['country'], album_info.country)
# Label.
if likelies['label'] and album_info.label:
dist.add_string('label', likelies['label'], album_info.label)
# Catalog number.
if likelies['catalognum'] and album_info.catalognum:
dist.add_string('catalognum', likelies['catalognum'],
album_info.catalognum)
# Disambiguation.
if likelies['albumdisambig'] and album_info.albumdisambig:
dist.add_string('albumdisambig', likelies['albumdisambig'],
album_info.albumdisambig)
# Album ID.
if likelies['mb_albumid']:
dist.add_equality('album_id', likelies['mb_albumid'],
album_info.album_id)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add('tracks', dist.tracks[track].distance)
# Missing tracks.
for i in range(len(album_info.tracks) - len(mapping)):
dist.add('missing_tracks', 1.0)
# Unmatched tracks.
for i in range(len(items) - len(mapping)):
dist.add('unmatched_tracks', 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist
def match_by_id(items):
"""If the items are tagged with a MusicBrainz album ID, returns an
AlbumInfo object for the corresponding album. Otherwise, returns
None.
"""
albumids = (item.mb_albumid for item in items if item.mb_albumid)
# Did any of the items have an MB album ID?
try:
first = next(albumids)
except StopIteration:
log.debug(u'No album ID found.')
return None
# Is there a consensus on the MB album ID?
for other in albumids:
if other != first:
log.debug(u'No album ID consensus.')
return None
# If all album IDs are equal, look up the album.
log.debug(u'Searching for discovered album ID: {0}', first)
return hooks.album_for_mbid(first)
def _recommendation(results):
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
recommendation based on the results' distances.
If the recommendation is higher than the configured maximum for
an applied penalty, the recommendation will be downgraded to the
configured maximum for that penalty.
"""
if not results:
# No candidates: no recommendation.
return Recommendation.none
# Basic distance thresholding.
min_dist = results[0].distance
if min_dist < config['match']['strong_rec_thresh'].as_number():
# Strong recommendation level.
rec = Recommendation.strong
elif min_dist <= config['match']['medium_rec_thresh'].as_number():
# Medium recommendation level.
rec = Recommendation.medium
elif len(results) == 1:
# Only a single candidate.
rec = Recommendation.low
elif results[1].distance - min_dist >= \
config['match']['rec_gap_thresh'].as_number():
# Gap between first two candidates is large.
rec = Recommendation.low
else:
# No conclusion. Return immediately. Can't be downgraded any further.
return Recommendation.none
# Downgrade to the max rec if it is lower than the current rec for an
# applied penalty.
keys = set(min_dist.keys())
if isinstance(results[0], hooks.AlbumMatch):
for track_dist in min_dist.tracks.values():
keys.update(list(track_dist.keys()))
max_rec_view = config['match']['max_rec']
for key in keys:
if key in list(max_rec_view.keys()):
max_rec = max_rec_view[key].as_choice({
'strong': Recommendation.strong,
'medium': Recommendation.medium,
'low': Recommendation.low,
'none': Recommendation.none,
})
rec = min(rec, max_rec)
return rec
def _sort_candidates(candidates):
"""Sort candidates by distance."""
return sorted(candidates, key=lambda match: match.distance)
def _add_candidate(items, results, info):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of AlbumMatch objects. This involves
checking the track count, ordering the items, checking for
duplicates, and calculating the distance.
"""
log.debug(u'Candidate: {0} - {1} ({2})',
info.artist, info.album, info.album_id)
# Discard albums with zero tracks.
if not info.tracks:
log.debug(u'No tracks.')
return
# Don't duplicate.
if info.album_id in results:
log.debug(u'Duplicate.')
return
# Discard matches without required tags.
for req_tag in config['match']['required'].as_str_seq():
if getattr(info, req_tag) is None:
log.debug(u'Ignored. Missing required tag: {0}', req_tag)
return
# Find mapping between the items and the track info.
mapping, extra_items, extra_tracks = assign_items(items, info.tracks)
# Get the change distance.
dist = distance(items, info, mapping)
# Skip matches with ignored penalties.
penalties = [key for key, _ in dist]
for penalty in config['match']['ignored'].as_str_seq():
if penalty in penalties:
log.debug(u'Ignored. Penalty: {0}', penalty)
return
log.debug(u'Success. Distance: {0}', dist)
results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
extra_items, extra_tracks)
def tag_album(items, search_artist=None, search_album=None,
search_ids=[]):
"""Return a tuple of the current artist name, the current album
name, and a `Proposal` containing `AlbumMatch` candidates.
The artist and album are the most common values of these fields
among `items`.
The `AlbumMatch` objects are generated by searching the metadata
backends. By default, the metadata of the items is used for the
search. This can be customized by setting the parameters.
`search_ids` is a list of metadata backend IDs: if specified,
it will restrict the candidates to those IDs, ignoring
`search_artist` and `search album`. The `mapping` field of the
album has the matched `items` as keys.
The recommendation is calculated from the match quality of the
candidates.
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
cur_artist = likelies['artist']
cur_album = likelies['album']
log.debug(u'Tagging {0} - {1}', cur_artist, cur_album)
# The output result (distance, AlbumInfo) tuples (keyed by MB album
# ID).
candidates = {}
# Search by explicit ID.
if search_ids:
for search_id in search_ids:
log.debug(u'Searching for album ID: {0}', search_id)
for id_candidate in hooks.albums_for_id(search_id):
_add_candidate(items, candidates, id_candidate)
# Use existing metadata or text search.
else:
# Try search based on current ID.
id_info = match_by_id(items)
if id_info:
_add_candidate(items, candidates, id_info)
rec = _recommendation(list(candidates.values()))
log.debug(u'Album ID match recommendation is {0}', rec)
if candidates and not config['import']['timid']:
# If we have a very good MBID match, return immediately.
# Otherwise, this match will compete against metadata-based
# matches.
if rec == Recommendation.strong:
log.debug(u'ID match.')
return cur_artist, cur_album, \
Proposal(list(candidates.values()), rec)
# Search terms.
if not (search_artist and search_album):
# No explicit search terms -- use current metadata.
search_artist, search_album = cur_artist, cur_album
log.debug(u'Search terms: {0} - {1}', search_artist, search_album)
# Is this album likely to be a "various artist" release?
va_likely = ((not consensus['artist']) or
(search_artist.lower() in VA_ARTISTS) or
any(item.comp for item in items))
log.debug(u'Album might be VA: {0}', va_likely)
# Get the results from the data sources.
for matched_candidate in hooks.album_candidates(items,
search_artist,
search_album,
va_likely):
_add_candidate(items, candidates, matched_candidate)
log.debug(u'Evaluating {0} candidates.', len(candidates))
# Sort and get the recommendation.
candidates = _sort_candidates(candidates.values())
rec = _recommendation(candidates)
return cur_artist, cur_album, Proposal(candidates, rec)
def tag_item(item, search_artist=None, search_title=None,
search_ids=[]):
"""Find metadata for a single track. Return a `Proposal` consisting
of `TrackMatch` objects.
`search_artist` and `search_title` may be used
to override the current metadata for the purposes of the MusicBrainz
title. `search_ids` may be used for restricting the search to a list
of metadata backend IDs.
"""
# Holds candidates found so far: keys are MBIDs; values are
# (distance, TrackInfo) pairs.
candidates = {}
# First, try matching by MusicBrainz ID.
trackids = search_ids or [t for t in [item.mb_trackid] if t]
if trackids:
for trackid in trackids:
log.debug(u'Searching for track ID: {0}', trackid)
for track_info in hooks.tracks_for_id(trackid):
dist = track_distance(item, track_info, incl_artist=True)
candidates[track_info.track_id] = \
hooks.TrackMatch(dist, track_info)
# If this is a good match, then don't keep searching.
rec = _recommendation(_sort_candidates(candidates.values()))
if rec == Recommendation.strong and \
not config['import']['timid']:
log.debug(u'Track ID match.')
return Proposal(_sort_candidates(candidates.values()), rec)
# If we're searching by ID, don't proceed.
if search_ids:
if candidates:
return Proposal(_sort_candidates(candidates.values()), rec)
else:
return Proposal([], Recommendation.none)
# Search terms.
if not (search_artist and search_title):
search_artist, search_title = item.artist, item.title
log.debug(u'Item search terms: {0} - {1}', search_artist, search_title)
# Get and evaluate candidate metadata.
for track_info in hooks.item_candidates(item, search_artist, search_title):
dist = track_distance(item, track_info, incl_artist=True)
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
# Sort by distance and return with recommendation.
log.debug(u'Found {0} candidates.', len(candidates))
candidates = _sort_candidates(candidates.values())
rec = _recommendation(candidates)
return Proposal(candidates, rec)

View file

@ -0,0 +1,516 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Searches for albums in the MusicBrainz database.
"""
from __future__ import division, absolute_import, print_function
import musicbrainzngs
import re
import traceback
from six.moves.urllib.parse import urljoin
from beets import logging
import beets.autotag.hooks
import beets
from beets import util
from beets import config
import six
VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377'
if util.SNI_SUPPORTED:
BASE_URL = 'https://musicbrainz.org/'
else:
BASE_URL = 'http://musicbrainz.org/'
SKIPPED_TRACKS = ['[data track]']
musicbrainzngs.set_useragent('beets', beets.__version__,
'http://beets.io/')
class MusicBrainzAPIError(util.HumanReadableException):
"""An error while talking to MusicBrainz. The `query` field is the
parameter to the action and may have any type.
"""
def __init__(self, reason, verb, query, tb=None):
self.query = query
if isinstance(reason, musicbrainzngs.WebServiceError):
reason = u'MusicBrainz not reachable'
super(MusicBrainzAPIError, self).__init__(reason, verb, tb)
def get_message(self):
return u'{0} in {1} with query {2}'.format(
self._reasonstr(), self.verb, repr(self.query)
)
log = logging.getLogger('beets')
RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
'labels', 'artist-credits', 'aliases',
'recording-level-rels', 'work-rels',
'work-level-rels', 'artist-rels']
TRACK_INCLUDES = ['artists', 'aliases']
if 'work-level-rels' in musicbrainzngs.VALID_INCLUDES['recording']:
TRACK_INCLUDES += ['work-level-rels', 'artist-rels']
def track_url(trackid):
return urljoin(BASE_URL, 'recording/' + trackid)
def album_url(albumid):
return urljoin(BASE_URL, 'release/' + albumid)
def configure():
"""Set up the python-musicbrainz-ngs module according to settings
from the beets configuration. This should be called at startup.
"""
hostname = config['musicbrainz']['host'].as_str()
musicbrainzngs.set_hostname(hostname)
musicbrainzngs.set_rate_limit(
config['musicbrainz']['ratelimit_interval'].as_number(),
config['musicbrainz']['ratelimit'].get(int),
)
def _preferred_alias(aliases):
"""Given an list of alias structures for an artist credit, select
and return the user's preferred alias alias or None if no matching
alias is found.
"""
if not aliases:
return
# Only consider aliases that have locales set.
aliases = [a for a in aliases if 'locale' in a]
# Search configured locales in order.
for locale in config['import']['languages'].as_str_seq():
# Find matching primary aliases for this locale.
matches = [a for a in aliases
if a['locale'] == locale and 'primary' in a]
# Skip to the next locale if we have no matches
if not matches:
continue
return matches[0]
def _preferred_release_event(release):
"""Given a release, select and return the user's preferred release
event as a tuple of (country, release_date). Fall back to the
default release event if a preferred event is not found.
"""
countries = config['match']['preferred']['countries'].as_str_seq()
for country in countries:
for event in release.get('release-event-list', {}):
try:
if country in event['area']['iso-3166-1-code-list']:
return country, event['date']
except KeyError:
pass
return release.get('country'), release.get('date')
def _flatten_artist_credit(credit):
"""Given a list representing an ``artist-credit`` block, flatten the
data into a triple of joined artist name strings: canonical, sort, and
credit.
"""
artist_parts = []
artist_sort_parts = []
artist_credit_parts = []
for el in credit:
if isinstance(el, six.string_types):
# Join phrase.
artist_parts.append(el)
artist_credit_parts.append(el)
artist_sort_parts.append(el)
else:
alias = _preferred_alias(el['artist'].get('alias-list', ()))
# An artist.
if alias:
cur_artist_name = alias['alias']
else:
cur_artist_name = el['artist']['name']
artist_parts.append(cur_artist_name)
# Artist sort name.
if alias:
artist_sort_parts.append(alias['sort-name'])
elif 'sort-name' in el['artist']:
artist_sort_parts.append(el['artist']['sort-name'])
else:
artist_sort_parts.append(cur_artist_name)
# Artist credit.
if 'name' in el:
artist_credit_parts.append(el['name'])
else:
artist_credit_parts.append(cur_artist_name)
return (
''.join(artist_parts),
''.join(artist_sort_parts),
''.join(artist_credit_parts),
)
def track_info(recording, index=None, medium=None, medium_index=None,
medium_total=None):
"""Translates a MusicBrainz recording result dictionary into a beets
``TrackInfo`` object. Three parameters are optional and are used
only for tracks that appear on releases (non-singletons): ``index``,
the overall track number; ``medium``, the disc number;
``medium_index``, the track's index on its medium; ``medium_total``,
the number of tracks on the medium. Each number is a 1-based index.
"""
info = beets.autotag.hooks.TrackInfo(
recording['title'],
recording['id'],
index=index,
medium=medium,
medium_index=medium_index,
medium_total=medium_total,
data_source=u'MusicBrainz',
data_url=track_url(recording['id']),
)
if recording.get('artist-credit'):
# Get the artist names.
info.artist, info.artist_sort, info.artist_credit = \
_flatten_artist_credit(recording['artist-credit'])
# Get the ID and sort name of the first artist.
artist = recording['artist-credit'][0]['artist']
info.artist_id = artist['id']
if recording.get('length'):
info.length = int(recording['length']) / (1000.0)
lyricist = []
composer = []
composer_sort = []
for work_relation in recording.get('work-relation-list', ()):
if work_relation['type'] != 'performance':
continue
for artist_relation in work_relation['work'].get(
'artist-relation-list', ()):
if 'type' in artist_relation:
type = artist_relation['type']
if type == 'lyricist':
lyricist.append(artist_relation['artist']['name'])
elif type == 'composer':
composer.append(artist_relation['artist']['name'])
composer_sort.append(
artist_relation['artist']['sort-name'])
if lyricist:
info.lyricist = u', '.join(lyricist)
if composer:
info.composer = u', '.join(composer)
info.composer_sort = u', '.join(composer_sort)
arranger = []
for artist_relation in recording.get('artist-relation-list', ()):
if 'type' in artist_relation:
type = artist_relation['type']
if type == 'arranger':
arranger.append(artist_relation['artist']['name'])
if arranger:
info.arranger = u', '.join(arranger)
info.decode()
return info
def _set_date_str(info, date_str, original=False):
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
object, set the object's release date fields appropriately. If
`original`, then set the original_year, etc., fields.
"""
if date_str:
date_parts = date_str.split('-')
for key in ('year', 'month', 'day'):
if date_parts:
date_part = date_parts.pop(0)
try:
date_num = int(date_part)
except ValueError:
continue
if original:
key = 'original_' + key
setattr(info, key, date_num)
def album_info(release):
"""Takes a MusicBrainz release result dictionary and returns a beets
AlbumInfo object containing the interesting data about that release.
"""
# Get artist name using join phrases.
artist_name, artist_sort_name, artist_credit_name = \
_flatten_artist_credit(release['artist-credit'])
# Basic info.
track_infos = []
index = 0
for medium in release['medium-list']:
disctitle = medium.get('title')
format = medium.get('format')
if format in config['match']['ignored_media'].as_str_seq():
continue
all_tracks = medium['track-list']
if 'data-track-list' in medium:
all_tracks += medium['data-track-list']
track_count = len(all_tracks)
if 'pregap' in medium:
all_tracks.insert(0, medium['pregap'])
for track in all_tracks:
if ('title' in track['recording'] and
track['recording']['title'] in SKIPPED_TRACKS):
continue
if ('video' in track['recording'] and
track['recording']['video'] == 'true' and
config['match']['ignore_video_tracks']):
continue
# Basic information from the recording.
index += 1
ti = track_info(
track['recording'],
index,
int(medium['position']),
int(track['position']),
track_count,
)
ti.release_track_id = track['id']
ti.disctitle = disctitle
ti.media = format
ti.track_alt = track['number']
# Prefer track data, where present, over recording data.
if track.get('title'):
ti.title = track['title']
if track.get('artist-credit'):
# Get the artist names.
ti.artist, ti.artist_sort, ti.artist_credit = \
_flatten_artist_credit(track['artist-credit'])
ti.artist_id = track['artist-credit'][0]['artist']['id']
if track.get('length'):
ti.length = int(track['length']) / (1000.0)
track_infos.append(ti)
info = beets.autotag.hooks.AlbumInfo(
release['title'],
release['id'],
artist_name,
release['artist-credit'][0]['artist']['id'],
track_infos,
mediums=len(release['medium-list']),
artist_sort=artist_sort_name,
artist_credit=artist_credit_name,
data_source=u'MusicBrainz',
data_url=album_url(release['id']),
)
info.va = info.artist_id == VARIOUS_ARTISTS_ID
if info.va:
info.artist = config['va_name'].as_str()
info.asin = release.get('asin')
info.releasegroup_id = release['release-group']['id']
info.albumstatus = release.get('status')
# Build up the disambiguation string from the release group and release.
disambig = []
if release['release-group'].get('disambiguation'):
disambig.append(release['release-group'].get('disambiguation'))
if release.get('disambiguation'):
disambig.append(release.get('disambiguation'))
info.albumdisambig = u', '.join(disambig)
# Get the "classic" Release type. This data comes from a legacy API
# feature before MusicBrainz supported multiple release types.
if 'type' in release['release-group']:
reltype = release['release-group']['type']
if reltype:
info.albumtype = reltype.lower()
# Log the new-style "primary" and "secondary" release types.
# Eventually, we'd like to actually store this data, but we just log
# it for now to help understand the differences.
if 'primary-type' in release['release-group']:
rel_primarytype = release['release-group']['primary-type']
if rel_primarytype:
log.debug('primary MB release type: ' + rel_primarytype.lower())
if 'secondary-type-list' in release['release-group']:
if release['release-group']['secondary-type-list']:
log.debug('secondary MB release type(s): ' + ', '.join(
[secondarytype.lower() for secondarytype in
release['release-group']['secondary-type-list']]))
# Release events.
info.country, release_date = _preferred_release_event(release)
release_group_date = release['release-group'].get('first-release-date')
if not release_date:
# Fall back if release-specific date is not available.
release_date = release_group_date
_set_date_str(info, release_date, False)
_set_date_str(info, release_group_date, True)
# Label name.
if release.get('label-info-list'):
label_info = release['label-info-list'][0]
if label_info.get('label'):
label = label_info['label']['name']
if label != '[no label]':
info.label = label
info.catalognum = label_info.get('catalog-number')
# Text representation data.
if release.get('text-representation'):
rep = release['text-representation']
info.script = rep.get('script')
info.language = rep.get('language')
# Media (format).
if release['medium-list']:
first_medium = release['medium-list'][0]
info.media = first_medium.get('format')
info.decode()
return info
def match_album(artist, album, tracks=None):
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
MusicBrainzAPIError.
The query consists of an artist name, an album name, and,
optionally, a number of tracks on the album.
"""
# Build search criteria.
criteria = {'release': album.lower().strip()}
if artist is not None:
criteria['artist'] = artist.lower().strip()
else:
# Various Artists search.
criteria['arid'] = VARIOUS_ARTISTS_ID
if tracks is not None:
criteria['tracks'] = six.text_type(tracks)
# Abort if we have no search terms.
if not any(criteria.values()):
return
try:
log.debug(u'Searching for MusicBrainz releases with: {!r}', criteria)
res = musicbrainzngs.search_releases(
limit=config['musicbrainz']['searchlimit'].get(int), **criteria)
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'release search', criteria,
traceback.format_exc())
for release in res['release-list']:
# The search result is missing some data (namely, the tracks),
# so we just use the ID and fetch the rest of the information.
albuminfo = album_for_id(release['id'])
if albuminfo is not None:
yield albuminfo
def match_track(artist, title):
"""Searches for a single track and returns an iterable of TrackInfo
objects. May raise a MusicBrainzAPIError.
"""
criteria = {
'artist': artist.lower().strip(),
'recording': title.lower().strip(),
}
if not any(criteria.values()):
return
try:
res = musicbrainzngs.search_recordings(
limit=config['musicbrainz']['searchlimit'].get(int), **criteria)
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'recording search', criteria,
traceback.format_exc())
for recording in res['recording-list']:
yield track_info(recording)
def _parse_id(s):
"""Search for a MusicBrainz ID in the given string and return it. If
no ID can be found, return None.
"""
# Find the first thing that looks like a UUID/MBID.
match = re.search(u'[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
if match:
return match.group()
def album_for_id(releaseid):
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
object or None if the album is not found. May raise a
MusicBrainzAPIError.
"""
log.debug(u'Requesting MusicBrainz release {}', releaseid)
albumid = _parse_id(releaseid)
if not albumid:
log.debug(u'Invalid MBID ({0}).', releaseid)
return
try:
res = musicbrainzngs.get_release_by_id(albumid,
RELEASE_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug(u'Album ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, u'get release by ID', albumid,
traceback.format_exc())
return album_info(res['release'])
def track_for_id(releaseid):
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
or None if no track is found. May raise a MusicBrainzAPIError.
"""
trackid = _parse_id(releaseid)
if not trackid:
log.debug(u'Invalid MBID ({0}).', releaseid)
return
try:
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug(u'Track ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, u'get recording by ID', trackid,
traceback.format_exc())
return track_info(res['recording'])

View file

@ -0,0 +1,134 @@
library: library.db
directory: ~/Music
import:
write: yes
copy: yes
move: no
link: no
hardlink: no
delete: no
resume: ask
incremental: no
incremental_skip_later: no
from_scratch: no
quiet_fallback: skip
none_rec_action: ask
timid: no
log:
autotag: yes
quiet: no
singletons: no
default_action: apply
languages: []
detail: no
flat: no
group_albums: no
pretend: false
search_ids: []
duplicate_action: ask
bell: no
set_fields: {}
clutter: ["Thumbs.DB", ".DS_Store"]
ignore: [".*", "*~", "System Volume Information", "lost+found"]
ignore_hidden: yes
replace:
'[\\/]': _
'^\.': _
'[\x00-\x1f]': _
'[<>:"\?\*\|]': _
'\.$': _
'\s+$': ''
'^\s+': ''
'^-': _
path_sep_replace: _
asciify_paths: false
art_filename: cover
max_filename_length: 0
plugins: []
pluginpath: []
threaded: yes
timeout: 5.0
per_disc_numbering: no
verbose: 0
terminal_encoding:
original_date: no
artist_credit: no
id3v23: no
va_name: "Various Artists"
ui:
terminal_width: 80
length_diff_thresh: 10.0
color: yes
colors:
text_success: green
text_warning: yellow
text_error: red
text_highlight: red
text_highlight_minor: lightgray
action_default: turquoise
action: blue
format_item: $artist - $album - $title
format_album: $albumartist - $album
time_format: '%Y-%m-%d %H:%M:%S'
format_raw_length: no
sort_album: albumartist+ album+
sort_item: artist+ album+ disc+ track+
sort_case_insensitive: yes
paths:
default: $albumartist/$album%aunique{}/$track $title
singleton: Non-Album/$artist/$title
comp: Compilations/$album%aunique{}/$track $title
statefile: state.pickle
musicbrainz:
host: musicbrainz.org
ratelimit: 1
ratelimit_interval: 1.0
searchlimit: 5
match:
strong_rec_thresh: 0.04
medium_rec_thresh: 0.25
rec_gap_thresh: 0.25
max_rec:
missing_tracks: medium
unmatched_tracks: medium
distance_weights:
source: 2.0
artist: 3.0
album: 3.0
media: 1.0
mediums: 1.0
year: 1.0
country: 0.5
label: 0.5
catalognum: 0.5
albumdisambig: 0.5
album_id: 5.0
tracks: 2.0
missing_tracks: 0.9
unmatched_tracks: 0.6
track_title: 3.0
track_artist: 2.0
track_index: 1.0
track_length: 2.0
track_id: 5.0
preferred:
countries: []
media: []
original_year: no
ignored: []
required: []
ignored_media: []
ignore_video_tracks: yes
track_length_grace: 10
track_length_max: 30

View file

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""DBCore is an abstract database package that forms the basis for beets'
Library.
"""
from __future__ import division, absolute_import, print_function
from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
from .types import Type
from .queryparse import query_from_strings
from .queryparse import sort_from_strings
from .queryparse import parse_sorted_query
from .query import InvalidQueryError
# flake8: noqa

View file

@ -0,0 +1,910 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The central Model and Database constructs for DBCore.
"""
from __future__ import division, absolute_import, print_function
import time
import os
from collections import defaultdict
import threading
import sqlite3
import contextlib
import collections
import beets
from beets.util.functemplate import Template
from beets.util import py3_path
from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery
import six
class DBAccessError(Exception):
"""The SQLite database became inaccessible.
This can happen when trying to read or write the database when, for
example, the database file is deleted or otherwise disappears. There
is probably no way to recover from this error.
"""
class FormattedMapping(collections.Mapping):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formatted version of
`model[key]` as a unicode string.
If `for_path` is true, all path separators in the formatted values
are replaced.
"""
def __init__(self, model, for_path=False):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
def __getitem__(self, key):
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
def __iter__(self):
return iter(self.model_keys)
def __len__(self):
return len(self.model_keys)
def get(self, key, default=None):
if default is None:
default = self.model._type(key).format(None)
return super(FormattedMapping, self).get(key, default)
def _get_formatted(self, model, key):
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
value = value.decode('utf-8', 'ignore')
if self.for_path:
sep_repl = beets.config['path_sep_replace'].as_str()
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
return value
# Abstract base for model classes.
class Model(object):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., the allow subscript access like
``obj['field']``). The same field set is available via attribute
access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
available:
* **Fixed attributes** come from a predetermined list of field
names. These fields correspond to SQLite table columns and are
thus fast to read, write, and query.
* **Flexible attributes** are free-form and do not need to be listed
ahead of time.
* **Computed attributes** are read-only fields computed by a getter
function provided by a plugin.
Access to all three field types is uniform: ``obj.field`` works the
same regardless of whether ``field`` is fixed, flexible, or
computed.
Model objects can optionally be associated with a `Library` object,
in which case they can be loaded and stored from the database. Dirty
flags are used to track which fields need to be stored.
"""
# Abstract components (to be provided by subclasses).
_table = None
"""The main SQLite table name.
"""
_flex_table = None
"""The flex field SQLite table name.
"""
_fields = {}
"""A mapping indicating available "fixed" fields on this type. The
keys are field names and the values are `Type` objects.
"""
_search_fields = ()
"""The fields that should be queried by default by unqualified query
terms.
"""
_types = {}
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
_sorts = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""
_always_dirty = False
"""By default, fields only become "dirty" when their value actually
changes. Enabling this flag marks fields as dirty even when the new
value is the same as the old value (e.g., `o.f = o.f`).
"""
@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
"""
# We could cache this if it becomes a performance problem to
# gather the getter mapping every time.
raise NotImplementedError()
def _template_funcs(self):
"""Return a mapping from function names to text-transformer
functions.
"""
# As above: we could consider caching this result.
raise NotImplementedError()
# Basic operation.
def __init__(self, db=None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
self._db = db
self._dirty = set()
self._values_fixed = {}
self._values_flex = {}
# Initial contents.
self.update(values)
self.clear_dirty()
@classmethod
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
ordinary construction are bypassed.
"""
obj = cls(db)
for key, value in fixed_values.items():
obj._values_fixed[key] = cls._type(key).from_sql(value)
for key, value in flex_values.items():
obj._values_flex[key] = cls._type(key).from_sql(value)
return obj
def __repr__(self):
return '{0}({1})'.format(
type(self).__name__,
', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
)
def clear_dirty(self):
"""Mark all fields as *clean* (i.e., not needing to be stored to
the database).
"""
self._dirty = set()
def _check_db(self, need_id=True):
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
"""
if not self._db:
raise ValueError(
u'{0} has no database'.format(type(self).__name__)
)
if need_id and not self.id:
raise ValueError(u'{0} has no id'.format(type(self).__name__))
def copy(self):
"""Create a copy of the model object.
The field values and other state is duplicated, but the new copy
remains associated with the same database as the old object.
(A simple `copy.deepcopy` will not work because it would try to
duplicate the SQLite connection.)
"""
new = self.__class__()
new._db = self._db
new._values_fixed = self._values_fixed.copy()
new._values_flex = self._values_flex.copy()
new._dirty = self._dirty.copy()
return new
# Essential field accessors.
@classmethod
def _type(cls, key):
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
which does no conversion.
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
"""
getters = self._getters()
if key in getters: # Computed.
return getters[key](self)
elif key in self._fields: # Fixed.
return self._values_fixed.get(key, self._type(key).null)
elif key in self._values_flex: # Flexible.
return self._values_flex[key]
else:
raise KeyError(key)
def _setitem(self, key, value):
"""Assign the value for a field, return whether new and old value
differ.
"""
# Choose where to place the value.
if key in self._fields:
source = self._values_fixed
else:
source = self._values_flex
# If the field has a type, filter the value.
value = self._type(key).normalize(value)
# Assign value and possibly mark as dirty.
old_value = source.get(key)
source[key] = value
changed = old_value != value
if self._always_dirty or changed:
self._dirty.add(key)
return changed
def __setitem__(self, key, value):
"""Assign the value for a field.
"""
self._setitem(key, value)
def __delitem__(self, key):
"""Remove a flexible attribute from the model.
"""
if key in self._values_flex: # Flexible.
del self._values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._getters(): # Computed.
raise KeyError(u'computed field {0} cannot be deleted'.format(key))
elif key in self._fields: # Fixed.
raise KeyError(u'fixed field {0} cannot be deleted'.format(key))
else:
raise KeyError(u'no such field {0}'.format(key))
def keys(self, computed=False):
"""Get a list of available field names for this object. The
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
"""
base_keys = list(self._fields) + list(self._values_flex.keys())
if computed:
return base_keys + list(self._getters().keys())
else:
return base_keys
@classmethod
def all_keys(cls):
"""Get a list of available keys for objects of this type.
Includes fixed and computed fields.
"""
return list(cls._fields) + list(cls._getters().keys())
# Act like a dictionary.
def update(self, values):
"""Assign all values in the given dict.
"""
for key, value in values.items():
self[key] = value
def items(self):
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default
def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(True)
def __iter__(self):
"""Iterate over the available field names (excluding computed
fields).
"""
return iter(self.keys())
# Convenient attribute access.
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(u'model has no attribute {0!r}'.format(key))
else:
try:
return self[key]
except KeyError:
raise AttributeError(u'no such field {0!r}'.format(key))
def __setattr__(self, key, value):
if key.startswith('_'):
super(Model, self).__setattr__(key, value)
else:
self[key] = value
def __delattr__(self, key):
if key.startswith('_'):
super(Model, self).__delattr__(key)
else:
del self[key]
# Database interaction (CRUD methods).
def store(self, fields=None):
"""Save the object's metadata into the library database.
:param fields: the fields to be stored. If not specified, all fields
will be.
"""
if fields is None:
fields = self._fields
self._check_db()
# Build assignments for query.
assignments = []
subvars = []
for key in fields:
if key != 'id' and key in self._dirty:
self._dirty.remove(key)
assignments.append(key + '=?')
value = self._type(key).to_sql(self[key])
subvars.append(value)
assignments = ','.join(assignments)
with self._db.transaction() as tx:
# Main table update.
if assignments:
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
self._table, assignments
)
subvars.append(self.id)
tx.mutate(query, subvars)
# Modified/added flexible attributes.
for key, value in self._values_flex.items():
if key in self._dirty:
self._dirty.remove(key)
tx.mutate(
'INSERT INTO {0} '
'(entity_id, key, value) '
'VALUES (?, ?, ?);'.format(self._flex_table),
(self.id, key, value),
)
# Deleted flexible attributes.
for key in self._dirty:
tx.mutate(
'DELETE FROM {0} '
'WHERE entity_id=? AND key=?'.format(self._flex_table),
(self.id, key)
)
self.clear_dirty()
def load(self):
"""Refresh the object's metadata from the library database.
"""
self._check_db()
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
self._values_fixed = {}
self._values_flex = {}
self.update(dict(stored_obj))
self.clear_dirty()
def remove(self):
"""Remove the object's associated rows from the database.
"""
self._check_db()
with self._db.transaction() as tx:
tx.mutate(
'DELETE FROM {0} WHERE id=?'.format(self._table),
(self.id,)
)
tx.mutate(
'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
(self.id,)
)
def add(self, db=None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
The object's `id` and `added` fields are set along with any
current field values.
"""
if db:
self._db = db
self._check_db(False)
with self._db.transaction() as tx:
new_id = tx.mutate(
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
)
self.id = new_id
self.added = time.time()
# Mark every non-null field as dirty and store.
for key in self:
if self[key] is not None:
self._dirty.add(key)
self.store()
# Formatting and templating.
_formatter = FormattedMapping
def formatted(self, for_path=False):
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
return self._formatter(self, for_path)
def evaluate_template(self, template, for_path=False):
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
"""
# Perform substitution.
if isinstance(template, six.string_types):
template = Template(template)
return template.substitute(self.formatted(for_path),
self._template_funcs())
# Parsing.
@classmethod
def _parse(cls, key, string):
"""Parse a string as a value for the given key.
"""
if not isinstance(string, six.string_types):
raise TypeError(u"_parse() argument must be a string")
return cls._type(key).parse(string)
def set_parse(self, key, string):
"""Set the object's key to a value represented by a string.
"""
self[key] = self._parse(key, string)
# Database controller and supporting interfaces.
class Results(object):
"""An item query result set. Iterating over the collection lazily
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, query=None, sort=None):
"""Create a result set that will construct objects of type
`model_class`.
`model_class` is a subclass of `LibModel` that will be
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
If `query` is provided, it is used as a predicate to filter the
results for a "slow query" that cannot be evaluated by the
database directly. If `sort` is provided, it is used to sort the
full list of results before returning. This means it is a "slow
sort" and all objects must be built before returning the first
one.
"""
self.model_class = model_class
self.rows = rows
self.db = db
self.query = query
self.sort = sort
# We keep a queue of rows we haven't yet consumed for
# materialization. We preserve the original total number of
# rows.
self._rows = rows
self._row_count = len(rows)
# The materialized objects corresponding to rows that have been
# consumed.
self._objects = []
def _get_objects(self):
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
For performance, this generator caches materialized objects to
avoid constructing them more than once. This way, iterating over
a `Results` object a second time should be much faster than the
first.
"""
index = 0 # Position in the materialized objects.
while index < len(self._objects) or self._rows:
# Are there previously-materialized objects to produce?
if index < len(self._objects):
yield self._objects[index]
index += 1
# Otherwise, we consume another row, materialize its object
# and produce it.
else:
while self._rows:
row = self._rows.pop(0)
obj = self._make_model(row)
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
self._objects.append(obj)
index += 1
yield obj
break
def __iter__(self):
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
if self.sort:
# Slow sort. Must build the full list first.
objects = self.sort.sort(list(self._get_objects()))
return iter(objects)
else:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _make_model(self, row):
# Get the flexible attributes for the object.
with self.db.transaction() as tx:
flex_rows = tx.query(
'SELECT * FROM {0} WHERE entity_id=?'.format(
self.model_class._flex_table
),
(row['id'],)
)
cols = dict(row)
values = dict((k, v) for (k, v) in cols.items()
if not k[:4] == 'flex')
flex_values = dict((row['key'], row['value']) for row in flex_rows)
# Construct the Python object
obj = self.model_class._awaken(self.db, values, flex_values)
return obj
def __len__(self):
"""Get the number of matching objects.
"""
if not self._rows:
# Fully materialized. Just count the objects.
return len(self._objects)
elif self.query:
# A slow query. Fall back to testing every object.
count = 0
for obj in self:
count += 1
return count
else:
# A fast query. Just count the rows.
return self._row_count
def __nonzero__(self):
"""Does this result contain any objects?
"""
return self.__bool__()
def __bool__(self):
"""Does this result contain any objects?
"""
return bool(len(self))
def __getitem__(self, n):
"""Get the nth item in this result set. This is inefficient: all
items up to n are materialized and thrown away.
"""
if not self._rows and not self.sort:
# Fully materialized and already in order. Just look up the
# object.
return self._objects[n]
it = iter(self)
try:
for i in range(n):
next(it)
return next(it)
except StopIteration:
raise IndexError(u'result index {0} out of range'.format(n))
def get(self):
"""Return the first matching object, or None if no objects
match.
"""
it = iter(self)
try:
return next(it)
except StopIteration:
return None
class Transaction(object):
"""A context manager for safe, concurrent access to the database.
All SQL commands should be executed through a transaction.
"""
def __init__(self, db):
self.db = db
def __enter__(self):
"""Begin a transaction. This transaction may be created while
another is active in a different thread.
"""
with self.db._tx_stack() as stack:
first = not stack
stack.append(self)
if first:
# Beginning a "root" transaction, which corresponds to an
# SQLite transaction.
self.db._db_lock.acquire()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Complete a transaction. This must be the most recently
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
"""
with self.db._tx_stack() as stack:
assert stack.pop() is self
empty = not stack
if empty:
# Ending a "root" transaction. End the SQLite transaction.
self.db._connection().commit()
self.db._db_lock.release()
def query(self, statement, subvals=()):
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
def mutate(self, statement, subvals=()):
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
try:
cursor = self.db._connection().execute(statement, subvals)
return cursor.lastrowid
except sqlite3.OperationalError as e:
# In two specific cases, SQLite reports an error while accessing
# the underlying database file. We surface these exceptions as
# DBAccessError so the application can abort.
if e.args[0] in ("attempt to write a readonly database",
"unable to open database file"):
raise DBAccessError(e.args[0])
else:
raise
def script(self, statements):
"""Execute a string containing multiple SQL statements."""
self.db._connection().executescript(statements)
class Database(object):
"""A container for Model objects that wraps an SQLite database as
the backend.
"""
_models = ()
"""The Model subclasses representing tables in this database.
"""
def __init__(self, path, timeout=5.0):
self.path = path
self.timeout = timeout
self._connections = {}
self._tx_stacks = defaultdict(list)
# A lock to protect the _connections and _tx_stacks maps, which
# both map thread IDs to private resources.
self._shared_map_lock = threading.Lock()
# A lock to protect access to the database itself. SQLite does
# allow multiple threads to access the database at the same
# time, but many users were experiencing crashes related to this
# capability: where SQLite was compiled without HAVE_USLEEP, its
# backoff algorithm in the case of contention was causing
# whole-second sleeps (!) that would trigger its internal
# timeout. Using this lock ensures only one SQLite transaction
# is active at a time.
self._db_lock = threading.Lock()
# Set up database schema.
for model_cls in self._models:
self._make_table(model_cls._table, model_cls._fields)
self._make_attribute_table(model_cls._flex_table)
# Primitive access control: connections and transactions.
def _connection(self):
"""Get a SQLite connection object to the underlying database.
One connection object is created per thread.
"""
thread_id = threading.current_thread().ident
with self._shared_map_lock:
if thread_id in self._connections:
return self._connections[thread_id]
else:
conn = self._create_connection()
self._connections[thread_id] = conn
return conn
def _create_connection(self):
"""Create a SQLite connection to the underlying database.
Makes a new connection every time. If you need to configure the
connection settings (e.g., add custom functions), override this
method.
"""
# Make a new connection. The `sqlite3` module can't use
# bytestring paths here on Python 3, so we need to
# provide a `str` using `py3_path`.
conn = sqlite3.connect(
py3_path(self.path), timeout=self.timeout
)
# Access SELECT results like dictionaries.
conn.row_factory = sqlite3.Row
return conn
def _close(self):
"""Close the all connections to the underlying SQLite database
from all threads. This does not render the database object
unusable; new connections can still be opened on demand.
"""
with self._shared_map_lock:
self._connections.clear()
@contextlib.contextmanager
def _tx_stack(self):
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
"""
thread_id = threading.current_thread().ident
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
def transaction(self):
"""Get a :class:`Transaction` object for interacting directly
with the underlying SQLite database.
"""
return Transaction(self)
# Schema setup and migration.
def _make_table(self, table, fields):
"""Set up the schema of the database. `fields` is a mapping
from field names to `Type`s. Columns are added if necessary.
"""
# Get current schema.
with self.transaction() as tx:
rows = tx.query('PRAGMA table_info(%s)' % table)
current_fields = set([row[1] for row in rows])
field_names = set(fields.keys())
if current_fields.issuperset(field_names):
# Table exists and has all the required columns.
return
if not current_fields:
# No table exists.
columns = []
for name, typ in fields.items():
columns.append('{0} {1}'.format(name, typ.sql))
setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
', '.join(columns))
else:
# Table exists does not match the field set.
setup_sql = ''
for name, typ in fields.items():
if name in current_fields:
continue
setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
table, name, typ.sql
)
with self.transaction() as tx:
tx.script(setup_sql)
def _make_attribute_table(self, flex_table):
"""Create a table and associated index for flexible attributes
for the given entity (if they don't exist).
"""
with self.transaction() as tx:
tx.script("""
CREATE TABLE IF NOT EXISTS {0} (
id INTEGER PRIMARY KEY,
entity_id INTEGER,
key TEXT,
value TEXT,
UNIQUE(entity_id, key) ON CONFLICT REPLACE);
CREATE INDEX IF NOT EXISTS {0}_by_entity
ON {0} (entity_id);
""".format(flex_table))
# Querying.
def _fetch(self, model_cls, query=None, sort=None):
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). `sort` is an
`Sort` object.
"""
query = query or TrueQuery() # A null query.
sort = sort or NullSort() # Unsorted.
where, subvals = query.clause()
order_by = sort.order_clause()
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
model_cls._table,
where or '1',
"ORDER BY {0}".format(order_by) if order_by else '',
)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
return Results(
model_cls, rows, self,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(self, model_cls, id):
"""Get a Model object by its id or None if the id does not
exist.
"""
return self._fetch(model_cls, MatchQuery('id', id)).get()

View file

@ -0,0 +1,944 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The Query type hierarchy for DBCore.
"""
from __future__ import division, absolute_import, print_function
import re
from operator import mul
from beets import util
from datetime import datetime, timedelta
import unicodedata
from functools import reduce
import six
if not six.PY2:
buffer = memoryview # sqlite won't accept memoryview in python 2
class ParsingError(ValueError):
"""Abstract class for any unparseable user-requested album/query
specification.
"""
class InvalidQueryError(ParsingError):
"""Represent any kind of invalid query.
The query should be a unicode string or a list, which will be space-joined.
"""
def __init__(self, query, explanation):
if isinstance(query, list):
query = " ".join(query)
message = u"'{0}': {1}".format(query, explanation)
super(InvalidQueryError, self).__init__(message)
class InvalidQueryArgumentValueError(ParsingError):
"""Represent a query argument that could not be converted as expected.
It exists to be caught in upper stack levels so a meaningful (i.e. with the
query) InvalidQueryError can be raised.
"""
def __init__(self, what, expected, detail=None):
message = u"'{0}' is not {1}".format(what, expected)
if detail:
message = u"{0}: {1}".format(message, detail)
super(InvalidQueryArgumentValueError, self).__init__(message)
class Query(object):
"""An abstract class representing a query into the item database.
"""
def clause(self):
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
"""
return None, ()
def match(self, item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
"""
raise NotImplementedError
def __repr__(self):
return "{0.__class__.__name__}()".format(self)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 0
class FieldQuery(Query):
"""An abstract query that searches in a specific field for a
pattern. Subclasses must provide a `value_match` class method, which
determines whether a certain pattern string matches a certain value
string. Subclasses may also provide `col_clause` to implement the
same matching functionality in SQLite.
"""
def __init__(self, field, pattern, fast=True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self):
return None, ()
def clause(self):
if self.fast:
return self.col_clause()
else:
# Matching a flexattr. This is a slow query.
return None, ()
@classmethod
def value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
raise NotImplementedError()
def match(self, item):
return self.value_match(self.pattern, item.get(self.field))
def __repr__(self):
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
"{0.fast})".format(self))
def __eq__(self, other):
return super(FieldQuery, self).__eq__(other) and \
self.field == other.field and self.pattern == other.pattern
def __hash__(self):
return hash((self.field, hash(self.pattern)))
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
def col_clause(self):
return self.field + " = ?", [self.pattern]
@classmethod
def value_match(cls, pattern, value):
return pattern == value
class NoneQuery(FieldQuery):
"""A query that checks whether a field is null."""
def __init__(self, field, fast=True):
super(NoneQuery, self).__init__(field, None, fast)
def col_clause(self):
return self.field + " IS NULL", ()
@classmethod
def match(cls, item):
try:
return item[cls.field] is None
except KeyError:
return True
def __repr__(self):
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
class StringFieldQuery(FieldQuery):
"""A FieldQuery that converts values to strings before matching
them.
"""
@classmethod
def value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
return cls.string_match(pattern, util.as_string(value))
@classmethod
def string_match(cls, pattern, value):
"""Determine whether the value matches the pattern. Both
arguments are strings. Subclasses implement this method.
"""
raise NotImplementedError()
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
def col_clause(self):
pattern = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
.replace('_', '\\_'))
search = '%' + pattern + '%'
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@classmethod
def string_match(cls, pattern, value):
return pattern.lower() in value.lower()
class RegexpQuery(StringFieldQuery):
"""A query that matches a regular expression in a specific item
field.
Raises InvalidQueryError when the pattern is not a valid regular
expression.
"""
def __init__(self, field, pattern, fast=True):
super(RegexpQuery, self).__init__(field, pattern, fast)
pattern = self._normalize(pattern)
try:
self.pattern = re.compile(self.pattern)
except re.error as exc:
# Invalid regular expression.
raise InvalidQueryArgumentValueError(pattern,
u"a regular expression",
format(exc))
@staticmethod
def _normalize(s):
"""Normalize a Unicode string's representation (used on both
patterns and matched values).
"""
return unicodedata.normalize('NFC', s)
@classmethod
def string_match(cls, pattern, value):
return pattern.search(cls._normalize(value)) is not None
class BooleanQuery(MatchQuery):
"""Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean.
"""
def __init__(self, field, pattern, fast=True):
super(BooleanQuery, self).__init__(field, pattern, fast)
if isinstance(pattern, six.string_types):
self.pattern = util.str2bool(pattern)
self.pattern = int(self.pattern)
class BytesQuery(MatchQuery):
"""Match a raw bytes field (i.e., a path). This is a necessary hack
to work around the `sqlite3` module's desire to treat `bytes` and
`unicode` equivalently in Python 2. Always use this query instead of
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field, pattern):
super(BytesQuery, self).__init__(field, pattern)
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
if isinstance(self.pattern, (six.text_type, bytes)):
if isinstance(self.pattern, six.text_type):
self.pattern = self.pattern.encode('utf-8')
self.buf_pattern = buffer(self.pattern)
elif isinstance(self.pattern, buffer):
self.buf_pattern = self.pattern
self.pattern = bytes(self.pattern)
def col_clause(self):
return self.field + " = ?", [self.buf_pattern]
class NumericQuery(FieldQuery):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
Raises InvalidQueryError when the pattern does not represent an int or
a float.
"""
def _convert(self, s):
"""Convert a string to a numeric type (float or int).
Return None if `s` is empty.
Raise an InvalidQueryError if the string cannot be converted.
"""
# This is really just a bit of fun premature optimization.
if not s:
return None
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
raise InvalidQueryArgumentValueError(s, u"an int or a float")
def __init__(self, field, pattern, fast=True):
super(NumericQuery, self).__init__(field, pattern, fast)
parts = pattern.split('..', 1)
if len(parts) == 1:
# No range.
self.point = self._convert(parts[0])
self.rangemin = None
self.rangemax = None
else:
# One- or two-sided range.
self.point = None
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item):
if self.field not in item:
return False
value = item[self.field]
if isinstance(value, six.string_types):
value = self._convert(value)
if self.point is not None:
return value == self.point
else:
if self.rangemin is not None and value < self.rangemin:
return False
if self.rangemax is not None and value > self.rangemax:
return False
return True
def col_clause(self):
if self.point is not None:
return self.field + '=?', (self.point,)
else:
if self.rangemin is not None and self.rangemax is not None:
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
(self.rangemin, self.rangemax))
elif self.rangemin is not None:
return u'{0} >= ?'.format(self.field), (self.rangemin,)
elif self.rangemax is not None:
return u'{0} <= ?'.format(self.field), (self.rangemax,)
else:
return u'1', ()
class CollectionQuery(Query):
"""An abstract query class that aggregates other queries. Can be
indexed like a list to access the sub-queries.
"""
def __init__(self, subqueries=()):
self.subqueries = subqueries
# Act like a sequence.
def __len__(self):
return len(self.subqueries)
def __getitem__(self, key):
return self.subqueries[key]
def __iter__(self):
return iter(self.subqueries)
def __contains__(self, item):
return item in self.subqueries
def clause_with_joiner(self, joiner):
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
clause_parts = []
subvals = []
for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause()
if not subq_clause:
# Fall back to slow query.
return None, ()
clause_parts.append('(' + subq_clause + ')')
subvals += subq_subvals
clause = (' ' + joiner + ' ').join(clause_parts)
return clause, subvals
def __repr__(self):
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
def __eq__(self, other):
return super(CollectionQuery, self).__eq__(other) and \
self.subqueries == other.subqueries
def __hash__(self):
"""Since subqueries are mutable, this object should not be hashable.
However and for conveniences purposes, it can be hashed.
"""
return reduce(mul, map(hash, self.subqueries), 1)
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
def __init__(self, pattern, fields, cls):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
super(AnyFieldQuery, self).__init__(subqueries)
def clause(self):
return self.clause_with_joiner('or')
def match(self, item):
for subq in self.subqueries:
if subq.match(item):
return True
return False
def __repr__(self):
return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
"{0.query_class.__name__})".format(self))
def __eq__(self, other):
return super(AnyFieldQuery, self).__eq__(other) and \
self.query_class == other.query_class
def __hash__(self):
return hash((self.pattern, tuple(self.fields), self.query_class))
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
"""
def __setitem__(self, key, value):
self.subqueries[key] = value
def __delitem__(self, key):
del self.subqueries[key]
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
return self.clause_with_joiner('and')
def match(self, item):
return all([q.match(item) for q in self.subqueries])
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
return self.clause_with_joiner('or')
def match(self, item):
return any([q.match(item) for q in self.subqueries])
class NotQuery(Query):
"""A query that matches the negation of its `subquery`, as a shorcut for
performing `not(subquery)` without using regular expressions.
"""
def __init__(self, subquery):
self.subquery = subquery
def clause(self):
clause, subvals = self.subquery.clause()
if clause:
return 'not ({0})'.format(clause), subvals
else:
# If there is no clause, there is nothing to negate. All the logic
# is handled by match() for slow queries.
return clause, subvals
def match(self, item):
return not self.subquery.match(item)
def __repr__(self):
return "{0.__class__.__name__}({0.subquery!r})".format(self)
def __eq__(self, other):
return super(NotQuery, self).__eq__(other) and \
self.subquery == other.subquery
def __hash__(self):
return hash(('not', hash(self.subquery)))
class TrueQuery(Query):
"""A query that always matches."""
def clause(self):
return '1', ()
def match(self, item):
return True
class FalseQuery(Query):
"""A query that never matches."""
def clause(self):
return '0', ()
def match(self, item):
return False
# Time/date queries.
def _to_epoch_time(date):
"""Convert a `datetime` object to an integer number of seconds since
the (local) Unix epoch.
"""
if hasattr(date, 'timestamp'):
# The `timestamp` method exists on Python 3.3+.
return int(date.timestamp())
else:
epoch = datetime.fromtimestamp(0)
delta = date - epoch
return int(delta.total_seconds())
def _parse_periods(pattern):
"""Parse a string containing two dates separated by two dots (..).
Return a pair of `Period` objects.
"""
parts = pattern.split('..', 1)
if len(parts) == 1:
instant = Period.parse(parts[0])
return (instant, instant)
else:
start = Period.parse(parts[0])
end = Period.parse(parts[1])
return (start, end)
class Period(object):
"""A period of time given by a date, time and precision.
Example: 2014-01-01 10:50:30 with precision 'month' represents all
instants of time during January 2014.
"""
precisions = ('year', 'month', 'day', 'hour', 'minute', 'second')
date_formats = (
('%Y',), # year
('%Y-%m',), # month
('%Y-%m-%d',), # day
('%Y-%m-%dT%H', '%Y-%m-%d %H'), # hour
('%Y-%m-%dT%H:%M', '%Y-%m-%d %H:%M'), # minute
('%Y-%m-%dT%H:%M:%S', '%Y-%m-%d %H:%M:%S') # second
)
relative_units = {'y': 365, 'm': 30, 'w': 7, 'd': 1}
relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \
'(?P<timespan>[y|m|w|d])'
def __init__(self, date, precision):
"""Create a period with the given date (a `datetime` object) and
precision (a string, one of "year", "month", "day", "hour", "minute",
or "second").
"""
if precision not in Period.precisions:
raise ValueError(u'Invalid precision {0}'.format(precision))
self.date = date
self.precision = precision
@classmethod
def parse(cls, string):
"""Parse a date and return a `Period` object or `None` if the
string is empty, or raise an InvalidQueryArgumentValueError if
the string cannot be parsed to a date.
The date may be absolute or relative. Absolute dates look like
`YYYY`, or `YYYY-MM-DD`, or `YYYY-MM-DD HH:MM:SS`, etc. Relative
dates have three parts:
- Optionally, a ``+`` or ``-`` sign indicating the future or the
past. The default is the future.
- A number: how much to add or subtract.
- A letter indicating the unit: days, weeks, months or years
(``d``, ``w``, ``m`` or ``y``). A "month" is exactly 30 days
and a "year" is exactly 365 days.
"""
def find_date_and_format(string):
for ord, format in enumerate(cls.date_formats):
for format_option in format:
try:
date = datetime.strptime(string, format_option)
return date, ord
except ValueError:
# Parsing failed.
pass
return (None, None)
if not string:
return None
# Check for a relative date.
match_dq = re.match(cls.relative_re, string)
if match_dq:
sign = match_dq.group('sign')
quantity = match_dq.group('quantity')
timespan = match_dq.group('timespan')
# Add or subtract the given amount of time from the current
# date.
multiplier = -1 if sign == '-' else 1
days = cls.relative_units[timespan]
date = datetime.now() + \
timedelta(days=int(quantity) * days) * multiplier
return cls(date, cls.precisions[5])
# Check for an absolute date.
date, ordinal = find_date_and_format(string)
if date is None:
raise InvalidQueryArgumentValueError(string,
'a valid date/time string')
precision = cls.precisions[ordinal]
return cls(date, precision)
def open_right_endpoint(self):
"""Based on the precision, convert the period to a precise
`datetime` for use as a right endpoint in a right-open interval.
"""
precision = self.precision
date = self.date
if 'year' == self.precision:
return date.replace(year=date.year + 1, month=1)
elif 'month' == precision:
if (date.month < 12):
return date.replace(month=date.month + 1)
else:
return date.replace(year=date.year + 1, month=1)
elif 'day' == precision:
return date + timedelta(days=1)
elif 'hour' == precision:
return date + timedelta(hours=1)
elif 'minute' == precision:
return date + timedelta(minutes=1)
elif 'second' == precision:
return date + timedelta(seconds=1)
else:
raise ValueError(u'unhandled precision {0}'.format(precision))
class DateInterval(object):
"""A closed-open interval of dates.
A left endpoint of None means since the beginning of time.
A right endpoint of None means towards infinity.
"""
def __init__(self, start, end):
if start is not None and end is not None and not start < end:
raise ValueError(u"start date {0} is not before end date {1}"
.format(start, end))
self.start = start
self.end = end
@classmethod
def from_periods(cls, start, end):
"""Create an interval with two Periods as the endpoints.
"""
end_date = end.open_right_endpoint() if end is not None else None
start_date = start.date if start is not None else None
return cls(start_date, end_date)
def contains(self, date):
if self.start is not None and date < self.start:
return False
if self.end is not None and date >= self.end:
return False
return True
def __str__(self):
return '[{0}, {1})'.format(self.start, self.end)
class DateQuery(FieldQuery):
"""Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year
is mandatory.
The value of a date field can be matched against a date interval by
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field, pattern, fast=True):
super(DateQuery, self).__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, item):
if self.field not in item:
return False
timestamp = float(item[self.field])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)
_clause_tmpl = "{0} {1} ?"
def col_clause(self):
clause_parts = []
subvals = []
if self.interval.start:
clause_parts.append(self._clause_tmpl.format(self.field, ">="))
subvals.append(_to_epoch_time(self.interval.start))
if self.interval.end:
clause_parts.append(self._clause_tmpl.format(self.field, "<"))
subvals.append(_to_epoch_time(self.interval.end))
if clause_parts:
# One- or two-sided interval.
clause = ' AND '.join(clause_parts)
else:
# Match any date.
clause = '1'
return clause, subvals
class DurationQuery(NumericQuery):
"""NumericQuery that allow human-friendly (M:SS) time interval formats.
Converts the range(s) to a float value, and delegates on NumericQuery.
Raises InvalidQueryError when the pattern does not represent an int, float
or M:SS time interval.
"""
def _convert(self, s):
"""Convert a M:SS or numeric string to a float.
Return None if `s` is empty.
Raise an InvalidQueryError if the string cannot be converted.
"""
if not s:
return None
try:
return util.raw_seconds_short(s)
except ValueError:
try:
return float(s)
except ValueError:
raise InvalidQueryArgumentValueError(
s,
u"a M:SS string or a float")
# Sorting.
class Sort(object):
"""An abstract class representing a sort operation for a query into
the item database.
"""
def order_clause(self):
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""
return None
def sort(self, items):
"""Sort the list of objects and return a list.
"""
return sorted(items)
def is_slow(self):
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False
def __hash__(self):
return 0
def __eq__(self, other):
return type(self) == type(other)
class MultipleSort(Sort):
"""Sort that encapsulates multiple sub-sorts.
"""
def __init__(self, sorts=None):
self.sorts = sorts or []
def add_sort(self, sort):
self.sorts.append(sort)
def _sql_sorts(self):
"""Return the list of sub-sorts for which we can be (at least
partially) fast.
A contiguous suffix of fast (SQL-capable) sub-sorts are
executable in SQL. The remaining, even if they are fast
independently, must be executed slowly.
"""
sql_sorts = []
for sort in reversed(self.sorts):
if not sort.order_clause() is None:
sql_sorts.append(sort)
else:
break
sql_sorts.reverse()
return sql_sorts
def order_clause(self):
order_strings = []
for sort in self._sql_sorts():
order = sort.order_clause()
order_strings.append(order)
return ", ".join(order_strings)
def is_slow(self):
for sort in self.sorts:
if sort.is_slow():
return True
return False
def sort(self, items):
slow_sorts = []
switch_slow = False
for sort in reversed(self.sorts):
if switch_slow:
slow_sorts.append(sort)
elif sort.order_clause() is None:
switch_slow = True
slow_sorts.append(sort)
else:
pass
for sort in slow_sorts:
items = sort.sort(items)
return items
def __repr__(self):
return 'MultipleSort({!r})'.format(self.sorts)
def __hash__(self):
return hash(tuple(self.sorts))
def __eq__(self, other):
return super(MultipleSort, self).__eq__(other) and \
self.sorts == other.sorts
class FieldSort(Sort):
"""An abstract sort criterion that orders by a specific field (of
any kind).
"""
def __init__(self, field, ascending=True, case_insensitive=True):
self.field = field
self.ascending = ascending
self.case_insensitive = case_insensitive
def sort(self, objs):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
def key(item):
field_val = item.get(self.field, '')
if self.case_insensitive and isinstance(field_val, six.text_type):
field_val = field_val.lower()
return field_val
return sorted(objs, key=key, reverse=not self.ascending)
def __repr__(self):
return '<{0}: {1}{2}>'.format(
type(self).__name__,
self.field,
'+' if self.ascending else '-',
)
def __hash__(self):
return hash((self.field, self.ascending))
def __eq__(self, other):
return super(FieldSort, self).__eq__(other) and \
self.field == other.field and \
self.ascending == other.ascending
class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field.
"""
def order_clause(self):
order = "ASC" if self.ascending else "DESC"
if self.case_insensitive:
field = '(CASE ' \
'WHEN TYPEOF({0})="text" THEN LOWER({0}) ' \
'WHEN TYPEOF({0})="blob" THEN LOWER({0}) ' \
'ELSE {0} END)'.format(self.field)
else:
field = self.field
return "{0} {1}".format(field, order)
class SlowFieldSort(FieldSort):
"""A sort criterion by some model field other than a fixed field:
i.e., a computed or flexible field.
"""
def is_slow(self):
return True
class NullSort(Sort):
"""No sorting. Leave results unsorted."""
def sort(self, items):
return items
def __nonzero__(self):
return self.__bool__()
def __bool__(self):
return False
def __eq__(self, other):
return type(self) == type(other) or other is None
def __hash__(self):
return 0

View file

@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Parsing of strings into DBCore queries.
"""
from __future__ import division, absolute_import, print_function
import re
import itertools
from . import query
import beets
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r'(-|\^)?' # Negation prefixes.
r'(?:'
r'(\S+?)' # The field key.
r'(?<!\\):' # Unescaped :
r')?'
r'(.*)', # The term itself.
re.I # Case-insensitive.
)
def parse_query_part(part, query_classes={}, prefixes={},
default_class=query.SubstringQuery):
"""Parse a single *query part*, which is a chunk of a complete query
string representing a single criterion.
A query part is a string consisting of:
- A *pattern*: the value to look for.
- Optionally, a *field name* preceding the pattern, separated by a
colon. So in `foo:bar`, `foo` is the field name and `bar` is the
pattern.
- Optionally, a *query prefix* just before the pattern (and after the
optional colon) indicating the type of query that should be used. For
example, in `~foo`, `~` might be a prefix. (The set of prefixes to
look for is given in the `prefixes` parameter.)
- Optionally, a negation indicator, `-` or `^`, at the very beginning.
Both prefixes and the separating `:` character may be escaped with a
backslash to avoid their normal meaning.
The function returns a tuple consisting of:
- The field name: a string or None if it's not present.
- The pattern, a string.
- The query class to use, which inherits from the base
:class:`Query` type.
- A negation flag, a bool.
The three optional parameters determine which query class is used (i.e.,
the third return value). They are:
- `query_classes`, which maps field names to query classes. These
are used when no explicit prefix is present.
- `prefixes`, which maps prefix strings to query classes.
- `default_class`, the fallback when neither the field nor a prefix
indicates a query class.
So the precedence for determining which query class to return is:
prefix, followed by field, and finally the default.
For example, assuming the `:` prefix is used for `RegexpQuery`:
- `'stapler'` -> `(None, 'stapler', SubstringQuery, False)`
- `'color:red'` -> `('color', 'red', SubstringQuery, False)`
- `':^Quiet'` -> `(None, '^Quiet', RegexpQuery, False)`, because
the `^` follows the `:`
- `'color::b..e'` -> `('color', 'b..e', RegexpQuery, False)`
- `'-color:red'` -> `('color', 'red', SubstringQuery, True)`
"""
# Apply the regular expression and extract the components.
part = part.strip()
match = PARSE_QUERY_PART_REGEX.match(part)
assert match # Regex should always match
negate = bool(match.group(1))
key = match.group(2)
term = match.group(3).replace('\:', ':')
# Check whether there's a prefix in the query and use the
# corresponding query type.
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class, negate
# No matching prefix, so use either the query class determined by
# the field or the default as a fallback.
query_class = query_classes.get(key, default_class)
return key, term, query_class, negate
def construct_query_part(model_cls, prefixes, query_part):
"""Parse a *query part* string and return a :class:`Query` object.
:param model_cls: The :class:`Model` class that this is a query for.
This is used to determine the appropriate query types for the
model's fields.
:param prefixes: A map from prefix strings to :class:`Query` types.
:param query_part: The string to parse.
See the documentation for `parse_query_part` for more information on
query part syntax.
"""
# A shortcut for empty query parts.
if not query_part:
return query.TrueQuery()
# Use `model_cls` to build up a map from field names to `Query`
# classes.
query_classes = {}
for k, t in itertools.chain(model_cls._fields.items(),
model_cls._types.items()):
query_classes[k] = t.query
# Parse the string.
key, pattern, query_class, negate = \
parse_query_part(query_part, query_classes, prefixes)
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
if issubclass(query_class, query.FieldQuery):
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
q = query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
if negate:
return query.NotQuery(q)
else:
return q
else:
# Non-field query type.
if negate:
return query.NotQuery(query_class(pattern))
else:
return query_class(pattern)
# Otherwise, this must be a `FieldQuery`. Use the field name to
# construct the query object.
key = key.lower()
q = query_class(key.lower(), pattern, key in model_cls._fields)
if negate:
return query.NotQuery(q)
return q
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
"""Creates a collection query of type `query_cls` from a list of
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
"""
subqueries = []
for part in query_parts:
subqueries.append(construct_query_part(model_cls, prefixes, part))
if not subqueries: # No terms in query.
subqueries = [query.TrueQuery()]
return query_cls(subqueries)
def construct_sort_part(model_cls, part):
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
ending in ``+`` or ``-`` indicating the sort.
"""
assert part, "part must be a field name and + or -"
field = part[:-1]
assert field, "field is missing"
direction = part[-1]
assert direction in ('+', '-'), "part must end with + or -"
is_ascending = direction == '+'
case_insensitive = beets.config['sort_case_insensitive'].get(bool)
if field in model_cls._sorts:
sort = model_cls._sorts[field](model_cls, is_ascending,
case_insensitive)
elif field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending, case_insensitive)
else:
# Flexible or computed.
sort = query.SlowFieldSort(field, is_ascending, case_insensitive)
return sort
def sort_from_strings(model_cls, sort_parts):
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
sort = query.NullSort()
elif len(sort_parts) == 1:
sort = construct_sort_part(model_cls, sort_parts[0])
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(construct_sort_part(model_cls, part))
return sort
def parse_sorted_query(model_cls, parts, prefixes={}):
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""
# Separate query token and sort token.
query_parts = []
sort_parts = []
# Split up query in to comma-separated subqueries, each representing
# an AndQuery, which need to be joined together in one OrQuery
subquery_parts = []
for part in parts + [u',']:
if part.endswith(u','):
# Ensure we can catch "foo, bar" as well as "foo , bar"
last_subquery_part = part[:-1]
if last_subquery_part:
subquery_parts.append(last_subquery_part)
# Parse the subquery in to a single AndQuery
# TODO: Avoid needlessly wrapping AndQueries containing 1 subquery?
query_parts.append(query_from_strings(
query.AndQuery, model_cls, prefixes, subquery_parts
))
del subquery_parts[:]
else:
# Sort parts (1) end in + or -, (2) don't have a field, and
# (3) consist of more than just the + or -.
if part.endswith((u'+', u'-')) \
and u':' not in part \
and len(part) > 1:
sort_parts.append(part)
else:
subquery_parts.append(part)
# Avoid needlessly wrapping single statements in an OR
q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0]
s = sort_from_strings(model_cls, sort_parts)
return q, s

View file

@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Representation of type information for DBCore model fields.
"""
from __future__ import division, absolute_import, print_function
from . import query
from beets.util import str2bool
import six
if not six.PY2:
buffer = memoryview # sqlite won't accept memoryview in python 2
# Abstract base.
class Type(object):
"""An object encapsulating the type of a model field. Includes
information about how to store, query, format, and parse a given
field.
"""
sql = u'TEXT'
"""The SQLite column type for the value.
"""
query = query.SubstringQuery
"""The `Query` subclass to be used when querying the field.
"""
model_type = six.text_type
"""The Python type that is used to represent the value in the model.
The model is guaranteed to return a value of this type if the field
is accessed. To this end, the constructor is used by the `normalize`
and `from_sql` methods and the `default` property.
"""
@property
def null(self):
"""The value to be exposed when the underlying value is None.
"""
return self.model_type()
def format(self, value):
"""Given a value of this type, produce a Unicode string
representing the value. This is used in template evaluation.
"""
if value is None:
value = self.null
# `self.null` might be `None`
if value is None:
value = u''
if isinstance(value, bytes):
value = value.decode('utf-8', 'ignore')
return six.text_type(value)
def parse(self, string):
"""Parse a (possibly human-written) string and return the
indicated value of this type.
"""
try:
return self.model_type(string)
except ValueError:
return self.null
def normalize(self, value):
"""Given a value that will be assigned into a field of this
type, normalize the value to have the appropriate type. This
base implementation only reinterprets `None`.
"""
if value is None:
return self.null
else:
# TODO This should eventually be replaced by
# `self.model_type(value)`
return value
def from_sql(self, sql_value):
"""Receives the value stored in the SQL backend and return the
value to be stored in the model.
For fixed fields the type of `value` is determined by the column
type affinity given in the `sql` property and the SQL to Python
mapping of the database adapter. For more information see:
http://www.sqlite.org/datatype3.html
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
Flexible fields have the type affinity `TEXT`. This means the
`sql_value` is either a `buffer`/`memoryview` or a `unicode` object`
and the method must handle these in addition.
"""
if isinstance(sql_value, buffer):
sql_value = bytes(sql_value).decode('utf-8', 'ignore')
if isinstance(sql_value, six.text_type):
return self.parse(sql_value)
else:
return self.normalize(sql_value)
def to_sql(self, model_value):
"""Convert a value as stored in the model object to a value used
by the database adapter.
"""
return model_value
# Reusable types.
class Default(Type):
null = None
class Integer(Type):
"""A basic integer type.
"""
sql = u'INTEGER'
query = query.NumericQuery
model_type = int
class PaddedInt(Integer):
"""An integer field that is formatted with a given number of digits,
padded with zeroes.
"""
def __init__(self, digits):
self.digits = digits
def format(self, value):
return u'{0:0{1}d}'.format(value or 0, self.digits)
class ScaledInt(Integer):
"""An integer whose formatting operation scales the number by a
constant and adds a suffix. Good for units with large magnitudes.
"""
def __init__(self, unit, suffix=u''):
self.unit = unit
self.suffix = suffix
def format(self, value):
return u'{0}{1}'.format((value or 0) // self.unit, self.suffix)
class Id(Integer):
"""An integer used as the row id or a foreign key in a SQLite table.
This type is nullable: None values are not translated to zero.
"""
null = None
def __init__(self, primary=True):
if primary:
self.sql = u'INTEGER PRIMARY KEY'
class Float(Type):
"""A basic floating-point type.
"""
sql = u'REAL'
query = query.NumericQuery
model_type = float
def format(self, value):
return u'{0:.1f}'.format(value or 0.0)
class NullFloat(Float):
"""Same as `Float`, but does not normalize `None` to `0.0`.
"""
null = None
class String(Type):
"""A Unicode string type.
"""
sql = u'TEXT'
query = query.SubstringQuery
class Boolean(Type):
"""A boolean type.
"""
sql = u'INTEGER'
query = query.BooleanQuery
model_type = bool
def format(self, value):
return six.text_type(bool(value))
def parse(self, string):
return str2bool(string)
# Shared instances of common types.
DEFAULT = Default()
INTEGER = Integer()
PRIMARY_ID = Id(True)
FOREIGN_ID = Id(False)
FLOAT = Float()
NULL_FLOAT = NullFloat()
STRING = String()
BOOLEAN = Boolean()

File diff suppressed because it is too large Load diff

1609
libs/common/beets/library.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""A drop-in replacement for the standard-library `logging` module that
allows {}-style log formatting on Python 2 and 3.
Provides everything the "logging" module does. The only difference is
that when getLogger(name) instantiates a logger that logger uses
{}-style formatting.
"""
from __future__ import division, absolute_import, print_function
from copy import copy
from logging import * # noqa
import subprocess
import threading
import six
def logsafe(val):
"""Coerce a potentially "problematic" value so it can be formatted
in a Unicode log string.
This works around a number of pitfalls when logging objects in
Python 2:
- Logging path names, which must be byte strings, requires
conversion for output.
- Some objects, including some exceptions, will crash when you call
`unicode(v)` while `str(v)` works fine. CalledProcessError is an
example.
"""
# Already Unicode.
if isinstance(val, six.text_type):
return val
# Bytestring: needs decoding.
elif isinstance(val, bytes):
# Blindly convert with UTF-8. Eventually, it would be nice to
# (a) only do this for paths, if they can be given a distinct
# type, and (b) warn the developer if they do this for other
# bytestrings.
return val.decode('utf-8', 'replace')
# A "problem" object: needs a workaround.
elif isinstance(val, subprocess.CalledProcessError):
try:
return six.text_type(val)
except UnicodeDecodeError:
# An object with a broken __unicode__ formatter. Use __str__
# instead.
return str(val).decode('utf-8', 'replace')
# Other objects are used as-is so field access, etc., still works in
# the format string.
else:
return val
class StrFormatLogger(Logger):
"""A version of `Logger` that uses `str.format`-style formatting
instead of %-style formatting.
"""
class _LogMessage(object):
def __init__(self, msg, args, kwargs):
self.msg = msg
self.args = args
self.kwargs = kwargs
def __str__(self):
args = [logsafe(a) for a in self.args]
kwargs = dict((k, logsafe(v)) for (k, v) in self.kwargs.items())
return self.msg.format(*args, **kwargs)
def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
"""Log msg.format(*args, **kwargs)"""
m = self._LogMessage(msg, args, kwargs)
return super(StrFormatLogger, self)._log(level, m, (), exc_info, extra)
class ThreadLocalLevelLogger(Logger):
"""A version of `Logger` whose level is thread-local instead of shared.
"""
def __init__(self, name, level=NOTSET):
self._thread_level = threading.local()
self.default_level = NOTSET
super(ThreadLocalLevelLogger, self).__init__(name, level)
@property
def level(self):
try:
return self._thread_level.level
except AttributeError:
self._thread_level.level = self.default_level
return self.level
@level.setter
def level(self, value):
self._thread_level.level = value
def set_global_level(self, level):
"""Set the level on the current thread + the default value for all
threads.
"""
self.default_level = level
self.setLevel(level)
class BeetsLogger(ThreadLocalLevelLogger, StrFormatLogger):
pass
my_manager = copy(Logger.manager)
my_manager.loggerClass = BeetsLogger
def getLogger(name=None): # noqa
if name:
return my_manager.getLogger(name)
else:
return Logger.root

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,565 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Support for beets plugins."""
from __future__ import division, absolute_import, print_function
import inspect
import traceback
import re
from collections import defaultdict
from functools import wraps
import beets
from beets import logging
from beets import mediafile
import six
PLUGIN_NAMESPACE = 'beetsplug'
# Plugins using the Last.fm API can share the same API key.
LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43'
# Global logger.
log = logging.getLogger('beets')
class PluginConflictException(Exception):
"""Indicates that the services provided by one plugin conflict with
those of another.
For example two plugins may define different types for flexible fields.
"""
class PluginLogFilter(logging.Filter):
"""A logging filter that identifies the plugin that emitted a log
message.
"""
def __init__(self, plugin):
self.prefix = u'{0}: '.format(plugin.name)
def filter(self, record):
if hasattr(record.msg, 'msg') and isinstance(record.msg.msg,
six.string_types):
# A _LogMessage from our hacked-up Logging replacement.
record.msg.msg = self.prefix + record.msg.msg
elif isinstance(record.msg, six.string_types):
record.msg = self.prefix + record.msg
return True
# Managing the plugins themselves.
class BeetsPlugin(object):
"""The base class for all beets plugins. Plugins provide
functionality by defining a subclass of BeetsPlugin and overriding
the abstract methods defined here.
"""
def __init__(self, name=None):
"""Perform one-time plugin setup.
"""
self.name = name or self.__module__.split('.')[-1]
self.config = beets.config[self.name]
if not self.template_funcs:
self.template_funcs = {}
if not self.template_fields:
self.template_fields = {}
if not self.album_template_fields:
self.album_template_fields = {}
self.early_import_stages = []
self.import_stages = []
self._log = log.getChild(self.name)
self._log.setLevel(logging.NOTSET) # Use `beets` logger level.
if not any(isinstance(f, PluginLogFilter) for f in self._log.filters):
self._log.addFilter(PluginLogFilter(self))
def commands(self):
"""Should return a list of beets.ui.Subcommand objects for
commands that should be added to beets' CLI.
"""
return ()
def _set_stage_log_level(self, stages):
"""Adjust all the stages in `stages` to WARNING logging level.
"""
return [self._set_log_level_and_params(logging.WARNING, stage)
for stage in stages]
def get_early_import_stages(self):
"""Return a list of functions that should be called as importer
pipelines stages early in the pipeline.
The callables are wrapped versions of the functions in
`self.early_import_stages`. Wrapping provides some bookkeeping for the
plugin: specifically, the logging level is adjusted to WARNING.
"""
return self._set_stage_log_level(self.early_import_stages)
def get_import_stages(self):
"""Return a list of functions that should be called as importer
pipelines stages.
The callables are wrapped versions of the functions in
`self.import_stages`. Wrapping provides some bookkeeping for the
plugin: specifically, the logging level is adjusted to WARNING.
"""
return self._set_stage_log_level(self.import_stages)
def _set_log_level_and_params(self, base_log_level, func):
"""Wrap `func` to temporarily set this plugin's logger level to
`base_log_level` + config options (and restore it to its previous
value after the function returns). Also determines which params may not
be sent for backwards-compatibility.
"""
argspec = inspect.getargspec(func)
@wraps(func)
def wrapper(*args, **kwargs):
assert self._log.level == logging.NOTSET
verbosity = beets.config['verbose'].get(int)
log_level = max(logging.DEBUG, base_log_level - 10 * verbosity)
self._log.setLevel(log_level)
try:
try:
return func(*args, **kwargs)
except TypeError as exc:
if exc.args[0].startswith(func.__name__):
# caused by 'func' and not stuff internal to 'func'
kwargs = dict((arg, val) for arg, val in kwargs.items()
if arg in argspec.args)
return func(*args, **kwargs)
else:
raise
finally:
self._log.setLevel(logging.NOTSET)
return wrapper
def queries(self):
"""Should return a dict mapping prefixes to Query subclasses.
"""
return {}
def track_distance(self, item, info):
"""Should return a Distance object to be added to the
distance for every track comparison.
"""
return beets.autotag.hooks.Distance()
def album_distance(self, items, album_info, mapping):
"""Should return a Distance object to be added to the
distance for every album-level comparison.
"""
return beets.autotag.hooks.Distance()
def candidates(self, items, artist, album, va_likely):
"""Should return a sequence of AlbumInfo objects that match the
album whose items are provided.
"""
return ()
def item_candidates(self, item, artist, title):
"""Should return a sequence of TrackInfo objects that match the
item provided.
"""
return ()
def album_for_id(self, album_id):
"""Return an AlbumInfo object or None if no matching release was
found.
"""
return None
def track_for_id(self, track_id):
"""Return a TrackInfo object or None if no matching release was
found.
"""
return None
def add_media_field(self, name, descriptor):
"""Add a field that is synchronized between media files and items.
When a media field is added ``item.write()`` will set the name
property of the item's MediaFile to ``item[name]`` and save the
changes. Similarly ``item.read()`` will set ``item[name]`` to
the value of the name property of the media file.
``descriptor`` must be an instance of ``mediafile.MediaField``.
"""
# Defer impor to prevent circular dependency
from beets import library
mediafile.MediaFile.add_field(name, descriptor)
library.Item._media_fields.add(name)
_raw_listeners = None
listeners = None
def register_listener(self, event, func):
"""Add a function as a listener for the specified event.
"""
wrapped_func = self._set_log_level_and_params(logging.WARNING, func)
cls = self.__class__
if cls.listeners is None or cls._raw_listeners is None:
cls._raw_listeners = defaultdict(list)
cls.listeners = defaultdict(list)
if func not in cls._raw_listeners[event]:
cls._raw_listeners[event].append(func)
cls.listeners[event].append(wrapped_func)
template_funcs = None
template_fields = None
album_template_fields = None
@classmethod
def template_func(cls, name):
"""Decorator that registers a path template function. The
function will be invoked as ``%name{}`` from path format
strings.
"""
def helper(func):
if cls.template_funcs is None:
cls.template_funcs = {}
cls.template_funcs[name] = func
return func
return helper
@classmethod
def template_field(cls, name):
"""Decorator that registers a path template field computation.
The value will be referenced as ``$name`` from path format
strings. The function must accept a single parameter, the Item
being formatted.
"""
def helper(func):
if cls.template_fields is None:
cls.template_fields = {}
cls.template_fields[name] = func
return func
return helper
_classes = set()
def load_plugins(names=()):
"""Imports the modules for a sequence of plugin names. Each name
must be the name of a Python module under the "beetsplug" namespace
package in sys.path; the module indicated should contain the
BeetsPlugin subclasses desired.
"""
for name in names:
modname = '{0}.{1}'.format(PLUGIN_NAMESPACE, name)
try:
try:
namespace = __import__(modname, None, None)
except ImportError as exc:
# Again, this is hacky:
if exc.args[0].endswith(' ' + name):
log.warning(u'** plugin {0} not found', name)
else:
raise
else:
for obj in getattr(namespace, name).__dict__.values():
if isinstance(obj, type) and issubclass(obj, BeetsPlugin) \
and obj != BeetsPlugin and obj not in _classes:
_classes.add(obj)
except Exception:
log.warning(
u'** error loading plugin {}:\n{}',
name,
traceback.format_exc(),
)
_instances = {}
def find_plugins():
"""Returns a list of BeetsPlugin subclass instances from all
currently loaded beets plugins. Loads the default plugin set
first.
"""
load_plugins()
plugins = []
for cls in _classes:
# Only instantiate each plugin class once.
if cls not in _instances:
_instances[cls] = cls()
plugins.append(_instances[cls])
return plugins
# Communication with plugins.
def commands():
"""Returns a list of Subcommand objects from all loaded plugins.
"""
out = []
for plugin in find_plugins():
out += plugin.commands()
return out
def queries():
"""Returns a dict mapping prefix strings to Query subclasses all loaded
plugins.
"""
out = {}
for plugin in find_plugins():
out.update(plugin.queries())
return out
def types(model_cls):
# Gives us `item_types` and `album_types`
attr_name = '{0}_types'.format(model_cls.__name__.lower())
types = {}
for plugin in find_plugins():
plugin_types = getattr(plugin, attr_name, {})
for field in plugin_types:
if field in types and plugin_types[field] != types[field]:
raise PluginConflictException(
u'Plugin {0} defines flexible field {1} '
u'which has already been defined with '
u'another type.'.format(plugin.name, field)
)
types.update(plugin_types)
return types
def track_distance(item, info):
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
"""
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.track_distance(item, info))
return dist
def album_distance(items, album_info, mapping):
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.album_distance(items, album_info, mapping))
return dist
def candidates(items, artist, album, va_likely):
"""Gets MusicBrainz candidates for an album from each plugin.
"""
for plugin in find_plugins():
for candidate in plugin.candidates(items, artist, album, va_likely):
yield candidate
def item_candidates(item, artist, title):
"""Gets MusicBrainz candidates for an item from the plugins.
"""
for plugin in find_plugins():
for item_candidate in plugin.item_candidates(item, artist, title):
yield item_candidate
def album_for_id(album_id):
"""Get AlbumInfo objects for a given ID string.
"""
for plugin in find_plugins():
album = plugin.album_for_id(album_id)
if album:
yield album
def track_for_id(track_id):
"""Get TrackInfo objects for a given ID string.
"""
for plugin in find_plugins():
track = plugin.track_for_id(track_id)
if track:
yield track
def template_funcs():
"""Get all the template functions declared by plugins as a
dictionary.
"""
funcs = {}
for plugin in find_plugins():
if plugin.template_funcs:
funcs.update(plugin.template_funcs)
return funcs
def early_import_stages():
"""Get a list of early import stage functions defined by plugins."""
stages = []
for plugin in find_plugins():
stages += plugin.get_early_import_stages()
return stages
def import_stages():
"""Get a list of import stage functions defined by plugins."""
stages = []
for plugin in find_plugins():
stages += plugin.get_import_stages()
return stages
# New-style (lazy) plugin-provided fields.
def item_field_getters():
"""Get a dictionary mapping field names to unary functions that
compute the field's value.
"""
funcs = {}
for plugin in find_plugins():
if plugin.template_fields:
funcs.update(plugin.template_fields)
return funcs
def album_field_getters():
"""As above, for album fields.
"""
funcs = {}
for plugin in find_plugins():
if plugin.album_template_fields:
funcs.update(plugin.album_template_fields)
return funcs
# Event dispatch.
def event_handlers():
"""Find all event handlers from plugins as a dictionary mapping
event names to sequences of callables.
"""
all_handlers = defaultdict(list)
for plugin in find_plugins():
if plugin.listeners:
for event, handlers in plugin.listeners.items():
all_handlers[event] += handlers
return all_handlers
def send(event, **arguments):
"""Send an event to all assigned event listeners.
`event` is the name of the event to send, all other named arguments
are passed along to the handlers.
Return a list of non-None values returned from the handlers.
"""
log.debug(u'Sending event: {0}', event)
results = []
for handler in event_handlers()[event]:
result = handler(**arguments)
if result is not None:
results.append(result)
return results
def feat_tokens(for_artist=True):
"""Return a regular expression that matches phrases like "featuring"
that separate a main artist or a song title from secondary artists.
The `for_artist` option determines whether the regex should be
suitable for matching artist fields (the default) or title fields.
"""
feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
if for_artist:
feat_words += ['with', 'vs', 'and', 'con', '&']
return '(?<=\s)(?:{0})(?=\s)'.format(
'|'.join(re.escape(x) for x in feat_words)
)
def sanitize_choices(choices, choices_all):
"""Clean up a stringlist configuration attribute: keep only choices
elements present in choices_all, remove duplicate elements, expand '*'
wildcard while keeping original stringlist order.
"""
seen = set()
others = [x for x in choices_all if x not in choices]
res = []
for s in choices:
if s not in seen:
if s in list(choices_all):
res.append(s)
elif s == '*':
res.extend(others)
seen.add(s)
return res
def sanitize_pairs(pairs, pairs_all):
"""Clean up a single-element mapping configuration attribute as returned
by `confit`'s `Pairs` template: keep only two-element tuples present in
pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*')
wildcards while keeping the original order. Note that ('*', '*') and
('*', 'whatever') have the same effect.
For example,
>>> sanitize_pairs(
... [('foo', 'baz bar'), ('key', '*'), ('*', '*')],
... [('foo', 'bar'), ('foo', 'baz'), ('foo', 'foobar'),
... ('key', 'value')]
... )
[('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')]
"""
pairs_all = list(pairs_all)
seen = set()
others = [x for x in pairs_all if x not in pairs]
res = []
for k, values in pairs:
for v in values.split():
x = (k, v)
if x in pairs_all:
if x not in seen:
seen.add(x)
res.append(x)
elif k == '*':
new = [o for o in others if o not in seen]
seen.update(new)
res.extend(new)
elif v == '*':
new = [o for o in others if o not in seen and o[0] == k]
seen.update(new)
res.extend(new)
return res
def notify_info_yielded(event):
"""Makes a generator send the event 'event' every time it yields.
This decorator is supposed to decorate a generator, but any function
returning an iterable should work.
Each yielded value is passed to plugins using the 'info' parameter of
'send'.
"""
def decorator(generator):
def decorated(*args, **kwargs):
for v in generator(*args, **kwargs):
send(event, info=v)
yield v
return decorated
return decorator

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,162 @@
# This file is part of beets.
# Copyright (c) 2014, Thomas Scholtes.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
# Completion for the `beet` command
# =================================
#
# Load this script to complete beets subcommands, options, and
# queries.
#
# If a beets command is found on the command line it completes filenames and
# the subcommand's options. Otherwise it will complete global options and
# subcommands. If the previous option on the command line expects an argument,
# it also completes filenames or directories. Options are only
# completed if '-' has already been typed on the command line.
#
# Note that completion of plugin commands only works for those plugins
# that were enabled when running `beet completion`. It does not check
# plugins dynamically
#
# Currently, only Bash 3.2 and newer is supported and the
# `bash-completion` package is requied.
#
# TODO
# ----
#
# * There are some issues with arguments that are quoted on the command line.
#
# * Complete arguments for the `--format` option by expanding field variables.
#
# beet ls -f "$tit[TAB]
# beet ls -f "$title
#
# * Support long options with `=`, e.g. `--config=file`. Debian's bash
# completion package can handle this.
#
# Determines the beets subcommand and dispatches the completion
# accordingly.
_beet_dispatch() {
local cur prev cmd=
COMPREPLY=()
_get_comp_words_by_ref -n : cur prev
# Look for the beets subcommand
local arg
for (( i=1; i < COMP_CWORD; i++ )); do
arg="${COMP_WORDS[i]}"
if _list_include_item "${opts___global}" $arg; then
((i++))
elif [[ "$arg" != -* ]]; then
cmd="$arg"
break
fi
done
# Replace command shortcuts
if [[ -n $cmd ]] && _list_include_item "$aliases" "$cmd"; then
eval "cmd=\$alias__${cmd//-/_}"
fi
case $cmd in
help)
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
;;
list|remove|move|update|write|stats)
_beet_complete_query
;;
"")
_beet_complete_global
;;
*)
_beet_complete
;;
esac
}
# Adds option and file completion to COMPREPLY for the subcommand $cmd
_beet_complete() {
if [[ $cur == -* ]]; then
local opts flags completions
eval "opts=\$opts__${cmd//-/_}"
eval "flags=\$flags__${cmd//-/_}"
completions="${flags___common} ${opts} ${flags}"
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
else
_filedir
fi
}
# Add global options and subcommands to the completion
_beet_complete_global() {
case $prev in
-h|--help)
# Complete commands
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
return
;;
-l|--library|-c|--config)
# Filename completion
_filedir
return
;;
-d|--directory)
# Directory completion
_filedir -d
return
;;
esac
if [[ $cur == -* ]]; then
local completions="$opts___global $flags___global"
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
elif [[ -n $cur ]] && _list_include_item "$aliases" "$cur"; then
local cmd
eval "cmd=\$alias__${cur//-/_}"
COMPREPLY+=( "$cmd" )
else
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
fi
}
_beet_complete_query() {
local opts
eval "opts=\$opts__${cmd//-/_}"
if [[ $cur == -* ]] || _list_include_item "$opts" "$prev"; then
_beet_complete
elif [[ $cur != \'* && $cur != \"* &&
$cur != *:* ]]; then
# Do not complete quoted queries or those who already have a field
# set.
compopt -o nospace
COMPREPLY+=( $(compgen -S : -W "$fields" -- $cur) )
return 0
fi
}
# Returns true if the space separated list $1 includes $2
_list_include_item() {
[[ " $1 " == *[[:space:]]$2[[:space:]]* ]]
}
# This is where beets dynamically adds the _beet function. This
# function sets the variables $flags, $opts, $commands, and $aliases.
complete -o filenames -F _beet beet

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,261 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Fabrice Laporte
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Abstraction layer to resize images using PIL, ImageMagick, or a
public resizing proxy if neither is available.
"""
from __future__ import division, absolute_import, print_function
import subprocess
import os
import re
from tempfile import NamedTemporaryFile
from six.moves.urllib.parse import urlencode
from beets import logging
from beets import util
import six
# Resizing methods
PIL = 1
IMAGEMAGICK = 2
WEBPROXY = 3
if util.SNI_SUPPORTED:
PROXY_URL = 'https://images.weserv.nl/'
else:
PROXY_URL = 'http://images.weserv.nl/'
log = logging.getLogger('beets')
def resize_url(url, maxwidth):
"""Return a proxied image URL that resizes the original image to
maxwidth (preserving aspect ratio).
"""
return '{0}?{1}'.format(PROXY_URL, urlencode({
'url': url.replace('http://', ''),
'w': maxwidth,
}))
def temp_file_for(path):
"""Return an unused filename with the same extension as the
specified path.
"""
ext = os.path.splitext(path)[1]
with NamedTemporaryFile(suffix=util.py3_path(ext), delete=False) as f:
return util.bytestring_path(f.name)
def pil_resize(maxwidth, path_in, path_out=None):
"""Resize using Python Imaging Library (PIL). Return the output path
of resized image.
"""
path_out = path_out or temp_file_for(path_in)
from PIL import Image
log.debug(u'artresizer: PIL resizing {0} to {1}',
util.displayable_path(path_in), util.displayable_path(path_out))
try:
im = Image.open(util.syspath(path_in))
size = maxwidth, maxwidth
im.thumbnail(size, Image.ANTIALIAS)
im.save(path_out)
return path_out
except IOError:
log.error(u"PIL cannot create thumbnail for '{0}'",
util.displayable_path(path_in))
return path_in
def im_resize(maxwidth, path_in, path_out=None):
"""Resize using ImageMagick's ``convert`` tool.
Return the output path of resized image.
"""
path_out = path_out or temp_file_for(path_in)
log.debug(u'artresizer: ImageMagick resizing {0} to {1}',
util.displayable_path(path_in), util.displayable_path(path_out))
# "-resize WIDTHx>" shrinks images with the width larger
# than the given width while maintaining the aspect ratio
# with regards to the height.
try:
util.command_output([
'convert', util.syspath(path_in, prefix=False),
'-resize', '{0}x>'.format(maxwidth),
util.syspath(path_out, prefix=False),
])
except subprocess.CalledProcessError:
log.warning(u'artresizer: IM convert failed for {0}',
util.displayable_path(path_in))
return path_in
return path_out
BACKEND_FUNCS = {
PIL: pil_resize,
IMAGEMAGICK: im_resize,
}
def pil_getsize(path_in):
from PIL import Image
try:
im = Image.open(util.syspath(path_in))
return im.size
except IOError as exc:
log.error(u"PIL could not read file {}: {}",
util.displayable_path(path_in), exc)
def im_getsize(path_in):
cmd = ['identify', '-format', '%w %h',
util.syspath(path_in, prefix=False)]
try:
out = util.command_output(cmd)
except subprocess.CalledProcessError as exc:
log.warning(u'ImageMagick size query failed')
log.debug(
u'`convert` exited with (status {}) when '
u'getting size with command {}:\n{}',
exc.returncode, cmd, exc.output.strip()
)
return
try:
return tuple(map(int, out.split(b' ')))
except IndexError:
log.warning(u'Could not understand IM output: {0!r}', out)
BACKEND_GET_SIZE = {
PIL: pil_getsize,
IMAGEMAGICK: im_getsize,
}
class Shareable(type):
"""A pseudo-singleton metaclass that allows both shared and
non-shared instances. The ``MyClass.shared`` property holds a
lazily-created shared instance of ``MyClass`` while calling
``MyClass()`` to construct a new object works as usual.
"""
def __init__(cls, name, bases, dict):
super(Shareable, cls).__init__(name, bases, dict)
cls._instance = None
@property
def shared(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
class ArtResizer(six.with_metaclass(Shareable, object)):
"""A singleton class that performs image resizes.
"""
def __init__(self):
"""Create a resizer object with an inferred method.
"""
self.method = self._check_method()
log.debug(u"artresizer: method is {0}", self.method)
self.can_compare = self._can_compare()
def resize(self, maxwidth, path_in, path_out=None):
"""Manipulate an image file according to the method, returning a
new path. For PIL or IMAGEMAGIC methods, resizes the image to a
temporary file. For WEBPROXY, returns `path_in` unmodified.
"""
if self.local:
func = BACKEND_FUNCS[self.method[0]]
return func(maxwidth, path_in, path_out)
else:
return path_in
def proxy_url(self, maxwidth, url):
"""Modifies an image URL according the method, returning a new
URL. For WEBPROXY, a URL on the proxy server is returned.
Otherwise, the URL is returned unmodified.
"""
if self.local:
return url
else:
return resize_url(url, maxwidth)
@property
def local(self):
"""A boolean indicating whether the resizing method is performed
locally (i.e., PIL or ImageMagick).
"""
return self.method[0] in BACKEND_FUNCS
def get_size(self, path_in):
"""Return the size of an image file as an int couple (width, height)
in pixels.
Only available locally
"""
if self.local:
func = BACKEND_GET_SIZE[self.method[0]]
return func(path_in)
def _can_compare(self):
"""A boolean indicating whether image comparison is available"""
return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7)
@staticmethod
def _check_method():
"""Return a tuple indicating an available method and its version."""
version = get_im_version()
if version:
return IMAGEMAGICK, version
version = get_pil_version()
if version:
return PIL, version
return WEBPROXY, (0)
def get_im_version():
"""Return Image Magick version or None if it is unavailable
Try invoking ImageMagick's "convert".
"""
try:
out = util.command_output(['convert', '--version'])
if b'imagemagick' in out.lower():
pattern = br".+ (\d+)\.(\d+)\.(\d+).*"
match = re.search(pattern, out)
if match:
return (int(match.group(1)),
int(match.group(2)),
int(match.group(3)))
return (0,)
except (subprocess.CalledProcessError, OSError) as exc:
log.debug(u'ImageMagick check `convert --version` failed: {}', exc)
return None
def get_pil_version():
"""Return Image Magick version or None if it is unavailable
Try importing PIL."""
try:
__import__('PIL', fromlist=[str('Image')])
return (0,)
except ImportError:
return None

View file

@ -0,0 +1,638 @@
# -*- coding: utf-8 -*-
"""Extremely simple pure-Python implementation of coroutine-style
asynchronous socket I/O. Inspired by, but inferior to, Eventlet.
Bluelet can also be thought of as a less-terrible replacement for
asyncore.
Bluelet: easy concurrency without all the messy parallelism.
"""
from __future__ import division, absolute_import, print_function
import six
import socket
import select
import sys
import types
import errno
import traceback
import time
import collections
# Basic events used for thread scheduling.
class Event(object):
"""Just a base class identifying Bluelet events. An event is an
object yielded from a Bluelet thread coroutine to suspend operation
and communicate with the scheduler.
"""
pass
class WaitableEvent(Event):
"""A waitable event is one encapsulating an action that can be
waited for using a select() call. That is, it's an event with an
associated file descriptor.
"""
def waitables(self):
"""Return "waitable" objects to pass to select(). Should return
three iterables for input readiness, output readiness, and
exceptional conditions (i.e., the three lists passed to
select()).
"""
return (), (), ()
def fire(self):
"""Called when an associated file descriptor becomes ready
(i.e., is returned from a select() call).
"""
pass
class ValueEvent(Event):
"""An event that does nothing but return a fixed value."""
def __init__(self, value):
self.value = value
class ExceptionEvent(Event):
"""Raise an exception at the yield point. Used internally."""
def __init__(self, exc_info):
self.exc_info = exc_info
class SpawnEvent(Event):
"""Add a new coroutine thread to the scheduler."""
def __init__(self, coro):
self.spawned = coro
class JoinEvent(Event):
"""Suspend the thread until the specified child thread has
completed.
"""
def __init__(self, child):
self.child = child
class KillEvent(Event):
"""Unschedule a child thread."""
def __init__(self, child):
self.child = child
class DelegationEvent(Event):
"""Suspend execution of the current thread, start a new thread and,
once the child thread finished, return control to the parent
thread.
"""
def __init__(self, coro):
self.spawned = coro
class ReturnEvent(Event):
"""Return a value the current thread's delegator at the point of
delegation. Ends the current (delegate) thread.
"""
def __init__(self, value):
self.value = value
class SleepEvent(WaitableEvent):
"""Suspend the thread for a given duration.
"""
def __init__(self, duration):
self.wakeup_time = time.time() + duration
def time_left(self):
return max(self.wakeup_time - time.time(), 0.0)
class ReadEvent(WaitableEvent):
"""Reads from a file-like object."""
def __init__(self, fd, bufsize):
self.fd = fd
self.bufsize = bufsize
def waitables(self):
return (self.fd,), (), ()
def fire(self):
return self.fd.read(self.bufsize)
class WriteEvent(WaitableEvent):
"""Writes to a file-like object."""
def __init__(self, fd, data):
self.fd = fd
self.data = data
def waitable(self):
return (), (self.fd,), ()
def fire(self):
self.fd.write(self.data)
# Core logic for executing and scheduling threads.
def _event_select(events):
"""Perform a select() over all the Events provided, returning the
ones ready to be fired. Only WaitableEvents (including SleepEvents)
matter here; all other events are ignored (and thus postponed).
"""
# Gather waitables and wakeup times.
waitable_to_event = {}
rlist, wlist, xlist = [], [], []
earliest_wakeup = None
for event in events:
if isinstance(event, SleepEvent):
if not earliest_wakeup:
earliest_wakeup = event.wakeup_time
else:
earliest_wakeup = min(earliest_wakeup, event.wakeup_time)
elif isinstance(event, WaitableEvent):
r, w, x = event.waitables()
rlist += r
wlist += w
xlist += x
for waitable in r:
waitable_to_event[('r', waitable)] = event
for waitable in w:
waitable_to_event[('w', waitable)] = event
for waitable in x:
waitable_to_event[('x', waitable)] = event
# If we have a any sleeping threads, determine how long to sleep.
if earliest_wakeup:
timeout = max(earliest_wakeup - time.time(), 0.0)
else:
timeout = None
# Perform select() if we have any waitables.
if rlist or wlist or xlist:
rready, wready, xready = select.select(rlist, wlist, xlist, timeout)
else:
rready, wready, xready = (), (), ()
if timeout:
time.sleep(timeout)
# Gather ready events corresponding to the ready waitables.
ready_events = set()
for ready in rready:
ready_events.add(waitable_to_event[('r', ready)])
for ready in wready:
ready_events.add(waitable_to_event[('w', ready)])
for ready in xready:
ready_events.add(waitable_to_event[('x', ready)])
# Gather any finished sleeps.
for event in events:
if isinstance(event, SleepEvent) and event.time_left() == 0.0:
ready_events.add(event)
return ready_events
class ThreadException(Exception):
def __init__(self, coro, exc_info):
self.coro = coro
self.exc_info = exc_info
def reraise(self):
six.reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2])
SUSPENDED = Event() # Special sentinel placeholder for suspended threads.
class Delegated(Event):
"""Placeholder indicating that a thread has delegated execution to a
different thread.
"""
def __init__(self, child):
self.child = child
def run(root_coro):
"""Schedules a coroutine, running it to completion. This
encapsulates the Bluelet scheduler, which the root coroutine can
add to by spawning new coroutines.
"""
# The "threads" dictionary keeps track of all the currently-
# executing and suspended coroutines. It maps coroutines to their
# currently "blocking" event. The event value may be SUSPENDED if
# the coroutine is waiting on some other condition: namely, a
# delegated coroutine or a joined coroutine. In this case, the
# coroutine should *also* appear as a value in one of the below
# dictionaries `delegators` or `joiners`.
threads = {root_coro: ValueEvent(None)}
# Maps child coroutines to delegating parents.
delegators = {}
# Maps child coroutines to joining (exit-waiting) parents.
joiners = collections.defaultdict(list)
def complete_thread(coro, return_value):
"""Remove a coroutine from the scheduling pool, awaking
delegators and joiners as necessary and returning the specified
value to any delegating parent.
"""
del threads[coro]
# Resume delegator.
if coro in delegators:
threads[delegators[coro]] = ValueEvent(return_value)
del delegators[coro]
# Resume joiners.
if coro in joiners:
for parent in joiners[coro]:
threads[parent] = ValueEvent(None)
del joiners[coro]
def advance_thread(coro, value, is_exc=False):
"""After an event is fired, run a given coroutine associated with
it in the threads dict until it yields again. If the coroutine
exits, then the thread is removed from the pool. If the coroutine
raises an exception, it is reraised in a ThreadException. If
is_exc is True, then the value must be an exc_info tuple and the
exception is thrown into the coroutine.
"""
try:
if is_exc:
next_event = coro.throw(*value)
else:
next_event = coro.send(value)
except StopIteration:
# Thread is done.
complete_thread(coro, None)
except BaseException:
# Thread raised some other exception.
del threads[coro]
raise ThreadException(coro, sys.exc_info())
else:
if isinstance(next_event, types.GeneratorType):
# Automatically invoke sub-coroutines. (Shorthand for
# explicit bluelet.call().)
next_event = DelegationEvent(next_event)
threads[coro] = next_event
def kill_thread(coro):
"""Unschedule this thread and its (recursive) delegates.
"""
# Collect all coroutines in the delegation stack.
coros = [coro]
while isinstance(threads[coro], Delegated):
coro = threads[coro].child
coros.append(coro)
# Complete each coroutine from the top to the bottom of the
# stack.
for coro in reversed(coros):
complete_thread(coro, None)
# Continue advancing threads until root thread exits.
exit_te = None
while threads:
try:
# Look for events that can be run immediately. Continue
# running immediate events until nothing is ready.
while True:
have_ready = False
for coro, event in list(threads.items()):
if isinstance(event, SpawnEvent):
threads[event.spawned] = ValueEvent(None) # Spawn.
advance_thread(coro, None)
have_ready = True
elif isinstance(event, ValueEvent):
advance_thread(coro, event.value)
have_ready = True
elif isinstance(event, ExceptionEvent):
advance_thread(coro, event.exc_info, True)
have_ready = True
elif isinstance(event, DelegationEvent):
threads[coro] = Delegated(event.spawned) # Suspend.
threads[event.spawned] = ValueEvent(None) # Spawn.
delegators[event.spawned] = coro
have_ready = True
elif isinstance(event, ReturnEvent):
# Thread is done.
complete_thread(coro, event.value)
have_ready = True
elif isinstance(event, JoinEvent):
threads[coro] = SUSPENDED # Suspend.
joiners[event.child].append(coro)
have_ready = True
elif isinstance(event, KillEvent):
threads[coro] = ValueEvent(None)
kill_thread(event.child)
have_ready = True
# Only start the select when nothing else is ready.
if not have_ready:
break
# Wait and fire.
event2coro = dict((v, k) for k, v in threads.items())
for event in _event_select(threads.values()):
# Run the IO operation, but catch socket errors.
try:
value = event.fire()
except socket.error as exc:
if isinstance(exc.args, tuple) and \
exc.args[0] == errno.EPIPE:
# Broken pipe. Remote host disconnected.
pass
else:
traceback.print_exc()
# Abort the coroutine.
threads[event2coro[event]] = ReturnEvent(None)
else:
advance_thread(event2coro[event], value)
except ThreadException as te:
# Exception raised from inside a thread.
event = ExceptionEvent(te.exc_info)
if te.coro in delegators:
# The thread is a delegate. Raise exception in its
# delegator.
threads[delegators[te.coro]] = event
del delegators[te.coro]
else:
# The thread is root-level. Raise in client code.
exit_te = te
break
except BaseException:
# For instance, KeyboardInterrupt during select(). Raise
# into root thread and terminate others.
threads = {root_coro: ExceptionEvent(sys.exc_info())}
# If any threads still remain, kill them.
for coro in threads:
coro.close()
# If we're exiting with an exception, raise it in the client.
if exit_te:
exit_te.reraise()
# Sockets and their associated events.
class SocketClosedError(Exception):
pass
class Listener(object):
"""A socket wrapper object for listening sockets.
"""
def __init__(self, host, port):
"""Create a listening socket on the given hostname and port.
"""
self._closed = False
self.host = host
self.port = port
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((host, port))
self.sock.listen(5)
def accept(self):
"""An event that waits for a connection on the listening socket.
When a connection is made, the event returns a Connection
object.
"""
if self._closed:
raise SocketClosedError()
return AcceptEvent(self)
def close(self):
"""Immediately close the listening socket. (Not an event.)
"""
self._closed = True
self.sock.close()
class Connection(object):
"""A socket wrapper object for connected sockets.
"""
def __init__(self, sock, addr):
self.sock = sock
self.addr = addr
self._buf = b''
self._closed = False
def close(self):
"""Close the connection."""
self._closed = True
self.sock.close()
def recv(self, size):
"""Read at most size bytes of data from the socket."""
if self._closed:
raise SocketClosedError()
if self._buf:
# We already have data read previously.
out = self._buf[:size]
self._buf = self._buf[size:]
return ValueEvent(out)
else:
return ReceiveEvent(self, size)
def send(self, data):
"""Sends data on the socket, returning the number of bytes
successfully sent.
"""
if self._closed:
raise SocketClosedError()
return SendEvent(self, data)
def sendall(self, data):
"""Send all of data on the socket."""
if self._closed:
raise SocketClosedError()
return SendEvent(self, data, True)
def readline(self, terminator=b"\n", bufsize=1024):
"""Reads a line (delimited by terminator) from the socket."""
if self._closed:
raise SocketClosedError()
while True:
if terminator in self._buf:
line, self._buf = self._buf.split(terminator, 1)
line += terminator
yield ReturnEvent(line)
break
data = yield ReceiveEvent(self, bufsize)
if data:
self._buf += data
else:
line = self._buf
self._buf = b''
yield ReturnEvent(line)
break
class AcceptEvent(WaitableEvent):
"""An event for Listener objects (listening sockets) that suspends
execution until the socket gets a connection.
"""
def __init__(self, listener):
self.listener = listener
def waitables(self):
return (self.listener.sock,), (), ()
def fire(self):
sock, addr = self.listener.sock.accept()
return Connection(sock, addr)
class ReceiveEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously reading data.
"""
def __init__(self, conn, bufsize):
self.conn = conn
self.bufsize = bufsize
def waitables(self):
return (self.conn.sock,), (), ()
def fire(self):
return self.conn.sock.recv(self.bufsize)
class SendEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously writing data.
"""
def __init__(self, conn, data, sendall=False):
self.conn = conn
self.data = data
self.sendall = sendall
def waitables(self):
return (), (self.conn.sock,), ()
def fire(self):
if self.sendall:
return self.conn.sock.sendall(self.data)
else:
return self.conn.sock.send(self.data)
# Public interface for threads; each returns an event object that
# can immediately be "yield"ed.
def null():
"""Event: yield to the scheduler without doing anything special.
"""
return ValueEvent(None)
def spawn(coro):
"""Event: add another coroutine to the scheduler. Both the parent
and child coroutines run concurrently.
"""
if not isinstance(coro, types.GeneratorType):
raise ValueError(u'%s is not a coroutine' % coro)
return SpawnEvent(coro)
def call(coro):
"""Event: delegate to another coroutine. The current coroutine
is resumed once the sub-coroutine finishes. If the sub-coroutine
returns a value using end(), then this event returns that value.
"""
if not isinstance(coro, types.GeneratorType):
raise ValueError(u'%s is not a coroutine' % coro)
return DelegationEvent(coro)
def end(value=None):
"""Event: ends the coroutine and returns a value to its
delegator.
"""
return ReturnEvent(value)
def read(fd, bufsize=None):
"""Event: read from a file descriptor asynchronously."""
if bufsize is None:
# Read all.
def reader():
buf = []
while True:
data = yield read(fd, 1024)
if not data:
break
buf.append(data)
yield ReturnEvent(''.join(buf))
return DelegationEvent(reader())
else:
return ReadEvent(fd, bufsize)
def write(fd, data):
"""Event: write to a file descriptor asynchronously."""
return WriteEvent(fd, data)
def connect(host, port):
"""Event: connect to a network address and return a Connection
object for communicating on the socket.
"""
addr = (host, port)
sock = socket.create_connection(addr)
return ValueEvent(Connection(sock, addr))
def sleep(duration):
"""Event: suspend the thread for ``duration`` seconds.
"""
return SleepEvent(duration)
def join(coro):
"""Suspend the thread until another, previously `spawn`ed thread
completes.
"""
return JoinEvent(coro)
def kill(coro):
"""Halt the execution of a different `spawn`ed thread.
"""
return KillEvent(coro)
# Convenience function for running socket servers.
def server(host, port, func):
"""A coroutine that runs a network server. Host and port specify the
listening address. func should be a coroutine that takes a single
parameter, a Connection object. The coroutine is invoked for every
incoming connection on the listening socket.
"""
def handler(conn):
try:
yield func(conn)
finally:
conn.close()
listener = Listener(host, port)
try:
while True:
conn = yield listener.accept()
yield spawn(handler(conn))
except KeyboardInterrupt:
pass
finally:
listener.close()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
from __future__ import division, absolute_import, print_function
from enum import Enum
class OrderedEnum(Enum):
"""
An Enum subclass that allows comparison of members.
"""
def __ge__(self, other):
if self.__class__ is other.__class__:
return self.value >= other.value
return NotImplemented
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __le__(self, other):
if self.__class__ is other.__class__:
return self.value <= other.value
return NotImplemented
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented

View file

@ -0,0 +1,623 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""This module implements a string formatter based on the standard PEP
292 string.Template class extended with function calls. Variables, as
with string.Template, are indicated with $ and functions are delimited
with %.
This module assumes that everything is Unicode: the template and the
substitution values. Bytestrings are not supported. Also, the templates
always behave like the ``safe_substitute`` method in the standard
library: unknown symbols are left intact.
This is sort of like a tiny, horrible degeneration of a real templating
engine like Jinja2 or Mustache.
"""
from __future__ import division, absolute_import, print_function
import re
import ast
import dis
import types
import sys
import six
SYMBOL_DELIM = u'$'
FUNC_DELIM = u'%'
GROUP_OPEN = u'{'
GROUP_CLOSE = u'}'
ARG_SEP = u','
ESCAPE_CHAR = u'$'
VARIABLE_PREFIX = '__var_'
FUNCTION_PREFIX = '__func_'
class Environment(object):
"""Contains the values and functions to be substituted into a
template.
"""
def __init__(self, values, functions):
self.values = values
self.functions = functions
# Code generation helpers.
def ex_lvalue(name):
"""A variable load expression."""
return ast.Name(name, ast.Store())
def ex_rvalue(name):
"""A variable store expression."""
return ast.Name(name, ast.Load())
def ex_literal(val):
"""An int, float, long, bool, string, or None literal with the given
value.
"""
if val is None:
return ast.Name('None', ast.Load())
elif isinstance(val, six.integer_types):
return ast.Num(val)
elif isinstance(val, bool):
return ast.Name(bytes(val), ast.Load())
elif isinstance(val, six.string_types):
return ast.Str(val)
raise TypeError(u'no literal for {0}'.format(type(val)))
def ex_varassign(name, expr):
"""Assign an expression into a single variable. The expression may
either be an `ast.expr` object or a value to be used as a literal.
"""
if not isinstance(expr, ast.expr):
expr = ex_literal(expr)
return ast.Assign([ex_lvalue(name)], expr)
def ex_call(func, args):
"""A function-call expression with only positional parameters. The
function may be an expression or the name of a function. Each
argument may be an expression or a value to be used as a literal.
"""
if isinstance(func, six.string_types):
func = ex_rvalue(func)
args = list(args)
for i in range(len(args)):
if not isinstance(args[i], ast.expr):
args[i] = ex_literal(args[i])
if sys.version_info[:2] < (3, 5):
return ast.Call(func, args, [], None, None)
else:
return ast.Call(func, args, [])
def compile_func(arg_names, statements, name='_the_func', debug=False):
"""Compile a list of statements as the body of a function and return
the resulting Python function. If `debug`, then print out the
bytecode of the compiled function.
"""
if six.PY2:
func_def = ast.FunctionDef(
name=name.encode('utf-8'),
args=ast.arguments(
args=[ast.Name(n, ast.Param()) for n in arg_names],
vararg=None,
kwarg=None,
defaults=[ex_literal(None) for _ in arg_names],
),
body=statements,
decorator_list=[],
)
else:
func_def = ast.FunctionDef(
name=name,
args=ast.arguments(
args=[ast.arg(arg=n, annotation=None) for n in arg_names],
kwonlyargs=[],
kw_defaults=[],
defaults=[ex_literal(None) for _ in arg_names],
),
body=statements,
decorator_list=[],
)
mod = ast.Module([func_def])
ast.fix_missing_locations(mod)
prog = compile(mod, '<generated>', 'exec')
# Debug: show bytecode.
if debug:
dis.dis(prog)
for const in prog.co_consts:
if isinstance(const, types.CodeType):
dis.dis(const)
the_locals = {}
exec(prog, {}, the_locals)
return the_locals[name]
# AST nodes for the template language.
class Symbol(object):
"""A variable-substitution symbol in a template."""
def __init__(self, ident, original):
self.ident = ident
self.original = original
def __repr__(self):
return u'Symbol(%s)' % repr(self.ident)
def evaluate(self, env):
"""Evaluate the symbol in the environment, returning a Unicode
string.
"""
if self.ident in env.values:
# Substitute for a value.
return env.values[self.ident]
else:
# Keep original text.
return self.original
def translate(self):
"""Compile the variable lookup."""
if six.PY2:
ident = self.ident.encode('utf-8')
else:
ident = self.ident
expr = ex_rvalue(VARIABLE_PREFIX + ident)
return [expr], set([ident]), set()
class Call(object):
"""A function call in a template."""
def __init__(self, ident, args, original):
self.ident = ident
self.args = args
self.original = original
def __repr__(self):
return u'Call(%s, %s, %s)' % (repr(self.ident), repr(self.args),
repr(self.original))
def evaluate(self, env):
"""Evaluate the function call in the environment, returning a
Unicode string.
"""
if self.ident in env.functions:
arg_vals = [expr.evaluate(env) for expr in self.args]
try:
out = env.functions[self.ident](*arg_vals)
except Exception as exc:
# Function raised exception! Maybe inlining the name of
# the exception will help debug.
return u'<%s>' % six.text_type(exc)
return six.text_type(out)
else:
return self.original
def translate(self):
"""Compile the function call."""
varnames = set()
if six.PY2:
ident = self.ident.encode('utf-8')
else:
ident = self.ident
funcnames = set([ident])
arg_exprs = []
for arg in self.args:
subexprs, subvars, subfuncs = arg.translate()
varnames.update(subvars)
funcnames.update(subfuncs)
# Create a subexpression that joins the result components of
# the arguments.
arg_exprs.append(ex_call(
ast.Attribute(ex_literal(u''), 'join', ast.Load()),
[ex_call(
'map',
[
ex_rvalue(six.text_type.__name__),
ast.List(subexprs, ast.Load()),
]
)],
))
subexpr_call = ex_call(
FUNCTION_PREFIX + ident,
arg_exprs
)
return [subexpr_call], varnames, funcnames
class Expression(object):
"""Top-level template construct: contains a list of text blobs,
Symbols, and Calls.
"""
def __init__(self, parts):
self.parts = parts
def __repr__(self):
return u'Expression(%s)' % (repr(self.parts))
def evaluate(self, env):
"""Evaluate the entire expression in the environment, returning
a Unicode string.
"""
out = []
for part in self.parts:
if isinstance(part, six.string_types):
out.append(part)
else:
out.append(part.evaluate(env))
return u''.join(map(six.text_type, out))
def translate(self):
"""Compile the expression to a list of Python AST expressions, a
set of variable names used, and a set of function names.
"""
expressions = []
varnames = set()
funcnames = set()
for part in self.parts:
if isinstance(part, six.string_types):
expressions.append(ex_literal(part))
else:
e, v, f = part.translate()
expressions.extend(e)
varnames.update(v)
funcnames.update(f)
return expressions, varnames, funcnames
# Parser.
class ParseError(Exception):
pass
class Parser(object):
"""Parses a template expression string. Instantiate the class with
the template source and call ``parse_expression``. The ``pos`` field
will indicate the character after the expression finished and
``parts`` will contain a list of Unicode strings, Symbols, and Calls
reflecting the concatenated portions of the expression.
This is a terrible, ad-hoc parser implementation based on a
left-to-right scan with no lexing step to speak of; it's probably
both inefficient and incorrect. Maybe this should eventually be
replaced with a real, accepted parsing technique (PEG, parser
generator, etc.).
"""
def __init__(self, string, in_argument=False):
""" Create a new parser.
:param in_arguments: boolean that indicates the parser is to be
used for parsing function arguments, ie. considering commas
(`ARG_SEP`) a special character
"""
self.string = string
self.in_argument = in_argument
self.pos = 0
self.parts = []
# Common parsing resources.
special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE,
ESCAPE_CHAR)
special_char_re = re.compile(r'[%s]|\Z' %
u''.join(re.escape(c) for c in special_chars))
escapable_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP)
terminator_chars = (GROUP_CLOSE,)
def parse_expression(self):
"""Parse a template expression starting at ``pos``. Resulting
components (Unicode strings, Symbols, and Calls) are added to
the ``parts`` field, a list. The ``pos`` field is updated to be
the next character after the expression.
"""
# Append comma (ARG_SEP) to the list of special characters only when
# parsing function arguments.
extra_special_chars = ()
special_char_re = self.special_char_re
if self.in_argument:
extra_special_chars = (ARG_SEP,)
special_char_re = re.compile(
r'[%s]|\Z' % u''.join(
re.escape(c) for c in
self.special_chars + extra_special_chars
)
)
text_parts = []
while self.pos < len(self.string):
char = self.string[self.pos]
if char not in self.special_chars + extra_special_chars:
# A non-special character. Skip to the next special
# character, treating the interstice as literal text.
next_pos = (
special_char_re.search(
self.string[self.pos:]).start() + self.pos
)
text_parts.append(self.string[self.pos:next_pos])
self.pos = next_pos
continue
if self.pos == len(self.string) - 1:
# The last character can never begin a structure, so we
# just interpret it as a literal character (unless it
# terminates the expression, as with , and }).
if char not in self.terminator_chars + extra_special_chars:
text_parts.append(char)
self.pos += 1
break
next_char = self.string[self.pos + 1]
if char == ESCAPE_CHAR and next_char in (self.escapable_chars +
extra_special_chars):
# An escaped special character ($$, $}, etc.). Note that
# ${ is not an escape sequence: this is ambiguous with
# the start of a symbol and it's not necessary (just
# using { suffices in all cases).
text_parts.append(next_char)
self.pos += 2 # Skip the next character.
continue
# Shift all characters collected so far into a single string.
if text_parts:
self.parts.append(u''.join(text_parts))
text_parts = []
if char == SYMBOL_DELIM:
# Parse a symbol.
self.parse_symbol()
elif char == FUNC_DELIM:
# Parse a function call.
self.parse_call()
elif char in self.terminator_chars + extra_special_chars:
# Template terminated.
break
elif char == GROUP_OPEN:
# Start of a group has no meaning hear; just pass
# through the character.
text_parts.append(char)
self.pos += 1
else:
assert False
# If any parsed characters remain, shift them into a string.
if text_parts:
self.parts.append(u''.join(text_parts))
def parse_symbol(self):
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
starting at ``pos``. Possibly appends a Symbol object (or,
failing that, text) to the ``parts`` field and updates ``pos``.
The character at ``pos`` must, as a precondition, be ``$``.
"""
assert self.pos < len(self.string)
assert self.string[self.pos] == SYMBOL_DELIM
if self.pos == len(self.string) - 1:
# Last character.
self.parts.append(SYMBOL_DELIM)
self.pos += 1
return
next_char = self.string[self.pos + 1]
start_pos = self.pos
self.pos += 1
if next_char == GROUP_OPEN:
# A symbol like ${this}.
self.pos += 1 # Skip opening.
closer = self.string.find(GROUP_CLOSE, self.pos)
if closer == -1 or closer == self.pos:
# No closing brace found or identifier is empty.
self.parts.append(self.string[start_pos:self.pos])
else:
# Closer found.
ident = self.string[self.pos:closer]
self.pos = closer + 1
self.parts.append(Symbol(ident,
self.string[start_pos:self.pos]))
else:
# A bare-word symbol.
ident = self._parse_ident()
if ident:
# Found a real symbol.
self.parts.append(Symbol(ident,
self.string[start_pos:self.pos]))
else:
# A standalone $.
self.parts.append(SYMBOL_DELIM)
def parse_call(self):
"""Parse a function call (like ``%foo{bar,baz}``) starting at
``pos``. Possibly appends a Call object to ``parts`` and update
``pos``. The character at ``pos`` must be ``%``.
"""
assert self.pos < len(self.string)
assert self.string[self.pos] == FUNC_DELIM
start_pos = self.pos
self.pos += 1
ident = self._parse_ident()
if not ident:
# No function name.
self.parts.append(FUNC_DELIM)
return
if self.pos >= len(self.string):
# Identifier terminates string.
self.parts.append(self.string[start_pos:self.pos])
return
if self.string[self.pos] != GROUP_OPEN:
# Argument list not opened.
self.parts.append(self.string[start_pos:self.pos])
return
# Skip past opening brace and try to parse an argument list.
self.pos += 1
args = self.parse_argument_list()
if self.pos >= len(self.string) or \
self.string[self.pos] != GROUP_CLOSE:
# Arguments unclosed.
self.parts.append(self.string[start_pos:self.pos])
return
self.pos += 1 # Move past closing brace.
self.parts.append(Call(ident, args, self.string[start_pos:self.pos]))
def parse_argument_list(self):
"""Parse a list of arguments starting at ``pos``, returning a
list of Expression objects. Does not modify ``parts``. Should
leave ``pos`` pointing to a } character or the end of the
string.
"""
# Try to parse a subexpression in a subparser.
expressions = []
while self.pos < len(self.string):
subparser = Parser(self.string[self.pos:], in_argument=True)
subparser.parse_expression()
# Extract and advance past the parsed expression.
expressions.append(Expression(subparser.parts))
self.pos += subparser.pos
if self.pos >= len(self.string) or \
self.string[self.pos] == GROUP_CLOSE:
# Argument list terminated by EOF or closing brace.
break
# Only other way to terminate an expression is with ,.
# Continue to the next argument.
assert self.string[self.pos] == ARG_SEP
self.pos += 1
return expressions
def _parse_ident(self):
"""Parse an identifier and return it (possibly an empty string).
Updates ``pos``.
"""
remainder = self.string[self.pos:]
ident = re.match(r'\w*', remainder).group(0)
self.pos += len(ident)
return ident
def _parse(template):
"""Parse a top-level template string Expression. Any extraneous text
is considered literal text.
"""
parser = Parser(template)
parser.parse_expression()
parts = parser.parts
remainder = parser.string[parser.pos:]
if remainder:
parts.append(remainder)
return Expression(parts)
# External interface.
class Template(object):
"""A string template, including text, Symbols, and Calls.
"""
def __init__(self, template):
self.expr = _parse(template)
self.original = template
self.compiled = self.translate()
def __eq__(self, other):
return self.original == other.original
def interpret(self, values={}, functions={}):
"""Like `substitute`, but forces the interpreter (rather than
the compiled version) to be used. The interpreter includes
exception-handling code for missing variables and buggy template
functions but is much slower.
"""
return self.expr.evaluate(Environment(values, functions))
def substitute(self, values={}, functions={}):
"""Evaluate the template given the values and functions.
"""
try:
res = self.compiled(values, functions)
except Exception: # Handle any exceptions thrown by compiled version.
res = self.interpret(values, functions)
return res
def translate(self):
"""Compile the template to a Python function."""
expressions, varnames, funcnames = self.expr.translate()
argnames = []
for varname in varnames:
argnames.append(VARIABLE_PREFIX + varname)
for funcname in funcnames:
argnames.append(FUNCTION_PREFIX + funcname)
func = compile_func(
argnames,
[ast.Return(ast.List(expressions, ast.Load()))],
)
def wrapper_func(values={}, functions={}):
args = {}
for varname in varnames:
args[VARIABLE_PREFIX + varname] = values[varname]
for funcname in funcnames:
args[FUNCTION_PREFIX + funcname] = functions[funcname]
parts = func(**args)
return u''.join(parts)
return wrapper_func
# Performance tests.
if __name__ == '__main__':
import timeit
_tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar')
_vars = {'bar': 'qux'}
_funcs = {'baz': six.text_type.upper}
interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)',
'from __main__ import _tmpl, _vars, _funcs',
number=10000)
print(interp_time)
comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)',
'from __main__ import _tmpl, _vars, _funcs',
number=10000)
print(comp_time)
print(u'Speedup:', interp_time / comp_time)

View file

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Simple library to work out if a file is hidden on different platforms."""
from __future__ import division, absolute_import, print_function
import os
import stat
import ctypes
import sys
import beets.util
def _is_hidden_osx(path):
"""Return whether or not a file is hidden on OS X.
This uses os.lstat to work out if a file has the "hidden" flag.
"""
file_stat = os.lstat(beets.util.syspath(path))
if hasattr(file_stat, 'st_flags') and hasattr(stat, 'UF_HIDDEN'):
return bool(file_stat.st_flags & stat.UF_HIDDEN)
else:
return False
def _is_hidden_win(path):
"""Return whether or not a file is hidden on Windows.
This uses GetFileAttributes to work out if a file has the "hidden" flag
(FILE_ATTRIBUTE_HIDDEN).
"""
# FILE_ATTRIBUTE_HIDDEN = 2 (0x2) from GetFileAttributes documentation.
hidden_mask = 2
# Retrieve the attributes for the file.
attrs = ctypes.windll.kernel32.GetFileAttributesW(beets.util.syspath(path))
# Ensure we have valid attribues and compare them against the mask.
return attrs >= 0 and attrs & hidden_mask
def _is_hidden_dot(path):
"""Return whether or not a file starts with a dot.
Files starting with a dot are seen as "hidden" files on Unix-based OSes.
"""
return os.path.basename(path).startswith(b'.')
def is_hidden(path):
"""Return whether or not a file is hidden. `path` should be a
bytestring filename.
This method works differently depending on the platform it is called on.
On OS X, it uses both the result of `is_hidden_osx` and `is_hidden_dot` to
work out if a file is hidden.
On Windows, it uses the result of `is_hidden_win` to work out if a file is
hidden.
On any other operating systems (i.e. Linux), it uses `is_hidden_dot` to
work out if a file is hidden.
"""
# Run platform specific functions depending on the platform
if sys.platform == 'darwin':
return _is_hidden_osx(path) or _is_hidden_dot(path)
elif sys.platform == 'win32':
return _is_hidden_win(path)
else:
return _is_hidden_dot(path)
__all__ = ['is_hidden']

View file

@ -0,0 +1,530 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Simple but robust implementation of generator/coroutine-based
pipelines in Python. The pipelines may be run either sequentially
(single-threaded) or in parallel (one thread per pipeline stage).
This implementation supports pipeline bubbles (indications that the
processing for a certain item should abort). To use them, yield the
BUBBLE constant from any stage coroutine except the last.
In the parallel case, the implementation transparently handles thread
shutdown when the processing is complete and when a stage raises an
exception. KeyboardInterrupts (^C) are also handled.
When running a parallel pipeline, it is also possible to use
multiple coroutines for the same pipeline stage; this lets you speed
up a bottleneck stage by dividing its work among multiple threads.
To do so, pass an iterable of coroutines to the Pipeline constructor
in place of any single coroutine.
"""
from __future__ import division, absolute_import, print_function
from six.moves import queue
from threading import Thread, Lock
import sys
import six
BUBBLE = '__PIPELINE_BUBBLE__'
POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def _invalidate_queue(q, val=None, sync=True):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size. get()ing from the queue returns `val`,
which defaults to None. `sync` controls whether a lock is
required (because it's not reentrant!).
"""
def _qsize(len=len):
return 1
def _put(item):
pass
def _get():
return val
if sync:
q.mutex.acquire()
try:
# Originally, we set `maxsize` to 0 here, which is supposed to mean
# an unlimited queue size. However, there is a race condition since
# Python 3.2 when this attribute is changed while another thread is
# waiting in put()/get() due to a full/empty queue.
# Setting it to 2 is still hacky because Python does not give any
# guarantee what happens if Queue methods/attributes are overwritten
# when it is already in use. However, because of our dummy _put()
# and _get() methods, it provides a workaround to let the queue appear
# to be never empty or full.
# See issue https://github.com/beetbox/beets/issues/2078
q.maxsize = 2
q._qsize = _qsize
q._put = _put
q._get = _get
q.not_empty.notifyAll()
q.not_full.notifyAll()
finally:
if sync:
q.mutex.release()
class CountedQueue(queue.Queue):
"""A queue that keeps track of the number of threads that are
still feeding into it. The queue is poisoned when all threads are
finished with the queue.
"""
def __init__(self, maxsize=0):
queue.Queue.__init__(self, maxsize)
self.nthreads = 0
self.poisoned = False
def acquire(self):
"""Indicate that a thread will start putting into this queue.
Should not be called after the queue is already poisoned.
"""
with self.mutex:
assert not self.poisoned
assert self.nthreads >= 0
self.nthreads += 1
def release(self):
"""Indicate that a thread that was putting into this queue has
exited. If this is the last thread using the queue, the queue
is poisoned.
"""
with self.mutex:
self.nthreads -= 1
assert self.nthreads >= 0
if self.nthreads == 0:
# All threads are done adding to this queue. Poison it
# when it becomes empty.
self.poisoned = True
# Replacement _get invalidates when no items remain.
_old_get = self._get
def _get():
out = _old_get()
if not self.queue:
_invalidate_queue(self, POISON, False)
return out
if self.queue:
# Items remain.
self._get = _get
else:
# No items. Invalidate immediately.
_invalidate_queue(self, POISON, False)
class MultiMessage(object):
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, messages):
self.messages = messages
def multiple(messages):
"""Yield multiple([message, ..]) from a pipeline stage to send
multiple values to the next pipeline stage.
"""
return MultiMessage(messages)
def stage(func):
"""Decorate a function to become a simple stage.
>>> @stage
... def add(n, i):
... return i + n
>>> pipe = Pipeline([
... iter([1, 2, 3]),
... add(2),
... ])
>>> list(pipe.pull())
[3, 4, 5]
"""
def coro(*args):
task = None
while True:
task = yield task
task = func(*(args + (task,)))
return coro
def mutator_stage(func):
"""Decorate a function that manipulates items in a coroutine to
become a simple stage.
>>> @mutator_stage
... def setkey(key, item):
... item[key] = True
>>> pipe = Pipeline([
... iter([{'x': False}, {'a': False}]),
... setkey('x'),
... ])
>>> list(pipe.pull())
[{'x': True}, {'a': False, 'x': True}]
"""
def coro(*args):
task = None
while True:
task = yield task
func(*(args + (task,)))
return coro
def _allmsgs(obj):
"""Returns a list of all the messages encapsulated in obj. If obj
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
returns an empty list. Otherwise, returns a list containing obj.
"""
if isinstance(obj, MultiMessage):
return obj.messages
elif obj == BUBBLE:
return []
else:
return [obj]
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
def __init__(self, all_threads):
super(PipelineThread, self).__init__()
self.abort_lock = Lock()
self.abort_flag = False
self.all_threads = all_threads
self.exc_info = None
def abort(self):
"""Shut down the thread at the next chance possible.
"""
with self.abort_lock:
self.abort_flag = True
# Ensure that we are not blocking on a queue read or write.
if hasattr(self, 'in_queue'):
_invalidate_queue(self.in_queue, POISON)
if hasattr(self, 'out_queue'):
_invalidate_queue(self.out_queue, POISON)
def abort_all(self, exc_info):
"""Abort all other threads in the system for an exception.
"""
self.exc_info = exc_info
for thread in self.all_threads:
thread.abort()
class FirstPipelineThread(PipelineThread):
"""The thread running the first stage in a parallel pipeline setup.
The coroutine should just be a generator.
"""
def __init__(self, coro, out_queue, all_threads):
super(FirstPipelineThread, self).__init__(all_threads)
self.coro = coro
self.out_queue = out_queue
self.out_queue.acquire()
self.abort_lock = Lock()
self.abort_flag = False
def run(self):
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the value from the generator.
try:
msg = next(self.coro)
except StopIteration:
break
# Send messages to the next stage.
for msg in _allmsgs(msg):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except BaseException:
self.abort_all(sys.exc_info())
return
# Generator finished; shut down the pipeline.
self.out_queue.release()
class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
last.
"""
def __init__(self, coro, in_queue, out_queue, all_threads):
super(MiddlePipelineThread, self).__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
self.out_queue = out_queue
self.out_queue.acquire()
def run(self):
try:
# Prime the coroutine.
next(self.coro)
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
with self.abort_lock:
if self.abort_flag:
return
# Invoke the current stage.
out = self.coro.send(msg)
# Send messages to next stage.
for msg in _allmsgs(out):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except BaseException:
self.abort_all(sys.exc_info())
return
# Pipeline is shutting down normally.
self.out_queue.release()
class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
should yield nothing.
"""
def __init__(self, coro, in_queue, all_threads):
super(LastPipelineThread, self).__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
def run(self):
# Prime the coroutine.
next(self.coro)
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
with self.abort_lock:
if self.abort_flag:
return
# Send to consumer.
self.coro.send(msg)
except BaseException:
self.abort_all(sys.exc_info())
return
class Pipeline(object):
"""Represents a staged pattern of work. Each stage in the pipeline
is a coroutine that receives messages from the previous stage and
yields messages to be sent to the next stage.
"""
def __init__(self, stages):
"""Makes a new pipeline from a list of coroutines. There must
be at least two stages.
"""
if len(stages) < 2:
raise ValueError(u'pipeline must have at least two stages')
self.stages = []
for stage in stages:
if isinstance(stage, (list, tuple)):
self.stages.append(stage)
else:
# Default to one thread per stage.
self.stages.append((stage,))
def run_sequential(self):
"""Run the pipeline sequentially in the current thread. The
stages are run one after the other. Only the first coroutine
in each stage is used.
"""
list(self.pull())
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
"""Run the pipeline in parallel using one thread per stage. The
messages between the stages are stored in queues of the given
size.
"""
queue_count = len(self.stages) - 1
queues = [CountedQueue(queue_size) for i in range(queue_count)]
threads = []
# Set up first stage.
for coro in self.stages[0]:
threads.append(FirstPipelineThread(coro, queues[0], threads))
# Middle stages.
for i in range(1, queue_count):
for coro in self.stages[i]:
threads.append(MiddlePipelineThread(
coro, queues[i - 1], queues[i], threads
))
# Last stage.
for coro in self.stages[-1]:
threads.append(
LastPipelineThread(coro, queues[-1], threads)
)
# Start threads.
for thread in threads:
thread.start()
# Wait for termination. The final thread lasts the longest.
try:
# Using a timeout allows us to receive KeyboardInterrupt
# exceptions during the join().
while threads[-1].is_alive():
threads[-1].join(1)
except BaseException:
# Stop all the threads immediately.
for thread in threads:
thread.abort()
raise
finally:
# Make completely sure that all the threads have finished
# before we return. They should already be either finished,
# in normal operation, or aborted, in case of an exception.
for thread in threads[:-1]:
thread.join()
for thread in threads:
exc_info = thread.exc_info
if exc_info:
# Make the exception appear as it was raised originally.
six.reraise(exc_info[0], exc_info[1], exc_info[2])
def pull(self):
"""Yield elements from the end of the pipeline. Runs the stages
sequentially until the last yields some messages. Each of the messages
is then yielded by ``pulled.next()``. If the pipeline has a consumer,
that is the last stage does not yield any messages, then pull will not
yield any messages. Only the first coroutine in each stage is used
"""
coros = [stage[0] for stage in self.stages]
# "Prime" the coroutines.
for coro in coros[1:]:
next(coro)
# Begin the pipeline.
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
for msg in msgs:
yield msg
# Smoke test.
if __name__ == '__main__':
import time
# Test a normally-terminating pipeline both in sequence and
# in parallel.
def produce():
for i in range(5):
print(u'generating %i' % i)
time.sleep(1)
yield i
def work():
num = yield
while True:
print(u'processing %i' % num)
time.sleep(2)
num = yield num * 2
def consume():
while True:
num = yield
time.sleep(1)
print(u'received %i' % num)
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
ts_seq = time.time()
Pipeline([produce(), work(), consume()]).run_parallel()
ts_par = time.time()
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
ts_end = time.time()
print(u'Sequential time:', ts_seq - ts_start)
print(u'Parallel time:', ts_par - ts_seq)
print(u'Multiply-parallel time:', ts_end - ts_par)
print()
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
print(u'generating %i' % i)
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
print(u'processing %i' % num)
time.sleep(3)
if num == 3:
raise Exception()
num = yield num * 2
def exc_consume():
while True:
num = yield
print(u'received %i' % num)
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)

53
libs/common/beets/vfs.py Normal file
View file

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""A simple utility for constructing filesystem-like trees from beets
libraries.
"""
from __future__ import division, absolute_import, print_function
from collections import namedtuple
from beets import util
Node = namedtuple('Node', ['files', 'dirs'])
def _insert(node, path, itemid):
"""Insert an item into a virtual filesystem node."""
if len(path) == 1:
# Last component. Insert file.
node.files[path[0]] = itemid
else:
# In a directory.
dirname = path[0]
rest = path[1:]
if dirname not in node.dirs:
node.dirs[dirname] = Node({}, {})
_insert(node.dirs[dirname], rest, itemid)
def libtree(lib):
"""Generates a filesystem-like directory tree for the files
contained in `lib`. Filesystem nodes are (files, dirs) named
tuples in which both components are dictionaries. The first
maps filenames to Item ids. The second maps directory names to
child node tuples.
"""
root = Node({}, {})
for item in lib.items():
dest = item.destination(fragment=True)
parts = util.components(dest)
_insert(root, parts, item.id)
return root