Update vendored beets to 1.6.0

Updates colorama to 0.4.6
Adds confuse version 1.7.0
Updates jellyfish to 0.9.0
Adds mediafile 0.10.1
Updates munkres to 1.1.4
Updates musicbrainzngs to 0.7.1
Updates mutagen to 1.46.0
Updates pyyaml to 6.0
Updates unidecode to 1.3.6
This commit is contained in:
Labrys of Knossos 2022-11-28 18:02:40 -05:00
commit 56c6773c6b
385 changed files with 25143 additions and 18080 deletions

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -13,30 +12,29 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
from __future__ import division, absolute_import, print_function
import os
import confuse
from sys import stderr
from beets.util import confit
__version__ = u'1.4.7'
__author__ = u'Adrian Sampson <adrian@radbox.org>'
__version__ = '1.6.0'
__author__ = 'Adrian Sampson <adrian@radbox.org>'
class IncludeLazyConfig(confit.LazyConfig):
"""A version of Confit's LazyConfig that also merges in data from
class IncludeLazyConfig(confuse.LazyConfig):
"""A version of Confuse's LazyConfig that also merges in data from
YAML files specified in an `include` setting.
"""
def read(self, user=True, defaults=True):
super(IncludeLazyConfig, self).read(user, defaults)
super().read(user, defaults)
try:
for view in self['include']:
filename = view.as_filename()
if os.path.isfile(filename):
self.set_file(filename)
except confit.NotFoundError:
self.set_file(view.as_filename())
except confuse.NotFoundError:
pass
except confuse.ConfigReadError as err:
stderr.write("configuration `import` failed: {}"
.format(err.reason))
config = IncludeLazyConfig('beets', __name__)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2017, Adrian Sampson.
#
@ -17,7 +16,6 @@
`python -m beets`.
"""
from __future__ import division, absolute_import, print_function
import sys
from .ui import main

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -17,7 +16,6 @@
music and items' embedded album art.
"""
from __future__ import division, absolute_import, print_function
import subprocess
import platform
@ -26,7 +24,7 @@ import os
from beets.util import displayable_path, syspath, bytestring_path
from beets.util.artresizer import ArtResizer
from beets import mediafile
import mediafile
def mediafile_image(image_path, maxwidth=None):
@ -43,7 +41,7 @@ def get_art(log, item):
try:
mf = mediafile.MediaFile(syspath(item.path))
except mediafile.UnreadableFileError as exc:
log.warning(u'Could not extract art from {0}: {1}',
log.warning('Could not extract art from {0}: {1}',
displayable_path(item.path), exc)
return
@ -51,26 +49,27 @@ def get_art(log, item):
def embed_item(log, item, imagepath, maxwidth=None, itempath=None,
compare_threshold=0, ifempty=False, as_album=False):
compare_threshold=0, ifempty=False, as_album=False, id3v23=None,
quality=0):
"""Embed an image into the item's media file.
"""
# Conditions and filters.
if compare_threshold:
if not check_art_similarity(log, item, imagepath, compare_threshold):
log.info(u'Image not similar; skipping.')
log.info('Image not similar; skipping.')
return
if ifempty and get_art(log, item):
log.info(u'media file already contained art')
return
log.info('media file already contained art')
return
if maxwidth and not as_album:
imagepath = resize_image(log, imagepath, maxwidth)
imagepath = resize_image(log, imagepath, maxwidth, quality)
# Get the `Image` object from the file.
try:
log.debug(u'embedding {0}', displayable_path(imagepath))
log.debug('embedding {0}', displayable_path(imagepath))
image = mediafile_image(imagepath, maxwidth)
except IOError as exc:
log.warning(u'could not read image file: {0}', exc)
except OSError as exc:
log.warning('could not read image file: {0}', exc)
return
# Make sure the image kind is safe (some formats only support PNG
@ -80,36 +79,39 @@ def embed_item(log, item, imagepath, maxwidth=None, itempath=None,
image.mime_type)
return
item.try_write(path=itempath, tags={'images': [image]})
item.try_write(path=itempath, tags={'images': [image]}, id3v23=id3v23)
def embed_album(log, album, maxwidth=None, quiet=False,
compare_threshold=0, ifempty=False):
def embed_album(log, album, maxwidth=None, quiet=False, compare_threshold=0,
ifempty=False, quality=0):
"""Embed album art into all of the album's items.
"""
imagepath = album.artpath
if not imagepath:
log.info(u'No album art present for {0}', album)
log.info('No album art present for {0}', album)
return
if not os.path.isfile(syspath(imagepath)):
log.info(u'Album art not found at {0} for {1}',
log.info('Album art not found at {0} for {1}',
displayable_path(imagepath), album)
return
if maxwidth:
imagepath = resize_image(log, imagepath, maxwidth)
imagepath = resize_image(log, imagepath, maxwidth, quality)
log.info(u'Embedding album art into {0}', album)
log.info('Embedding album art into {0}', album)
for item in album.items():
embed_item(log, item, imagepath, maxwidth, None,
compare_threshold, ifempty, as_album=True)
embed_item(log, item, imagepath, maxwidth, None, compare_threshold,
ifempty, as_album=True, quality=quality)
def resize_image(log, imagepath, maxwidth):
"""Returns path to an image resized to maxwidth.
def resize_image(log, imagepath, maxwidth, quality):
"""Returns path to an image resized to maxwidth and encoded with the
specified quality level.
"""
log.debug(u'Resizing album art to {0} pixels wide', maxwidth)
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath))
log.debug('Resizing album art to {0} pixels wide and encoding at quality \
level {1}', maxwidth, quality)
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath),
quality=quality)
return imagepath
@ -131,7 +133,7 @@ def check_art_similarity(log, item, imagepath, compare_threshold):
syspath(art, prefix=False),
'-colorspace', 'gray', 'MIFF:-']
compare_cmd = ['compare', '-metric', 'PHASH', '-', 'null:']
log.debug(u'comparing images with pipeline {} | {}',
log.debug('comparing images with pipeline {} | {}',
convert_cmd, compare_cmd)
convert_proc = subprocess.Popen(
convert_cmd,
@ -155,7 +157,7 @@ def check_art_similarity(log, item, imagepath, compare_threshold):
convert_proc.wait()
if convert_proc.returncode:
log.debug(
u'ImageMagick convert failed with status {}: {!r}',
'ImageMagick convert failed with status {}: {!r}',
convert_proc.returncode,
convert_stderr,
)
@ -165,7 +167,7 @@ def check_art_similarity(log, item, imagepath, compare_threshold):
stdout, stderr = compare_proc.communicate()
if compare_proc.returncode:
if compare_proc.returncode != 1:
log.debug(u'ImageMagick compare failed: {0}, {1}',
log.debug('ImageMagick compare failed: {0}, {1}',
displayable_path(imagepath),
displayable_path(art))
return
@ -176,10 +178,10 @@ def check_art_similarity(log, item, imagepath, compare_threshold):
try:
phash_diff = float(out_str)
except ValueError:
log.debug(u'IM output is not a number: {0!r}', out_str)
log.debug('IM output is not a number: {0!r}', out_str)
return
log.debug(u'ImageMagick compare score: {0}', phash_diff)
log.debug('ImageMagick compare score: {0}', phash_diff)
return phash_diff <= compare_threshold
return True
@ -189,18 +191,18 @@ def extract(log, outpath, item):
art = get_art(log, item)
outpath = bytestring_path(outpath)
if not art:
log.info(u'No album art present in {0}, skipping.', item)
log.info('No album art present in {0}, skipping.', item)
return
# Add an extension to the filename.
ext = mediafile.image_extension(art)
if not ext:
log.warning(u'Unknown image type in {0}.',
log.warning('Unknown image type in {0}.',
displayable_path(item.path))
return
outpath += bytestring_path('.' + ext)
log.info(u'Extracting album art from: {0} to: {1}',
log.info('Extracting album art from: {0} to: {1}',
item, displayable_path(outpath))
with open(syspath(outpath), 'wb') as f:
f.write(art)
@ -216,7 +218,7 @@ def extract_first(log, outpath, items):
def clear(log, lib, query):
items = lib.items(query)
log.info(u'Clearing album art from {0} items', len(items))
log.info('Clearing album art from {0} items', len(items))
for item in items:
log.debug(u'Clearing art for {0}', item)
log.debug('Clearing art for {0}', item)
item.try_write(tags={'images': None})

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -16,19 +15,59 @@
"""Facilities for automatically determining files' correct metadata.
"""
from __future__ import division, absolute_import, print_function
from beets import logging
from beets import config
# Parts of external interface.
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa
from .hooks import ( # noqa
AlbumInfo,
TrackInfo,
AlbumMatch,
TrackMatch,
Distance,
)
from .match import tag_item, tag_album, Proposal # noqa
from .match import Recommendation # noqa
# Global logger.
log = logging.getLogger('beets')
# Metadata fields that are already hardcoded, or where the tag name changes.
SPECIAL_FIELDS = {
'album': (
'va',
'releasegroup_id',
'artist_id',
'album_id',
'mediums',
'tracks',
'year',
'month',
'day',
'artist',
'artist_credit',
'artist_sort',
'data_url'
),
'track': (
'track_alt',
'artist_id',
'release_track_id',
'medium',
'index',
'medium_index',
'title',
'artist_credit',
'artist_sort',
'artist',
'track_id',
'medium_total',
'data_url',
'length'
)
}
# Additional utilities for the main interface.
@ -43,17 +82,14 @@ def apply_item_metadata(item, track_info):
item.mb_releasetrackid = track_info.release_track_id
if track_info.artist_id:
item.mb_artistid = track_info.artist_id
if track_info.data_source:
item.data_source = track_info.data_source
if track_info.lyricist is not None:
item.lyricist = track_info.lyricist
if track_info.composer is not None:
item.composer = track_info.composer
if track_info.composer_sort is not None:
item.composer_sort = track_info.composer_sort
if track_info.arranger is not None:
item.arranger = track_info.arranger
for field, value in track_info.items():
# We only overwrite fields that are not already hardcoded.
if field in SPECIAL_FIELDS['track']:
continue
if value is None:
continue
item[field] = value
# At the moment, the other metadata is left intact (including album
# and track number). Perhaps these should be emptied?
@ -142,33 +178,24 @@ def apply_metadata(album_info, mapping):
# Compilation flag.
item.comp = album_info.va
# Miscellaneous metadata.
for field in ('albumtype',
'label',
'asin',
'catalognum',
'script',
'language',
'country',
'albumstatus',
'albumdisambig',
'data_source',):
value = getattr(album_info, field)
if value is not None:
item[field] = value
if track_info.disctitle is not None:
item.disctitle = track_info.disctitle
if track_info.media is not None:
item.media = track_info.media
if track_info.lyricist is not None:
item.lyricist = track_info.lyricist
if track_info.composer is not None:
item.composer = track_info.composer
if track_info.composer_sort is not None:
item.composer_sort = track_info.composer_sort
if track_info.arranger is not None:
item.arranger = track_info.arranger
# Track alt.
item.track_alt = track_info.track_alt
# Don't overwrite fields with empty values unless the
# field is explicitly allowed to be overwritten
for field, value in album_info.items():
if field in SPECIAL_FIELDS['album']:
continue
clobber = field in config['overwrite_null']['album'].as_str_seq()
if value is None and not clobber:
continue
item[field] = value
for field, value in track_info.items():
if field in SPECIAL_FIELDS['track']:
continue
clobber = field in config['overwrite_null']['track'].as_str_seq()
value = getattr(track_info, field)
if value is None and not clobber:
continue
item[field] = value

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -14,7 +13,6 @@
# included in all copies or substantial portions of the Software.
"""Glue between metadata sources and the matching logic."""
from __future__ import division, absolute_import, print_function
from collections import namedtuple
from functools import total_ordering
@ -27,14 +25,36 @@ from beets.util import as_string
from beets.autotag import mb
from jellyfish import levenshtein_distance
from unidecode import unidecode
import six
log = logging.getLogger('beets')
# The name of the type for patterns in re changed in Python 3.7.
try:
Pattern = re._pattern_type
except AttributeError:
Pattern = re.Pattern
# Classes used to represent candidate options.
class AttrDict(dict):
"""A dictionary that supports attribute ("dot") access, so `d.field`
is equivalent to `d['field']`.
"""
class AlbumInfo(object):
def __getattr__(self, attr):
if attr in self:
return self.get(attr)
else:
raise AttributeError
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __hash__(self):
return id(self)
class AlbumInfo(AttrDict):
"""Describes a canonical release that may be used to match a release
in the library. Consists of these data members:
@ -43,38 +63,22 @@ class AlbumInfo(object):
- ``artist``: name of the release's primary artist
- ``artist_id``
- ``tracks``: list of TrackInfo objects making up the release
- ``asin``: Amazon ASIN
- ``albumtype``: string describing the kind of release
- ``va``: boolean: whether the release has "various artists"
- ``year``: release year
- ``month``: release month
- ``day``: release day
- ``label``: music label responsible for the release
- ``mediums``: the number of discs in this release
- ``artist_sort``: name of the release's artist for sorting
- ``releasegroup_id``: MBID for the album's release group
- ``catalognum``: the label's catalog number for the release
- ``script``: character set used for metadata
- ``language``: human language of the metadata
- ``country``: the release country
- ``albumstatus``: MusicBrainz release status (Official, etc.)
- ``media``: delivery mechanism (Vinyl, etc.)
- ``albumdisambig``: MusicBrainz release disambiguation comment
- ``artist_credit``: Release-specific artist name
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
- ``data_url``: The data source release URL.
The fields up through ``tracks`` are required. The others are
optional and may be None.
``mediums`` along with the fields up through ``tracks`` are required.
The others are optional and may be None.
"""
def __init__(self, album, album_id, artist, artist_id, tracks, asin=None,
albumtype=None, va=False, year=None, month=None, day=None,
label=None, mediums=None, artist_sort=None,
releasegroup_id=None, catalognum=None, script=None,
language=None, country=None, albumstatus=None, media=None,
albumdisambig=None, artist_credit=None, original_year=None,
original_month=None, original_day=None, data_source=None,
data_url=None):
def __init__(self, tracks, album=None, album_id=None, artist=None,
artist_id=None, asin=None, albumtype=None, va=False,
year=None, month=None, day=None, label=None, mediums=None,
artist_sort=None, releasegroup_id=None, catalognum=None,
script=None, language=None, country=None, style=None,
genre=None, albumstatus=None, media=None, albumdisambig=None,
releasegroupdisambig=None, artist_credit=None,
original_year=None, original_month=None,
original_day=None, data_source=None, data_url=None,
discogs_albumid=None, discogs_labelid=None,
discogs_artistid=None, **kwargs):
self.album = album
self.album_id = album_id
self.artist = artist
@ -94,15 +98,22 @@ class AlbumInfo(object):
self.script = script
self.language = language
self.country = country
self.style = style
self.genre = genre
self.albumstatus = albumstatus
self.media = media
self.albumdisambig = albumdisambig
self.releasegroupdisambig = releasegroupdisambig
self.artist_credit = artist_credit
self.original_year = original_year
self.original_month = original_month
self.original_day = original_day
self.data_source = data_source
self.data_url = data_url
self.discogs_albumid = discogs_albumid
self.discogs_labelid = discogs_labelid
self.discogs_artistid = discogs_artistid
self.update(kwargs)
# Work around a bug in python-musicbrainz-ngs that causes some
# strings to be bytes rather than Unicode.
@ -112,54 +123,46 @@ class AlbumInfo(object):
constituent `TrackInfo` objects, are decoded to Unicode.
"""
for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
'catalognum', 'script', 'language', 'country',
'albumstatus', 'albumdisambig', 'artist_credit', 'media']:
'catalognum', 'script', 'language', 'country', 'style',
'genre', 'albumstatus', 'albumdisambig',
'releasegroupdisambig', 'artist_credit',
'media', 'discogs_albumid', 'discogs_labelid',
'discogs_artistid']:
value = getattr(self, fld)
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, 'ignore'))
if self.tracks:
for track in self.tracks:
track.decode(codec)
for track in self.tracks:
track.decode(codec)
def copy(self):
dupe = AlbumInfo([])
dupe.update(self)
dupe.tracks = [track.copy() for track in self.tracks]
return dupe
class TrackInfo(object):
class TrackInfo(AttrDict):
"""Describes a canonical track present on a release. Appears as part
of an AlbumInfo's ``tracks`` list. Consists of these data members:
- ``title``: name of the track
- ``track_id``: MusicBrainz ID; UUID fragment only
- ``release_track_id``: MusicBrainz ID respective to a track on a
particular release; UUID fragment only
- ``artist``: individual track artist name
- ``artist_id``
- ``length``: float: duration of the track in seconds
- ``index``: position on the entire release
- ``media``: delivery mechanism (Vinyl, etc.)
- ``medium``: the disc number this track appears on in the album
- ``medium_index``: the track's position on the disc
- ``medium_total``: the number of tracks on the item's disc
- ``artist_sort``: name of the track artist for sorting
- ``disctitle``: name of the individual medium (subtitle)
- ``artist_credit``: Recording-specific artist name
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
- ``data_url``: The data source release URL.
- ``lyricist``: individual track lyricist name
- ``composer``: individual track composer name
- ``composer_sort``: individual track composer sort name
- ``arranger`: individual track arranger name
- ``track_alt``: alternative track number (tape, vinyl, etc.)
Only ``title`` and ``track_id`` are required. The rest of the fields
may be None. The indices ``index``, ``medium``, and ``medium_index``
are all 1-based.
"""
def __init__(self, title, track_id, release_track_id=None, artist=None,
artist_id=None, length=None, index=None, medium=None,
medium_index=None, medium_total=None, artist_sort=None,
disctitle=None, artist_credit=None, data_source=None,
data_url=None, media=None, lyricist=None, composer=None,
composer_sort=None, arranger=None, track_alt=None):
def __init__(self, title=None, track_id=None, release_track_id=None,
artist=None, artist_id=None, length=None, index=None,
medium=None, medium_index=None, medium_total=None,
artist_sort=None, disctitle=None, artist_credit=None,
data_source=None, data_url=None, media=None, lyricist=None,
composer=None, composer_sort=None, arranger=None,
track_alt=None, work=None, mb_workid=None,
work_disambig=None, bpm=None, initial_key=None, genre=None,
**kwargs):
self.title = title
self.track_id = track_id
self.release_track_id = release_track_id
@ -181,6 +184,13 @@ class TrackInfo(object):
self.composer_sort = composer_sort
self.arranger = arranger
self.track_alt = track_alt
self.work = work
self.mb_workid = mb_workid
self.work_disambig = work_disambig
self.bpm = bpm
self.initial_key = initial_key
self.genre = genre
self.update(kwargs)
# As above, work around a bug in python-musicbrainz-ngs.
def decode(self, codec='utf-8'):
@ -193,6 +203,11 @@ class TrackInfo(object):
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, 'ignore'))
def copy(self):
dupe = TrackInfo()
dupe.update(self)
return dupe
# Candidate distance scoring.
@ -220,8 +235,8 @@ def _string_dist_basic(str1, str2):
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, six.text_type)
assert isinstance(str2, six.text_type)
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
@ -249,9 +264,9 @@ def string_dist(str1, str2):
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(', %s' % word):
str1 = '%s %s' % (word, str1[:-len(word) - 2])
str1 = '{} {}'.format(word, str1[:-len(word) - 2])
if str2.endswith(', %s' % word):
str2 = '%s %s' % (word, str2[:-len(word) - 2])
str2 = '{} {}'.format(word, str2[:-len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
@ -289,11 +304,12 @@ def string_dist(str1, str2):
return base_dist + penalty
class LazyClassProperty(object):
class LazyClassProperty:
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.computed = False
@ -306,17 +322,17 @@ class LazyClassProperty(object):
@total_ordering
@six.python_2_unicode_compatible
class Distance(object):
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self):
self._penalties = {}
@LazyClassProperty
def _weights(cls): # noqa
def _weights(cls): # noqa: N805
"""A dictionary from keys to floating-point weights.
"""
weights_view = config['match']['distance_weights']
@ -394,7 +410,7 @@ class Distance(object):
return other - self.distance
def __str__(self):
return "{0:.2f}".format(self.distance)
return f"{self.distance:.2f}"
# Behave like a dict.
@ -421,7 +437,7 @@ class Distance(object):
"""
if not isinstance(dist, Distance):
raise ValueError(
u'`dist` must be a Distance object, not {0}'.format(type(dist))
'`dist` must be a Distance object, not {}'.format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
@ -433,7 +449,7 @@ class Distance(object):
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re._pattern_type):
if isinstance(value1, Pattern):
return bool(value1.match(value2))
return value1 == value2
@ -445,7 +461,7 @@ class Distance(object):
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(
u'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
f'`dist` must be between 0.0 and 1.0, not {dist}'
)
self._penalties.setdefault(key, []).append(dist)
@ -541,7 +557,7 @@ def album_for_mbid(release_id):
try:
album = mb.album_for_id(release_id)
if album:
plugins.send(u'albuminfo_received', info=album)
plugins.send('albuminfo_received', info=album)
return album
except mb.MusicBrainzAPIError as exc:
exc.log(log)
@ -554,7 +570,7 @@ def track_for_mbid(recording_id):
try:
track = mb.track_for_id(recording_id)
if track:
plugins.send(u'trackinfo_received', info=track)
plugins.send('trackinfo_received', info=track)
return track
except mb.MusicBrainzAPIError as exc:
exc.log(log)
@ -567,7 +583,7 @@ def albums_for_id(album_id):
yield a
for a in plugins.album_for_id(album_id):
if a:
plugins.send(u'albuminfo_received', info=a)
plugins.send('albuminfo_received', info=a)
yield a
@ -578,40 +594,43 @@ def tracks_for_id(track_id):
yield t
for t in plugins.track_for_id(track_id):
if t:
plugins.send(u'trackinfo_received', info=t)
plugins.send('trackinfo_received', info=t)
yield t
@plugins.notify_info_yielded(u'albuminfo_received')
def album_candidates(items, artist, album, va_likely):
@plugins.notify_info_yielded('albuminfo_received')
def album_candidates(items, artist, album, va_likely, extra_tags):
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
names (strings), which may be derived from the item list or may be
entered by the user. ``va_likely`` is a boolean indicating whether
the album is likely to be a "various artists" release.
the album is likely to be a "various artists" release. ``extra_tags``
is an optional dictionary of additional tags used to further
constrain the search.
"""
# Base candidates if we have album and artist to match.
if artist and album:
try:
for candidate in mb.match_album(artist, album, len(items)):
yield candidate
yield from mb.match_album(artist, album, len(items),
extra_tags)
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Also add VA matches from MusicBrainz where appropriate.
if va_likely and album:
try:
for candidate in mb.match_album(None, album, len(items)):
yield candidate
yield from mb.match_album(None, album, len(items),
extra_tags)
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Candidates from plugins.
for candidate in plugins.candidates(items, artist, album, va_likely):
yield candidate
yield from plugins.candidates(items, artist, album, va_likely,
extra_tags)
@plugins.notify_info_yielded(u'trackinfo_received')
@plugins.notify_info_yielded('trackinfo_received')
def item_candidates(item, artist, title):
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or
@ -621,11 +640,9 @@ def item_candidates(item, artist, title):
# MusicBrainz candidates.
if artist and title:
try:
for candidate in mb.match_track(artist, title):
yield candidate
yield from mb.match_track(artist, title)
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Plugin candidates.
for candidate in plugins.item_candidates(item, artist, title):
yield candidate
yield from plugins.item_candidates(item, artist, title)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -17,7 +16,6 @@
releases and tracks.
"""
from __future__ import division, absolute_import, print_function
import datetime
import re
@ -35,7 +33,7 @@ from beets.util.enumeration import OrderedEnum
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown')
VA_ARTISTS = ('', 'various artists', 'various', 'va', 'unknown')
# Global logger.
log = logging.getLogger('beets')
@ -108,7 +106,7 @@ def assign_items(items, tracks):
log.debug('...done.')
# Produce the output matching.
mapping = dict((items[i], tracks[j]) for (i, j) in matching)
mapping = {items[i]: tracks[j] for (i, j) in matching}
extra_items = list(set(items) - set(mapping.keys()))
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
@ -276,16 +274,16 @@ def match_by_id(items):
try:
first = next(albumids)
except StopIteration:
log.debug(u'No album ID found.')
log.debug('No album ID found.')
return None
# Is there a consensus on the MB album ID?
for other in albumids:
if other != first:
log.debug(u'No album ID consensus.')
log.debug('No album ID consensus.')
return None
# If all album IDs are equal, look up the album.
log.debug(u'Searching for discovered album ID: {0}', first)
log.debug('Searching for discovered album ID: {0}', first)
return hooks.album_for_mbid(first)
@ -351,23 +349,23 @@ def _add_candidate(items, results, info):
checking the track count, ordering the items, checking for
duplicates, and calculating the distance.
"""
log.debug(u'Candidate: {0} - {1} ({2})',
log.debug('Candidate: {0} - {1} ({2})',
info.artist, info.album, info.album_id)
# Discard albums with zero tracks.
if not info.tracks:
log.debug(u'No tracks.')
log.debug('No tracks.')
return
# Don't duplicate.
if info.album_id in results:
log.debug(u'Duplicate.')
log.debug('Duplicate.')
return
# Discard matches without required tags.
for req_tag in config['match']['required'].as_str_seq():
if getattr(info, req_tag) is None:
log.debug(u'Ignored. Missing required tag: {0}', req_tag)
log.debug('Ignored. Missing required tag: {0}', req_tag)
return
# Find mapping between the items and the track info.
@ -380,10 +378,10 @@ def _add_candidate(items, results, info):
penalties = [key for key, _ in dist]
for penalty in config['match']['ignored'].as_str_seq():
if penalty in penalties:
log.debug(u'Ignored. Penalty: {0}', penalty)
log.debug('Ignored. Penalty: {0}', penalty)
return
log.debug(u'Success. Distance: {0}', dist)
log.debug('Success. Distance: {0}', dist)
results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
extra_items, extra_tracks)
@ -411,7 +409,7 @@ def tag_album(items, search_artist=None, search_album=None,
likelies, consensus = current_metadata(items)
cur_artist = likelies['artist']
cur_album = likelies['album']
log.debug(u'Tagging {0} - {1}', cur_artist, cur_album)
log.debug('Tagging {0} - {1}', cur_artist, cur_album)
# The output result (distance, AlbumInfo) tuples (keyed by MB album
# ID).
@ -420,7 +418,7 @@ def tag_album(items, search_artist=None, search_album=None,
# Search by explicit ID.
if search_ids:
for search_id in search_ids:
log.debug(u'Searching for album ID: {0}', search_id)
log.debug('Searching for album ID: {0}', search_id)
for id_candidate in hooks.albums_for_id(search_id):
_add_candidate(items, candidates, id_candidate)
@ -431,13 +429,13 @@ def tag_album(items, search_artist=None, search_album=None,
if id_info:
_add_candidate(items, candidates, id_info)
rec = _recommendation(list(candidates.values()))
log.debug(u'Album ID match recommendation is {0}', rec)
log.debug('Album ID match recommendation is {0}', rec)
if candidates and not config['import']['timid']:
# If we have a very good MBID match, return immediately.
# Otherwise, this match will compete against metadata-based
# matches.
if rec == Recommendation.strong:
log.debug(u'ID match.')
log.debug('ID match.')
return cur_artist, cur_album, \
Proposal(list(candidates.values()), rec)
@ -445,22 +443,29 @@ def tag_album(items, search_artist=None, search_album=None,
if not (search_artist and search_album):
# No explicit search terms -- use current metadata.
search_artist, search_album = cur_artist, cur_album
log.debug(u'Search terms: {0} - {1}', search_artist, search_album)
log.debug('Search terms: {0} - {1}', search_artist, search_album)
extra_tags = None
if config['musicbrainz']['extra_tags']:
tag_list = config['musicbrainz']['extra_tags'].get()
extra_tags = {k: v for (k, v) in likelies.items() if k in tag_list}
log.debug('Additional search terms: {0}', extra_tags)
# Is this album likely to be a "various artist" release?
va_likely = ((not consensus['artist']) or
(search_artist.lower() in VA_ARTISTS) or
any(item.comp for item in items))
log.debug(u'Album might be VA: {0}', va_likely)
log.debug('Album might be VA: {0}', va_likely)
# Get the results from the data sources.
for matched_candidate in hooks.album_candidates(items,
search_artist,
search_album,
va_likely):
va_likely,
extra_tags):
_add_candidate(items, candidates, matched_candidate)
log.debug(u'Evaluating {0} candidates.', len(candidates))
log.debug('Evaluating {0} candidates.', len(candidates))
# Sort and get the recommendation.
candidates = _sort_candidates(candidates.values())
rec = _recommendation(candidates)
@ -485,7 +490,7 @@ def tag_item(item, search_artist=None, search_title=None,
trackids = search_ids or [t for t in [item.mb_trackid] if t]
if trackids:
for trackid in trackids:
log.debug(u'Searching for track ID: {0}', trackid)
log.debug('Searching for track ID: {0}', trackid)
for track_info in hooks.tracks_for_id(trackid):
dist = track_distance(item, track_info, incl_artist=True)
candidates[track_info.track_id] = \
@ -494,7 +499,7 @@ def tag_item(item, search_artist=None, search_title=None,
rec = _recommendation(_sort_candidates(candidates.values()))
if rec == Recommendation.strong and \
not config['import']['timid']:
log.debug(u'Track ID match.')
log.debug('Track ID match.')
return Proposal(_sort_candidates(candidates.values()), rec)
# If we're searching by ID, don't proceed.
@ -507,7 +512,7 @@ def tag_item(item, search_artist=None, search_title=None,
# Search terms.
if not (search_artist and search_title):
search_artist, search_title = item.artist, item.title
log.debug(u'Item search terms: {0} - {1}', search_artist, search_title)
log.debug('Item search terms: {0} - {1}', search_artist, search_title)
# Get and evaluate candidate metadata.
for track_info in hooks.item_candidates(item, search_artist, search_title):
@ -515,7 +520,7 @@ def tag_item(item, search_artist=None, search_title=None,
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
# Sort by distance and return with recommendation.
log.debug(u'Found {0} candidates.', len(candidates))
log.debug('Found {0} candidates.', len(candidates))
candidates = _sort_candidates(candidates.values())
rec = _recommendation(candidates)
return Proposal(candidates, rec)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,57 +14,72 @@
"""Searches for albums in the MusicBrainz database.
"""
from __future__ import division, absolute_import, print_function
import musicbrainzngs
import re
import traceback
from six.moves.urllib.parse import urljoin
from beets import logging
from beets import plugins
import beets.autotag.hooks
import beets
from beets import util
from beets import config
import six
from collections import Counter
from urllib.parse import urljoin
VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377'
if util.SNI_SUPPORTED:
BASE_URL = 'https://musicbrainz.org/'
else:
BASE_URL = 'http://musicbrainz.org/'
BASE_URL = 'https://musicbrainz.org/'
SKIPPED_TRACKS = ['[data track]']
FIELDS_TO_MB_KEYS = {
'catalognum': 'catno',
'country': 'country',
'label': 'label',
'media': 'format',
'year': 'date',
}
musicbrainzngs.set_useragent('beets', beets.__version__,
'http://beets.io/')
'https://beets.io/')
class MusicBrainzAPIError(util.HumanReadableException):
"""An error while talking to MusicBrainz. The `query` field is the
parameter to the action and may have any type.
"""
def __init__(self, reason, verb, query, tb=None):
self.query = query
if isinstance(reason, musicbrainzngs.WebServiceError):
reason = u'MusicBrainz not reachable'
super(MusicBrainzAPIError, self).__init__(reason, verb, tb)
reason = 'MusicBrainz not reachable'
super().__init__(reason, verb, tb)
def get_message(self):
return u'{0} in {1} with query {2}'.format(
return '{} in {} with query {}'.format(
self._reasonstr(), self.verb, repr(self.query)
)
log = logging.getLogger('beets')
RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
'labels', 'artist-credits', 'aliases',
'recording-level-rels', 'work-rels',
'work-level-rels', 'artist-rels']
TRACK_INCLUDES = ['artists', 'aliases']
'work-level-rels', 'artist-rels', 'isrcs']
BROWSE_INCLUDES = ['artist-credits', 'work-rels',
'artist-rels', 'recording-rels', 'release-rels']
if "work-level-rels" in musicbrainzngs.VALID_BROWSE_INCLUDES['recording']:
BROWSE_INCLUDES.append("work-level-rels")
BROWSE_CHUNKSIZE = 100
BROWSE_MAXTRACKS = 500
TRACK_INCLUDES = ['artists', 'aliases', 'isrcs']
if 'work-level-rels' in musicbrainzngs.VALID_INCLUDES['recording']:
TRACK_INCLUDES += ['work-level-rels', 'artist-rels']
if 'genres' in musicbrainzngs.VALID_INCLUDES['recording']:
RELEASE_INCLUDES += ['genres']
def track_url(trackid):
@ -81,7 +95,11 @@ def configure():
from the beets configuration. This should be called at startup.
"""
hostname = config['musicbrainz']['host'].as_str()
musicbrainzngs.set_hostname(hostname)
https = config['musicbrainz']['https'].get(bool)
# Only call set_hostname when a custom server is configured. Since
# musicbrainz-ngs connects to musicbrainz.org with HTTPS by default
if hostname != "musicbrainz.org":
musicbrainzngs.set_hostname(hostname, https)
musicbrainzngs.set_rate_limit(
config['musicbrainz']['ratelimit_interval'].as_number(),
config['musicbrainz']['ratelimit'].get(int),
@ -138,7 +156,7 @@ def _flatten_artist_credit(credit):
artist_sort_parts = []
artist_credit_parts = []
for el in credit:
if isinstance(el, six.string_types):
if isinstance(el, str):
# Join phrase.
artist_parts.append(el)
artist_credit_parts.append(el)
@ -185,13 +203,13 @@ def track_info(recording, index=None, medium=None, medium_index=None,
the number of tracks on the medium. Each number is a 1-based index.
"""
info = beets.autotag.hooks.TrackInfo(
recording['title'],
recording['id'],
title=recording['title'],
track_id=recording['id'],
index=index,
medium=medium,
medium_index=medium_index,
medium_total=medium_total,
data_source=u'MusicBrainz',
data_source='MusicBrainz',
data_url=track_url(recording['id']),
)
@ -207,12 +225,22 @@ def track_info(recording, index=None, medium=None, medium_index=None,
if recording.get('length'):
info.length = int(recording['length']) / (1000.0)
info.trackdisambig = recording.get('disambiguation')
if recording.get('isrc-list'):
info.isrc = ';'.join(recording['isrc-list'])
lyricist = []
composer = []
composer_sort = []
for work_relation in recording.get('work-relation-list', ()):
if work_relation['type'] != 'performance':
continue
info.work = work_relation['work']['title']
info.mb_workid = work_relation['work']['id']
if 'disambiguation' in work_relation['work']:
info.work_disambig = work_relation['work']['disambiguation']
for artist_relation in work_relation['work'].get(
'artist-relation-list', ()):
if 'type' in artist_relation:
@ -224,10 +252,10 @@ def track_info(recording, index=None, medium=None, medium_index=None,
composer_sort.append(
artist_relation['artist']['sort-name'])
if lyricist:
info.lyricist = u', '.join(lyricist)
info.lyricist = ', '.join(lyricist)
if composer:
info.composer = u', '.join(composer)
info.composer_sort = u', '.join(composer_sort)
info.composer = ', '.join(composer)
info.composer_sort = ', '.join(composer_sort)
arranger = []
for artist_relation in recording.get('artist-relation-list', ()):
@ -236,7 +264,12 @@ def track_info(recording, index=None, medium=None, medium_index=None,
if type == 'arranger':
arranger.append(artist_relation['artist']['name'])
if arranger:
info.arranger = u', '.join(arranger)
info.arranger = ', '.join(arranger)
# Supplementary fields provided by plugins
extra_trackdatas = plugins.send('mb_track_extract', data=recording)
for extra_trackdata in extra_trackdatas:
info.update(extra_trackdata)
info.decode()
return info
@ -270,6 +303,26 @@ def album_info(release):
artist_name, artist_sort_name, artist_credit_name = \
_flatten_artist_credit(release['artist-credit'])
ntracks = sum(len(m['track-list']) for m in release['medium-list'])
# The MusicBrainz API omits 'artist-relation-list' and 'work-relation-list'
# when the release has more than 500 tracks. So we use browse_recordings
# on chunks of tracks to recover the same information in this case.
if ntracks > BROWSE_MAXTRACKS:
log.debug('Album {} has too many tracks', release['id'])
recording_list = []
for i in range(0, ntracks, BROWSE_CHUNKSIZE):
log.debug('Retrieving tracks starting at {}', i)
recording_list.extend(musicbrainzngs.browse_recordings(
release=release['id'], limit=BROWSE_CHUNKSIZE,
includes=BROWSE_INCLUDES,
offset=i)['recording-list'])
track_map = {r['id']: r for r in recording_list}
for medium in release['medium-list']:
for recording in medium['track-list']:
recording_info = track_map[recording['recording']['id']]
recording['recording'] = recording_info
# Basic info.
track_infos = []
index = 0
@ -281,7 +334,8 @@ def album_info(release):
continue
all_tracks = medium['track-list']
if 'data-track-list' in medium:
if ('data-track-list' in medium
and not config['match']['ignore_data_tracks']):
all_tracks += medium['data-track-list']
track_count = len(all_tracks)
@ -327,15 +381,15 @@ def album_info(release):
track_infos.append(ti)
info = beets.autotag.hooks.AlbumInfo(
release['title'],
release['id'],
artist_name,
release['artist-credit'][0]['artist']['id'],
track_infos,
album=release['title'],
album_id=release['id'],
artist=artist_name,
artist_id=release['artist-credit'][0]['artist']['id'],
tracks=track_infos,
mediums=len(release['medium-list']),
artist_sort=artist_sort_name,
artist_credit=artist_credit_name,
data_source=u'MusicBrainz',
data_source='MusicBrainz',
data_url=album_url(release['id']),
)
info.va = info.artist_id == VARIOUS_ARTISTS_ID
@ -345,13 +399,12 @@ def album_info(release):
info.releasegroup_id = release['release-group']['id']
info.albumstatus = release.get('status')
# Build up the disambiguation string from the release group and release.
disambig = []
# Get the disambiguation strings at the release and release group level.
if release['release-group'].get('disambiguation'):
disambig.append(release['release-group'].get('disambiguation'))
info.releasegroupdisambig = \
release['release-group'].get('disambiguation')
if release.get('disambiguation'):
disambig.append(release.get('disambiguation'))
info.albumdisambig = u', '.join(disambig)
info.albumdisambig = release.get('disambiguation')
# Get the "classic" Release type. This data comes from a legacy API
# feature before MusicBrainz supported multiple release types.
@ -360,18 +413,17 @@ def album_info(release):
if reltype:
info.albumtype = reltype.lower()
# Log the new-style "primary" and "secondary" release types.
# Eventually, we'd like to actually store this data, but we just log
# it for now to help understand the differences.
# Set the new-style "primary" and "secondary" release types.
albumtypes = []
if 'primary-type' in release['release-group']:
rel_primarytype = release['release-group']['primary-type']
if rel_primarytype:
log.debug('primary MB release type: ' + rel_primarytype.lower())
albumtypes.append(rel_primarytype.lower())
if 'secondary-type-list' in release['release-group']:
if release['release-group']['secondary-type-list']:
log.debug('secondary MB release type(s): ' + ', '.join(
[secondarytype.lower() for secondarytype in
release['release-group']['secondary-type-list']]))
for sec_type in release['release-group']['secondary-type-list']:
albumtypes.append(sec_type.lower())
info.albumtypes = '; '.join(albumtypes)
# Release events.
info.country, release_date = _preferred_release_event(release)
@ -402,17 +454,33 @@ def album_info(release):
first_medium = release['medium-list'][0]
info.media = first_medium.get('format')
if config['musicbrainz']['genres']:
sources = [
release['release-group'].get('genre-list', []),
release.get('genre-list', []),
]
genres = Counter()
for source in sources:
for genreitem in source:
genres[genreitem['name']] += int(genreitem['count'])
info.genre = '; '.join(g[0] for g in sorted(genres.items(),
key=lambda g: -g[1]))
extra_albumdatas = plugins.send('mb_album_extract', data=release)
for extra_albumdata in extra_albumdatas:
info.update(extra_albumdata)
info.decode()
return info
def match_album(artist, album, tracks=None):
def match_album(artist, album, tracks=None, extra_tags=None):
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
MusicBrainzAPIError.
The query consists of an artist name, an album name, and,
optionally, a number of tracks on the album.
optionally, a number of tracks on the album and any other extra tags.
"""
# Build search criteria.
criteria = {'release': album.lower().strip()}
@ -422,14 +490,24 @@ def match_album(artist, album, tracks=None):
# Various Artists search.
criteria['arid'] = VARIOUS_ARTISTS_ID
if tracks is not None:
criteria['tracks'] = six.text_type(tracks)
criteria['tracks'] = str(tracks)
# Additional search cues from existing metadata.
if extra_tags:
for tag in extra_tags:
key = FIELDS_TO_MB_KEYS[tag]
value = str(extra_tags.get(tag, '')).lower().strip()
if key == 'catno':
value = value.replace(' ', '')
if value:
criteria[key] = value
# Abort if we have no search terms.
if not any(criteria.values()):
return
try:
log.debug(u'Searching for MusicBrainz releases with: {!r}', criteria)
log.debug('Searching for MusicBrainz releases with: {!r}', criteria)
res = musicbrainzngs.search_releases(
limit=config['musicbrainz']['searchlimit'].get(int), **criteria)
except musicbrainzngs.MusicBrainzError as exc:
@ -470,7 +548,7 @@ def _parse_id(s):
no ID can be found, return None.
"""
# Find the first thing that looks like a UUID/MBID.
match = re.search(u'[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
if match:
return match.group()
@ -480,19 +558,19 @@ def album_for_id(releaseid):
object or None if the album is not found. May raise a
MusicBrainzAPIError.
"""
log.debug(u'Requesting MusicBrainz release {}', releaseid)
log.debug('Requesting MusicBrainz release {}', releaseid)
albumid = _parse_id(releaseid)
if not albumid:
log.debug(u'Invalid MBID ({0}).', releaseid)
log.debug('Invalid MBID ({0}).', releaseid)
return
try:
res = musicbrainzngs.get_release_by_id(albumid,
RELEASE_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug(u'Album ID match failed.')
log.debug('Album ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, u'get release by ID', albumid,
raise MusicBrainzAPIError(exc, 'get release by ID', albumid,
traceback.format_exc())
return album_info(res['release'])
@ -503,14 +581,14 @@ def track_for_id(releaseid):
"""
trackid = _parse_id(releaseid)
if not trackid:
log.debug(u'Invalid MBID ({0}).', releaseid)
log.debug('Invalid MBID ({0}).', releaseid)
return
try:
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug(u'Track ID match failed.')
log.debug('Track ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, u'get recording by ID', trackid,
raise MusicBrainzAPIError(exc, 'get recording by ID', trackid,
traceback.format_exc())
return track_info(res['recording'])

View file

@ -7,6 +7,7 @@ import:
move: no
link: no
hardlink: no
reflink: no
delete: no
resume: ask
incremental: no
@ -44,10 +45,20 @@ replace:
'^\s+': ''
'^-': _
path_sep_replace: _
drive_sep_replace: _
asciify_paths: false
art_filename: cover
max_filename_length: 0
aunique:
keys: albumartist album
disambiguators: albumtype year label catalognum albumdisambig releasegroupdisambig
bracket: '[]'
overwrite_null:
album: []
track: []
plugins: []
pluginpath: []
threaded: yes
@ -91,9 +102,12 @@ statefile: state.pickle
musicbrainz:
host: musicbrainz.org
https: no
ratelimit: 1
ratelimit_interval: 1.0
searchlimit: 5
extra_tags: []
genres: no
match:
strong_rec_thresh: 0.04
@ -129,6 +143,7 @@ match:
ignored: []
required: []
ignored_media: []
ignore_data_tracks: yes
ignore_video_tracks: yes
track_length_grace: 10
track_length_max: 30

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -16,7 +15,6 @@
"""DBCore is an abstract database package that forms the basis for beets'
Library.
"""
from __future__ import division, absolute_import, print_function
from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,22 +14,21 @@
"""The central Model and Database constructs for DBCore.
"""
from __future__ import division, absolute_import, print_function
import time
import os
import re
from collections import defaultdict
import threading
import sqlite3
import contextlib
import collections
import beets
from beets.util.functemplate import Template
from beets.util import functemplate
from beets.util import py3_path
from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery
import six
from collections.abc import Mapping
class DBAccessError(Exception):
@ -42,20 +40,30 @@ class DBAccessError(Exception):
"""
class FormattedMapping(collections.Mapping):
class FormattedMapping(Mapping):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formatted version of
`model[key]` as a unicode string.
The `included_keys` parameter allows filtering the fields that are
returned. By default all fields are returned. Limiting to specific keys can
avoid expensive per-item database queries.
If `for_path` is true, all path separators in the formatted values
are replaced.
"""
def __init__(self, model, for_path=False):
ALL_KEYS = '*'
def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
if included_keys == self.ALL_KEYS:
# Performance note: this triggers a database query.
self.model_keys = self.model.keys(True)
else:
self.model_keys = included_keys
def __getitem__(self, key):
if key in self.model_keys:
@ -72,7 +80,7 @@ class FormattedMapping(collections.Mapping):
def get(self, key, default=None):
if default is None:
default = self.model._type(key).format(None)
return super(FormattedMapping, self).get(key, default)
return super().get(key, default)
def _get_formatted(self, model, key):
value = model._type(key).format(model.get(key))
@ -81,6 +89,11 @@ class FormattedMapping(collections.Mapping):
if self.for_path:
sep_repl = beets.config['path_sep_replace'].as_str()
sep_drive = beets.config['drive_sep_replace'].as_str()
if re.match(r'^\w:', value):
value = re.sub(r'(?<=^\w):', sep_drive, value)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
@ -88,11 +101,105 @@ class FormattedMapping(collections.Mapping):
return value
class LazyConvertDict:
"""Lazily convert types for attributes fetched from the database
"""
def __init__(self, model_cls):
"""Initialize the object empty
"""
self.data = {}
self.model_cls = model_cls
self._converted = {}
def init(self, data):
"""Set the base data that should be lazily converted
"""
self.data = data
def _convert(self, key, value):
"""Convert the attribute type according the the SQL type
"""
return self.model_cls._type(key).from_sql(value)
def __setitem__(self, key, value):
"""Set an attribute value, assume it's already converted
"""
self._converted[key] = value
def __getitem__(self, key):
"""Get an attribute value, converting the type on demand
if needed
"""
if key in self._converted:
return self._converted[key]
elif key in self.data:
value = self._convert(key, self.data[key])
self._converted[key] = value
return value
def __delitem__(self, key):
"""Delete both converted and base data
"""
if key in self._converted:
del self._converted[key]
if key in self.data:
del self.data[key]
def keys(self):
"""Get a list of available field names for this object.
"""
return list(self._converted.keys()) + list(self.data.keys())
def copy(self):
"""Create a copy of the object.
"""
new = self.__class__(self.model_cls)
new.data = self.data.copy()
new._converted = self._converted.copy()
return new
# Act like a dictionary.
def update(self, values):
"""Assign all values in the given dict.
"""
for key, value in values.items():
self[key] = value
def items(self):
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default
def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys()
def __iter__(self):
"""Iterate over the available field names (excluding computed
fields).
"""
return iter(self.keys())
# Abstract base for model classes.
class Model(object):
class Model:
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., the allow subscript access like
objects act like dictionaries (i.e., they allow subscript access like
``obj['field']``). The same field set is available via attribute
access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
available:
@ -143,12 +250,22 @@ class Model(object):
are subclasses of `Sort`.
"""
_queries = {}
"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
_always_dirty = False
"""By default, fields only become "dirty" when their value actually
changes. Enabling this flag marks fields as dirty even when the new
value is the same as the old value (e.g., `o.f = o.f`).
"""
_revision = -1
"""A revision number from when the model was loaded from or written
to the database.
"""
@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
@ -172,8 +289,8 @@ class Model(object):
"""
self._db = db
self._dirty = set()
self._values_fixed = {}
self._values_flex = {}
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
# Initial contents.
self.update(values)
@ -187,23 +304,25 @@ class Model(object):
ordinary construction are bypassed.
"""
obj = cls(db)
for key, value in fixed_values.items():
obj._values_fixed[key] = cls._type(key).from_sql(value)
for key, value in flex_values.items():
obj._values_flex[key] = cls._type(key).from_sql(value)
obj._values_fixed.init(fixed_values)
obj._values_flex.init(flex_values)
return obj
def __repr__(self):
return '{0}({1})'.format(
return '{}({})'.format(
type(self).__name__,
', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
)
def clear_dirty(self):
"""Mark all fields as *clean* (i.e., not needing to be stored to
the database).
the database). Also update the revision.
"""
self._dirty = set()
if self._db:
self._revision = self._db.revision
def _check_db(self, need_id=True):
"""Ensure that this object is associated with a database row: it
@ -212,10 +331,10 @@ class Model(object):
"""
if not self._db:
raise ValueError(
u'{0} has no database'.format(type(self).__name__)
'{} has no database'.format(type(self).__name__)
)
if need_id and not self.id:
raise ValueError(u'{0} has no id'.format(type(self).__name__))
raise ValueError('{} has no id'.format(type(self).__name__))
def copy(self):
"""Create a copy of the model object.
@ -243,19 +362,32 @@ class Model(object):
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
def _get(self, key, default=None, raise_=False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""
getters = self._getters()
if key in getters: # Computed.
return getters[key](self)
elif key in self._fields: # Fixed.
return self._values_fixed.get(key, self._type(key).null)
if key in self._values_fixed:
return self._values_fixed[key]
else:
return self._type(key).null
elif key in self._values_flex: # Flexible.
return self._values_flex[key]
else:
elif raise_:
raise KeyError(key)
else:
return default
get = _get
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
"""
return self._get(key, raise_=True)
def _setitem(self, key, value):
"""Assign the value for a field, return whether new and old value
@ -290,12 +422,12 @@ class Model(object):
if key in self._values_flex: # Flexible.
del self._values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._fields: # Fixed
setattr(self, key, self._type(key).null)
elif key in self._getters(): # Computed.
raise KeyError(u'computed field {0} cannot be deleted'.format(key))
elif key in self._fields: # Fixed.
raise KeyError(u'fixed field {0} cannot be deleted'.format(key))
raise KeyError(f'computed field {key} cannot be deleted')
else:
raise KeyError(u'no such field {0}'.format(key))
raise KeyError(f'no such field {key}')
def keys(self, computed=False):
"""Get a list of available field names for this object. The
@ -330,19 +462,10 @@ class Model(object):
for key in self:
yield key, self[key]
def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default
def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(True)
return key in self.keys(computed=True)
def __iter__(self):
"""Iterate over the available field names (excluding computed
@ -354,22 +477,22 @@ class Model(object):
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(u'model has no attribute {0!r}'.format(key))
raise AttributeError(f'model has no attribute {key!r}')
else:
try:
return self[key]
except KeyError:
raise AttributeError(u'no such field {0!r}'.format(key))
raise AttributeError(f'no such field {key!r}')
def __setattr__(self, key, value):
if key.startswith('_'):
super(Model, self).__setattr__(key, value)
super().__setattr__(key, value)
else:
self[key] = value
def __delattr__(self, key):
if key.startswith('_'):
super(Model, self).__delattr__(key)
super().__delattr__(key)
else:
del self[key]
@ -398,7 +521,7 @@ class Model(object):
with self._db.transaction() as tx:
# Main table update.
if assignments:
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
query = 'UPDATE {} SET {} WHERE id=?'.format(
self._table, assignments
)
subvars.append(self.id)
@ -409,7 +532,7 @@ class Model(object):
if key in self._dirty:
self._dirty.remove(key)
tx.mutate(
'INSERT INTO {0} '
'INSERT INTO {} '
'(entity_id, key, value) '
'VALUES (?, ?, ?);'.format(self._flex_table),
(self.id, key, value),
@ -418,7 +541,7 @@ class Model(object):
# Deleted flexible attributes.
for key in self._dirty:
tx.mutate(
'DELETE FROM {0} '
'DELETE FROM {} '
'WHERE entity_id=? AND key=?'.format(self._flex_table),
(self.id, key)
)
@ -427,12 +550,18 @@ class Model(object):
def load(self):
"""Refresh the object's metadata from the library database.
If check_revision is true, the database is only queried loaded when a
transaction has been committed since the item was last loaded.
"""
self._check_db()
if not self._dirty and self._db.revision == self._revision:
# Exit early
return
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
self._values_fixed = {}
self._values_flex = {}
assert stored_obj is not None, f"object {self.id} not in DB"
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
self.update(dict(stored_obj))
self.clear_dirty()
@ -442,11 +571,11 @@ class Model(object):
self._check_db()
with self._db.transaction() as tx:
tx.mutate(
'DELETE FROM {0} WHERE id=?'.format(self._table),
f'DELETE FROM {self._table} WHERE id=?',
(self.id,)
)
tx.mutate(
'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
f'DELETE FROM {self._flex_table} WHERE entity_id=?',
(self.id,)
)
@ -464,7 +593,7 @@ class Model(object):
with self._db.transaction() as tx:
new_id = tx.mutate(
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
f'INSERT INTO {self._table} DEFAULT VALUES'
)
self.id = new_id
self.added = time.time()
@ -479,11 +608,11 @@ class Model(object):
_formatter = FormattedMapping
def formatted(self, for_path=False):
def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
return self._formatter(self, for_path)
return self._formatter(self, included_keys, for_path)
def evaluate_template(self, template, for_path=False):
"""Evaluate a template (a string or a `Template` object) using
@ -491,9 +620,9 @@ class Model(object):
separators will be added to the template.
"""
# Perform substitution.
if isinstance(template, six.string_types):
template = Template(template)
return template.substitute(self.formatted(for_path),
if isinstance(template, str):
template = functemplate.template(template)
return template.substitute(self.formatted(for_path=for_path),
self._template_funcs())
# Parsing.
@ -502,8 +631,8 @@ class Model(object):
def _parse(cls, key, string):
"""Parse a string as a value for the given key.
"""
if not isinstance(string, six.string_types):
raise TypeError(u"_parse() argument must be a string")
if not isinstance(string, str):
raise TypeError("_parse() argument must be a string")
return cls._type(key).parse(string)
@ -515,11 +644,13 @@ class Model(object):
# Database controller and supporting interfaces.
class Results(object):
class Results:
"""An item query result set. Iterating over the collection lazily
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, query=None, sort=None):
def __init__(self, model_class, rows, db, flex_rows,
query=None, sort=None):
"""Create a result set that will construct objects of type
`model_class`.
@ -539,6 +670,7 @@ class Results(object):
self.db = db
self.query = query
self.sort = sort
self.flex_rows = flex_rows
# We keep a queue of rows we haven't yet consumed for
# materialization. We preserve the original total number of
@ -560,6 +692,10 @@ class Results(object):
a `Results` object a second time should be much faster than the
first.
"""
# Index flexible attributes by the item ID, so we have easier access
flex_attrs = self._get_indexed_flex_attrs()
index = 0 # Position in the materialized objects.
while index < len(self._objects) or self._rows:
# Are there previously-materialized objects to produce?
@ -572,7 +708,7 @@ class Results(object):
else:
while self._rows:
row = self._rows.pop(0)
obj = self._make_model(row)
obj = self._make_model(row, flex_attrs.get(row['id'], {}))
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
@ -594,20 +730,24 @@ class Results(object):
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _make_model(self, row):
# Get the flexible attributes for the object.
with self.db.transaction() as tx:
flex_rows = tx.query(
'SELECT * FROM {0} WHERE entity_id=?'.format(
self.model_class._flex_table
),
(row['id'],)
)
def _get_indexed_flex_attrs(self):
""" Index flexible attributes by the entity id they belong to
"""
flex_values = {}
for row in self.flex_rows:
if row['entity_id'] not in flex_values:
flex_values[row['entity_id']] = {}
flex_values[row['entity_id']][row['key']] = row['value']
return flex_values
def _make_model(self, row, flex_values={}):
""" Create a Model object for the given row
"""
cols = dict(row)
values = dict((k, v) for (k, v) in cols.items()
if not k[:4] == 'flex')
flex_values = dict((row['key'], row['value']) for row in flex_rows)
values = {k: v for (k, v) in cols.items()
if not k[:4] == 'flex'}
# Construct the Python object
obj = self.model_class._awaken(self.db, values, flex_values)
@ -656,7 +796,7 @@ class Results(object):
next(it)
return next(it)
except StopIteration:
raise IndexError(u'result index {0} out of range'.format(n))
raise IndexError(f'result index {n} out of range')
def get(self):
"""Return the first matching object, or None if no objects
@ -669,10 +809,16 @@ class Results(object):
return None
class Transaction(object):
class Transaction:
"""A context manager for safe, concurrent access to the database.
All SQL commands should be executed through a transaction.
"""
_mutated = False
"""A flag storing whether a mutation has been executed in the
current transaction.
"""
def __init__(self, db):
self.db = db
@ -694,12 +840,15 @@ class Transaction(object):
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
"""
# Beware of races; currently secured by db._db_lock
self.db.revision += self._mutated
with self.db._tx_stack() as stack:
assert stack.pop() is self
empty = not stack
if empty:
# Ending a "root" transaction. End the SQLite transaction.
self.db._connection().commit()
self._mutated = False
self.db._db_lock.release()
def query(self, statement, subvals=()):
@ -715,7 +864,6 @@ class Transaction(object):
"""
try:
cursor = self.db._connection().execute(statement, subvals)
return cursor.lastrowid
except sqlite3.OperationalError as e:
# In two specific cases, SQLite reports an error while accessing
# the underlying database file. We surface these exceptions as
@ -725,26 +873,41 @@ class Transaction(object):
raise DBAccessError(e.args[0])
else:
raise
else:
self._mutated = True
return cursor.lastrowid
def script(self, statements):
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
self.db._connection().executescript(statements)
class Database(object):
class Database:
"""A container for Model objects that wraps an SQLite database as
the backend.
"""
_models = ()
"""The Model subclasses representing tables in this database.
"""
supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
"""Whether or not the current version of SQLite supports extensions"""
revision = 0
"""The current revision of the database. To be increased whenever
data is written in a transaction.
"""
def __init__(self, path, timeout=5.0):
self.path = path
self.timeout = timeout
self._connections = {}
self._tx_stacks = defaultdict(list)
self._extensions = []
# A lock to protect the _connections and _tx_stacks maps, which
# both map thread IDs to private resources.
@ -794,6 +957,13 @@ class Database(object):
py3_path(self.path), timeout=self.timeout
)
if self.supports_extensions:
conn.enable_load_extension(True)
# Load any extension that are already loaded for other connections.
for path in self._extensions:
conn.load_extension(path)
# Access SELECT results like dictionaries.
conn.row_factory = sqlite3.Row
return conn
@ -822,6 +992,18 @@ class Database(object):
"""
return Transaction(self)
def load_extension(self, path):
"""Load an SQLite extension into all open connections."""
if not self.supports_extensions:
raise ValueError(
'this sqlite3 installation does not support extensions')
self._extensions.append(path)
# Load the extension into every open connection.
for conn in self._connections.values():
conn.load_extension(path)
# Schema setup and migration.
def _make_table(self, table, fields):
@ -831,7 +1013,7 @@ class Database(object):
# Get current schema.
with self.transaction() as tx:
rows = tx.query('PRAGMA table_info(%s)' % table)
current_fields = set([row[1] for row in rows])
current_fields = {row[1] for row in rows}
field_names = set(fields.keys())
if current_fields.issuperset(field_names):
@ -842,9 +1024,9 @@ class Database(object):
# No table exists.
columns = []
for name, typ in fields.items():
columns.append('{0} {1}'.format(name, typ.sql))
setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
', '.join(columns))
columns.append(f'{name} {typ.sql}')
setup_sql = 'CREATE TABLE {} ({});\n'.format(table,
', '.join(columns))
else:
# Table exists does not match the field set.
@ -852,7 +1034,7 @@ class Database(object):
for name, typ in fields.items():
if name in current_fields:
continue
setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
setup_sql += 'ALTER TABLE {} ADD COLUMN {} {};\n'.format(
table, name, typ.sql
)
@ -888,17 +1070,31 @@ class Database(object):
where, subvals = query.clause()
order_by = sort.order_clause()
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
sql = ("SELECT * FROM {} WHERE {} {}").format(
model_cls._table,
where or '1',
"ORDER BY {0}".format(order_by) if order_by else '',
f"ORDER BY {order_by}" if order_by else '',
)
# Fetch flexible attributes for items matching the main query.
# Doing the per-item filtering in python is faster than issuing
# one query per item to sqlite.
flex_sql = ("""
SELECT * FROM {} WHERE entity_id IN
(SELECT id FROM {} WHERE {});
""".format(
model_cls._flex_table,
model_cls._table,
where or '1',
)
)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
flex_rows = tx.query(flex_sql, subvals)
return Results(
model_cls, rows, self,
model_cls, rows, self, flex_rows,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,7 +14,6 @@
"""The Query type hierarchy for DBCore.
"""
from __future__ import division, absolute_import, print_function
import re
from operator import mul
@ -23,10 +21,6 @@ from beets import util
from datetime import datetime, timedelta
import unicodedata
from functools import reduce
import six
if not six.PY2:
buffer = memoryview # sqlite won't accept memoryview in python 2
class ParsingError(ValueError):
@ -44,8 +38,8 @@ class InvalidQueryError(ParsingError):
def __init__(self, query, explanation):
if isinstance(query, list):
query = " ".join(query)
message = u"'{0}': {1}".format(query, explanation)
super(InvalidQueryError, self).__init__(message)
message = f"'{query}': {explanation}"
super().__init__(message)
class InvalidQueryArgumentValueError(ParsingError):
@ -56,13 +50,13 @@ class InvalidQueryArgumentValueError(ParsingError):
"""
def __init__(self, what, expected, detail=None):
message = u"'{0}' is not {1}".format(what, expected)
message = f"'{what}' is not {expected}"
if detail:
message = u"{0}: {1}".format(message, detail)
super(InvalidQueryArgumentValueError, self).__init__(message)
message = f"{message}: {detail}"
super().__init__(message)
class Query(object):
class Query:
"""An abstract class representing a query into the item database.
"""
@ -82,7 +76,7 @@ class Query(object):
raise NotImplementedError
def __repr__(self):
return "{0.__class__.__name__}()".format(self)
return f"{self.__class__.__name__}()"
def __eq__(self, other):
return type(self) == type(other)
@ -129,7 +123,7 @@ class FieldQuery(Query):
"{0.fast})".format(self))
def __eq__(self, other):
return super(FieldQuery, self).__eq__(other) and \
return super().__eq__(other) and \
self.field == other.field and self.pattern == other.pattern
def __hash__(self):
@ -151,17 +145,13 @@ class NoneQuery(FieldQuery):
"""A query that checks whether a field is null."""
def __init__(self, field, fast=True):
super(NoneQuery, self).__init__(field, None, fast)
super().__init__(field, None, fast)
def col_clause(self):
return self.field + " IS NULL", ()
@classmethod
def match(cls, item):
try:
return item[cls.field] is None
except KeyError:
return True
def match(self, item):
return item.get(self.field) is None
def __repr__(self):
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
@ -214,14 +204,14 @@ class RegexpQuery(StringFieldQuery):
"""
def __init__(self, field, pattern, fast=True):
super(RegexpQuery, self).__init__(field, pattern, fast)
super().__init__(field, pattern, fast)
pattern = self._normalize(pattern)
try:
self.pattern = re.compile(self.pattern)
except re.error as exc:
# Invalid regular expression.
raise InvalidQueryArgumentValueError(pattern,
u"a regular expression",
"a regular expression",
format(exc))
@staticmethod
@ -242,8 +232,8 @@ class BooleanQuery(MatchQuery):
"""
def __init__(self, field, pattern, fast=True):
super(BooleanQuery, self).__init__(field, pattern, fast)
if isinstance(pattern, six.string_types):
super().__init__(field, pattern, fast)
if isinstance(pattern, str):
self.pattern = util.str2bool(pattern)
self.pattern = int(self.pattern)
@ -256,16 +246,16 @@ class BytesQuery(MatchQuery):
"""
def __init__(self, field, pattern):
super(BytesQuery, self).__init__(field, pattern)
super().__init__(field, pattern)
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
if isinstance(self.pattern, (six.text_type, bytes)):
if isinstance(self.pattern, six.text_type):
if isinstance(self.pattern, (str, bytes)):
if isinstance(self.pattern, str):
self.pattern = self.pattern.encode('utf-8')
self.buf_pattern = buffer(self.pattern)
elif isinstance(self.pattern, buffer):
self.buf_pattern = memoryview(self.pattern)
elif isinstance(self.pattern, memoryview):
self.buf_pattern = self.pattern
self.pattern = bytes(self.pattern)
@ -297,10 +287,10 @@ class NumericQuery(FieldQuery):
try:
return float(s)
except ValueError:
raise InvalidQueryArgumentValueError(s, u"an int or a float")
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field, pattern, fast=True):
super(NumericQuery, self).__init__(field, pattern, fast)
super().__init__(field, pattern, fast)
parts = pattern.split('..', 1)
if len(parts) == 1:
@ -318,7 +308,7 @@ class NumericQuery(FieldQuery):
if self.field not in item:
return False
value = item[self.field]
if isinstance(value, six.string_types):
if isinstance(value, str):
value = self._convert(value)
if self.point is not None:
@ -335,14 +325,14 @@ class NumericQuery(FieldQuery):
return self.field + '=?', (self.point,)
else:
if self.rangemin is not None and self.rangemax is not None:
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
return ('{0} >= ? AND {0} <= ?'.format(self.field),
(self.rangemin, self.rangemax))
elif self.rangemin is not None:
return u'{0} >= ?'.format(self.field), (self.rangemin,)
return f'{self.field} >= ?', (self.rangemin,)
elif self.rangemax is not None:
return u'{0} <= ?'.format(self.field), (self.rangemax,)
return f'{self.field} <= ?', (self.rangemax,)
else:
return u'1', ()
return '1', ()
class CollectionQuery(Query):
@ -387,7 +377,7 @@ class CollectionQuery(Query):
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
def __eq__(self, other):
return super(CollectionQuery, self).__eq__(other) and \
return super().__eq__(other) and \
self.subqueries == other.subqueries
def __hash__(self):
@ -411,7 +401,7 @@ class AnyFieldQuery(CollectionQuery):
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
super(AnyFieldQuery, self).__init__(subqueries)
super().__init__(subqueries)
def clause(self):
return self.clause_with_joiner('or')
@ -427,7 +417,7 @@ class AnyFieldQuery(CollectionQuery):
"{0.query_class.__name__})".format(self))
def __eq__(self, other):
return super(AnyFieldQuery, self).__eq__(other) and \
return super().__eq__(other) and \
self.query_class == other.query_class
def __hash__(self):
@ -453,7 +443,7 @@ class AndQuery(MutableCollectionQuery):
return self.clause_with_joiner('and')
def match(self, item):
return all([q.match(item) for q in self.subqueries])
return all(q.match(item) for q in self.subqueries)
class OrQuery(MutableCollectionQuery):
@ -463,7 +453,7 @@ class OrQuery(MutableCollectionQuery):
return self.clause_with_joiner('or')
def match(self, item):
return any([q.match(item) for q in self.subqueries])
return any(q.match(item) for q in self.subqueries)
class NotQuery(Query):
@ -477,7 +467,7 @@ class NotQuery(Query):
def clause(self):
clause, subvals = self.subquery.clause()
if clause:
return 'not ({0})'.format(clause), subvals
return f'not ({clause})', subvals
else:
# If there is no clause, there is nothing to negate. All the logic
# is handled by match() for slow queries.
@ -490,7 +480,7 @@ class NotQuery(Query):
return "{0.__class__.__name__}({0.subquery!r})".format(self)
def __eq__(self, other):
return super(NotQuery, self).__eq__(other) and \
return super().__eq__(other) and \
self.subquery == other.subquery
def __hash__(self):
@ -546,7 +536,7 @@ def _parse_periods(pattern):
return (start, end)
class Period(object):
class Period:
"""A period of time given by a date, time and precision.
Example: 2014-01-01 10:50:30 with precision 'month' represents all
@ -572,7 +562,7 @@ class Period(object):
or "second").
"""
if precision not in Period.precisions:
raise ValueError(u'Invalid precision {0}'.format(precision))
raise ValueError(f'Invalid precision {precision}')
self.date = date
self.precision = precision
@ -653,10 +643,10 @@ class Period(object):
elif 'second' == precision:
return date + timedelta(seconds=1)
else:
raise ValueError(u'unhandled precision {0}'.format(precision))
raise ValueError(f'unhandled precision {precision}')
class DateInterval(object):
class DateInterval:
"""A closed-open interval of dates.
A left endpoint of None means since the beginning of time.
@ -665,7 +655,7 @@ class DateInterval(object):
def __init__(self, start, end):
if start is not None and end is not None and not start < end:
raise ValueError(u"start date {0} is not before end date {1}"
raise ValueError("start date {} is not before end date {}"
.format(start, end))
self.start = start
self.end = end
@ -686,7 +676,7 @@ class DateInterval(object):
return True
def __str__(self):
return '[{0}, {1})'.format(self.start, self.end)
return f'[{self.start}, {self.end})'
class DateQuery(FieldQuery):
@ -700,7 +690,7 @@ class DateQuery(FieldQuery):
"""
def __init__(self, field, pattern, fast=True):
super(DateQuery, self).__init__(field, pattern, fast)
super().__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
@ -759,12 +749,12 @@ class DurationQuery(NumericQuery):
except ValueError:
raise InvalidQueryArgumentValueError(
s,
u"a M:SS string or a float")
"a M:SS string or a float")
# Sorting.
class Sort(object):
class Sort:
"""An abstract class representing a sort operation for a query into
the item database.
"""
@ -851,13 +841,13 @@ class MultipleSort(Sort):
return items
def __repr__(self):
return 'MultipleSort({!r})'.format(self.sorts)
return f'MultipleSort({self.sorts!r})'
def __hash__(self):
return hash(tuple(self.sorts))
def __eq__(self, other):
return super(MultipleSort, self).__eq__(other) and \
return super().__eq__(other) and \
self.sorts == other.sorts
@ -878,14 +868,14 @@ class FieldSort(Sort):
def key(item):
field_val = item.get(self.field, '')
if self.case_insensitive and isinstance(field_val, six.text_type):
if self.case_insensitive and isinstance(field_val, str):
field_val = field_val.lower()
return field_val
return sorted(objs, key=key, reverse=not self.ascending)
def __repr__(self):
return '<{0}: {1}{2}>'.format(
return '<{}: {}{}>'.format(
type(self).__name__,
self.field,
'+' if self.ascending else '-',
@ -895,7 +885,7 @@ class FieldSort(Sort):
return hash((self.field, self.ascending))
def __eq__(self, other):
return super(FieldSort, self).__eq__(other) and \
return super().__eq__(other) and \
self.field == other.field and \
self.ascending == other.ascending
@ -913,7 +903,7 @@ class FixedFieldSort(FieldSort):
'ELSE {0} END)'.format(self.field)
else:
field = self.field
return "{0} {1}".format(field, order)
return f"{field} {order}"
class SlowFieldSort(FieldSort):

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,12 +14,10 @@
"""Parsing of strings into DBCore queries.
"""
from __future__ import division, absolute_import, print_function
import re
import itertools
from . import query
import beets
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
@ -89,7 +86,7 @@ def parse_query_part(part, query_classes={}, prefixes={},
assert match # Regex should always match
negate = bool(match.group(1))
key = match.group(2)
term = match.group(3).replace('\:', ':')
term = match.group(3).replace('\\:', ':')
# Check whether there's a prefix in the query and use the
# corresponding query type.
@ -119,12 +116,13 @@ def construct_query_part(model_cls, prefixes, query_part):
if not query_part:
return query.TrueQuery()
# Use `model_cls` to build up a map from field names to `Query`
# classes.
# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
query_classes = {}
for k, t in itertools.chain(model_cls._fields.items(),
model_cls._types.items()):
query_classes[k] = t.query
query_classes.update(model_cls._queries) # Non-field queries.
# Parse the string.
key, pattern, query_class, negate = \
@ -137,26 +135,27 @@ def construct_query_part(model_cls, prefixes, query_part):
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
q = query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
if negate:
return query.NotQuery(q)
else:
return q
out_query = query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
else:
# Non-field query type.
if negate:
return query.NotQuery(query_class(pattern))
else:
return query_class(pattern)
out_query = query_class(pattern)
# Otherwise, this must be a `FieldQuery`. Use the field name to
# construct the query object.
key = key.lower()
q = query_class(key.lower(), pattern, key in model_cls._fields)
# Field queries get constructed according to the name of the field
# they are querying.
elif issubclass(query_class, query.FieldQuery):
key = key.lower()
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
# Non-field (named) query.
else:
out_query = query_class(pattern)
# Apply negation.
if negate:
return query.NotQuery(q)
return q
return query.NotQuery(out_query)
else:
return out_query
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
@ -172,11 +171,13 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
return query_cls(subqueries)
def construct_sort_part(model_cls, part):
def construct_sort_part(model_cls, part, case_insensitive=True):
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
ending in ``+`` or ``-`` indicating the sort.
ending in ``+`` or ``-`` indicating the sort. `case_insensitive`
indicates whether or not the sort should be performed in a case
sensitive manner.
"""
assert part, "part must be a field name and + or -"
field = part[:-1]
@ -185,7 +186,6 @@ def construct_sort_part(model_cls, part):
assert direction in ('+', '-'), "part must end with + or -"
is_ascending = direction == '+'
case_insensitive = beets.config['sort_case_insensitive'].get(bool)
if field in model_cls._sorts:
sort = model_cls._sorts[field](model_cls, is_ascending,
case_insensitive)
@ -197,21 +197,23 @@ def construct_sort_part(model_cls, part):
return sort
def sort_from_strings(model_cls, sort_parts):
def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
sort = query.NullSort()
elif len(sort_parts) == 1:
sort = construct_sort_part(model_cls, sort_parts[0])
sort = construct_sort_part(model_cls, sort_parts[0], case_insensitive)
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(construct_sort_part(model_cls, part))
sort.add_sort(construct_sort_part(model_cls, part,
case_insensitive))
return sort
def parse_sorted_query(model_cls, parts, prefixes={}):
def parse_sorted_query(model_cls, parts, prefixes={},
case_insensitive=True):
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""
@ -222,8 +224,8 @@ def parse_sorted_query(model_cls, parts, prefixes={}):
# Split up query in to comma-separated subqueries, each representing
# an AndQuery, which need to be joined together in one OrQuery
subquery_parts = []
for part in parts + [u',']:
if part.endswith(u','):
for part in parts + [',']:
if part.endswith(','):
# Ensure we can catch "foo, bar" as well as "foo , bar"
last_subquery_part = part[:-1]
if last_subquery_part:
@ -237,8 +239,8 @@ def parse_sorted_query(model_cls, parts, prefixes={}):
else:
# Sort parts (1) end in + or -, (2) don't have a field, and
# (3) consist of more than just the + or -.
if part.endswith((u'+', u'-')) \
and u':' not in part \
if part.endswith(('+', '-')) \
and ':' not in part \
and len(part) > 1:
sort_parts.append(part)
else:
@ -246,5 +248,5 @@ def parse_sorted_query(model_cls, parts, prefixes={}):
# Avoid needlessly wrapping single statements in an OR
q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0]
s = sort_from_strings(model_cls, sort_parts)
s = sort_from_strings(model_cls, sort_parts, case_insensitive)
return q, s

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,25 +14,20 @@
"""Representation of type information for DBCore model fields.
"""
from __future__ import division, absolute_import, print_function
from . import query
from beets.util import str2bool
import six
if not six.PY2:
buffer = memoryview # sqlite won't accept memoryview in python 2
# Abstract base.
class Type(object):
class Type:
"""An object encapsulating the type of a model field. Includes
information about how to store, query, format, and parse a given
field.
"""
sql = u'TEXT'
sql = 'TEXT'
"""The SQLite column type for the value.
"""
@ -41,7 +35,7 @@ class Type(object):
"""The `Query` subclass to be used when querying the field.
"""
model_type = six.text_type
model_type = str
"""The Python type that is used to represent the value in the model.
The model is guaranteed to return a value of this type if the field
@ -63,11 +57,11 @@ class Type(object):
value = self.null
# `self.null` might be `None`
if value is None:
value = u''
value = ''
if isinstance(value, bytes):
value = value.decode('utf-8', 'ignore')
return six.text_type(value)
return str(value)
def parse(self, string):
"""Parse a (possibly human-written) string and return the
@ -97,16 +91,16 @@ class Type(object):
For fixed fields the type of `value` is determined by the column
type affinity given in the `sql` property and the SQL to Python
mapping of the database adapter. For more information see:
http://www.sqlite.org/datatype3.html
https://www.sqlite.org/datatype3.html
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
Flexible fields have the type affinity `TEXT`. This means the
`sql_value` is either a `buffer`/`memoryview` or a `unicode` object`
`sql_value` is either a `memoryview` or a `unicode` object`
and the method must handle these in addition.
"""
if isinstance(sql_value, buffer):
if isinstance(sql_value, memoryview):
sql_value = bytes(sql_value).decode('utf-8', 'ignore')
if isinstance(sql_value, six.text_type):
if isinstance(sql_value, str):
return self.parse(sql_value)
else:
return self.normalize(sql_value)
@ -127,10 +121,18 @@ class Default(Type):
class Integer(Type):
"""A basic integer type.
"""
sql = u'INTEGER'
sql = 'INTEGER'
query = query.NumericQuery
model_type = int
def normalize(self, value):
try:
return self.model_type(round(float(value)))
except ValueError:
return self.null
except TypeError:
return self.null
class PaddedInt(Integer):
"""An integer field that is formatted with a given number of digits,
@ -140,19 +142,25 @@ class PaddedInt(Integer):
self.digits = digits
def format(self, value):
return u'{0:0{1}d}'.format(value or 0, self.digits)
return '{0:0{1}d}'.format(value or 0, self.digits)
class NullPaddedInt(PaddedInt):
"""Same as `PaddedInt`, but does not normalize `None` to `0.0`.
"""
null = None
class ScaledInt(Integer):
"""An integer whose formatting operation scales the number by a
constant and adds a suffix. Good for units with large magnitudes.
"""
def __init__(self, unit, suffix=u''):
def __init__(self, unit, suffix=''):
self.unit = unit
self.suffix = suffix
def format(self, value):
return u'{0}{1}'.format((value or 0) // self.unit, self.suffix)
return '{}{}'.format((value or 0) // self.unit, self.suffix)
class Id(Integer):
@ -163,18 +171,22 @@ class Id(Integer):
def __init__(self, primary=True):
if primary:
self.sql = u'INTEGER PRIMARY KEY'
self.sql = 'INTEGER PRIMARY KEY'
class Float(Type):
"""A basic floating-point type.
"""A basic floating-point type. The `digits` parameter specifies how
many decimal places to use in the human-readable representation.
"""
sql = u'REAL'
sql = 'REAL'
query = query.NumericQuery
model_type = float
def __init__(self, digits=1):
self.digits = digits
def format(self, value):
return u'{0:.1f}'.format(value or 0.0)
return '{0:.{1}f}'.format(value or 0, self.digits)
class NullFloat(Float):
@ -186,19 +198,25 @@ class NullFloat(Float):
class String(Type):
"""A Unicode string type.
"""
sql = u'TEXT'
sql = 'TEXT'
query = query.SubstringQuery
def normalize(self, value):
if value is None:
return self.null
else:
return self.model_type(value)
class Boolean(Type):
"""A boolean type.
"""
sql = u'INTEGER'
sql = 'INTEGER'
query = query.BooleanQuery
model_type = bool
def format(self, value):
return six.text_type(bool(value))
return str(bool(value))
def parse(self, string):
return str2bool(string)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -13,7 +12,6 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
from __future__ import division, absolute_import, print_function
"""Provides the basic, interface-agnostic workflow for importing and
autotagging music files.
@ -40,7 +38,7 @@ from beets import config
from beets.util import pipeline, sorted_walk, ancestry, MoveOperation
from beets.util import syspath, normpath, displayable_path
from enum import Enum
from beets import mediafile
import mediafile
action = Enum('action',
['SKIP', 'ASIS', 'TRACKS', 'APPLY', 'ALBUMS', 'RETAG'])
@ -75,7 +73,7 @@ def _open_state():
# unpickling, including ImportError. We use a catch-all
# exception to avoid enumerating them all (the docs don't even have a
# full list!).
log.debug(u'state file could not be read: {0}', exc)
log.debug('state file could not be read: {0}', exc)
return {}
@ -84,8 +82,8 @@ def _save_state(state):
try:
with open(config['statefile'].as_filename(), 'wb') as f:
pickle.dump(state, f)
except IOError as exc:
log.error(u'state file could not be written: {0}', exc)
except OSError as exc:
log.error('state file could not be written: {0}', exc)
# Utilities for reading and writing the beets progress file, which
@ -174,10 +172,11 @@ def history_get():
# Abstract session class.
class ImportSession(object):
class ImportSession:
"""Controls an import action. Subclasses should implement methods to
communicate with the user or otherwise make decisions.
"""
def __init__(self, lib, loghandler, paths, query):
"""Create a session. `lib` is a Library object. `loghandler` is a
logging.Handler. Either `paths` or `query` is non-null and indicates
@ -187,7 +186,7 @@ class ImportSession(object):
self.logger = self._setup_logging(loghandler)
self.paths = paths
self.query = query
self._is_resuming = dict()
self._is_resuming = {}
self._merged_items = set()
self._merged_dirs = set()
@ -222,19 +221,31 @@ class ImportSession(object):
iconfig['resume'] = False
iconfig['incremental'] = False
# Copy, move, link, and hardlink are mutually exclusive.
if iconfig['reflink']:
iconfig['reflink'] = iconfig['reflink'] \
.as_choice(['auto', True, False])
# Copy, move, reflink, link, and hardlink are mutually exclusive.
if iconfig['move']:
iconfig['copy'] = False
iconfig['link'] = False
iconfig['hardlink'] = False
iconfig['reflink'] = False
elif iconfig['link']:
iconfig['copy'] = False
iconfig['move'] = False
iconfig['hardlink'] = False
iconfig['reflink'] = False
elif iconfig['hardlink']:
iconfig['copy'] = False
iconfig['move'] = False
iconfig['link'] = False
iconfig['reflink'] = False
elif iconfig['reflink']:
iconfig['copy'] = False
iconfig['move'] = False
iconfig['link'] = False
iconfig['hardlink'] = False
# Only delete when copying.
if not iconfig['copy']:
@ -246,7 +257,7 @@ class ImportSession(object):
"""Log a message about a given album to the importer log. The status
should reflect the reason the album couldn't be tagged.
"""
self.logger.info(u'{0} {1}', status, displayable_path(paths))
self.logger.info('{0} {1}', status, displayable_path(paths))
def log_choice(self, task, duplicate=False):
"""Logs the task's current choice if it should be logged. If
@ -257,17 +268,17 @@ class ImportSession(object):
if duplicate:
# Duplicate: log all three choices (skip, keep both, and trump).
if task.should_remove_duplicates:
self.tag_log(u'duplicate-replace', paths)
self.tag_log('duplicate-replace', paths)
elif task.choice_flag in (action.ASIS, action.APPLY):
self.tag_log(u'duplicate-keep', paths)
self.tag_log('duplicate-keep', paths)
elif task.choice_flag is (action.SKIP):
self.tag_log(u'duplicate-skip', paths)
self.tag_log('duplicate-skip', paths)
else:
# Non-duplicate: log "skip" and "asis" choices.
if task.choice_flag is action.ASIS:
self.tag_log(u'asis', paths)
self.tag_log('asis', paths)
elif task.choice_flag is action.SKIP:
self.tag_log(u'skip', paths)
self.tag_log('skip', paths)
def should_resume(self, path):
raise NotImplementedError
@ -284,7 +295,7 @@ class ImportSession(object):
def run(self):
"""Run the import task.
"""
self.logger.info(u'import started {0}', time.asctime())
self.logger.info('import started {0}', time.asctime())
self.set_config(config['import'])
# Set up the pipeline.
@ -368,8 +379,8 @@ class ImportSession(object):
"""Mark paths and directories as merged for future reimport tasks.
"""
self._merged_items.update(paths)
dirs = set([os.path.dirname(path) if os.path.isfile(path) else path
for path in paths])
dirs = {os.path.dirname(path) if os.path.isfile(path) else path
for path in paths}
self._merged_dirs.update(dirs)
def is_resuming(self, toppath):
@ -389,7 +400,7 @@ class ImportSession(object):
# Either accept immediately or prompt for input to decide.
if self.want_resume is True or \
self.should_resume(toppath):
log.warning(u'Resuming interrupted import of {0}',
log.warning('Resuming interrupted import of {0}',
util.displayable_path(toppath))
self._is_resuming[toppath] = True
else:
@ -399,11 +410,12 @@ class ImportSession(object):
# The importer task class.
class BaseImportTask(object):
class BaseImportTask:
"""An abstract base class for importer tasks.
Tasks flow through the importer pipeline. Each stage can update
them. """
def __init__(self, toppath, paths, items):
"""Create a task. The primary fields that define a task are:
@ -457,8 +469,9 @@ class ImportTask(BaseImportTask):
* `finalize()` Update the import progress and cleanup the file
system.
"""
def __init__(self, toppath, paths, items):
super(ImportTask, self).__init__(toppath, paths, items)
super().__init__(toppath, paths, items)
self.choice_flag = None
self.cur_album = None
self.cur_artist = None
@ -550,28 +563,34 @@ class ImportTask(BaseImportTask):
def remove_duplicates(self, lib):
duplicate_items = self.duplicate_items(lib)
log.debug(u'removing {0} old duplicated items', len(duplicate_items))
log.debug('removing {0} old duplicated items', len(duplicate_items))
for item in duplicate_items:
item.remove()
if lib.directory in util.ancestry(item.path):
log.debug(u'deleting duplicate {0}',
log.debug('deleting duplicate {0}',
util.displayable_path(item.path))
util.remove(item.path)
util.prune_dirs(os.path.dirname(item.path),
lib.directory)
def set_fields(self):
def set_fields(self, lib):
"""Sets the fields given at CLI or configuration to the specified
values.
values, for both the album and all its items.
"""
items = self.imported_items()
for field, view in config['import']['set_fields'].items():
value = view.get()
log.debug(u'Set field {1}={2} for {0}',
log.debug('Set field {1}={2} for {0}',
displayable_path(self.paths),
field,
value)
self.album[field] = value
self.album.store()
for item in items:
item[field] = value
with lib.transaction():
for item in items:
item.store()
self.album.store()
def finalize(self, session):
"""Save progress, clean up files, and emit plugin event.
@ -655,7 +674,7 @@ class ImportTask(BaseImportTask):
return []
duplicates = []
task_paths = set(i.path for i in self.items if i)
task_paths = {i.path for i in self.items if i}
duplicate_query = dbcore.AndQuery((
dbcore.MatchQuery('albumartist', artist),
dbcore.MatchQuery('album', album),
@ -665,7 +684,7 @@ class ImportTask(BaseImportTask):
# Check whether the album paths are all present in the task
# i.e. album is being completely re-imported by the task,
# in which case it is not a duplicate (will be replaced).
album_paths = set(i.path for i in album.items())
album_paths = {i.path for i in album.items()}
if not (album_paths <= task_paths):
duplicates.append(album)
return duplicates
@ -707,7 +726,7 @@ class ImportTask(BaseImportTask):
item.update(changes)
def manipulate_files(self, operation=None, write=False, session=None):
""" Copy, move, link or hardlink (depending on `operation`) the files
""" Copy, move, link, hardlink or reflink (depending on `operation`) the files
as well as write metadata.
`operation` should be an instance of `util.MoveOperation`.
@ -754,6 +773,8 @@ class ImportTask(BaseImportTask):
self.record_replaced(lib)
self.remove_replaced(lib)
self.album = lib.add_album(self.imported_items())
if 'data_source' in self.imported_items()[0]:
self.album.data_source = self.imported_items()[0].data_source
self.reimport_metadata(lib)
def record_replaced(self, lib):
@ -772,7 +793,7 @@ class ImportTask(BaseImportTask):
if (not dup_item.album_id or
dup_item.album_id in replaced_album_ids):
continue
replaced_album = dup_item.get_album()
replaced_album = dup_item._cached_album
if replaced_album:
replaced_album_ids.add(dup_item.album_id)
self.replaced_albums[replaced_album.path] = replaced_album
@ -789,8 +810,8 @@ class ImportTask(BaseImportTask):
self.album.artpath = replaced_album.artpath
self.album.store()
log.debug(
u'Reimported album: added {0}, flexible '
u'attributes {1} from album {2} for {3}',
'Reimported album: added {0}, flexible '
'attributes {1} from album {2} for {3}',
self.album.added,
replaced_album._values_flex.keys(),
replaced_album.id,
@ -803,16 +824,16 @@ class ImportTask(BaseImportTask):
if dup_item.added and dup_item.added != item.added:
item.added = dup_item.added
log.debug(
u'Reimported item added {0} '
u'from item {1} for {2}',
'Reimported item added {0} '
'from item {1} for {2}',
item.added,
dup_item.id,
displayable_path(item.path)
)
item.update(dup_item._values_flex)
log.debug(
u'Reimported item flexible attributes {0} '
u'from item {1} for {2}',
'Reimported item flexible attributes {0} '
'from item {1} for {2}',
dup_item._values_flex.keys(),
dup_item.id,
displayable_path(item.path)
@ -825,10 +846,10 @@ class ImportTask(BaseImportTask):
"""
for item in self.imported_items():
for dup_item in self.replaced_items[item]:
log.debug(u'Replacing item {0}: {1}',
log.debug('Replacing item {0}: {1}',
dup_item.id, displayable_path(item.path))
dup_item.remove()
log.debug(u'{0} of {1} items replaced',
log.debug('{0} of {1} items replaced',
sum(bool(l) for l in self.replaced_items.values()),
len(self.imported_items()))
@ -866,7 +887,7 @@ class SingletonImportTask(ImportTask):
"""
def __init__(self, toppath, item):
super(SingletonImportTask, self).__init__(toppath, [item.path], [item])
super().__init__(toppath, [item.path], [item])
self.item = item
self.is_album = False
self.paths = [item.path]
@ -932,13 +953,13 @@ class SingletonImportTask(ImportTask):
def reload(self):
self.item.load()
def set_fields(self):
def set_fields(self, lib):
"""Sets the fields given at CLI or configuration to the specified
values.
values, for the singleton item.
"""
for field, view in config['import']['set_fields'].items():
value = view.get()
log.debug(u'Set field {1}={2} for {0}',
log.debug('Set field {1}={2} for {0}',
displayable_path(self.paths),
field,
value)
@ -959,7 +980,7 @@ class SentinelImportTask(ImportTask):
"""
def __init__(self, toppath, paths):
super(SentinelImportTask, self).__init__(toppath, paths, ())
super().__init__(toppath, paths, ())
# TODO Remove the remaining attributes eventually
self.should_remove_duplicates = False
self.is_album = True
@ -1003,7 +1024,7 @@ class ArchiveImportTask(SentinelImportTask):
"""
def __init__(self, toppath):
super(ArchiveImportTask, self).__init__(toppath, ())
super().__init__(toppath, ())
self.extracted = False
@classmethod
@ -1032,14 +1053,20 @@ class ArchiveImportTask(SentinelImportTask):
cls._handlers = []
from zipfile import is_zipfile, ZipFile
cls._handlers.append((is_zipfile, ZipFile))
from tarfile import is_tarfile, TarFile
cls._handlers.append((is_tarfile, TarFile))
import tarfile
cls._handlers.append((tarfile.is_tarfile, tarfile.open))
try:
from rarfile import is_rarfile, RarFile
except ImportError:
pass
else:
cls._handlers.append((is_rarfile, RarFile))
try:
from py7zr import is_7zfile, SevenZipFile
except ImportError:
pass
else:
cls._handlers.append((is_7zfile, SevenZipFile))
return cls._handlers
@ -1047,7 +1074,7 @@ class ArchiveImportTask(SentinelImportTask):
"""Removes the temporary directory the archive was extracted to.
"""
if self.extracted:
log.debug(u'Removing extracted directory: {0}',
log.debug('Removing extracted directory: {0}',
displayable_path(self.toppath))
shutil.rmtree(self.toppath)
@ -1059,9 +1086,9 @@ class ArchiveImportTask(SentinelImportTask):
if path_test(util.py3_path(self.toppath)):
break
extract_to = mkdtemp()
archive = handler_class(util.py3_path(self.toppath), mode='r')
try:
extract_to = mkdtemp()
archive = handler_class(util.py3_path(self.toppath), mode='r')
archive.extractall(extract_to)
finally:
archive.close()
@ -1069,10 +1096,11 @@ class ArchiveImportTask(SentinelImportTask):
self.toppath = extract_to
class ImportTaskFactory(object):
class ImportTaskFactory:
"""Generate album and singleton import tasks for all media files
indicated by a path.
"""
def __init__(self, toppath, session):
"""Create a new task factory.
@ -1110,14 +1138,12 @@ class ImportTaskFactory(object):
if self.session.config['singletons']:
for path in paths:
tasks = self._create(self.singleton(path))
for task in tasks:
yield task
yield from tasks
yield self.sentinel(dirs)
else:
tasks = self._create(self.album(paths, dirs))
for task in tasks:
yield task
yield from tasks
# Produce the final sentinel for this toppath to indicate that
# it is finished. This is usually just a SentinelImportTask, but
@ -1165,7 +1191,7 @@ class ImportTaskFactory(object):
"""Return a `SingletonImportTask` for the music file.
"""
if self.session.already_imported(self.toppath, [path]):
log.debug(u'Skipping previously-imported path: {0}',
log.debug('Skipping previously-imported path: {0}',
displayable_path(path))
self.skipped += 1
return None
@ -1186,10 +1212,10 @@ class ImportTaskFactory(object):
return None
if dirs is None:
dirs = list(set(os.path.dirname(p) for p in paths))
dirs = list({os.path.dirname(p) for p in paths})
if self.session.already_imported(self.toppath, dirs):
log.debug(u'Skipping previously-imported path: {0}',
log.debug('Skipping previously-imported path: {0}',
displayable_path(dirs))
self.skipped += 1
return None
@ -1219,22 +1245,22 @@ class ImportTaskFactory(object):
if not (self.session.config['move'] or
self.session.config['copy']):
log.warning(u"Archive importing requires either "
u"'copy' or 'move' to be enabled.")
log.warning("Archive importing requires either "
"'copy' or 'move' to be enabled.")
return
log.debug(u'Extracting archive: {0}',
log.debug('Extracting archive: {0}',
displayable_path(self.toppath))
archive_task = ArchiveImportTask(self.toppath)
try:
archive_task.extract()
except Exception as exc:
log.error(u'extraction failed: {0}', exc)
log.error('extraction failed: {0}', exc)
return
# Now read albums from the extracted directory.
self.toppath = archive_task.toppath
log.debug(u'Archive extracted to: {0}', self.toppath)
log.debug('Archive extracted to: {0}', self.toppath)
return archive_task
def read_item(self, path):
@ -1250,9 +1276,9 @@ class ImportTaskFactory(object):
# Silently ignore non-music files.
pass
elif isinstance(exc.reason, mediafile.UnreadableFileError):
log.warning(u'unreadable file: {0}', displayable_path(path))
log.warning('unreadable file: {0}', displayable_path(path))
else:
log.error(u'error reading {0}: {1}',
log.error('error reading {0}: {1}',
displayable_path(path), exc)
@ -1291,17 +1317,16 @@ def read_tasks(session):
# Generate tasks.
task_factory = ImportTaskFactory(toppath, session)
for t in task_factory.tasks():
yield t
yield from task_factory.tasks()
skipped += task_factory.skipped
if not task_factory.imported:
log.warning(u'No files imported from {0}',
log.warning('No files imported from {0}',
displayable_path(toppath))
# Show skipped directories (due to incremental/resume).
if skipped:
log.info(u'Skipped {0} paths.', skipped)
log.info('Skipped {0} paths.', skipped)
def query_tasks(session):
@ -1319,7 +1344,7 @@ def query_tasks(session):
else:
# Search for albums.
for album in session.lib.albums(session.query):
log.debug(u'yielding album {0}: {1} - {2}',
log.debug('yielding album {0}: {1} - {2}',
album.id, album.albumartist, album.album)
items = list(album.items())
_freshen_items(items)
@ -1342,7 +1367,7 @@ def lookup_candidates(session, task):
return
plugins.send('import_task_start', session=session, task=task)
log.debug(u'Looking up: {0}', displayable_path(task.paths))
log.debug('Looking up: {0}', displayable_path(task.paths))
# Restrict the initial lookup to IDs specified by the user via the -m
# option. Currently all the IDs are passed onto the tasks directly.
@ -1381,8 +1406,7 @@ def user_query(session, task):
def emitter(task):
for item in task.items:
task = SingletonImportTask(task.toppath, item)
for new_task in task.handle_created(session):
yield new_task
yield from task.handle_created(session)
yield SentinelImportTask(task.toppath, task.paths)
return _extend_pipeline(emitter(task),
@ -1428,30 +1452,30 @@ def resolve_duplicates(session, task):
if task.choice_flag in (action.ASIS, action.APPLY, action.RETAG):
found_duplicates = task.find_duplicates(session.lib)
if found_duplicates:
log.debug(u'found duplicates: {}'.format(
log.debug('found duplicates: {}'.format(
[o.id for o in found_duplicates]
))
# Get the default action to follow from config.
duplicate_action = config['import']['duplicate_action'].as_choice({
u'skip': u's',
u'keep': u'k',
u'remove': u'r',
u'merge': u'm',
u'ask': u'a',
'skip': 's',
'keep': 'k',
'remove': 'r',
'merge': 'm',
'ask': 'a',
})
log.debug(u'default action for duplicates: {0}', duplicate_action)
log.debug('default action for duplicates: {0}', duplicate_action)
if duplicate_action == u's':
if duplicate_action == 's':
# Skip new.
task.set_choice(action.SKIP)
elif duplicate_action == u'k':
elif duplicate_action == 'k':
# Keep both. Do nothing; leave the choice intact.
pass
elif duplicate_action == u'r':
elif duplicate_action == 'r':
# Remove old.
task.should_remove_duplicates = True
elif duplicate_action == u'm':
elif duplicate_action == 'm':
# Merge duplicates together
task.should_merge_duplicates = True
else:
@ -1471,7 +1495,7 @@ def import_asis(session, task):
if task.skip:
return
log.info(u'{}', displayable_path(task.paths))
log.info('{}', displayable_path(task.paths))
task.set_choice(action.ASIS)
apply_choice(session, task)
@ -1496,7 +1520,7 @@ def apply_choice(session, task):
# because then the ``ImportTask`` won't have an `album` for which
# it can set the fields.
if config['import']['set_fields']:
task.set_fields()
task.set_fields(session.lib)
@pipeline.mutator_stage
@ -1534,6 +1558,8 @@ def manipulate_files(session, task):
operation = MoveOperation.LINK
elif session.config['hardlink']:
operation = MoveOperation.HARDLINK
elif session.config['reflink']:
operation = MoveOperation.REFLINK
else:
operation = None
@ -1552,11 +1578,11 @@ def log_files(session, task):
"""A coroutine (pipeline stage) to log each file to be imported.
"""
if isinstance(task, SingletonImportTask):
log.info(u'Singleton: {0}', displayable_path(task.item['path']))
log.info('Singleton: {0}', displayable_path(task.item['path']))
elif task.items:
log.info(u'Album: {0}', displayable_path(task.paths[0]))
log.info('Album: {0}', displayable_path(task.paths[0]))
for item in task.items:
log.info(u' {0}', displayable_path(item['path']))
log.info(' {0}', displayable_path(item['path']))
def group_albums(session):

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -21,13 +20,11 @@ that when getLogger(name) instantiates a logger that logger uses
{}-style formatting.
"""
from __future__ import division, absolute_import, print_function
from copy import copy
from logging import * # noqa
import subprocess
import threading
import six
def logsafe(val):
@ -43,7 +40,7 @@ def logsafe(val):
example.
"""
# Already Unicode.
if isinstance(val, six.text_type):
if isinstance(val, str):
return val
# Bytestring: needs decoding.
@ -57,7 +54,7 @@ def logsafe(val):
# A "problem" object: needs a workaround.
elif isinstance(val, subprocess.CalledProcessError):
try:
return six.text_type(val)
return str(val)
except UnicodeDecodeError:
# An object with a broken __unicode__ formatter. Use __str__
# instead.
@ -74,7 +71,7 @@ class StrFormatLogger(Logger):
instead of %-style formatting.
"""
class _LogMessage(object):
class _LogMessage:
def __init__(self, msg, args, kwargs):
self.msg = msg
self.args = args
@ -82,22 +79,23 @@ class StrFormatLogger(Logger):
def __str__(self):
args = [logsafe(a) for a in self.args]
kwargs = dict((k, logsafe(v)) for (k, v) in self.kwargs.items())
kwargs = {k: logsafe(v) for (k, v) in self.kwargs.items()}
return self.msg.format(*args, **kwargs)
def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
"""Log msg.format(*args, **kwargs)"""
m = self._LogMessage(msg, args, kwargs)
return super(StrFormatLogger, self)._log(level, m, (), exc_info, extra)
return super()._log(level, m, (), exc_info, extra)
class ThreadLocalLevelLogger(Logger):
"""A version of `Logger` whose level is thread-local instead of shared.
"""
def __init__(self, name, level=NOTSET):
self._thread_level = threading.local()
self.default_level = NOTSET
super(ThreadLocalLevelLogger, self).__init__(name, level)
super().__init__(name, level)
@property
def level(self):

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,19 +14,19 @@
"""Support for beets plugins."""
from __future__ import division, absolute_import, print_function
import inspect
import traceback
import re
import inspect
import abc
from collections import defaultdict
from functools import wraps
import beets
from beets import logging
from beets import mediafile
import six
import mediafile
PLUGIN_NAMESPACE = 'beetsplug'
@ -50,26 +49,28 @@ class PluginLogFilter(logging.Filter):
"""A logging filter that identifies the plugin that emitted a log
message.
"""
def __init__(self, plugin):
self.prefix = u'{0}: '.format(plugin.name)
self.prefix = f'{plugin.name}: '
def filter(self, record):
if hasattr(record.msg, 'msg') and isinstance(record.msg.msg,
six.string_types):
str):
# A _LogMessage from our hacked-up Logging replacement.
record.msg.msg = self.prefix + record.msg.msg
elif isinstance(record.msg, six.string_types):
elif isinstance(record.msg, str):
record.msg = self.prefix + record.msg
return True
# Managing the plugins themselves.
class BeetsPlugin(object):
class BeetsPlugin:
"""The base class for all beets plugins. Plugins provide
functionality by defining a subclass of BeetsPlugin and overriding
the abstract methods defined here.
"""
def __init__(self, name=None):
"""Perform one-time plugin setup.
"""
@ -127,27 +128,24 @@ class BeetsPlugin(object):
value after the function returns). Also determines which params may not
be sent for backwards-compatibility.
"""
argspec = inspect.getargspec(func)
argspec = inspect.getfullargspec(func)
@wraps(func)
def wrapper(*args, **kwargs):
assert self._log.level == logging.NOTSET
verbosity = beets.config['verbose'].get(int)
log_level = max(logging.DEBUG, base_log_level - 10 * verbosity)
self._log.setLevel(log_level)
if argspec.varkw is None:
kwargs = {k: v for k, v in kwargs.items()
if k in argspec.args}
try:
try:
return func(*args, **kwargs)
except TypeError as exc:
if exc.args[0].startswith(func.__name__):
# caused by 'func' and not stuff internal to 'func'
kwargs = dict((arg, val) for arg, val in kwargs.items()
if arg in argspec.args)
return func(*args, **kwargs)
else:
raise
return func(*args, **kwargs)
finally:
self._log.setLevel(logging.NOTSET)
return wrapper
def queries(self):
@ -167,7 +165,7 @@ class BeetsPlugin(object):
"""
return beets.autotag.hooks.Distance()
def candidates(self, items, artist, album, va_likely):
def candidates(self, items, artist, album, va_likely, extra_tags=None):
"""Should return a sequence of AlbumInfo objects that match the
album whose items are provided.
"""
@ -201,7 +199,7 @@ class BeetsPlugin(object):
``descriptor`` must be an instance of ``mediafile.MediaField``.
"""
# Defer impor to prevent circular dependency
# Defer import to prevent circular dependency
from beets import library
mediafile.MediaFile.add_field(name, descriptor)
library.Item._media_fields.add(name)
@ -264,14 +262,14 @@ def load_plugins(names=()):
BeetsPlugin subclasses desired.
"""
for name in names:
modname = '{0}.{1}'.format(PLUGIN_NAMESPACE, name)
modname = f'{PLUGIN_NAMESPACE}.{name}'
try:
try:
namespace = __import__(modname, None, None)
except ImportError as exc:
# Again, this is hacky:
if exc.args[0].endswith(' ' + name):
log.warning(u'** plugin {0} not found', name)
log.warning('** plugin {0} not found', name)
else:
raise
else:
@ -282,7 +280,7 @@ def load_plugins(names=()):
except Exception:
log.warning(
u'** error loading plugin {}:\n{}',
'** error loading plugin {}:\n{}',
name,
traceback.format_exc(),
)
@ -296,6 +294,11 @@ def find_plugins():
currently loaded beets plugins. Loads the default plugin set
first.
"""
if _instances:
# After the first call, use cached instances for performance reasons.
# See https://github.com/beetbox/beets/pull/3810
return list(_instances.values())
load_plugins()
plugins = []
for cls in _classes:
@ -329,21 +332,31 @@ def queries():
def types(model_cls):
# Gives us `item_types` and `album_types`
attr_name = '{0}_types'.format(model_cls.__name__.lower())
attr_name = f'{model_cls.__name__.lower()}_types'
types = {}
for plugin in find_plugins():
plugin_types = getattr(plugin, attr_name, {})
for field in plugin_types:
if field in types and plugin_types[field] != types[field]:
raise PluginConflictException(
u'Plugin {0} defines flexible field {1} '
u'which has already been defined with '
u'another type.'.format(plugin.name, field)
'Plugin {} defines flexible field {} '
'which has already been defined with '
'another type.'.format(plugin.name, field)
)
types.update(plugin_types)
return types
def named_queries(model_cls):
# Gather `item_queries` and `album_queries` from the plugins.
attr_name = f'{model_cls.__name__.lower()}_queries'
queries = {}
for plugin in find_plugins():
plugin_queries = getattr(plugin, attr_name, {})
queries.update(plugin_queries)
return queries
def track_distance(item, info):
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
@ -364,20 +377,19 @@ def album_distance(items, album_info, mapping):
return dist
def candidates(items, artist, album, va_likely):
def candidates(items, artist, album, va_likely, extra_tags=None):
"""Gets MusicBrainz candidates for an album from each plugin.
"""
for plugin in find_plugins():
for candidate in plugin.candidates(items, artist, album, va_likely):
yield candidate
yield from plugin.candidates(items, artist, album, va_likely,
extra_tags)
def item_candidates(item, artist, title):
"""Gets MusicBrainz candidates for an item from the plugins.
"""
for plugin in find_plugins():
for item_candidate in plugin.item_candidates(item, artist, title):
yield item_candidate
yield from plugin.item_candidates(item, artist, title)
def album_for_id(album_id):
@ -470,7 +482,7 @@ def send(event, **arguments):
Return a list of non-None values returned from the handlers.
"""
log.debug(u'Sending event: {0}', event)
log.debug('Sending event: {0}', event)
results = []
for handler in event_handlers()[event]:
result = handler(**arguments)
@ -488,7 +500,7 @@ def feat_tokens(for_artist=True):
feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
if for_artist:
feat_words += ['with', 'vs', 'and', 'con', '&']
return '(?<=\s)(?:{0})(?=\s)'.format(
return r'(?<=\s)(?:{})(?=\s)'.format(
'|'.join(re.escape(x) for x in feat_words)
)
@ -513,7 +525,7 @@ def sanitize_choices(choices, choices_all):
def sanitize_pairs(pairs, pairs_all):
"""Clean up a single-element mapping configuration attribute as returned
by `confit`'s `Pairs` template: keep only two-element tuples present in
by Confuse's `Pairs` template: keep only two-element tuples present in
pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*')
wildcards while keeping the original order. Note that ('*', '*') and
('*', 'whatever') have the same effect.
@ -563,3 +575,188 @@ def notify_info_yielded(event):
yield v
return decorated
return decorator
def get_distance(config, data_source, info):
"""Returns the ``data_source`` weight and the maximum source weight
for albums or individual tracks.
"""
dist = beets.autotag.Distance()
if info.data_source == data_source:
dist.add('source', config['source_weight'].as_number())
return dist
def apply_item_changes(lib, item, move, pretend, write):
"""Store, move, and write the item according to the arguments.
:param lib: beets library.
:type lib: beets.library.Library
:param item: Item whose changes to apply.
:type item: beets.library.Item
:param move: Move the item if it's in the library.
:type move: bool
:param pretend: Return without moving, writing, or storing the item's
metadata.
:type pretend: bool
:param write: Write the item's metadata to its media file.
:type write: bool
"""
if pretend:
return
from beets import util
# Move the item if it's in the library.
if move and lib.directory in util.ancestry(item.path):
item.move(with_album=False)
if write:
item.try_write()
item.store()
class MetadataSourcePlugin(metaclass=abc.ABCMeta):
def __init__(self):
super().__init__()
self.config.add({'source_weight': 0.5})
@abc.abstractproperty
def id_regex(self):
raise NotImplementedError
@abc.abstractproperty
def data_source(self):
raise NotImplementedError
@abc.abstractproperty
def search_url(self):
raise NotImplementedError
@abc.abstractproperty
def album_url(self):
raise NotImplementedError
@abc.abstractproperty
def track_url(self):
raise NotImplementedError
@abc.abstractmethod
def _search_api(self, query_type, filters, keywords=''):
raise NotImplementedError
@abc.abstractmethod
def album_for_id(self, album_id):
raise NotImplementedError
@abc.abstractmethod
def track_for_id(self, track_id=None, track_data=None):
raise NotImplementedError
@staticmethod
def get_artist(artists, id_key='id', name_key='name'):
"""Returns an artist string (all artists) and an artist_id (the main
artist) for a list of artist object dicts.
For each artist, this function moves articles (such as 'a', 'an',
and 'the') to the front and strips trailing disambiguation numbers. It
returns a tuple containing the comma-separated string of all
normalized artists and the ``id`` of the main/first artist.
:param artists: Iterable of artist dicts or lists returned by API.
:type artists: list[dict] or list[list]
:param id_key: Key or index corresponding to the value of ``id`` for
the main/first artist. Defaults to 'id'.
:type id_key: str or int
:param name_key: Key or index corresponding to values of names
to concatenate for the artist string (containing all artists).
Defaults to 'name'.
:type name_key: str or int
:return: Normalized artist string.
:rtype: str
"""
artist_id = None
artist_names = []
for artist in artists:
if not artist_id:
artist_id = artist[id_key]
name = artist[name_key]
# Strip disambiguation number.
name = re.sub(r' \(\d+\)$', '', name)
# Move articles to the front.
name = re.sub(r'^(.*?), (a|an|the)$', r'\2 \1', name, flags=re.I)
artist_names.append(name)
artist = ', '.join(artist_names).replace(' ,', ',') or None
return artist, artist_id
def _get_id(self, url_type, id_):
"""Parse an ID from its URL if necessary.
:param url_type: Type of URL. Either 'album' or 'track'.
:type url_type: str
:param id_: Album/track ID or URL.
:type id_: str
:return: Album/track ID.
:rtype: str
"""
self._log.debug(
"Searching {} for {} '{}'", self.data_source, url_type, id_
)
match = re.search(self.id_regex['pattern'].format(url_type), str(id_))
if match:
id_ = match.group(self.id_regex['match_group'])
if id_:
return id_
return None
def candidates(self, items, artist, album, va_likely, extra_tags=None):
"""Returns a list of AlbumInfo objects for Search API results
matching an ``album`` and ``artist`` (if not various).
:param items: List of items comprised by an album to be matched.
:type items: list[beets.library.Item]
:param artist: The artist of the album to be matched.
:type artist: str
:param album: The name of the album to be matched.
:type album: str
:param va_likely: True if the album to be matched likely has
Various Artists.
:type va_likely: bool
:return: Candidate AlbumInfo objects.
:rtype: list[beets.autotag.hooks.AlbumInfo]
"""
query_filters = {'album': album}
if not va_likely:
query_filters['artist'] = artist
results = self._search_api(query_type='album', filters=query_filters)
albums = [self.album_for_id(album_id=r['id']) for r in results]
return [a for a in albums if a is not None]
def item_candidates(self, item, artist, title):
"""Returns a list of TrackInfo objects for Search API results
matching ``title`` and ``artist``.
:param item: Singleton item to be matched.
:type item: beets.library.Item
:param artist: The artist of the track to be matched.
:type artist: str
:param title: The title of the track to be matched.
:type title: str
:return: Candidate TrackInfo objects.
:rtype: list[beets.autotag.hooks.TrackInfo]
"""
tracks = self._search_api(
query_type='track', keywords=title, filters={'artist': artist}
)
return [self.track_for_id(track_data=track) for track in tracks]
def album_distance(self, items, album_info, mapping):
return get_distance(
data_source=self.data_source, info=album_info, config=self.config
)
def track_distance(self, item, track_info):
return get_distance(
data_source=self.data_source, info=track_info, config=self.config
)

113
libs/common/beets/random.py Normal file
View file

@ -0,0 +1,113 @@
# This file is part of beets.
# Copyright 2016, Philippe Mongeau.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Get a random song or album from the library.
"""
import random
from operator import attrgetter
from itertools import groupby
def _length(obj, album):
"""Get the duration of an item or album.
"""
if album:
return sum(i.length for i in obj.items())
else:
return obj.length
def _equal_chance_permutation(objs, field='albumartist', random_gen=None):
"""Generate (lazily) a permutation of the objects where every group
with equal values for `field` have an equal chance of appearing in
any given position.
"""
rand = random_gen or random
# Group the objects by artist so we can sample from them.
key = attrgetter(field)
objs.sort(key=key)
objs_by_artists = {}
for artist, v in groupby(objs, key):
objs_by_artists[artist] = list(v)
# While we still have artists with music to choose from, pick one
# randomly and pick a track from that artist.
while objs_by_artists:
# Choose an artist and an object for that artist, removing
# this choice from the pool.
artist = rand.choice(list(objs_by_artists.keys()))
objs_from_artist = objs_by_artists[artist]
i = rand.randint(0, len(objs_from_artist) - 1)
yield objs_from_artist.pop(i)
# Remove the artist if we've used up all of its objects.
if not objs_from_artist:
del objs_by_artists[artist]
def _take(iter, num):
"""Return a list containing the first `num` values in `iter` (or
fewer, if the iterable ends early).
"""
out = []
for val in iter:
out.append(val)
num -= 1
if num <= 0:
break
return out
def _take_time(iter, secs, album):
"""Return a list containing the first values in `iter`, which should
be Item or Album objects, that add up to the given amount of time in
seconds.
"""
out = []
total_time = 0.0
for obj in iter:
length = _length(obj, album)
if total_time + length <= secs:
out.append(obj)
total_time += length
return out
def random_objs(objs, album, number=1, time=None, equal_chance=False,
random_gen=None):
"""Get a random subset of the provided `objs`.
If `number` is provided, produce that many matches. Otherwise, if
`time` is provided, instead select a list whose total time is close
to that number of minutes. If `equal_chance` is true, give each
artist an equal chance of being included so that artists with more
songs are not represented disproportionately.
"""
rand = random_gen or random
# Permute the objects either in a straightforward way or an
# artist-balanced way.
if equal_chance:
perm = _equal_chance_permutation(objs)
else:
perm = objs
rand.shuffle(perm) # N.B. This shuffles the original list.
# Select objects by time our count.
if time:
return _take_time(perm, time * 60, album)
else:
return _take(perm, number)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -18,7 +17,6 @@ interface. To invoke the CLI, just call beets.ui.main(). The actual
CLI commands are implemented in the ui.commands module.
"""
from __future__ import division, absolute_import, print_function
import optparse
import textwrap
@ -30,19 +28,18 @@ import re
import struct
import traceback
import os.path
from six.moves import input
from beets import logging
from beets import library
from beets import plugins
from beets import util
from beets.util.functemplate import Template
from beets.util.functemplate import template
from beets import config
from beets.util import confit, as_string
from beets.util import as_string
from beets.autotag import mb
from beets.dbcore import query as db_query
from beets.dbcore import db
import six
import confuse
# On Windows platforms, use colorama to support "ANSI" terminal colors.
if sys.platform == 'win32':
@ -61,8 +58,8 @@ log.propagate = False # Don't propagate to root handler.
PF_KEY_QUERIES = {
'comp': u'comp:true',
'singleton': u'singleton:true',
'comp': 'comp:true',
'singleton': 'singleton:true',
}
@ -112,10 +109,7 @@ def decargs(arglist):
"""Given a list of command-line argument bytestrings, attempts to
decode them to Unicode strings when running under Python 2.
"""
if six.PY2:
return [s.decode(util.arg_encoding()) for s in arglist]
else:
return arglist
return arglist
def print_(*strings, **kwargs):
@ -130,30 +124,25 @@ def print_(*strings, **kwargs):
(it defaults to a newline).
"""
if not strings:
strings = [u'']
assert isinstance(strings[0], six.text_type)
strings = ['']
assert isinstance(strings[0], str)
txt = u' '.join(strings)
txt += kwargs.get('end', u'\n')
txt = ' '.join(strings)
txt += kwargs.get('end', '\n')
# Encode the string and write it to stdout.
if six.PY2:
# On Python 2, sys.stdout expects bytes.
# On Python 3, sys.stdout expects text strings and uses the
# exception-throwing encoding error policy. To avoid throwing
# errors and use our configurable encoding override, we use the
# underlying bytes buffer instead.
if hasattr(sys.stdout, 'buffer'):
out = txt.encode(_out_encoding(), 'replace')
sys.stdout.write(out)
sys.stdout.buffer.write(out)
sys.stdout.buffer.flush()
else:
# On Python 3, sys.stdout expects text strings and uses the
# exception-throwing encoding error policy. To avoid throwing
# errors and use our configurable encoding override, we use the
# underlying bytes buffer instead.
if hasattr(sys.stdout, 'buffer'):
out = txt.encode(_out_encoding(), 'replace')
sys.stdout.buffer.write(out)
sys.stdout.buffer.flush()
else:
# In our test harnesses (e.g., DummyOut), sys.stdout.buffer
# does not exist. We instead just record the text string.
sys.stdout.write(txt)
# In our test harnesses (e.g., DummyOut), sys.stdout.buffer
# does not exist. We instead just record the text string.
sys.stdout.write(txt)
# Configuration wrappers.
@ -203,19 +192,16 @@ def input_(prompt=None):
"""
# raw_input incorrectly sends prompts to stderr, not stdout, so we
# use print_() explicitly to display prompts.
# http://bugs.python.org/issue1927
# https://bugs.python.org/issue1927
if prompt:
print_(prompt, end=u' ')
print_(prompt, end=' ')
try:
resp = input()
except EOFError:
raise UserError(u'stdin stream ended while input required')
raise UserError('stdin stream ended while input required')
if six.PY2:
return resp.decode(_in_encoding(), 'ignore')
else:
return resp
return resp
def input_options(options, require=False, prompt=None, fallback_prompt=None,
@ -259,7 +245,7 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
found_letter = letter
break
else:
raise ValueError(u'no unambiguous lettering found')
raise ValueError('no unambiguous lettering found')
letters[found_letter.lower()] = option
index = option.index(found_letter)
@ -267,7 +253,7 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
# Mark the option's shortcut letter for display.
if not require and (
(default is None and not numrange and first) or
(isinstance(default, six.string_types) and
(isinstance(default, str) and
found_letter.lower() == default.lower())):
# The first option is the default; mark it.
show_letter = '[%s]' % found_letter.upper()
@ -303,11 +289,11 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
prompt_part_lengths = []
if numrange:
if isinstance(default, int):
default_name = six.text_type(default)
default_name = str(default)
default_name = colorize('action_default', default_name)
tmpl = '# selection (default %s)'
prompt_parts.append(tmpl % default_name)
prompt_part_lengths.append(len(tmpl % six.text_type(default)))
prompt_part_lengths.append(len(tmpl % str(default)))
else:
prompt_parts.append('# selection')
prompt_part_lengths.append(len(prompt_parts[-1]))
@ -342,9 +328,9 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
# Make a fallback prompt too. This is displayed if the user enters
# something that is not recognized.
if not fallback_prompt:
fallback_prompt = u'Enter one of '
fallback_prompt = 'Enter one of '
if numrange:
fallback_prompt += u'%i-%i, ' % numrange
fallback_prompt += '%i-%i, ' % numrange
fallback_prompt += ', '.join(display_letters) + ':'
resp = input_(prompt)
@ -383,34 +369,41 @@ def input_yn(prompt, require=False):
"yes" unless `require` is `True`, in which case there is no default.
"""
sel = input_options(
('y', 'n'), require, prompt, u'Enter Y or N:'
('y', 'n'), require, prompt, 'Enter Y or N:'
)
return sel == u'y'
return sel == 'y'
def input_select_objects(prompt, objs, rep):
def input_select_objects(prompt, objs, rep, prompt_all=None):
"""Prompt to user to choose all, none, or some of the given objects.
Return the list of selected objects.
`prompt` is the prompt string to use for each question (it should be
phrased as an imperative verb). `rep` is a function to call on each
object to print it out when confirming objects individually.
phrased as an imperative verb). If `prompt_all` is given, it is used
instead of `prompt` for the first (yes(/no/select) question.
`rep` is a function to call on each object to print it out when confirming
objects individually.
"""
choice = input_options(
(u'y', u'n', u's'), False,
u'%s? (Yes/no/select)' % prompt)
('y', 'n', 's'), False,
'%s? (Yes/no/select)' % (prompt_all or prompt))
print() # Blank line.
if choice == u'y': # Yes.
if choice == 'y': # Yes.
return objs
elif choice == u's': # Select.
elif choice == 's': # Select.
out = []
for obj in objs:
rep(obj)
if input_yn(u'%s? (yes/no)' % prompt, True):
answer = input_options(
('y', 'n', 'q'), True, '%s? (yes/no/quit)' % prompt,
'Enter Y or N:'
)
if answer == 'y':
out.append(obj)
print() # go to a new line
elif answer == 'q':
return out
return out
else: # No.
@ -421,14 +414,14 @@ def input_select_objects(prompt, objs, rep):
def human_bytes(size):
"""Formats size, a number of bytes, in a human-readable way."""
powers = [u'', u'K', u'M', u'G', u'T', u'P', u'E', u'Z', u'Y', u'H']
powers = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y', 'H']
unit = 'B'
for power in powers:
if size < 1024:
return u"%3.1f %s%s" % (size, power, unit)
return f"{size:3.1f} {power}{unit}"
size /= 1024.0
unit = u'iB'
return u"big"
unit = 'iB'
return "big"
def human_seconds(interval):
@ -436,13 +429,13 @@ def human_seconds(interval):
interval using English words.
"""
units = [
(1, u'second'),
(60, u'minute'),
(60, u'hour'),
(24, u'day'),
(7, u'week'),
(52, u'year'),
(10, u'decade'),
(1, 'second'),
(60, 'minute'),
(60, 'hour'),
(24, 'day'),
(7, 'week'),
(52, 'year'),
(10, 'decade'),
]
for i in range(len(units) - 1):
increment, suffix = units[i]
@ -455,7 +448,7 @@ def human_seconds(interval):
increment, suffix = units[-1]
interval /= float(increment)
return u"%3.1f %ss" % (interval, suffix)
return f"{interval:3.1f} {suffix}s"
def human_seconds_short(interval):
@ -463,13 +456,13 @@ def human_seconds_short(interval):
string.
"""
interval = int(interval)
return u'%i:%02i' % (interval // 60, interval % 60)
return '%i:%02i' % (interval // 60, interval % 60)
# Colorization.
# ANSI terminal colorization code heavily inspired by pygments:
# http://dev.pocoo.org/hg/pygments-main/file/b2deea5b5030/pygments/console.py
# https://bitbucket.org/birkenfeld/pygments-main/src/default/pygments/console.py
# (pygments is by Tim Hatch, Armin Ronacher, et al.)
COLOR_ESCAPE = "\x1b["
DARK_COLORS = {
@ -516,7 +509,7 @@ def _colorize(color, text):
elif color in LIGHT_COLORS:
escape = COLOR_ESCAPE + "%i;01m" % (LIGHT_COLORS[color] + 30)
else:
raise ValueError(u'no such color %s', color)
raise ValueError('no such color %s', color)
return escape + text + RESET_COLOR
@ -524,22 +517,22 @@ def colorize(color_name, text):
"""Colorize text if colored output is enabled. (Like _colorize but
conditional.)
"""
if config['ui']['color']:
global COLORS
if not COLORS:
COLORS = dict((name,
config['ui']['colors'][name].as_str())
for name in COLOR_NAMES)
# In case a 3rd party plugin is still passing the actual color ('red')
# instead of the abstract color name ('text_error')
color = COLORS.get(color_name)
if not color:
log.debug(u'Invalid color_name: {0}', color_name)
color = color_name
return _colorize(color, text)
else:
if not config['ui']['color'] or 'NO_COLOR' in os.environ.keys():
return text
global COLORS
if not COLORS:
COLORS = {name:
config['ui']['colors'][name].as_str()
for name in COLOR_NAMES}
# In case a 3rd party plugin is still passing the actual color ('red')
# instead of the abstract color name ('text_error')
color = COLORS.get(color_name)
if not color:
log.debug('Invalid color_name: {0}', color_name)
color = color_name
return _colorize(color, text)
def _colordiff(a, b, highlight='text_highlight',
minor_highlight='text_highlight_minor'):
@ -548,11 +541,11 @@ def _colordiff(a, b, highlight='text_highlight',
highlighted intelligently to show differences; other values are
stringified and highlighted in their entirety.
"""
if not isinstance(a, six.string_types) \
or not isinstance(b, six.string_types):
if not isinstance(a, str) \
or not isinstance(b, str):
# Non-strings: use ordinary equality.
a = six.text_type(a)
b = six.text_type(b)
a = str(a)
b = str(b)
if a == b:
return a, b
else:
@ -590,7 +583,7 @@ def _colordiff(a, b, highlight='text_highlight',
else:
assert(False)
return u''.join(a_out), u''.join(b_out)
return ''.join(a_out), ''.join(b_out)
def colordiff(a, b, highlight='text_highlight'):
@ -600,7 +593,7 @@ def colordiff(a, b, highlight='text_highlight'):
if config['ui']['color']:
return _colordiff(a, b, highlight)
else:
return six.text_type(a), six.text_type(b)
return str(a), str(b)
def get_path_formats(subview=None):
@ -611,12 +604,12 @@ def get_path_formats(subview=None):
subview = subview or config['paths']
for query, view in subview.items():
query = PF_KEY_QUERIES.get(query, query) # Expand common queries.
path_formats.append((query, Template(view.as_str())))
path_formats.append((query, template(view.as_str())))
return path_formats
def get_replacements():
"""Confit validation function that reads regex/string pairs.
"""Confuse validation function that reads regex/string pairs.
"""
replacements = []
for pattern, repl in config['replace'].get(dict).items():
@ -625,7 +618,7 @@ def get_replacements():
replacements.append((re.compile(pattern), repl))
except re.error:
raise UserError(
u'malformed regular expression in replace: {0}'.format(
'malformed regular expression in replace: {}'.format(
pattern
)
)
@ -646,7 +639,7 @@ def term_width():
try:
buf = fcntl.ioctl(0, termios.TIOCGWINSZ, ' ' * 4)
except IOError:
except OSError:
return fallback
try:
height, width = struct.unpack('hh', buf)
@ -658,10 +651,10 @@ def term_width():
FLOAT_EPSILON = 0.01
def _field_diff(field, old, new):
"""Given two Model objects, format their values for `field` and
highlight changes among them. Return a human-readable string. If the
value has not changed, return None instead.
def _field_diff(field, old, old_fmt, new, new_fmt):
"""Given two Model objects and their formatted views, format their values
for `field` and highlight changes among them. Return a human-readable
string. If the value has not changed, return None instead.
"""
oldval = old.get(field)
newval = new.get(field)
@ -674,18 +667,18 @@ def _field_diff(field, old, new):
return None
# Get formatted values for output.
oldstr = old.formatted().get(field, u'')
newstr = new.formatted().get(field, u'')
oldstr = old_fmt.get(field, '')
newstr = new_fmt.get(field, '')
# For strings, highlight changes. For others, colorize the whole
# thing.
if isinstance(oldval, six.string_types):
if isinstance(oldval, str):
oldstr, newstr = colordiff(oldval, newstr)
else:
oldstr = colorize('text_error', oldstr)
newstr = colorize('text_error', newstr)
return u'{0} -> {1}'.format(oldstr, newstr)
return f'{oldstr} -> {newstr}'
def show_model_changes(new, old=None, fields=None, always=False):
@ -700,6 +693,11 @@ def show_model_changes(new, old=None, fields=None, always=False):
"""
old = old or new._db._get(type(new), new.id)
# Keep the formatted views around instead of re-creating them in each
# iteration step
old_fmt = old.formatted()
new_fmt = new.formatted()
# Build up lines showing changed fields.
changes = []
for field in old:
@ -708,25 +706,25 @@ def show_model_changes(new, old=None, fields=None, always=False):
continue
# Detect and show difference for this field.
line = _field_diff(field, old, new)
line = _field_diff(field, old, old_fmt, new, new_fmt)
if line:
changes.append(u' {0}: {1}'.format(field, line))
changes.append(f' {field}: {line}')
# New fields.
for field in set(new) - set(old):
if fields and field not in fields:
continue
changes.append(u' {0}: {1}'.format(
changes.append(' {}: {}'.format(
field,
colorize('text_highlight', new.formatted()[field])
colorize('text_highlight', new_fmt[field])
))
# Print changes.
if changes or always:
print_(format(old))
if changes:
print_(u'\n'.join(changes))
print_('\n'.join(changes))
return bool(changes)
@ -759,15 +757,21 @@ def show_path_changes(path_changes):
if max_width > col_width:
# Print every change over two lines
for source, dest in zip(sources, destinations):
log.info(u'{0} \n -> {1}', source, dest)
color_source, color_dest = colordiff(source, dest)
print_('{0} \n -> {1}'.format(color_source, color_dest))
else:
# Print every change on a single line, and add a header
title_pad = max_width - len('Source ') + len(' -> ')
log.info(u'Source {0} Destination', ' ' * title_pad)
print_('Source {0} Destination'.format(' ' * title_pad))
for source, dest in zip(sources, destinations):
pad = max_width - len(source)
log.info(u'{0} {1} -> {2}', source, ' ' * pad, dest)
color_source, color_dest = colordiff(source, dest)
print_('{0} {1} -> {2}'.format(
color_source,
' ' * pad,
color_dest,
))
# Helper functions for option parsing.
@ -783,22 +787,25 @@ def _store_dict(option, opt_str, value, parser):
if option_values is None:
# This is the first supplied ``key=value`` pair of option.
# Initialize empty dictionary and get a reference to it.
setattr(parser.values, dest, dict())
setattr(parser.values, dest, {})
option_values = getattr(parser.values, dest)
# Decode the argument using the platform's argument encoding.
value = util.text_string(value, util.arg_encoding())
try:
key, value = map(lambda s: util.text_string(s), value.split('='))
key, value = value.split('=', 1)
if not (key and value):
raise ValueError
except ValueError:
raise UserError(
"supplied argument `{0}' is not of the form `key=value'"
"supplied argument `{}' is not of the form `key=value'"
.format(value))
option_values[key] = value
class CommonOptionsParser(optparse.OptionParser, object):
class CommonOptionsParser(optparse.OptionParser):
"""Offers a simple way to add common formatting options.
Options available include:
@ -813,8 +820,9 @@ class CommonOptionsParser(optparse.OptionParser, object):
Each method is fully documented in the related method.
"""
def __init__(self, *args, **kwargs):
super(CommonOptionsParser, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self._album_flags = False
# this serves both as an indicator that we offer the feature AND allows
# us to check whether it has been specified on the CLI - bypassing the
@ -828,7 +836,7 @@ class CommonOptionsParser(optparse.OptionParser, object):
Sets the album property on the options extracted from the CLI.
"""
album = optparse.Option(*flags, action='store_true',
help=u'match albums instead of tracks')
help='match albums instead of tracks')
self.add_option(album)
self._album_flags = set(flags)
@ -846,7 +854,7 @@ class CommonOptionsParser(optparse.OptionParser, object):
elif value:
value, = decargs([value])
else:
value = u''
value = ''
parser.values.format = value
if target:
@ -873,14 +881,14 @@ class CommonOptionsParser(optparse.OptionParser, object):
By default this affects both items and albums. If add_album_option()
is used then the target will be autodetected.
Sets the format property to u'$path' on the options extracted from the
Sets the format property to '$path' on the options extracted from the
CLI.
"""
path = optparse.Option(*flags, nargs=0, action='callback',
callback=self._set_format,
callback_kwargs={'fmt': u'$path',
callback_kwargs={'fmt': '$path',
'store_true': True},
help=u'print paths for matched items or albums')
help='print paths for matched items or albums')
self.add_option(path)
def add_format_option(self, flags=('-f', '--format'), target=None):
@ -900,7 +908,7 @@ class CommonOptionsParser(optparse.OptionParser, object):
"""
kwargs = {}
if target:
if isinstance(target, six.string_types):
if isinstance(target, str):
target = {'item': library.Item,
'album': library.Album}[target]
kwargs['target'] = target
@ -908,7 +916,7 @@ class CommonOptionsParser(optparse.OptionParser, object):
opt = optparse.Option(*flags, action='callback',
callback=self._set_format,
callback_kwargs=kwargs,
help=u'print with custom format')
help='print with custom format')
self.add_option(opt)
def add_all_common_options(self):
@ -923,14 +931,15 @@ class CommonOptionsParser(optparse.OptionParser, object):
#
# This is a fairly generic subcommand parser for optparse. It is
# maintained externally here:
# http://gist.github.com/462717
# https://gist.github.com/462717
# There you will also find a better description of the code and a more
# succinct example program.
class Subcommand(object):
class Subcommand:
"""A subcommand of a root command-line application that may be
invoked by a SubcommandOptionParser.
"""
def __init__(self, name, parser=None, help='', aliases=(), hide=False):
"""Creates a new subcommand. name is the primary way to invoke
the subcommand; aliases are alternate names. parser is an
@ -958,7 +967,7 @@ class Subcommand(object):
@root_parser.setter
def root_parser(self, root_parser):
self._root_parser = root_parser
self.parser.prog = '{0} {1}'.format(
self.parser.prog = '{} {}'.format(
as_string(root_parser.get_prog_name()), self.name)
@ -974,13 +983,13 @@ class SubcommandsOptionParser(CommonOptionsParser):
"""
# A more helpful default usage.
if 'usage' not in kwargs:
kwargs['usage'] = u"""
kwargs['usage'] = """
%prog COMMAND [ARGS...]
%prog help COMMAND"""
kwargs['add_help_option'] = False
# Super constructor.
super(SubcommandsOptionParser, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
# Our root parser needs to stop on the first unrecognized argument.
self.disable_interspersed_args()
@ -997,7 +1006,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
# Add the list of subcommands to the help message.
def format_help(self, formatter=None):
# Get the original help message, to which we will append.
out = super(SubcommandsOptionParser, self).format_help(formatter)
out = super().format_help(formatter)
if formatter is None:
formatter = self.formatter
@ -1083,7 +1092,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
cmdname = args.pop(0)
subcommand = self._subcommand_for_name(cmdname)
if not subcommand:
raise UserError(u"unknown command '{0}'".format(cmdname))
raise UserError(f"unknown command '{cmdname}'")
suboptions, subargs = subcommand.parse_args(args)
return subcommand, suboptions, subargs
@ -1094,26 +1103,32 @@ optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',)
# The main entry point and bootstrapping.
def _load_plugins(config):
"""Load the plugins specified in the configuration.
def _load_plugins(options, config):
"""Load the plugins specified on the command line or in the configuration.
"""
paths = config['pluginpath'].as_str_seq(split=False)
paths = [util.normpath(p) for p in paths]
log.debug(u'plugin paths: {0}', util.displayable_path(paths))
log.debug('plugin paths: {0}', util.displayable_path(paths))
# On Python 3, the search paths need to be unicode.
paths = [util.py3_path(p) for p in paths]
# Extend the `beetsplug` package to include the plugin paths.
import beetsplug
beetsplug.__path__ = paths + beetsplug.__path__
beetsplug.__path__ = paths + list(beetsplug.__path__)
# For backwards compatibility, also support plugin paths that
# *contain* a `beetsplug` package.
sys.path += paths
plugins.load_plugins(config['plugins'].as_str_seq())
plugins.send("pluginload")
# If we were given any plugins on the command line, use those.
if options.plugins is not None:
plugin_list = (options.plugins.split(',')
if len(options.plugins) > 0 else [])
else:
plugin_list = config['plugins'].as_str_seq()
plugins.load_plugins(plugin_list)
return plugins
@ -1127,7 +1142,20 @@ def _setup(options, lib=None):
config = _configure(options)
plugins = _load_plugins(config)
plugins = _load_plugins(options, config)
# Add types and queries defined by plugins.
plugin_types_album = plugins.types(library.Album)
library.Album._types.update(plugin_types_album)
item_types = plugin_types_album.copy()
item_types.update(library.Item._types)
item_types.update(plugins.types(library.Item))
library.Item._types = item_types
library.Item._queries.update(plugins.named_queries(library.Item))
library.Album._queries.update(plugins.named_queries(library.Album))
plugins.send("pluginload")
# Get the default subcommands.
from beets.ui.commands import default_commands
@ -1138,8 +1166,6 @@ def _setup(options, lib=None):
if lib is None:
lib = _open_library(config)
plugins.send("library_opened", lib=lib)
library.Item._types.update(plugins.types(library.Item))
library.Album._types.update(plugins.types(library.Album))
return subcommands, plugins, lib
@ -1165,18 +1191,18 @@ def _configure(options):
log.set_global_level(logging.INFO)
if overlay_path:
log.debug(u'overlaying configuration: {0}',
log.debug('overlaying configuration: {0}',
util.displayable_path(overlay_path))
config_path = config.user_config_path()
if os.path.isfile(config_path):
log.debug(u'user configuration: {0}',
log.debug('user configuration: {0}',
util.displayable_path(config_path))
else:
log.debug(u'no user configuration found at {0}',
log.debug('no user configuration found at {0}',
util.displayable_path(config_path))
log.debug(u'data directory: {0}',
log.debug('data directory: {0}',
util.displayable_path(config.config_dir()))
return config
@ -1193,13 +1219,14 @@ def _open_library(config):
get_replacements(),
)
lib.get_item(0) # Test database connection.
except (sqlite3.OperationalError, sqlite3.DatabaseError):
log.debug(u'{}', traceback.format_exc())
raise UserError(u"database file {0} could not be opened".format(
util.displayable_path(dbpath)
except (sqlite3.OperationalError, sqlite3.DatabaseError) as db_error:
log.debug('{}', traceback.format_exc())
raise UserError("database file {} cannot not be opened: {}".format(
util.displayable_path(dbpath),
db_error
))
log.debug(u'library database: {0}\n'
u'library directory: {1}',
log.debug('library database: {0}\n'
'library directory: {1}',
util.displayable_path(lib.path),
util.displayable_path(lib.directory))
return lib
@ -1213,15 +1240,17 @@ def _raw_main(args, lib=None):
parser.add_format_option(flags=('--format-item',), target=library.Item)
parser.add_format_option(flags=('--format-album',), target=library.Album)
parser.add_option('-l', '--library', dest='library',
help=u'library database file to use')
help='library database file to use')
parser.add_option('-d', '--directory', dest='directory',
help=u"destination music directory")
help="destination music directory")
parser.add_option('-v', '--verbose', dest='verbose', action='count',
help=u'log more details (use twice for even more)')
help='log more details (use twice for even more)')
parser.add_option('-c', '--config', dest='config',
help=u'path to configuration file')
help='path to configuration file')
parser.add_option('-p', '--plugins', dest='plugins',
help='a comma-separated list of plugins to load')
parser.add_option('-h', '--help', dest='help', action='store_true',
help=u'show this help message and exit')
help='show this help message and exit')
parser.add_option('--version', dest='version', action='store_true',
help=optparse.SUPPRESS_HELP)
@ -1256,7 +1285,7 @@ def main(args=None):
_raw_main(args)
except UserError as exc:
message = exc.args[0] if exc.args else None
log.error(u'error: {0}', message)
log.error('error: {0}', message)
sys.exit(1)
except util.HumanReadableException as exc:
exc.log(log)
@ -1267,13 +1296,13 @@ def main(args=None):
log.debug('{}', traceback.format_exc())
log.error('{}', exc)
sys.exit(1)
except confit.ConfigError as exc:
log.error(u'configuration error: {0}', exc)
except confuse.ConfigError as exc:
log.error('configuration error: {0}', exc)
sys.exit(1)
except db_query.InvalidQueryError as exc:
log.error(u'invalid query: {0}', exc)
log.error('invalid query: {0}', exc)
sys.exit(1)
except IOError as exc:
except OSError as exc:
if exc.errno == errno.EPIPE:
# "Broken pipe". End silently.
sys.stderr.close()
@ -1281,11 +1310,11 @@ def main(args=None):
raise
except KeyboardInterrupt:
# Silently ignore ^C except in verbose mode.
log.debug(u'{}', traceback.format_exc())
log.debug('{}', traceback.format_exc())
except db.DBAccessError as exc:
log.error(
u'database access error: {0}\n'
u'the library file might have a permissions problem',
'database access error: {0}\n'
'the library file might have a permissions problem',
exc
)
sys.exit(1)

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -15,28 +14,28 @@
"""Miscellaneous utility functions."""
from __future__ import division, absolute_import, print_function
import os
import sys
import errno
import locale
import re
import tempfile
import shutil
import fnmatch
from collections import Counter
import functools
from collections import Counter, namedtuple
from multiprocessing.pool import ThreadPool
import traceback
import subprocess
import platform
import shlex
from beets.util import hidden
import six
from unidecode import unidecode
from enum import Enum
MAX_FILENAME_LENGTH = 200
WINDOWS_MAGIC_PREFIX = u'\\\\?\\'
SNI_SUPPORTED = sys.version_info >= (2, 7, 9)
WINDOWS_MAGIC_PREFIX = '\\\\?\\'
class HumanReadableException(Exception):
@ -58,27 +57,27 @@ class HumanReadableException(Exception):
self.reason = reason
self.verb = verb
self.tb = tb
super(HumanReadableException, self).__init__(self.get_message())
super().__init__(self.get_message())
def _gerund(self):
"""Generate a (likely) gerund form of the English verb.
"""
if u' ' in self.verb:
if ' ' in self.verb:
return self.verb
gerund = self.verb[:-1] if self.verb.endswith(u'e') else self.verb
gerund += u'ing'
gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb
gerund += 'ing'
return gerund
def _reasonstr(self):
"""Get the reason as a string."""
if isinstance(self.reason, six.text_type):
if isinstance(self.reason, str):
return self.reason
elif isinstance(self.reason, bytes):
return self.reason.decode('utf-8', 'ignore')
elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError
return self.reason.strerror
else:
return u'"{0}"'.format(six.text_type(self.reason))
return '"{}"'.format(str(self.reason))
def get_message(self):
"""Create the human-readable description of the error, sans
@ -92,7 +91,7 @@ class HumanReadableException(Exception):
"""
if self.tb:
logger.debug(self.tb)
logger.error(u'{0}: {1}', self.error_kind, self.args[0])
logger.error('{0}: {1}', self.error_kind, self.args[0])
class FilesystemError(HumanReadableException):
@ -100,29 +99,30 @@ class FilesystemError(HumanReadableException):
via a function in this module. The `paths` field is a sequence of
pathnames involved in the operation.
"""
def __init__(self, reason, verb, paths, tb=None):
self.paths = paths
super(FilesystemError, self).__init__(reason, verb, tb)
super().__init__(reason, verb, tb)
def get_message(self):
# Use a nicer English phrasing for some specific verbs.
if self.verb in ('move', 'copy', 'rename'):
clause = u'while {0} {1} to {2}'.format(
clause = 'while {} {} to {}'.format(
self._gerund(),
displayable_path(self.paths[0]),
displayable_path(self.paths[1])
)
elif self.verb in ('delete', 'write', 'create', 'read'):
clause = u'while {0} {1}'.format(
clause = 'while {} {}'.format(
self._gerund(),
displayable_path(self.paths[0])
)
else:
clause = u'during {0} of paths {1}'.format(
self.verb, u', '.join(displayable_path(p) for p in self.paths)
clause = 'during {} of paths {}'.format(
self.verb, ', '.join(displayable_path(p) for p in self.paths)
)
return u'{0} {1}'.format(self._reasonstr(), clause)
return f'{self._reasonstr()} {clause}'
class MoveOperation(Enum):
@ -132,6 +132,8 @@ class MoveOperation(Enum):
COPY = 1
LINK = 2
HARDLINK = 3
REFLINK = 4
REFLINK_AUTO = 5
def normpath(path):
@ -182,7 +184,7 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
contents = os.listdir(syspath(path))
except OSError as exc:
if logger:
logger.warning(u'could not list directory {0}: {1}'.format(
logger.warning('could not list directory {}: {}'.format(
displayable_path(path), exc.strerror
))
return
@ -195,6 +197,10 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
skip = False
for pat in ignore:
if fnmatch.fnmatch(base, pat):
if logger:
logger.debug('ignoring {} due to ignore rule {}'.format(
base, pat
))
skip = True
break
if skip:
@ -217,8 +223,14 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
for base in dirs:
cur = os.path.join(path, base)
# yield from sorted_walk(...)
for res in sorted_walk(cur, ignore, ignore_hidden, logger):
yield res
yield from sorted_walk(cur, ignore, ignore_hidden, logger)
def path_as_posix(path):
"""Return the string representation of the path with forward (/)
slashes.
"""
return path.replace(b'\\', b'/')
def mkdirall(path):
@ -229,7 +241,7 @@ def mkdirall(path):
if not os.path.isdir(syspath(ancestor)):
try:
os.mkdir(syspath(ancestor))
except (OSError, IOError) as exc:
except OSError as exc:
raise FilesystemError(exc, 'create', (ancestor,),
traceback.format_exc())
@ -282,13 +294,13 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
continue
clutter = [bytestring_path(c) for c in clutter]
match_paths = [bytestring_path(d) for d in os.listdir(directory)]
if fnmatch_all(match_paths, clutter):
# Directory contains only clutter (or nothing).
try:
try:
if fnmatch_all(match_paths, clutter):
# Directory contains only clutter (or nothing).
shutil.rmtree(directory)
except OSError:
else:
break
else:
except OSError:
break
@ -367,18 +379,18 @@ def bytestring_path(path):
PATH_SEP = bytestring_path(os.sep)
def displayable_path(path, separator=u'; '):
def displayable_path(path, separator='; '):
"""Attempts to decode a bytestring path to a unicode object for the
purpose of displaying it to the user. If the `path` argument is a
list or a tuple, the elements are joined with `separator`.
"""
if isinstance(path, (list, tuple)):
return separator.join(displayable_path(p) for p in path)
elif isinstance(path, six.text_type):
elif isinstance(path, str):
return path
elif not isinstance(path, bytes):
# A non-string object: just get its unicode representation.
return six.text_type(path)
return str(path)
try:
return path.decode(_fsencoding(), 'ignore')
@ -397,7 +409,7 @@ def syspath(path, prefix=True):
if os.path.__name__ != 'ntpath':
return path
if not isinstance(path, six.text_type):
if not isinstance(path, str):
# Beets currently represents Windows paths internally with UTF-8
# arbitrarily. But earlier versions used MBCS because it is
# reported as the FS encoding by Windows. Try both.
@ -410,11 +422,11 @@ def syspath(path, prefix=True):
path = path.decode(encoding, 'replace')
# Add the magic prefix if it isn't already there.
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
if path.startswith(u'\\\\'):
if path.startswith('\\\\'):
# UNC path. Final path should look like \\?\UNC\...
path = u'UNC' + path[1:]
path = 'UNC' + path[1:]
path = WINDOWS_MAGIC_PREFIX + path
return path
@ -436,7 +448,7 @@ def remove(path, soft=True):
return
try:
os.remove(path)
except (OSError, IOError) as exc:
except OSError as exc:
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
@ -451,10 +463,10 @@ def copy(path, dest, replace=False):
path = syspath(path)
dest = syspath(dest)
if not replace and os.path.exists(dest):
raise FilesystemError(u'file exists', 'copy', (path, dest))
raise FilesystemError('file exists', 'copy', (path, dest))
try:
shutil.copyfile(path, dest)
except (OSError, IOError) as exc:
except OSError as exc:
raise FilesystemError(exc, 'copy', (path, dest),
traceback.format_exc())
@ -467,24 +479,37 @@ def move(path, dest, replace=False):
instead, in which case metadata will *not* be preserved. Paths are
translated to system paths.
"""
if os.path.isdir(path):
raise FilesystemError(u'source is directory', 'move', (path, dest))
if os.path.isdir(dest):
raise FilesystemError(u'destination is directory', 'move',
(path, dest))
if samefile(path, dest):
return
path = syspath(path)
dest = syspath(dest)
if os.path.exists(dest) and not replace:
raise FilesystemError(u'file exists', 'rename', (path, dest))
raise FilesystemError('file exists', 'rename', (path, dest))
# First, try renaming the file.
try:
os.rename(path, dest)
os.replace(path, dest)
except OSError:
# Otherwise, copy and delete the original.
tmp = tempfile.mktemp(suffix='.beets',
prefix=py3_path(b'.' + os.path.basename(dest)),
dir=py3_path(os.path.dirname(dest)))
tmp = syspath(tmp)
try:
shutil.copyfile(path, dest)
shutil.copyfile(path, tmp)
os.replace(tmp, dest)
tmp = None
os.remove(path)
except (OSError, IOError) as exc:
except OSError as exc:
raise FilesystemError(exc, 'move', (path, dest),
traceback.format_exc())
finally:
if tmp is not None:
os.remove(tmp)
def link(path, dest, replace=False):
@ -496,18 +521,18 @@ def link(path, dest, replace=False):
return
if os.path.exists(syspath(dest)) and not replace:
raise FilesystemError(u'file exists', 'rename', (path, dest))
raise FilesystemError('file exists', 'rename', (path, dest))
try:
os.symlink(syspath(path), syspath(dest))
except NotImplementedError:
# raised on python >= 3.2 and Windows versions before Vista
raise FilesystemError(u'OS does not support symbolic links.'
raise FilesystemError('OS does not support symbolic links.'
'link', (path, dest), traceback.format_exc())
except OSError as exc:
# TODO: Windows version checks can be removed for python 3
if hasattr('sys', 'getwindowsversion'):
if sys.getwindowsversion()[0] < 6: # is before Vista
exc = u'OS does not support symbolic links.'
exc = 'OS does not support symbolic links.'
raise FilesystemError(exc, 'link', (path, dest),
traceback.format_exc())
@ -521,21 +546,50 @@ def hardlink(path, dest, replace=False):
return
if os.path.exists(syspath(dest)) and not replace:
raise FilesystemError(u'file exists', 'rename', (path, dest))
raise FilesystemError('file exists', 'rename', (path, dest))
try:
os.link(syspath(path), syspath(dest))
except NotImplementedError:
raise FilesystemError(u'OS does not support hard links.'
raise FilesystemError('OS does not support hard links.'
'link', (path, dest), traceback.format_exc())
except OSError as exc:
if exc.errno == errno.EXDEV:
raise FilesystemError(u'Cannot hard link across devices.'
raise FilesystemError('Cannot hard link across devices.'
'link', (path, dest), traceback.format_exc())
else:
raise FilesystemError(exc, 'link', (path, dest),
traceback.format_exc())
def reflink(path, dest, replace=False, fallback=False):
"""Create a reflink from `dest` to `path`.
Raise an `OSError` if `dest` already exists, unless `replace` is
True. If `path` == `dest`, then do nothing.
If reflinking fails and `fallback` is enabled, try copying the file
instead. Otherwise, raise an error without trying a plain copy.
May raise an `ImportError` if the `reflink` module is not available.
"""
import reflink as pyreflink
if samefile(path, dest):
return
if os.path.exists(syspath(dest)) and not replace:
raise FilesystemError('file exists', 'rename', (path, dest))
try:
pyreflink.reflink(path, dest)
except (NotImplementedError, pyreflink.ReflinkImpossibleError):
if fallback:
copy(path, dest, replace)
else:
raise FilesystemError('OS/filesystem does not support reflinks.',
'link', (path, dest), traceback.format_exc())
def unique_path(path):
"""Returns a version of ``path`` that does not exist on the
filesystem. Specifically, if ``path` itself already exists, then
@ -553,22 +607,23 @@ def unique_path(path):
num = 0
while True:
num += 1
suffix = u'.{}'.format(num).encode() + ext
suffix = f'.{num}'.encode() + ext
new_path = base + suffix
if not os.path.exists(new_path):
return new_path
# Note: The Windows "reserved characters" are, of course, allowed on
# Unix. They are forbidden here because they cause problems on Samba
# shares, which are sufficiently common as to cause frequent problems.
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
CHAR_REPLACE = [
(re.compile(r'[\\/]'), u'_'), # / and \ -- forbidden everywhere.
(re.compile(r'^\.'), u'_'), # Leading dot (hidden files on Unix).
(re.compile(r'[\x00-\x1f]'), u''), # Control characters.
(re.compile(r'[<>:"\?\*\|]'), u'_'), # Windows "reserved characters".
(re.compile(r'\.$'), u'_'), # Trailing dots.
(re.compile(r'\s+$'), u''), # Trailing whitespace.
(re.compile(r'[\\/]'), '_'), # / and \ -- forbidden everywhere.
(re.compile(r'^\.'), '_'), # Leading dot (hidden files on Unix).
(re.compile(r'[\x00-\x1f]'), ''), # Control characters.
(re.compile(r'[<>:"\?\*\|]'), '_'), # Windows "reserved characters".
(re.compile(r'\.$'), '_'), # Trailing dots.
(re.compile(r'\s+$'), ''), # Trailing whitespace.
]
@ -692,36 +747,29 @@ def py3_path(path):
it is. So this function helps us "smuggle" the true bytes data
through APIs that took Python 3's Unicode mandate too seriously.
"""
if isinstance(path, six.text_type):
if isinstance(path, str):
return path
assert isinstance(path, bytes)
if six.PY2:
return path
return os.fsdecode(path)
def str2bool(value):
"""Returns a boolean reflecting a human-entered string."""
return value.lower() in (u'yes', u'1', u'true', u't', u'y')
return value.lower() in ('yes', '1', 'true', 't', 'y')
def as_string(value):
"""Convert a value to a Unicode object for matching with a query.
None becomes the empty string. Bytestrings are silently decoded.
"""
if six.PY2:
buffer_types = buffer, memoryview # noqa: F821
else:
buffer_types = memoryview
if value is None:
return u''
elif isinstance(value, buffer_types):
return ''
elif isinstance(value, memoryview):
return bytes(value).decode('utf-8', 'ignore')
elif isinstance(value, bytes):
return value.decode('utf-8', 'ignore')
else:
return six.text_type(value)
return str(value)
def text_string(value, encoding='utf-8'):
@ -744,7 +792,7 @@ def plurality(objs):
"""
c = Counter(objs)
if not c:
raise ValueError(u'sequence must be non-empty')
raise ValueError('sequence must be non-empty')
return c.most_common(1)[0]
@ -761,7 +809,11 @@ def cpu_count():
num = 0
elif sys.platform == 'darwin':
try:
num = int(command_output(['/usr/sbin/sysctl', '-n', 'hw.ncpu']))
num = int(command_output([
'/usr/sbin/sysctl',
'-n',
'hw.ncpu',
]).stdout)
except (ValueError, OSError, subprocess.CalledProcessError):
num = 0
else:
@ -781,20 +833,23 @@ def convert_command_args(args):
assert isinstance(args, list)
def convert(arg):
if six.PY2:
if isinstance(arg, six.text_type):
arg = arg.encode(arg_encoding())
else:
if isinstance(arg, bytes):
arg = arg.decode(arg_encoding(), 'surrogateescape')
if isinstance(arg, bytes):
arg = arg.decode(arg_encoding(), 'surrogateescape')
return arg
return [convert(a) for a in args]
# stdout and stderr as bytes
CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr"))
def command_output(cmd, shell=False):
"""Runs the command and returns its output after it has exited.
Returns a CommandOutput. The attributes ``stdout`` and ``stderr`` contain
byte strings of the respective output streams.
``cmd`` is a list of arguments starting with the command names. The
arguments are bytes on Unix and strings on Windows.
If ``shell`` is true, ``cmd`` is assumed to be a string and passed to a
@ -829,7 +884,7 @@ def command_output(cmd, shell=False):
cmd=' '.join(cmd),
output=stdout + stderr,
)
return stdout
return CommandOutput(stdout, stderr)
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
@ -876,25 +931,6 @@ def editor_command():
return open_anything()
def shlex_split(s):
"""Split a Unicode or bytes string according to shell lexing rules.
Raise `ValueError` if the string is not a well-formed shell string.
This is a workaround for a bug in some versions of Python.
"""
if not six.PY2 or isinstance(s, bytes): # Shlex works fine.
return shlex.split(s)
elif isinstance(s, six.text_type):
# Work around a Python bug.
# http://bugs.python.org/issue6988
bs = s.encode('utf-8')
return [c.decode('utf-8') for c in shlex.split(bs)]
else:
raise TypeError(u'shlex_split called with non-string')
def interactive_open(targets, command):
"""Open the files in `targets` by `exec`ing a new `command`, given
as a Unicode string. (The new program takes over, and Python
@ -906,7 +942,7 @@ def interactive_open(targets, command):
# Split the command string into its arguments.
try:
args = shlex_split(command)
args = shlex.split(command)
except ValueError: # Malformed shell tokens.
args = [command]
@ -921,7 +957,7 @@ def _windows_long_path_name(short_path):
"""Use Windows' `GetLongPathNameW` via ctypes to get the canonical,
long path given a short filename.
"""
if not isinstance(short_path, six.text_type):
if not isinstance(short_path, str):
short_path = short_path.decode(_fsencoding())
import ctypes
@ -982,7 +1018,7 @@ def raw_seconds_short(string):
"""
match = re.match(r'^(\d+):([0-5]\d)$', string)
if not match:
raise ValueError(u'String not in M:SS format')
raise ValueError('String not in M:SS format')
minutes, seconds = map(int, match.groups())
return float(minutes * 60 + seconds)
@ -1009,3 +1045,59 @@ def asciify_path(path, sep_replace):
sep_replace
)
return os.sep.join(path_components)
def par_map(transform, items):
"""Apply the function `transform` to all the elements in the
iterable `items`, like `map(transform, items)` but with no return
value. The map *might* happen in parallel: it's parallel on Python 3
and sequential on Python 2.
The parallelism uses threads (not processes), so this is only useful
for IO-bound `transform`s.
"""
pool = ThreadPool()
pool.map(transform, items)
pool.close()
pool.join()
def lazy_property(func):
"""A decorator that creates a lazily evaluated property. On first access,
the property is assigned the return value of `func`. This first value is
stored, so that future accesses do not have to evaluate `func` again.
This behaviour is useful when `func` is expensive to evaluate, and it is
not certain that the result will be needed.
"""
field_name = '_' + func.__name__
@property
@functools.wraps(func)
def wrapper(self):
if hasattr(self, field_name):
return getattr(self, field_name)
value = func(self)
setattr(self, field_name, value)
return value
return wrapper
def decode_commandline_path(path):
"""Prepare a path for substitution into commandline template.
On Python 3, we need to construct the subprocess commands to invoke as a
Unicode string. On Unix, this is a little unfortunate---the OS is
expecting bytes---so we use surrogate escaping and decode with the
argument encoding, which is the same encoding that will then be
*reversed* to recover the same bytes before invoking the OS. On
Windows, we want to preserve the Unicode filename "as is."
"""
# On Python 3, the template is a Unicode string, which only supports
# substitution of Unicode variables.
if platform.system() == 'Windows':
return path.decode(_fsencoding())
else:
return path.decode(arg_encoding(), 'surrogateescape')

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Fabrice Laporte
#
@ -16,38 +15,39 @@
"""Abstraction layer to resize images using PIL, ImageMagick, or a
public resizing proxy if neither is available.
"""
from __future__ import division, absolute_import, print_function
import subprocess
import os
import os.path
import re
from tempfile import NamedTemporaryFile
from six.moves.urllib.parse import urlencode
from urllib.parse import urlencode
from beets import logging
from beets import util
import six
# Resizing methods
PIL = 1
IMAGEMAGICK = 2
WEBPROXY = 3
if util.SNI_SUPPORTED:
PROXY_URL = 'https://images.weserv.nl/'
else:
PROXY_URL = 'http://images.weserv.nl/'
PROXY_URL = 'https://images.weserv.nl/'
log = logging.getLogger('beets')
def resize_url(url, maxwidth):
def resize_url(url, maxwidth, quality=0):
"""Return a proxied image URL that resizes the original image to
maxwidth (preserving aspect ratio).
"""
return '{0}?{1}'.format(PROXY_URL, urlencode({
params = {
'url': url.replace('http://', ''),
'w': maxwidth,
}))
}
if quality > 0:
params['q'] = quality
return '{}?{}'.format(PROXY_URL, urlencode(params))
def temp_file_for(path):
@ -59,48 +59,102 @@ def temp_file_for(path):
return util.bytestring_path(f.name)
def pil_resize(maxwidth, path_in, path_out=None):
def pil_resize(maxwidth, path_in, path_out=None, quality=0, max_filesize=0):
"""Resize using Python Imaging Library (PIL). Return the output path
of resized image.
"""
path_out = path_out or temp_file_for(path_in)
from PIL import Image
log.debug(u'artresizer: PIL resizing {0} to {1}',
log.debug('artresizer: PIL resizing {0} to {1}',
util.displayable_path(path_in), util.displayable_path(path_out))
try:
im = Image.open(util.syspath(path_in))
size = maxwidth, maxwidth
im.thumbnail(size, Image.ANTIALIAS)
im.save(path_out)
return path_out
except IOError:
log.error(u"PIL cannot create thumbnail for '{0}'",
if quality == 0:
# Use PIL's default quality.
quality = -1
# progressive=False only affects JPEGs and is the default,
# but we include it here for explicitness.
im.save(util.py3_path(path_out), quality=quality, progressive=False)
if max_filesize > 0:
# If maximum filesize is set, we attempt to lower the quality of
# jpeg conversion by a proportional amount, up to 3 attempts
# First, set the maximum quality to either provided, or 95
if quality > 0:
lower_qual = quality
else:
lower_qual = 95
for i in range(5):
# 5 attempts is an abitrary choice
filesize = os.stat(util.syspath(path_out)).st_size
log.debug("PIL Pass {0} : Output size: {1}B", i, filesize)
if filesize <= max_filesize:
return path_out
# The relationship between filesize & quality will be
# image dependent.
lower_qual -= 10
# Restrict quality dropping below 10
if lower_qual < 10:
lower_qual = 10
# Use optimize flag to improve filesize decrease
im.save(util.py3_path(path_out), quality=lower_qual,
optimize=True, progressive=False)
log.warning("PIL Failed to resize file to below {0}B",
max_filesize)
return path_out
else:
return path_out
except OSError:
log.error("PIL cannot create thumbnail for '{0}'",
util.displayable_path(path_in))
return path_in
def im_resize(maxwidth, path_in, path_out=None):
"""Resize using ImageMagick's ``convert`` tool.
Return the output path of resized image.
def im_resize(maxwidth, path_in, path_out=None, quality=0, max_filesize=0):
"""Resize using ImageMagick.
Use the ``magick`` program or ``convert`` on older versions. Return
the output path of resized image.
"""
path_out = path_out or temp_file_for(path_in)
log.debug(u'artresizer: ImageMagick resizing {0} to {1}',
log.debug('artresizer: ImageMagick resizing {0} to {1}',
util.displayable_path(path_in), util.displayable_path(path_out))
# "-resize WIDTHx>" shrinks images with the width larger
# than the given width while maintaining the aspect ratio
# with regards to the height.
# ImageMagick already seems to default to no interlace, but we include it
# here for the sake of explicitness.
cmd = ArtResizer.shared.im_convert_cmd + [
util.syspath(path_in, prefix=False),
'-resize', f'{maxwidth}x>',
'-interlace', 'none',
]
if quality > 0:
cmd += ['-quality', f'{quality}']
# "-define jpeg:extent=SIZEb" sets the target filesize for imagemagick to
# SIZE in bytes.
if max_filesize > 0:
cmd += ['-define', f'jpeg:extent={max_filesize}b']
cmd.append(util.syspath(path_out, prefix=False))
try:
util.command_output([
'convert', util.syspath(path_in, prefix=False),
'-resize', '{0}x>'.format(maxwidth),
util.syspath(path_out, prefix=False),
])
util.command_output(cmd)
except subprocess.CalledProcessError:
log.warning(u'artresizer: IM convert failed for {0}',
log.warning('artresizer: IM convert failed for {0}',
util.displayable_path(path_in))
return path_in
return path_out
@ -112,31 +166,33 @@ BACKEND_FUNCS = {
def pil_getsize(path_in):
from PIL import Image
try:
im = Image.open(util.syspath(path_in))
return im.size
except IOError as exc:
log.error(u"PIL could not read file {}: {}",
except OSError as exc:
log.error("PIL could not read file {}: {}",
util.displayable_path(path_in), exc)
def im_getsize(path_in):
cmd = ['identify', '-format', '%w %h',
util.syspath(path_in, prefix=False)]
cmd = ArtResizer.shared.im_identify_cmd + \
['-format', '%w %h', util.syspath(path_in, prefix=False)]
try:
out = util.command_output(cmd)
out = util.command_output(cmd).stdout
except subprocess.CalledProcessError as exc:
log.warning(u'ImageMagick size query failed')
log.warning('ImageMagick size query failed')
log.debug(
u'`convert` exited with (status {}) when '
u'getting size with command {}:\n{}',
'`convert` exited with (status {}) when '
'getting size with command {}:\n{}',
exc.returncode, cmd, exc.output.strip()
)
return
try:
return tuple(map(int, out.split(b' ')))
except IndexError:
log.warning(u'Could not understand IM output: {0!r}', out)
log.warning('Could not understand IM output: {0!r}', out)
BACKEND_GET_SIZE = {
@ -145,14 +201,115 @@ BACKEND_GET_SIZE = {
}
def pil_deinterlace(path_in, path_out=None):
path_out = path_out or temp_file_for(path_in)
from PIL import Image
try:
im = Image.open(util.syspath(path_in))
im.save(util.py3_path(path_out), progressive=False)
return path_out
except IOError:
return path_in
def im_deinterlace(path_in, path_out=None):
path_out = path_out or temp_file_for(path_in)
cmd = ArtResizer.shared.im_convert_cmd + [
util.syspath(path_in, prefix=False),
'-interlace', 'none',
util.syspath(path_out, prefix=False),
]
try:
util.command_output(cmd)
return path_out
except subprocess.CalledProcessError:
return path_in
DEINTERLACE_FUNCS = {
PIL: pil_deinterlace,
IMAGEMAGICK: im_deinterlace,
}
def im_get_format(filepath):
cmd = ArtResizer.shared.im_identify_cmd + [
'-format', '%[magick]',
util.syspath(filepath)
]
try:
return util.command_output(cmd).stdout
except subprocess.CalledProcessError:
return None
def pil_get_format(filepath):
from PIL import Image, UnidentifiedImageError
try:
with Image.open(util.syspath(filepath)) as im:
return im.format
except (ValueError, TypeError, UnidentifiedImageError, FileNotFoundError):
log.exception("failed to detect image format for {}", filepath)
return None
BACKEND_GET_FORMAT = {
PIL: pil_get_format,
IMAGEMAGICK: im_get_format,
}
def im_convert_format(source, target, deinterlaced):
cmd = ArtResizer.shared.im_convert_cmd + [
util.syspath(source),
*(["-interlace", "none"] if deinterlaced else []),
util.syspath(target),
]
try:
subprocess.check_call(
cmd,
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL
)
return target
except subprocess.CalledProcessError:
return source
def pil_convert_format(source, target, deinterlaced):
from PIL import Image, UnidentifiedImageError
try:
with Image.open(util.syspath(source)) as im:
im.save(util.py3_path(target), progressive=not deinterlaced)
return target
except (ValueError, TypeError, UnidentifiedImageError, FileNotFoundError,
OSError):
log.exception("failed to convert image {} -> {}", source, target)
return source
BACKEND_CONVERT_IMAGE_FORMAT = {
PIL: pil_convert_format,
IMAGEMAGICK: im_convert_format,
}
class Shareable(type):
"""A pseudo-singleton metaclass that allows both shared and
non-shared instances. The ``MyClass.shared`` property holds a
lazily-created shared instance of ``MyClass`` while calling
``MyClass()`` to construct a new object works as usual.
"""
def __init__(cls, name, bases, dict):
super(Shareable, cls).__init__(name, bases, dict)
super().__init__(name, bases, dict)
cls._instance = None
@property
@ -162,7 +319,7 @@ class Shareable(type):
return cls._instance
class ArtResizer(six.with_metaclass(Shareable, object)):
class ArtResizer(metaclass=Shareable):
"""A singleton class that performs image resizes.
"""
@ -170,21 +327,44 @@ class ArtResizer(six.with_metaclass(Shareable, object)):
"""Create a resizer object with an inferred method.
"""
self.method = self._check_method()
log.debug(u"artresizer: method is {0}", self.method)
log.debug("artresizer: method is {0}", self.method)
self.can_compare = self._can_compare()
def resize(self, maxwidth, path_in, path_out=None):
# Use ImageMagick's magick binary when it's available. If it's
# not, fall back to the older, separate convert and identify
# commands.
if self.method[0] == IMAGEMAGICK:
self.im_legacy = self.method[2]
if self.im_legacy:
self.im_convert_cmd = ['convert']
self.im_identify_cmd = ['identify']
else:
self.im_convert_cmd = ['magick']
self.im_identify_cmd = ['magick', 'identify']
def resize(
self, maxwidth, path_in, path_out=None, quality=0, max_filesize=0
):
"""Manipulate an image file according to the method, returning a
new path. For PIL or IMAGEMAGIC methods, resizes the image to a
temporary file. For WEBPROXY, returns `path_in` unmodified.
temporary file and encodes with the specified quality level.
For WEBPROXY, returns `path_in` unmodified.
"""
if self.local:
func = BACKEND_FUNCS[self.method[0]]
return func(maxwidth, path_in, path_out)
return func(maxwidth, path_in, path_out,
quality=quality, max_filesize=max_filesize)
else:
return path_in
def proxy_url(self, maxwidth, url):
def deinterlace(self, path_in, path_out=None):
if self.local:
func = DEINTERLACE_FUNCS[self.method[0]]
return func(path_in, path_out)
else:
return path_in
def proxy_url(self, maxwidth, url, quality=0):
"""Modifies an image URL according the method, returning a new
URL. For WEBPROXY, a URL on the proxy server is returned.
Otherwise, the URL is returned unmodified.
@ -192,7 +372,7 @@ class ArtResizer(six.with_metaclass(Shareable, object)):
if self.local:
return url
else:
return resize_url(url, maxwidth)
return resize_url(url, maxwidth, quality)
@property
def local(self):
@ -205,12 +385,50 @@ class ArtResizer(six.with_metaclass(Shareable, object)):
"""Return the size of an image file as an int couple (width, height)
in pixels.
Only available locally
Only available locally.
"""
if self.local:
func = BACKEND_GET_SIZE[self.method[0]]
return func(path_in)
def get_format(self, path_in):
"""Returns the format of the image as a string.
Only available locally.
"""
if self.local:
func = BACKEND_GET_FORMAT[self.method[0]]
return func(path_in)
def reformat(self, path_in, new_format, deinterlaced=True):
"""Converts image to desired format, updating its extension, but
keeping the same filename.
Only available locally.
"""
if not self.local:
return path_in
new_format = new_format.lower()
# A nonexhaustive map of image "types" to extensions overrides
new_format = {
'jpeg': 'jpg',
}.get(new_format, new_format)
fname, ext = os.path.splitext(path_in)
path_new = fname + b'.' + new_format.encode('utf8')
func = BACKEND_CONVERT_IMAGE_FORMAT[self.method[0]]
# allows the exception to propagate, while still making sure a changed
# file path was removed
result_path = path_in
try:
result_path = func(path_in, path_new, deinterlaced)
finally:
if result_path != path_in:
os.unlink(path_in)
return result_path
def _can_compare(self):
"""A boolean indicating whether image comparison is available"""
@ -218,10 +436,20 @@ class ArtResizer(six.with_metaclass(Shareable, object)):
@staticmethod
def _check_method():
"""Return a tuple indicating an available method and its version."""
"""Return a tuple indicating an available method and its version.
The result has at least two elements:
- The method, eitehr WEBPROXY, PIL, or IMAGEMAGICK.
- The version.
If the method is IMAGEMAGICK, there is also a third element: a
bool flag indicating whether to use the `magick` binary or
legacy single-purpose executables (`convert`, `identify`, etc.)
"""
version = get_im_version()
if version:
return IMAGEMAGICK, version
version, legacy = version
return IMAGEMAGICK, version, legacy
version = get_pil_version()
if version:
@ -231,31 +459,34 @@ class ArtResizer(six.with_metaclass(Shareable, object)):
def get_im_version():
"""Return Image Magick version or None if it is unavailable
Try invoking ImageMagick's "convert".
"""Get the ImageMagick version and legacy flag as a pair. Or return
None if ImageMagick is not available.
"""
try:
out = util.command_output(['convert', '--version'])
for cmd_name, legacy in ((['magick'], False), (['convert'], True)):
cmd = cmd_name + ['--version']
if b'imagemagick' in out.lower():
pattern = br".+ (\d+)\.(\d+)\.(\d+).*"
match = re.search(pattern, out)
if match:
return (int(match.group(1)),
int(match.group(2)),
int(match.group(3)))
return (0,)
try:
out = util.command_output(cmd).stdout
except (subprocess.CalledProcessError, OSError) as exc:
log.debug('ImageMagick version check failed: {}', exc)
else:
if b'imagemagick' in out.lower():
pattern = br".+ (\d+)\.(\d+)\.(\d+).*"
match = re.search(pattern, out)
if match:
version = (int(match.group(1)),
int(match.group(2)),
int(match.group(3)))
return version, legacy
except (subprocess.CalledProcessError, OSError) as exc:
log.debug(u'ImageMagick check `convert --version` failed: {}', exc)
return None
return None
def get_pil_version():
"""Return Image Magick version or None if it is unavailable
Try importing PIL."""
"""Get the PIL/Pillow version, or None if it is unavailable.
"""
try:
__import__('PIL', fromlist=[str('Image')])
__import__('PIL', fromlist=['Image'])
return (0,)
except ImportError:
return None

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
"""Extremely simple pure-Python implementation of coroutine-style
asynchronous socket I/O. Inspired by, but inferior to, Eventlet.
Bluelet can also be thought of as a less-terrible replacement for
@ -7,9 +5,7 @@ asyncore.
Bluelet: easy concurrency without all the messy parallelism.
"""
from __future__ import division, absolute_import, print_function
import six
import socket
import select
import sys
@ -22,7 +18,7 @@ import collections
# Basic events used for thread scheduling.
class Event(object):
class Event:
"""Just a base class identifying Bluelet events. An event is an
object yielded from a Bluelet thread coroutine to suspend operation
and communicate with the scheduler.
@ -201,7 +197,7 @@ class ThreadException(Exception):
self.exc_info = exc_info
def reraise(self):
six.reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2])
raise self.exc_info[1].with_traceback(self.exc_info[2])
SUSPENDED = Event() # Special sentinel placeholder for suspended threads.
@ -336,16 +332,20 @@ def run(root_coro):
break
# Wait and fire.
event2coro = dict((v, k) for k, v in threads.items())
event2coro = {v: k for k, v in threads.items()}
for event in _event_select(threads.values()):
# Run the IO operation, but catch socket errors.
try:
value = event.fire()
except socket.error as exc:
except OSError as exc:
if isinstance(exc.args, tuple) and \
exc.args[0] == errno.EPIPE:
# Broken pipe. Remote host disconnected.
pass
elif isinstance(exc.args, tuple) and \
exc.args[0] == errno.ECONNRESET:
# Connection was reset by peer.
pass
else:
traceback.print_exc()
# Abort the coroutine.
@ -386,7 +386,7 @@ class SocketClosedError(Exception):
pass
class Listener(object):
class Listener:
"""A socket wrapper object for listening sockets.
"""
def __init__(self, host, port):
@ -416,7 +416,7 @@ class Listener(object):
self.sock.close()
class Connection(object):
class Connection:
"""A socket wrapper object for connected sockets.
"""
def __init__(self, sock, addr):
@ -541,7 +541,7 @@ def spawn(coro):
and child coroutines run concurrently.
"""
if not isinstance(coro, types.GeneratorType):
raise ValueError(u'%s is not a coroutine' % coro)
raise ValueError('%s is not a coroutine' % coro)
return SpawnEvent(coro)
@ -551,7 +551,7 @@ def call(coro):
returns a value using end(), then this event returns that value.
"""
if not isinstance(coro, types.GeneratorType):
raise ValueError(u'%s is not a coroutine' % coro)
raise ValueError('%s is not a coroutine' % coro)
return DelegationEvent(coro)

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -13,7 +12,6 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
from __future__ import division, absolute_import, print_function
from enum import Enum

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -27,30 +26,30 @@ This is sort of like a tiny, horrible degeneration of a real templating
engine like Jinja2 or Mustache.
"""
from __future__ import division, absolute_import, print_function
import re
import ast
import dis
import types
import sys
import six
import functools
SYMBOL_DELIM = u'$'
FUNC_DELIM = u'%'
GROUP_OPEN = u'{'
GROUP_CLOSE = u'}'
ARG_SEP = u','
ESCAPE_CHAR = u'$'
SYMBOL_DELIM = '$'
FUNC_DELIM = '%'
GROUP_OPEN = '{'
GROUP_CLOSE = '}'
ARG_SEP = ','
ESCAPE_CHAR = '$'
VARIABLE_PREFIX = '__var_'
FUNCTION_PREFIX = '__func_'
class Environment(object):
class Environment:
"""Contains the values and functions to be substituted into a
template.
"""
def __init__(self, values, functions):
self.values = values
self.functions = functions
@ -72,15 +71,7 @@ def ex_literal(val):
"""An int, float, long, bool, string, or None literal with the given
value.
"""
if val is None:
return ast.Name('None', ast.Load())
elif isinstance(val, six.integer_types):
return ast.Num(val)
elif isinstance(val, bool):
return ast.Name(bytes(val), ast.Load())
elif isinstance(val, six.string_types):
return ast.Str(val)
raise TypeError(u'no literal for {0}'.format(type(val)))
return ast.Constant(val)
def ex_varassign(name, expr):
@ -97,7 +88,7 @@ def ex_call(func, args):
function may be an expression or the name of a function. Each
argument may be an expression or a value to be used as a literal.
"""
if isinstance(func, six.string_types):
if isinstance(func, str):
func = ex_rvalue(func)
args = list(args)
@ -105,10 +96,7 @@ def ex_call(func, args):
if not isinstance(args[i], ast.expr):
args[i] = ex_literal(args[i])
if sys.version_info[:2] < (3, 5):
return ast.Call(func, args, [], None, None)
else:
return ast.Call(func, args, [])
return ast.Call(func, args, [])
def compile_func(arg_names, statements, name='_the_func', debug=False):
@ -116,32 +104,30 @@ def compile_func(arg_names, statements, name='_the_func', debug=False):
the resulting Python function. If `debug`, then print out the
bytecode of the compiled function.
"""
if six.PY2:
func_def = ast.FunctionDef(
name=name.encode('utf-8'),
args=ast.arguments(
args=[ast.Name(n, ast.Param()) for n in arg_names],
vararg=None,
kwarg=None,
defaults=[ex_literal(None) for _ in arg_names],
),
body=statements,
decorator_list=[],
)
else:
func_def = ast.FunctionDef(
name=name,
args=ast.arguments(
args=[ast.arg(arg=n, annotation=None) for n in arg_names],
kwonlyargs=[],
kw_defaults=[],
defaults=[ex_literal(None) for _ in arg_names],
),
body=statements,
decorator_list=[],
)
args_fields = {
'args': [ast.arg(arg=n, annotation=None) for n in arg_names],
'kwonlyargs': [],
'kw_defaults': [],
'defaults': [ex_literal(None) for _ in arg_names],
}
if 'posonlyargs' in ast.arguments._fields: # Added in Python 3.8.
args_fields['posonlyargs'] = []
args = ast.arguments(**args_fields)
func_def = ast.FunctionDef(
name=name,
args=args,
body=statements,
decorator_list=[],
)
# The ast.Module signature changed in 3.8 to accept a list of types to
# ignore.
if sys.version_info >= (3, 8):
mod = ast.Module([func_def], [])
else:
mod = ast.Module([func_def])
mod = ast.Module([func_def])
ast.fix_missing_locations(mod)
prog = compile(mod, '<generated>', 'exec')
@ -160,14 +146,15 @@ def compile_func(arg_names, statements, name='_the_func', debug=False):
# AST nodes for the template language.
class Symbol(object):
class Symbol:
"""A variable-substitution symbol in a template."""
def __init__(self, ident, original):
self.ident = ident
self.original = original
def __repr__(self):
return u'Symbol(%s)' % repr(self.ident)
return 'Symbol(%s)' % repr(self.ident)
def evaluate(self, env):
"""Evaluate the symbol in the environment, returning a Unicode
@ -182,24 +169,22 @@ class Symbol(object):
def translate(self):
"""Compile the variable lookup."""
if six.PY2:
ident = self.ident.encode('utf-8')
else:
ident = self.ident
ident = self.ident
expr = ex_rvalue(VARIABLE_PREFIX + ident)
return [expr], set([ident]), set()
return [expr], {ident}, set()
class Call(object):
class Call:
"""A function call in a template."""
def __init__(self, ident, args, original):
self.ident = ident
self.args = args
self.original = original
def __repr__(self):
return u'Call(%s, %s, %s)' % (repr(self.ident), repr(self.args),
repr(self.original))
return 'Call({}, {}, {})'.format(repr(self.ident), repr(self.args),
repr(self.original))
def evaluate(self, env):
"""Evaluate the function call in the environment, returning a
@ -212,19 +197,15 @@ class Call(object):
except Exception as exc:
# Function raised exception! Maybe inlining the name of
# the exception will help debug.
return u'<%s>' % six.text_type(exc)
return six.text_type(out)
return '<%s>' % str(exc)
return str(out)
else:
return self.original
def translate(self):
"""Compile the function call."""
varnames = set()
if six.PY2:
ident = self.ident.encode('utf-8')
else:
ident = self.ident
funcnames = set([ident])
funcnames = {self.ident}
arg_exprs = []
for arg in self.args:
@ -235,32 +216,33 @@ class Call(object):
# Create a subexpression that joins the result components of
# the arguments.
arg_exprs.append(ex_call(
ast.Attribute(ex_literal(u''), 'join', ast.Load()),
ast.Attribute(ex_literal(''), 'join', ast.Load()),
[ex_call(
'map',
[
ex_rvalue(six.text_type.__name__),
ex_rvalue(str.__name__),
ast.List(subexprs, ast.Load()),
]
)],
))
subexpr_call = ex_call(
FUNCTION_PREFIX + ident,
FUNCTION_PREFIX + self.ident,
arg_exprs
)
return [subexpr_call], varnames, funcnames
class Expression(object):
class Expression:
"""Top-level template construct: contains a list of text blobs,
Symbols, and Calls.
"""
def __init__(self, parts):
self.parts = parts
def __repr__(self):
return u'Expression(%s)' % (repr(self.parts))
return 'Expression(%s)' % (repr(self.parts))
def evaluate(self, env):
"""Evaluate the entire expression in the environment, returning
@ -268,11 +250,11 @@ class Expression(object):
"""
out = []
for part in self.parts:
if isinstance(part, six.string_types):
if isinstance(part, str):
out.append(part)
else:
out.append(part.evaluate(env))
return u''.join(map(six.text_type, out))
return ''.join(map(str, out))
def translate(self):
"""Compile the expression to a list of Python AST expressions, a
@ -282,7 +264,7 @@ class Expression(object):
varnames = set()
funcnames = set()
for part in self.parts:
if isinstance(part, six.string_types):
if isinstance(part, str):
expressions.append(ex_literal(part))
else:
e, v, f = part.translate()
@ -298,7 +280,7 @@ class ParseError(Exception):
pass
class Parser(object):
class Parser:
"""Parses a template expression string. Instantiate the class with
the template source and call ``parse_expression``. The ``pos`` field
will indicate the character after the expression finished and
@ -311,6 +293,7 @@ class Parser(object):
replaced with a real, accepted parsing technique (PEG, parser
generator, etc.).
"""
def __init__(self, string, in_argument=False):
""" Create a new parser.
:param in_arguments: boolean that indicates the parser is to be
@ -326,7 +309,7 @@ class Parser(object):
special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE,
ESCAPE_CHAR)
special_char_re = re.compile(r'[%s]|\Z' %
u''.join(re.escape(c) for c in special_chars))
''.join(re.escape(c) for c in special_chars))
escapable_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP)
terminator_chars = (GROUP_CLOSE,)
@ -343,7 +326,7 @@ class Parser(object):
if self.in_argument:
extra_special_chars = (ARG_SEP,)
special_char_re = re.compile(
r'[%s]|\Z' % u''.join(
r'[%s]|\Z' % ''.join(
re.escape(c) for c in
self.special_chars + extra_special_chars
)
@ -387,7 +370,7 @@ class Parser(object):
# Shift all characters collected so far into a single string.
if text_parts:
self.parts.append(u''.join(text_parts))
self.parts.append(''.join(text_parts))
text_parts = []
if char == SYMBOL_DELIM:
@ -409,7 +392,7 @@ class Parser(object):
# If any parsed characters remain, shift them into a string.
if text_parts:
self.parts.append(u''.join(text_parts))
self.parts.append(''.join(text_parts))
def parse_symbol(self):
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
@ -547,11 +530,27 @@ def _parse(template):
return Expression(parts)
# External interface.
def cached(func):
"""Like the `functools.lru_cache` decorator, but works (as a no-op)
on Python < 3.2.
"""
if hasattr(functools, 'lru_cache'):
return functools.lru_cache(maxsize=128)(func)
else:
# Do nothing when lru_cache is not available.
return func
class Template(object):
@cached
def template(fmt):
return Template(fmt)
# External interface.
class Template:
"""A string template, including text, Symbols, and Calls.
"""
def __init__(self, template):
self.expr = _parse(template)
self.original = template
@ -600,7 +599,7 @@ class Template(object):
for funcname in funcnames:
args[FUNCTION_PREFIX + funcname] = functions[funcname]
parts = func(**args)
return u''.join(parts)
return ''.join(parts)
return wrapper_func
@ -609,9 +608,9 @@ class Template(object):
if __name__ == '__main__':
import timeit
_tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar')
_tmpl = Template('foo $bar %baz{foozle $bar barzle} $bar')
_vars = {'bar': 'qux'}
_funcs = {'baz': six.text_type.upper}
_funcs = {'baz': str.upper}
interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)',
'from __main__ import _tmpl, _vars, _funcs',
number=10000)
@ -620,4 +619,4 @@ if __name__ == '__main__':
'from __main__ import _tmpl, _vars, _funcs',
number=10000)
print(comp_time)
print(u'Speedup:', interp_time / comp_time)
print('Speedup:', interp_time / comp_time)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -14,7 +13,6 @@
# included in all copies or substantial portions of the Software.
"""Simple library to work out if a file is hidden on different platforms."""
from __future__ import division, absolute_import, print_function
import os
import stat

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -32,12 +31,10 @@ To do so, pass an iterable of coroutines to the Pipeline constructor
in place of any single coroutine.
"""
from __future__ import division, absolute_import, print_function
from six.moves import queue
import queue
from threading import Thread, Lock
import sys
import six
BUBBLE = '__PIPELINE_BUBBLE__'
POISON = '__PIPELINE_POISON__'
@ -91,6 +88,7 @@ class CountedQueue(queue.Queue):
still feeding into it. The queue is poisoned when all threads are
finished with the queue.
"""
def __init__(self, maxsize=0):
queue.Queue.__init__(self, maxsize)
self.nthreads = 0
@ -135,10 +133,11 @@ class CountedQueue(queue.Queue):
_invalidate_queue(self, POISON, False)
class MultiMessage(object):
class MultiMessage:
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, messages):
self.messages = messages
@ -210,8 +209,9 @@ def _allmsgs(obj):
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
def __init__(self, all_threads):
super(PipelineThread, self).__init__()
super().__init__()
self.abort_lock = Lock()
self.abort_flag = False
self.all_threads = all_threads
@ -241,15 +241,13 @@ class FirstPipelineThread(PipelineThread):
"""The thread running the first stage in a parallel pipeline setup.
The coroutine should just be a generator.
"""
def __init__(self, coro, out_queue, all_threads):
super(FirstPipelineThread, self).__init__(all_threads)
super().__init__(all_threads)
self.coro = coro
self.out_queue = out_queue
self.out_queue.acquire()
self.abort_lock = Lock()
self.abort_flag = False
def run(self):
try:
while True:
@ -282,8 +280,9 @@ class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
last.
"""
def __init__(self, coro, in_queue, out_queue, all_threads):
super(MiddlePipelineThread, self).__init__(all_threads)
super().__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
self.out_queue = out_queue
@ -330,8 +329,9 @@ class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
should yield nothing.
"""
def __init__(self, coro, in_queue, all_threads):
super(LastPipelineThread, self).__init__(all_threads)
super().__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
@ -362,17 +362,18 @@ class LastPipelineThread(PipelineThread):
return
class Pipeline(object):
class Pipeline:
"""Represents a staged pattern of work. Each stage in the pipeline
is a coroutine that receives messages from the previous stage and
yields messages to be sent to the next stage.
"""
def __init__(self, stages):
"""Makes a new pipeline from a list of coroutines. There must
be at least two stages.
"""
if len(stages) < 2:
raise ValueError(u'pipeline must have at least two stages')
raise ValueError('pipeline must have at least two stages')
self.stages = []
for stage in stages:
if isinstance(stage, (list, tuple)):
@ -442,7 +443,7 @@ class Pipeline(object):
exc_info = thread.exc_info
if exc_info:
# Make the exception appear as it was raised originally.
six.reraise(exc_info[0], exc_info[1], exc_info[2])
raise exc_info[1].with_traceback(exc_info[2])
def pull(self):
"""Yield elements from the end of the pipeline. Runs the stages
@ -469,6 +470,7 @@ class Pipeline(object):
for msg in msgs:
yield msg
# Smoke test.
if __name__ == '__main__':
import time
@ -477,14 +479,14 @@ if __name__ == '__main__':
# in parallel.
def produce():
for i in range(5):
print(u'generating %i' % i)
print('generating %i' % i)
time.sleep(1)
yield i
def work():
num = yield
while True:
print(u'processing %i' % num)
print('processing %i' % num)
time.sleep(2)
num = yield num * 2
@ -492,7 +494,7 @@ if __name__ == '__main__':
while True:
num = yield
time.sleep(1)
print(u'received %i' % num)
print('received %i' % num)
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
@ -501,22 +503,22 @@ if __name__ == '__main__':
ts_par = time.time()
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
ts_end = time.time()
print(u'Sequential time:', ts_seq - ts_start)
print(u'Parallel time:', ts_par - ts_seq)
print(u'Multiply-parallel time:', ts_end - ts_par)
print('Sequential time:', ts_seq - ts_start)
print('Parallel time:', ts_par - ts_seq)
print('Multiply-parallel time:', ts_end - ts_par)
print()
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
print(u'generating %i' % i)
print('generating %i' % i)
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
print(u'processing %i' % num)
print('processing %i' % num)
time.sleep(3)
if num == 3:
raise Exception()
@ -525,6 +527,6 @@ if __name__ == '__main__':
def exc_consume():
while True:
num = yield
print(u'received %i' % num)
print('received %i' % num)
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
@ -16,7 +15,6 @@
"""A simple utility for constructing filesystem-like trees from beets
libraries.
"""
from __future__ import division, absolute_import, print_function
from collections import namedtuple
from beets import util