Merge branch 'nightly' into concurrent_stream_graph

This commit is contained in:
herby2212 2023-07-11 01:24:59 +02:00
commit 713f8f18e6
335 changed files with 15376 additions and 10025 deletions

4
.github/codeql-config.yml vendored Normal file
View file

@ -0,0 +1,4 @@
name: CodeQL Config
paths-ignore:
- lib

38
.github/workflows/codeql.yml vendored Normal file
View file

@ -0,0 +1,38 @@
name: CodeQL
on:
push:
branches: [nightly]
pull_request:
branches: [nightly]
schedule:
- cron: '05 10 * * 1'
jobs:
codeql-analysis:
name: CodeQL Analysis
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: ['javascript', 'python']
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
config-file: ./.github/codeql-config.yml
languages: ${{ matrix.language }}
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{matrix.language}}"

View file

@ -13,7 +13,7 @@ jobs:
if: ${{ !contains(github.event.head_commit.message, '[skip ci]') }} if: ${{ !contains(github.event.head_commit.message, '[skip ci]') }}
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3.2.0 uses: actions/checkout@v3
- name: Prepare - name: Prepare
id: prepare id: prepare
@ -47,7 +47,7 @@ jobs:
version: latest version: latest
- name: Cache Docker Layers - name: Cache Docker Layers
uses: actions/cache@v3.2.0 uses: actions/cache@v3
with: with:
path: /tmp/.buildx-cache path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }} key: ${{ runner.os }}-buildx-${{ github.sha }}
@ -70,7 +70,7 @@ jobs:
password: ${{ secrets.GHCR_TOKEN }} password: ${{ secrets.GHCR_TOKEN }}
- name: Docker Build and Push - name: Docker Build and Push
uses: docker/build-push-action@v3 uses: docker/build-push-action@v4
if: success() if: success()
with: with:
context: . context: .

View file

@ -24,7 +24,7 @@ jobs:
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3.2.0 uses: actions/checkout@v3
- name: Set Release Version - name: Set Release Version
id: get_version id: get_version
@ -52,7 +52,7 @@ jobs:
echo $GITHUB_SHA > version.txt echo $GITHUB_SHA > version.txt
- name: Set Up Python - name: Set Up Python
uses: actions/setup-python@v4.4.0 uses: actions/setup-python@v4
with: with:
python-version: '3.9' python-version: '3.9'
cache: pip cache: pip
@ -119,7 +119,10 @@ jobs:
run: | run: |
CHANGELOG="$( sed -n '/^## /{p; :loop n; p; /^## /q; b loop}' CHANGELOG.md \ CHANGELOG="$( sed -n '/^## /{p; :loop n; p; /^## /q; b loop}' CHANGELOG.md \
| sed '$d' | sed '$d' | sed '$d' )" | sed '$d' | sed '$d' | sed '$d' )"
echo "CHANGELOG=${CHANGELOG}" >> $GITHUB_OUTPUT EOF=$(dd if=/dev/urandom bs=15 count=1 status=none | base64)
echo "CHANGELOG<<$EOF" >> $GITHUB_OUTPUT
echo "$CHANGELOG" >> $GITHUB_OUTPUT
echo "$EOF" >> $GITHUB_OUTPUT
- name: Create Release - name: Create Release
uses: actions/create-release@v1 uses: actions/create-release@v1

View file

@ -20,7 +20,7 @@ jobs:
- armhf - armhf
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3.2.0 uses: actions/checkout@v3
- name: Prepare - name: Prepare
id: prepare id: prepare

View file

@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3.2.0 uses: actions/checkout@v3
- name: Comment on Pull Request - name: Comment on Pull Request
uses: mshick/add-pr-comment@v2 uses: mshick/add-pr-comment@v2

View file

@ -9,6 +9,7 @@ jobs:
winget: winget:
name: Submit Winget Package name: Submit Winget Package
runs-on: windows-latest runs-on: windows-latest
if: ${{ !github.event.release.prerelease }}
steps: steps:
- name: Submit package to Windows Package Manager Community Repository - name: Submit package to Windows Package Manager Community Repository
run: | run: |

3
.gitignore vendored
View file

@ -53,6 +53,9 @@ Thumbs.db
#Ignore files generated by PyCharm #Ignore files generated by PyCharm
*.idea/* *.idea/*
#Ignore files generated by VSCode
*.vscode/*
#Ignore files generated by vi #Ignore files generated by vi
*.swp *.swp

View file

@ -1,5 +1,78 @@
# Changelog # Changelog
## v2.12.4 (2023-05-23)
* History:
* Fix: Set view offset equal to duration if a stream is stopped within the last 10 sec.
* Other:
* Fix: Database import may fail for some older databases.
* Fix: Double-quoted strings for newer versions of SQLite. (#2015, #2057)
* API:
* Change: Return the ID for async API calls (export_metadata, notify, notify_newsletter).
## v2.12.3 (2023-04-14)
* Activity:
* Fix: Incorrect subtitle decision shown when subtitles are transcoded.
* History:
* Fix: Incorrect order when sorting by the duration column in the history tables.
* Notifications:
* Fix: Logging error when running scripts that use PlexAPI.
* UI:
* Fix: Calculate file sizes setting causing the media info table to fail to load.
* Fix: Incorrect artwork and thumbnail shown for Live TV on the Most Active Libraries statistics card.
* API:
* Change: Renamed duration to play_duration in the get_history API response. (Note: duration kept for backwards compatibility.)
## v2.12.2 (2023-03-16)
* Other:
* Fix: Tautulli not starting on FreeBSD jails.
## v2.12.1 (2023-03-14)
* Activity:
* Fix: Stop checking for deprecated sync items sessions.
* Change: Do not show audio language on activity cards for music.
* Other:
* Fix: Tautulli not starting on macOS.
## v2.12.0 (2023-03-13)
* Notifications:
* New: Added support for Telegram group topics. (#1980)
* New: Added anidb_id and anidb_url notification parameters. (#1973)
* New: Added notification triggers for Intro Marker, Commercial Marker, and Credits Marker.
* New: Added various intro, commercial, and credits marker notification parameters.
* New: Allow setting a custom Pushover notification sound. (#2005)
* Change: Notification images are now uploaded directly to Discord without the need for a 3rd party image hosting service.
* Change: Automatically strip whitespace from notification condition values.
* Change: Trigger watched notifications based on the video watched completion behaviour setting.
* Exporter:
* Fix: Unable to run exporter when using the Snap package. (#2007)
* New: Added credits marker, and audio/subtitle settings to export fields.
* UI:
* Fix: Incorrect styling and missing content for collection media info pages.
* New: Added edition details field on movie media info pages. (#1957) (Thanks @herby2212)
* New: Added setting to change the video watched completion behaviour.
* New: Added watch time and user statistics to collection and playlist media info pages. (#1982, #2012) (Thanks @herby2212)
* New: Added history table to collection and playlist media info pages.
* New: Dynamically change watched status in the UI based on video watched completion behaviour setting.
* New: Added hidden setting to override server name.
* Change: Move track artist to a details field instead of in the title on track media info pages.
* API:
* New: Added section_id and user_id parameters to get_home_stats API command. (#1944)
* New: Added marker info to get_metadata API command results.
* New: Added media_type parameter to get_item_watch_time_stats and get_item_user_stats API commands. (#1982) (Thanks @herby2212)
* New: Added last_refreshed timestamp to get_library_media_info API command response.
* Other:
* Change: Migrate analytics to Google Analytics 4.
## v2.11.1 (2022-12-22) ## v2.11.1 (2022-12-22)
* Activity: * Activity:

View file

@ -57,24 +57,24 @@ Read the [Installation Guides][Installation] for instructions on how to install
[badge-release-nightly-last-commit]: https://img.shields.io/github/last-commit/Tautulli/Tautulli/nightly?style=flat-square&color=blue [badge-release-nightly-last-commit]: https://img.shields.io/github/last-commit/Tautulli/Tautulli/nightly?style=flat-square&color=blue
[badge-release-nightly-commits]: https://img.shields.io/github/commits-since/Tautulli/Tautulli/latest/nightly?style=flat-square&color=blue [badge-release-nightly-commits]: https://img.shields.io/github/commits-since/Tautulli/Tautulli/latest/nightly?style=flat-square&color=blue
[badge-docker-master]: https://img.shields.io/badge/docker-latest-blue?style=flat-square [badge-docker-master]: https://img.shields.io/badge/docker-latest-blue?style=flat-square
[badge-docker-master-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Docker/master?style=flat-square [badge-docker-master-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-docker.yml?style=flat-square&branch=master
[badge-docker-beta]: https://img.shields.io/badge/docker-beta-blue?style=flat-square [badge-docker-beta]: https://img.shields.io/badge/docker-beta-blue?style=flat-square
[badge-docker-beta-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Docker/beta?style=flat-square [badge-docker-beta-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-docker.yml?style=flat-square&branch=beta
[badge-docker-nightly]: https://img.shields.io/badge/docker-nightly-blue?style=flat-square [badge-docker-nightly]: https://img.shields.io/badge/docker-nightly-blue?style=flat-square
[badge-docker-nightly-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Docker/nightly?style=flat-square [badge-docker-nightly-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-docker.yml?style=flat-square&branch=nightly
[badge-snap-master]: https://img.shields.io/badge/snap-stable-blue?style=flat-square [badge-snap-master]: https://img.shields.io/badge/snap-stable-blue?style=flat-square
[badge-snap-master-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Snap/master?style=flat-square [badge-snap-master-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-snap.yml?style=flat-square&branch=master
[badge-snap-beta]: https://img.shields.io/badge/snap-beta-blue?style=flat-square [badge-snap-beta]: https://img.shields.io/badge/snap-beta-blue?style=flat-square
[badge-snap-beta-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Snap/beta?style=flat-square [badge-snap-beta-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-snap.yml?style=flat-square&branch=beta
[badge-snap-nightly]: https://img.shields.io/badge/snap-edge-blue?style=flat-square [badge-snap-nightly]: https://img.shields.io/badge/snap-edge-blue?style=flat-square
[badge-snap-nightly-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Snap/nightly?style=flat-square [badge-snap-nightly-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-snap.yml?style=flat-square&branch=nightly
[badge-installer-master-win]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=windows&style=flat-square [badge-installer-master-win]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=windows&style=flat-square
[badge-installer-master-macos]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=macos&style=flat-square [badge-installer-master-macos]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=macos&style=flat-square
[badge-installer-master-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Installers/master?style=flat-square [badge-installer-master-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-installers.yml?style=flat-square&branch=master
[badge-installer-beta-win]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=windows&include_prereleases&style=flat-square [badge-installer-beta-win]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=windows&include_prereleases&style=flat-square
[badge-installer-beta-macos]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=macos&include_prereleases&style=flat-square [badge-installer-beta-macos]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=macos&include_prereleases&style=flat-square
[badge-installer-beta-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Installers/beta?style=flat-square [badge-installer-beta-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-installers.yml?style=flat-square&branch=beta
[badge-installer-nightly-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Installers/nightly?style=flat-square [badge-installer-nightly-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-installers.yml?style=flat-square&branch=nightly
## Support ## Support

File diff suppressed because one or more lines are too long

View file

@ -79,7 +79,6 @@ select.form-control {
color: #eee !important; color: #eee !important;
border: 0px solid #444 !important; border: 0px solid #444 !important;
background: #555 !important; background: #555 !important;
padding: 1px 2px;
transition: background-color .3s; transition: background-color .3s;
} }
.selectize-control.form-control .selectize-input { .selectize-control.form-control .selectize-input {
@ -87,7 +86,6 @@ select.form-control {
align-items: center; align-items: center;
flex-wrap: wrap; flex-wrap: wrap;
margin-bottom: 4px; margin-bottom: 4px;
padding-left: 5px;
} }
.selectize-control.form-control.selectize-pms-ip .selectize-input { .selectize-control.form-control.selectize-pms-ip .selectize-input {
padding-left: 12px !important; padding-left: 12px !important;
@ -2916,7 +2914,7 @@ a .home-platforms-list-cover-face:hover
margin-bottom: -20px; margin-bottom: -20px;
width: 100%; width: 100%;
max-width: 1750px; max-width: 1750px;
overflow: hidden; display: flow-root;
} }
.table-card-back td { .table-card-back td {
font-size: 12px; font-size: 12px;

View file

@ -265,12 +265,15 @@ DOCUMENTATION :: END
<div class="sub-heading">Audio</div> <div class="sub-heading">Audio</div>
<div class="sub-value" id="audio_decision-${sk}"> <div class="sub-value" id="audio_decision-${sk}">
% if data['stream_audio_decision']: % if data['stream_audio_decision']:
<%
audio_language = (data['audio_language'] or 'Unknown') + ' - ' if data['media_type'] != 'track' else ''
%>
% if data['stream_audio_decision'] == 'transcode': % if data['stream_audio_decision'] == 'transcode':
Transcode (${data['audio_language'] or 'Unknown'} - ${AUDIO_CODEC_OVERRIDES.get(data['audio_codec'], data['audio_codec'].upper())} ${data['audio_channel_layout'].split('(')[0].capitalize()} <i class="fa fa-long-arrow-right"></i> ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()}) Transcode (${audio_language}${AUDIO_CODEC_OVERRIDES.get(data['audio_codec'], data['audio_codec'].upper())} ${data['audio_channel_layout'].split('(')[0].capitalize()} <i class="fa fa-long-arrow-right"></i> ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
% elif data['stream_audio_decision'] == 'copy': % elif data['stream_audio_decision'] == 'copy':
Direct Stream (${data['audio_language'] or 'Unknown'} - ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()}) Direct Stream (${audio_language}${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
% else: % else:
Direct Play (${data['audio_language'] or 'Unknown'} - ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()}) Direct Play (${audio_language}${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
% endif % endif
% endif % endif
</div> </div>

View file

@ -1,6 +1,7 @@
<%inherit file="base.html"/> <%inherit file="base.html"/>
<%def name="headIncludes()"> <%def name="headIncludes()">
<link rel="stylesheet" href="${http_root}css/bootstrap-select.min.css">
<link rel="stylesheet" href="${http_root}css/dataTables.bootstrap.min.css"> <link rel="stylesheet" href="${http_root}css/dataTables.bootstrap.min.css">
<link rel="stylesheet" href="${http_root}css/tautulli-dataTables.css"> <link rel="stylesheet" href="${http_root}css/tautulli-dataTables.css">
</%def> </%def>
@ -14,9 +15,7 @@
<div class="button-bar"> <div class="button-bar">
<div class="btn-group" id="user-selection"> <div class="btn-group" id="user-selection">
<label> <label>
<select name="graph-user" id="graph-user" class="btn" style="color: inherit;"> <select name="graph-user" id="graph-user" multiple>
<option value="">All Users</option>
<option disabled>&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;</option>
</select> </select>
</label> </label>
</div> </div>
@ -239,6 +238,7 @@
</%def> </%def>
<%def name="javascriptIncludes()"> <%def name="javascriptIncludes()">
<script src="${http_root}js/bootstrap-select.min.js"></script>
<script src="${http_root}js/highcharts.min.js"></script> <script src="${http_root}js/highcharts.min.js"></script>
<script src="${http_root}js/jquery.dataTables.min.js"></script> <script src="${http_root}js/jquery.dataTables.min.js"></script>
<script src="${http_root}js/dataTables.bootstrap.min.js"></script> <script src="${http_root}js/dataTables.bootstrap.min.js"></script>
@ -379,8 +379,8 @@
//$(current_tab).addClass('active'); //$(current_tab).addClass('active');
$('.days').html(current_day_range); $('.days').text(current_day_range);
$('.months').html(current_month_range); $('.months').text(current_month_range);
// Load user ids and names (for the selector) // Load user ids and names (for the selector)
$.ajax({ $.ajax({
@ -388,14 +388,35 @@
type: 'get', type: 'get',
dataType: "json", dataType: "json",
success: function (data) { success: function (data) {
var select = $('#graph-user'); let select = $('#graph-user');
let by_id = {};
data.sort(function(a, b) { data.sort(function(a, b) {
return a.friendly_name.localeCompare(b.friendly_name); return a.friendly_name.localeCompare(b.friendly_name);
}); });
data.forEach(function(item) { data.forEach(function(item) {
select.append('<option value="' + item.user_id + '">' + select.append('<option value="' + item.user_id + '">' +
item.friendly_name + '</option>'); item.friendly_name + '</option>');
by_id[item.user_id] = item.friendly_name;
}); });
select.selectpicker({
countSelectedText: function(sel, total) {
if (sel === 0 || sel === total) {
return 'All users';
} else if (sel > 1) {
return sel + ' users';
} else {
return select.val().map(function(id) {
return by_id[id];
}).join(', ');
}
},
style: 'btn-dark',
actionsBox: true,
selectedTextFormat: 'count',
noneSelectedText: 'All users'
});
select.selectpicker('render');
select.selectpicker('selectAll');
} }
}); });
@ -644,11 +665,6 @@
$('#nav-tabs-total').tab('show'); $('#nav-tabs-total').tab('show');
} }
// Set initial state
if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); }
if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); }
if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); }
// Tab1 opened // Tab1 opened
$('#nav-tabs-plays').on('shown.bs.tab', function (e) { $('#nav-tabs-plays').on('shown.bs.tab', function (e) {
e.preventDefault(); e.preventDefault();
@ -681,7 +697,7 @@
setLocalStorage('graph_days', current_day_range); setLocalStorage('graph_days', current_day_range);
if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); } if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); }
if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); } if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); }
$('.days').html(current_day_range); $('.days').text(current_day_range);
}); });
// Month range changed // Month range changed
@ -691,12 +707,23 @@
current_month_range = $(this).val(); current_month_range = $(this).val();
setLocalStorage('graph_months', current_month_range); setLocalStorage('graph_months', current_month_range);
if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); } if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); }
$('.months').html(current_month_range); $('.months').text(current_month_range);
}); });
let graph_user_last_id = undefined;
// User changed // User changed
$('#graph-user').on('change', function() { $('#graph-user').on('change', function() {
selected_user_id = $(this).val() || null; let val = $(this).val();
if (val.length === 0 || val.length === $(this).children().length) {
selected_user_id = null; // if all users are selected, just send an empty list
} else {
selected_user_id = val.join(",");
}
if (selected_user_id === graph_user_last_id) {
return;
}
graph_user_last_id = selected_user_id;
if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); } if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); }
if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); } if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); }
if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); } if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); }

View file

@ -1,6 +1,7 @@
<%inherit file="base.html"/> <%inherit file="base.html"/>
<%def name="headIncludes()"> <%def name="headIncludes()">
<link rel="stylesheet" href="${http_root}css/bootstrap-select.min.css">
<link rel="stylesheet" href="${http_root}css/dataTables.bootstrap.min.css"> <link rel="stylesheet" href="${http_root}css/dataTables.bootstrap.min.css">
<link rel="stylesheet" href="${http_root}css/dataTables.colVis.css"> <link rel="stylesheet" href="${http_root}css/dataTables.colVis.css">
<link rel="stylesheet" href="${http_root}css/tautulli-dataTables.css"> <link rel="stylesheet" href="${http_root}css/tautulli-dataTables.css">
@ -31,9 +32,7 @@
% if _session['user_group'] == 'admin': % if _session['user_group'] == 'admin':
<div class="btn-group" id="user-selection"> <div class="btn-group" id="user-selection">
<label> <label>
<select name="history-user" id="history-user" class="btn" style="color: inherit;"> <select name="history-user" id="history-user" multiple>
<option value="">All Users</option>
<option disabled>&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;</option>
</select> </select>
</label> </label>
</div> </div>
@ -84,7 +83,7 @@
<th align="left" id="started">Started</th> <th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th> <th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th> <th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th> <th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th> <th align="left" id="percent_complete"></th>
</tr> </tr>
</thead> </thead>
@ -121,6 +120,7 @@
</%def> </%def>
<%def name="javascriptIncludes()"> <%def name="javascriptIncludes()">
<script src="${http_root}js/bootstrap-select.min.js"></script>
<script src="${http_root}js/jquery.dataTables.min.js"></script> <script src="${http_root}js/jquery.dataTables.min.js"></script>
<script src="${http_root}js/dataTables.colVis.js"></script> <script src="${http_root}js/dataTables.colVis.js"></script>
<script src="${http_root}js/dataTables.bootstrap.min.js"></script> <script src="${http_root}js/dataTables.bootstrap.min.js"></script>
@ -134,17 +134,40 @@
type: 'GET', type: 'GET',
dataType: 'json', dataType: 'json',
success: function (data) { success: function (data) {
var select = $('#history-user'); let select = $('#history-user');
let by_id = {};
data.sort(function (a, b) { data.sort(function (a, b) {
return a.friendly_name.localeCompare(b.friendly_name); return a.friendly_name.localeCompare(b.friendly_name);
}); });
data.forEach(function (item) { data.forEach(function (item) {
select.append('<option value="' + item.user_id + '">' + select.append('<option value="' + item.user_id + '">' +
item.friendly_name + '</option>'); item.friendly_name + '</option>');
by_id[item.user_id] = item.friendly_name;
}); });
select.selectpicker({
countSelectedText: function(sel, total) {
if (sel === 0 || sel === total) {
return 'All users';
} else if (sel > 1) {
return sel + ' users';
} else {
return select.val().map(function(id) {
return by_id[id];
}).join(', ');
}
},
style: 'btn-dark',
actionsBox: true,
selectedTextFormat: 'count',
noneSelectedText: 'All users'
});
select.selectpicker('render');
select.selectpicker('selectAll');
} }
}); });
let history_user_last_id = undefined;
function loadHistoryTable(media_type, transcode_decision, selected_user_id) { function loadHistoryTable(media_type, transcode_decision, selected_user_id) {
history_table_options.ajax = { history_table_options.ajax = {
url: 'get_history', url: 'get_history',
@ -187,7 +210,16 @@
}); });
$('#history-user').on('change', function () { $('#history-user').on('change', function () {
selected_user_id = $(this).val() || null; let val = $(this).val();
if (val.length === 0 || val.length === $(this).children().length) {
selected_user_id = null; // if all users are selected, just send an empty list
} else {
selected_user_id = val.join(",");
}
if (selected_user_id === history_user_last_id) {
return;
}
history_user_last_id = selected_user_id;
history_table.draw(); history_table.draw();
}); });
} }

View file

@ -32,7 +32,7 @@
<th align="left" id="started">Started</th> <th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th> <th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th> <th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th> <th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th> <th align="left" id="percent_complete"></th>
</tr> </tr>
</thead> </thead>

View file

@ -77,7 +77,8 @@ DOCUMENTATION :: END
<% fallback = 'art-live' if row0['live'] else 'art' %> <% fallback = 'art-live' if row0['live'] else 'art' %>
<div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'], row0['rating_key'], 500, 280, 40, '282828', 3, fallback=fallback)});"> <div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'], row0['rating_key'], 500, 280, 40, '282828', 3, fallback=fallback)});">
% elif stat_id == 'top_libraries': % elif stat_id == 'top_libraries':
<div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'] or row0['library_art'], None, 500, 280, 40, '282828', 3, fallback=row0['library_art'])});" data-library_art="${row0['library_art']}"> <% fallback = 'art-live' if row0['live'] else row0['library_art'] %>
<div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'] or row0['library_art'], None, 500, 280, 40, '282828', 3, fallback=fallback)});" data-library_art="${row0['library_art']}">
% elif stat_id == 'top_users': % elif stat_id == 'top_users':
<div id="stats-background-${stat_id}" class="dashboard-stats-background" data-blurhash="${page('pms_image_proxy', row0['user_thumb'] or 'interfaces/default/images/gravatar-default.png', None, 100, 100, 40, '282828', 0, fallback='user')}"> <div id="stats-background-${stat_id}" class="dashboard-stats-background" data-blurhash="${page('pms_image_proxy', row0['user_thumb'] or 'interfaces/default/images/gravatar-default.png', None, 100, 100, 40, '282828', 0, fallback='user')}">
% elif stat_id == 'top_platforms': % elif stat_id == 'top_platforms':
@ -109,8 +110,8 @@ DOCUMENTATION :: END
</a> </a>
</div> </div>
% elif stat_id == 'top_libraries': % elif stat_id == 'top_libraries':
% if row0['thumb'].startswith('http'): % if row0['library_thumb'].startswith('http'):
<div id="stats-thumb-${stat_id}" class="dashboard-stats-flat hidden-xs" style="background-image: url(${page('pms_image_proxy', row0['thumb'], None, 80, 80)});"></div> <div id="stats-thumb-${stat_id}" class="dashboard-stats-flat hidden-xs" style="background-image: url(${page('pms_image_proxy', row0['library_thumb'], None, 100, 100, fallback='cover')});"></div>
% else: % else:
<div id="stats-thumb-${stat_id}" class="dashboard-stats-flat svg-icon library-${row0['section_type']} hidden-xs"></div> <div id="stats-thumb-${stat_id}" class="dashboard-stats-flat svg-icon library-${row0['section_type']} hidden-xs"></div>
% endif % endif
@ -147,7 +148,8 @@ DOCUMENTATION :: END
data-rating_key="${row.get('rating_key')}" data-grandparent_rating_key="${row.get('grandparent_rating_key')}" data-guid="${row.get('guid')}" data-title="${row.get('title')}" data-rating_key="${row.get('rating_key')}" data-grandparent_rating_key="${row.get('grandparent_rating_key')}" data-guid="${row.get('guid')}" data-title="${row.get('title')}"
data-art="${row.get('art')}" data-thumb="${row.get('thumb')}" data-platform="${row.get('platform_name')}" data-library-type="${row.get('section_type')}" data-art="${row.get('art')}" data-thumb="${row.get('thumb')}" data-platform="${row.get('platform_name')}" data-library-type="${row.get('section_type')}"
data-user_id="${row.get('user_id')}" data-user="${row.get('user')}" data-friendly_name="${row.get('friendly_name')}" data-user_thumb="${row.get('user_thumb')}" data-user_id="${row.get('user_id')}" data-user="${row.get('user')}" data-friendly_name="${row.get('friendly_name')}" data-user_thumb="${row.get('user_thumb')}"
data-last_watch="${row.get('last_watch')}" data-started="${row.get('started')}" data-live="${row.get('live')}" data-library_art="${row.get('library_art', '')}"> data-last_watch="${row.get('last_watch')}" data-started="${row.get('started')}" data-live="${row.get('live')}"
data-library_art="${row.get('library_art', '')}" data-library_thumb="${row.get('library_thumb', '')}">
<div class="sub-list">${loop.index + 1}</div> <div class="sub-list">${loop.index + 1}</div>
<div class="sub-value"> <div class="sub-value">
% if stat_id in ('top_movies', 'popular_movies', 'top_tv', 'popular_tv', 'top_music', 'popular_music', 'last_watched'): % if stat_id in ('top_movies', 'popular_movies', 'top_tv', 'popular_tv', 'top_music', 'popular_music', 'last_watched'):

View file

@ -523,14 +523,15 @@
var audio_decision = ''; var audio_decision = '';
if (['movie', 'episode', 'clip', 'track'].indexOf(s.media_type) > -1 && s.stream_audio_decision) { if (['movie', 'episode', 'clip', 'track'].indexOf(s.media_type) > -1 && s.stream_audio_decision) {
var audio_language = (s.media_type !== 'track') ? (s.audio_language || 'Unknown') + ' - ' : '';
var a_codec = (s.audio_codec === 'truehd') ? 'TrueHD' : s.audio_codec.toUpperCase(); var a_codec = (s.audio_codec === 'truehd') ? 'TrueHD' : s.audio_codec.toUpperCase();
var sa_codec = (s.stream_audio_codec === 'truehd') ? 'TrueHD' : s.stream_audio_codec.toUpperCase(); var sa_codec = (s.stream_audio_codec === 'truehd') ? 'TrueHD' : s.stream_audio_codec.toUpperCase();
if (s.stream_audio_decision === 'transcode') { if (s.stream_audio_decision === 'transcode') {
audio_decision = 'Transcode ('+ (s.audio_language || 'Unknown')+ ' - ' + a_codec + ' ' + capitalizeFirstLetter(s.audio_channel_layout.split('(')[0]) + ' <i class="fa fa-long-arrow-right"></i> ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')'; audio_decision = 'Transcode (' + audio_language + a_codec + ' ' + capitalizeFirstLetter(s.audio_channel_layout.split('(')[0]) + ' <i class="fa fa-long-arrow-right"></i> ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
} else if (s.stream_audio_decision === 'copy') { } else if (s.stream_audio_decision === 'copy') {
audio_decision = 'Direct Stream ('+ (s.audio_language || 'Unknown')+ ' - ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')'; audio_decision = 'Direct Stream (' + audio_language + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
} else { } else {
audio_decision = 'Direct Play ('+ (s.audio_language || 'Unknown')+ ' - ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')'; audio_decision = 'Direct Play (' + audio_language + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
} }
} }
$('#audio_decision-' + key).html(audio_decision); $('#audio_decision-' + key).html(audio_decision);
@ -797,6 +798,7 @@
var guid = $(elem).data('guid'); var guid = $(elem).data('guid');
var live = $(elem).data('live'); var live = $(elem).data('live');
var library_art = $(elem).data('library_art'); var library_art = $(elem).data('library_art');
var library_thumb = $(elem).data('library_thumb');
var [height, fallback_poster, fallback_art] = [450, 'poster', 'art']; var [height, fallback_poster, fallback_art] = [450, 'poster', 'art'];
if ($.inArray(stat_id, ['top_music', 'popular_music']) > -1) { if ($.inArray(stat_id, ['top_music', 'popular_music']) > -1) {
[height, fallback_poster, fallback_art] = [300, 'cover', 'art']; [height, fallback_poster, fallback_art] = [300, 'cover', 'art'];
@ -808,11 +810,11 @@
if (stat_id === 'most_concurrent') { if (stat_id === 'most_concurrent') {
return return
} else if (stat_id === 'top_libraries') { } else if (stat_id === 'top_libraries') {
$('#stats-background-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', art || library_art, null, 500, 280, 40, '282828', 3, library_art || fallback_art) + ')'); $('#stats-background-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', art || library_art, null, 500, 280, 40, '282828', 3, fallback_art) + ')');
$('#stats-thumb-' + stat_id).removeClass(function (index, className) { $('#stats-thumb-' + stat_id).removeClass(function (index, className) {
return (className.match (/(^|\s)svg-icon library-\S+/g) || []).join(' ')}); return (className.match (/(^|\s)svg-icon library-\S+/g) || []).join(' ')});
if (thumb.startsWith('http')) { if (library_thumb.startsWith('http')) {
$('#stats-thumb-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', thumb, null, 300, 300, null, null, null, 'cover') + ')'); $('#stats-thumb-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', library_thumb, null, 100, 100, null, null, null, 'cover') + ')');
} else { } else {
$('#stats-thumb-' + stat_id).css('background-image', '') $('#stats-thumb-' + stat_id).css('background-image', '')
.addClass('svg-icon library-' + library_type); .addClass('svg-icon library-' + library_type);

View file

@ -12,8 +12,10 @@ data :: Usable parameters (if not applicable for media type, blank value will be
== Global keys == == Global keys ==
rating_key Returns the unique identifier for the media item. rating_key Returns the unique identifier for the media item.
media_type Returns the type of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'. media_type Returns the type of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'.
sub_media_type Returns the subtype of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'.
art Returns the location of the item's artwork art Returns the location of the item's artwork
title Returns the name of the movie, show, episode, artist, album, or track. title Returns the name of the movie, show, episode, artist, album, or track.
edition_title Returns the edition title of a movie.
duration Returns the standard runtime of the media. duration Returns the standard runtime of the media.
content_rating Returns the age rating for the media. content_rating Returns the age rating for the media.
summary Returns a brief description of the media plot. summary Returns a brief description of the media plot.
@ -212,7 +214,7 @@ DOCUMENTATION :: END
% if _session['user_group'] == 'admin': % if _session['user_group'] == 'admin':
<span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span> <span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span>
% endif % endif
% elif data['media_type'] in ('artist', 'album', 'track', 'playlist', 'photo_album', 'photo', 'clip'): % elif data['media_type'] in ('artist', 'album', 'track', 'playlist', 'photo_album', 'photo', 'clip') or data['sub_media_type'] in ('artist', 'album', 'track'):
<div class="summary-poster-face-track" style="background-image: url(${page('pms_image_proxy', data['thumb'], data['rating_key'], 300, 300, fallback='cover')});"> <div class="summary-poster-face-track" style="background-image: url(${page('pms_image_proxy', data['thumb'], data['rating_key'], 300, 300, fallback='cover')});">
<div class="summary-poster-face-overlay"> <div class="summary-poster-face-overlay">
<span></span> <span></span>
@ -266,7 +268,7 @@ DOCUMENTATION :: END
<h1><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a></h1> <h1><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a></h1>
<h2>${data['title']}</h2> <h2>${data['title']}</h2>
% elif data['media_type'] == 'track': % elif data['media_type'] == 'track':
<h1><a href="${page('info', data['grandparent_rating_key'])}">${data['original_title'] or data['grandparent_title']}</a></h1> <h1><a href="${page('info', data['grandparent_rating_key'])}">${data['grandparent_title']}</a></h1>
<h2><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a> - ${data['title']}</h2> <h2><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a> - ${data['title']}</h2>
<h3 class="hidden-xs">T${data['media_index']}</h3> <h3 class="hidden-xs">T${data['media_index']}</h3>
% elif data['media_type'] in ('photo', 'clip'): % elif data['media_type'] in ('photo', 'clip'):
@ -282,14 +284,14 @@ DOCUMENTATION :: END
padding_height = '' padding_height = ''
if data['media_type'] == 'movie' or data['live']: if data['media_type'] == 'movie' or data['live']:
padding_height = 'height: 305px;' padding_height = 'height: 305px;'
elif data['media_type'] in ('show', 'season', 'collection'): elif data['media_type'] in ('artist', 'album', 'playlist', 'photo_album', 'photo') or data['sub_media_type'] in ('artist', 'album', 'track'):
padding_height = 'height: 270px;'
elif data['media_type'] == 'episode':
padding_height = 'height: 70px;'
elif data['media_type'] in ('artist', 'album', 'playlist', 'photo_album', 'photo'):
padding_height = 'height: 150px;' padding_height = 'height: 150px;'
elif data['media_type'] in ('track', 'clip'): elif data['media_type'] in ('track', 'clip'):
padding_height = 'height: 180px;' padding_height = 'height: 180px;'
elif data['media_type'] == 'episode':
padding_height = 'height: 70px;'
elif data['media_type'] in ('show', 'season', 'collection'):
padding_height = 'height: 270px;'
%> %>
<div class="summary-content-padding hidden-xs hidden-sm" style="${padding_height}"> <div class="summary-content-padding hidden-xs hidden-sm" style="${padding_height}">
% if data['media_type'] in ('movie', 'episode', 'track', 'clip'): % if data['media_type'] in ('movie', 'episode', 'track', 'clip'):
@ -368,6 +370,11 @@ DOCUMENTATION :: END
Studio <strong> ${data['studio']}</strong> Studio <strong> ${data['studio']}</strong>
% endif % endif
</div> </div>
<div class="summary-content-details-tag">
% if data['media_type'] == 'track' and data['original_title']:
Track Artists <strong> ${data['original_title']}</strong>
% endif
</div>
<div class="summary-content-details-tag"> <div class="summary-content-details-tag">
% if data['media_type'] == 'movie': % if data['media_type'] == 'movie':
Year <strong> ${data['year']}</strong> Year <strong> ${data['year']}</strong>
@ -390,6 +397,11 @@ DOCUMENTATION :: END
Runtime <strong> <span id="runtime">${data['duration']}</span></strong> Runtime <strong> <span id="runtime">${data['duration']}</span></strong>
% endif % endif
</div> </div>
% if data['edition_title']:
<div class="summary-content-details-tag">
Edition <strong> ${data['edition_title']} </strong>
</div>
% endif
<div class="summary-content-details-tag"> <div class="summary-content-details-tag">
% if data['content_rating']: % if data['content_rating']:
Rated <strong> ${data['content_rating']} </strong> Rated <strong> ${data['content_rating']} </strong>
@ -542,7 +554,7 @@ DOCUMENTATION :: END
</div> </div>
</div> </div>
% endif % endif
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'): % if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
<div class="col-md-12"> <div class="col-md-12">
<div class="table-card-header"> <div class="table-card-header">
<div class="header-bar"> <div class="header-bar">
@ -571,7 +583,7 @@ DOCUMENTATION :: END
</div> </div>
% endif % endif
<% <%
history_type = data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track') history_type = data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist')
history_active = 'active' if history_type else '' history_active = 'active' if history_type else ''
export_active = 'active' if not history_type else '' export_active = 'active' if not history_type else ''
%> %>
@ -634,7 +646,7 @@ DOCUMENTATION :: END
<div class="col-md-12"> <div class="col-md-12">
<div class="table-card-header"> <div class="table-card-header">
<div class="header-bar"> <div class="header-bar">
% if data['media_type'] in ('artist', 'album', 'track'): % if data['media_type'] in ('artist', 'album', 'track', 'playlist'):
<span>Play History for <strong>${data['title']}</strong></span> <span>Play History for <strong>${data['title']}</strong></span>
% else: % else:
<span>Watch History for <strong>${data['title']}</strong></span> <span>Watch History for <strong>${data['title']}</strong></span>
@ -680,7 +692,7 @@ DOCUMENTATION :: END
<th align="left" id="started">Started</th> <th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th> <th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th> <th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th> <th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th> <th align="left" id="percent_complete"></th>
</tr> </tr>
</thead> </thead>
@ -806,7 +818,7 @@ DOCUMENTATION :: END
% elif data['media_type'] == 'album': % elif data['media_type'] == 'album':
${data['parent_title']}<br />${data['title']} ${data['parent_title']}<br />${data['title']}
% elif data['media_type'] == 'track': % elif data['media_type'] == 'track':
${data['original_title'] or data['grandparent_title']}<br />${data['title']}<br />${data['parent_title']} ${data['grandparent_title']}<br />${data['title']}<br />${data['parent_title']}
% endif % endif
</strong> </strong>
</p> </p>
@ -853,7 +865,7 @@ DOCUMENTATION :: END
%> %>
<script src="${http_root}js/tables/history_table.js${cache_param}"></script> <script src="${http_root}js/tables/history_table.js${cache_param}"></script>
<script src="${http_root}js/tables/export_table.js${cache_param}"></script> <script src="${http_root}js/tables/export_table.js${cache_param}"></script>
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'): % if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
<script> <script>
function loadHistoryTable(transcode_decision) { function loadHistoryTable(transcode_decision) {
// Build watch history table // Build watch history table
@ -873,6 +885,9 @@ DOCUMENTATION :: END
parent_rating_key: "${data['rating_key']}" parent_rating_key: "${data['rating_key']}"
% elif data['media_type'] in ('movie', 'episode', 'track'): % elif data['media_type'] in ('movie', 'episode', 'track'):
rating_key: "${data['rating_key']}" rating_key: "${data['rating_key']}"
% elif data['media_type'] in ('collection', 'playlist'):
media_type: "${data['media_type']}",
rating_key: "${data['rating_key']}"
% endif % endif
}; };
} }
@ -925,13 +940,16 @@ DOCUMENTATION :: END
}); });
</script> </script>
% endif % endif
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'): % if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
<script> <script>
// Populate watch time stats // Populate watch time stats
$.ajax({ $.ajax({
url: 'item_watch_time_stats', url: 'item_watch_time_stats',
async: true, async: true,
data: { rating_key: "${data['rating_key']}" }, data: {
rating_key: "${data['rating_key']}",
media_type: "${data['media_type']}"
},
complete: function(xhr, status) { complete: function(xhr, status) {
$("#watch-time-stats").html(xhr.responseText); $("#watch-time-stats").html(xhr.responseText);
} }
@ -940,7 +958,10 @@ DOCUMENTATION :: END
$.ajax({ $.ajax({
url: 'item_user_stats', url: 'item_user_stats',
async: true, async: true,
data: { rating_key: "${data['rating_key']}" }, data: {
rating_key: "${data['rating_key']}",
media_type: "${data['media_type']}"
},
complete: function(xhr, status) { complete: function(xhr, status) {
$("#user-stats").html(xhr.responseText); $("#user-stats").html(xhr.responseText);
} }

View file

@ -160,6 +160,16 @@ DOCUMENTATION :: END
% endif % endif
</div> </div>
</a> </a>
<div class="item-children-instance-text-wrapper poster-item">
<h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3>
% if media_type == 'collection':
<h3 class="text-muted">
<a class="text-muted" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>
</h3>
% endif
</div>
% elif child['media_type'] == 'episode': % elif child['media_type'] == 'episode':
<a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}"> <a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}">
<div class="item-children-poster"> <div class="item-children-poster">
@ -179,6 +189,29 @@ DOCUMENTATION :: END
<h3> <h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a> <a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3> </h3>
% if media_type == 'collection':
<h3 class="text-muted">
<a href="${page('info', child['grandparent_rating_key'])}" title="${child['grandparent_title']}">${child['grandparent_title']}</a>
</h3>
<h3 class="text-muted">
<a href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${short_season(child['parent_title'])}</a>
&middot; <a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}">E${child['media_index']}</a>
</h3>
% endif
</div>
% elif child['media_type'] == 'artist':
<a href="${page('info', child['rating_key'])}" title="${child['title']}">
<div class="item-children-poster">
<div class="item-children-poster-face cover-item" style="background-image: url(${page('pms_image_proxy', child['thumb'], child['rating_key'], 300, 300, fallback='cover')});"></div>
% if _session['user_group'] == 'admin':
<span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span>
% endif
</div>
</a>
<div class="item-children-instance-text-wrapper cover-item">
<h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3>
</div> </div>
% elif child['media_type'] == 'album': % elif child['media_type'] == 'album':
<a href="${page('info', child['rating_key'])}" title="${child['title']}"> <a href="${page('info', child['rating_key'])}" title="${child['title']}">
@ -193,6 +226,11 @@ DOCUMENTATION :: END
<h3> <h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a> <a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3> </h3>
% if media_type == 'collection':
<h3 class="text-muted">
<a class="text-muted" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>
</h3>
% endif
</div> </div>
% elif child['media_type'] == 'track': % elif child['media_type'] == 'track':
<% e = 'even' if loop.index % 2 == 0 else 'odd' %> <% e = 'even' if loop.index % 2 == 0 else 'odd' %>
@ -205,7 +243,15 @@ DOCUMENTATION :: END
${child['title']} ${child['title']}
</span> </span>
</a> </a>
% if child['original_title']: % if media_type == 'collection':
-
<a href="${page('info', child['grandparent_rating_key'])}" title="${child['grandparent_title']}">
<span class="thumb-tooltip" data-toggle="popover" data-img="${page('pms_image_proxy', child['grandparent_thumb'], child['grandparent_rating_key'], 300, 300, fallback='cover')}" data-height="80" data-width="80">
${child['grandparent_title']}
</span>
</a>
<span class="text-muted"> (<a class="no-highlight" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>)</span>
% elif child['original_title']:
<span class="text-muted"> - ${child['original_title']}</span> <span class="text-muted"> - ${child['original_title']}</span>
% endif % endif
</span> </span>

File diff suppressed because one or more lines are too long

View file

@ -32,7 +32,12 @@ collections_table_options = {
if (rowData['smart']) { if (rowData['smart']) {
smart = '<span class="media-type-tooltip" data-toggle="tooltip" title="Smart Collection"><i class="fa fa-cog fa-fw"></i></span>&nbsp;' smart = '<span class="media-type-tooltip" data-toggle="tooltip" title="Smart Collection"><i class="fa fa-cog fa-fw"></i></span>&nbsp;'
} }
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 450, null, null, null, 'poster') + '" data-height="120" data-width="80">' + rowData['title'] + '</span>'; console.log(rowData['subtype'])
if (rowData['subtype'] === 'artist' || rowData['subtype'] === 'album' || rowData['subtype'] === 'track') {
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 300, null, null, null, 'cover') + '" data-height="80" data-width="80">' + rowData['title'] + '</span>';
} else {
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 450, null, null, null, 'poster') + '" data-height="120" data-width="80">' + rowData['title'] + '</span>';
}
$(td).html(smart + '<a href="' + page('info', rowData['ratingKey']) + '">' + thumb_popover + '</a>'); $(td).html(smart + '<a href="' + page('info', rowData['ratingKey']) + '">' + thumb_popover + '</a>');
} }
}, },

View file

@ -247,7 +247,7 @@ history_table_options = {
}, },
{ {
"targets": [11], "targets": [11],
"data": "duration", "data": "play_duration",
"render": function (data, type, full) { "render": function (data, type, full) {
if (data !== null) { if (data !== null) {
return Math.round(moment.duration(data, 'seconds').as('minutes')) + ' mins'; return Math.round(moment.duration(data, 'seconds').as('minutes')) + ' mins';
@ -529,7 +529,7 @@ function childTableFormat(rowData) {
'<th align="left" id="started">Started</th>' + '<th align="left" id="started">Started</th>' +
'<th align="left" id="paused_counter">Paused</th>' + '<th align="left" id="paused_counter">Paused</th>' +
'<th align="left" id="stopped">Stopped</th>' + '<th align="left" id="stopped">Stopped</th>' +
'<th align="left" id="duration">Duration</th>' + '<th align="left" id="play_duration">Duration</th>' +
'<th align="left" id="percent_complete"></th>' + '<th align="left" id="percent_complete"></th>' +
'</tr>' + '</tr>' +
'</thead>' + '</thead>' +

View file

@ -248,7 +248,7 @@ DOCUMENTATION :: END
<th align="left" id="started">Started</th> <th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th> <th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th> <th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th> <th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th> <th align="left" id="percent_complete"></th>
</tr> </tr>
</thead> </thead>

View file

@ -453,12 +453,12 @@
$("#download-tautullilog").click(function () { $("#download-tautullilog").click(function () {
var logfile = $(".tab-pane.active").data('logfile'); var logfile = $(".tab-pane.active").data('logfile');
window.location.href = "download_log?logfile=" + logfile; window.location.href = "download_log?logfile=" + window.encodeURIComponent(logfile);
}); });
$("#download-plexserverlog").click(function () { $("#download-plexserverlog").click(function () {
var logfile = $("option:selected", "#plex-log-files").val(); var logfile = $("option:selected", "#plex-log-files").val();
window.location.href = "download_plex_log?logfile=" + logfile; window.location.href = "download_plex_log?logfile=" + window.encodeURIComponent(logfile);
}); });
$("#clear-notify-logs").click(function () { $("#clear-notify-logs").click(function () {

View file

@ -70,7 +70,7 @@ DOCUMENTATION :: END
function checkQRAddress(url) { function checkQRAddress(url) {
var parser = document.createElement('a'); var parser = document.createElement('a');
parser.href = url; parser.setAttribute('href', url);
var hostname = parser.hostname; var hostname = parser.hostname;
var protocol = parser.protocol; var protocol = parser.protocol;

View file

@ -142,8 +142,10 @@
<div class="row"> <div class="row">
<div class="col-md-12"> <div class="col-md-12">
<select class="form-control" id="${item['name']}" name="${item['name']}"> <select class="form-control" id="${item['name']}" name="${item['name']}">
% if item['select_all']:
<option value="select-all">Select All</option> <option value="select-all">Select All</option>
<option value="remove-all">Remove All</option> <option value="remove-all">Remove All</option>
% endif
% if isinstance(item['select_options'], dict): % if isinstance(item['select_options'], dict):
% for section, options in item['select_options'].items(): % for section, options in item['select_options'].items():
<optgroup label="${section}"> <optgroup label="${section}">
@ -153,7 +155,9 @@
</optgroup> </optgroup>
% endfor % endfor
% else: % else:
% if item['select_all']:
<option value="border-all"></option> <option value="border-all"></option>
% endif
% for option in sorted(item['select_options'], key=lambda x: x['text'].lower()): % for option in sorted(item['select_options'], key=lambda x: x['text'].lower()):
<option value="${option['value']}">${option['text']}</option> <option value="${option['value']}">${option['text']}</option>
% endfor % endfor

View file

@ -134,8 +134,10 @@
<div class="row"> <div class="row">
<div class="col-md-12"> <div class="col-md-12">
<select class="form-control" id="${item['name']}" name="${item['name']}"> <select class="form-control" id="${item['name']}" name="${item['name']}">
% if item['select_all']:
<option value="select-all">Select All</option> <option value="select-all">Select All</option>
<option value="remove-all">Remove All</option> <option value="remove-all">Remove All</option>
% endif
% if isinstance(item['select_options'], dict): % if isinstance(item['select_options'], dict):
% for section, options in item['select_options'].items(): % for section, options in item['select_options'].items():
<optgroup label="${section}"> <optgroup label="${section}">
@ -145,7 +147,9 @@
</optgroup> </optgroup>
% endfor % endfor
% else: % else:
% if item['select_all']:
<option value="border-all"></option> <option value="border-all"></option>
% endif
% for option in sorted(item['select_options'], key=lambda x: x['text'].lower()): % for option in sorted(item['select_options'], key=lambda x: x['text'].lower()):
<option value="${option['value']}">${option['text']}</option> <option value="${option['value']}">${option['text']}</option>
% endfor % endfor
@ -718,6 +722,12 @@
$('#pushover_priority').change( function () { $('#pushover_priority').change( function () {
pushoverPriority(); pushoverPriority();
}); });
var $pushover_sound = $('#pushover_sound').selectize({
create: true
});
var pushover_sound = $pushover_sound[0].selectize;
pushover_sound.setValue(${json.dumps(next((c['value'] for c in notifier['config_options'] if c['name'] == 'pushover_sound'), [])) | n});
% elif notifier['agent_name'] == 'plexmobileapp': % elif notifier['agent_name'] == 'plexmobileapp':
var $plexmobileapp_user_ids = $('#plexmobileapp_user_ids').selectize({ var $plexmobileapp_user_ids = $('#plexmobileapp_user_ids').selectize({

View file

@ -132,12 +132,6 @@
</label> </label>
<p class="help-block">Change the "<em>Play by day of week</em>" graph to start on Monday. Default is start on Sunday.</p> <p class="help-block">Change the "<em>Play by day of week</em>" graph to start on Monday. Default is start on Sunday.</p>
</div> </div>
<div class="checkbox advanced-setting">
<label>
<input type="checkbox" id="group_history_tables" name="group_history_tables" value="1" ${config['group_history_tables']}> Group Play History
</label>
<p class="help-block">Group play history for the same item and user as a single entry when progress is less than the watched percent.</p>
</div>
<div class="checkbox advanced-setting"> <div class="checkbox advanced-setting">
<label> <label>
<input type="checkbox" id="history_table_activity" name="history_table_activity" value="1" ${config['history_table_activity']}> Current Activity in History Tables <input type="checkbox" id="history_table_activity" name="history_table_activity" value="1" ${config['history_table_activity']}> Current Activity in History Tables
@ -213,6 +207,39 @@
</div> </div>
<p class="help-block">Set the percentage for a music track to be considered as listened. Minimum 50, Maximum 95.</p> <p class="help-block">Set the percentage for a music track to be considered as listened. Minimum 50, Maximum 95.</p>
</div> </div>
<div class="form-group">
<label for="music_watched_percent">Video Watched Completion Behaviour</label>
<div class="row">
<div class="col-md-7">
<select class="form-control" id="watched_marker" name="watched_marker">
<option value="0" ${'selected' if config['watched_marker'] == 0 else ''}>At selected threshold percentage</option>
<option value="1" ${'selected' if config['watched_marker'] == 1 else ''}>At final credits marker position</option>
<option value="2" ${'selected' if config['watched_marker'] == 2 else ''}>At first credits marker position</option>
<option value="3" ${'selected' if config['watched_marker'] == 3 else ''}>Earliest between threshold percent and first credits marker</option>
</select>
</div>
</div>
<p class="help-block">Decide whether to use end credits markers to determine the 'watched' state of video items. When markers are not available the selected threshold percentage will be used.</p>
</div>
<div class="checkbox advanced-setting">
<label>
<input type="checkbox" id="group_history_tables" name="group_history_tables" value="1" ${config['group_history_tables']}> Group Play History
</label>
<p class="help-block">Group play history for the same item and user as a single entry when progress is less than the watched percent.</p>
</div>
<div class="form-group advanced-setting">
<label>Regroup Play History</label>
<p class="help-block">
Fix grouping of play history in the database.<br />
</p>
<div class="row">
<div class="col-md-4">
<div class="btn-group">
<button class="btn btn-form" type="button" id="regroup_history">Regroup</button>
</div>
</div>
</div>
</div>
<div class="form-group advanced-setting"> <div class="form-group advanced-setting">
<label>Flush Temporary Sessions</label> <label>Flush Temporary Sessions</label>
<p class="help-block"> <p class="help-block">
@ -2470,6 +2497,12 @@ $(document).ready(function() {
confirmAjaxCall(url, msg); confirmAjaxCall(url, msg);
}); });
$("#regroup_history").click(function () {
var msg = 'Are you sure you want to regroup play history in the database?<br /><br /><strong>This make take a long time for large databases.<br />Regrouping will continue in the background.</strong>';
var url = 'regroup_history';
confirmAjaxCall(url, msg);
});
$("#delete_temp_sessions").click(function () { $("#delete_temp_sessions").click(function () {
var msg = 'Are you sure you want to flush the temporary sessions?<br /><br /><strong>This will reset all currently active sessions.</strong>'; var msg = 'Are you sure you want to flush the temporary sessions?<br /><br /><strong>This will reset all currently active sessions.</strong>';
var url = 'delete_temp_sessions'; var url = 'delete_temp_sessions';

View file

@ -212,7 +212,7 @@ DOCUMENTATION :: END
<th align="left" id="started">Started</th> <th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th> <th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th> <th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th> <th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th> <th align="left" id="percent_complete"></th>
</tr> </tr>
</thead> </thead>

View file

@ -1,121 +0,0 @@
#!/usr/bin/python
###############################################################################
# Formatting filter for urllib2's HTTPHandler(debuglevel=1) output
# Copyright (c) 2013, Analytics Pros
#
# This project is free software, distributed under the BSD license.
# Analytics Pros offers consulting and integration services if your firm needs
# assistance in strategy, implementation, or auditing existing work.
###############################################################################
import sys, re, os
from io import StringIO
class BufferTranslator(object):
""" Provides a buffer-compatible interface for filtering buffer content.
"""
parsers = []
def __init__(self, output):
self.output = output
self.encoding = getattr(output, 'encoding', None)
def write(self, content):
content = self.translate(content)
self.output.write(content)
@staticmethod
def stripslashes(content):
return content.decode('string_escape')
@staticmethod
def addslashes(content):
return content.encode('string_escape')
def translate(self, line):
for pattern, method in self.parsers:
match = pattern.match(line)
if match:
return method(match)
return line
class LineBufferTranslator(BufferTranslator):
""" Line buffer implementation supports translation of line-format input
even when input is not already line-buffered. Caches input until newlines
occur, and then dispatches translated input to output buffer.
"""
def __init__(self, *a, **kw):
self._linepending = []
super(LineBufferTranslator, self).__init__(*a, **kw)
def write(self, _input):
lines = _input.splitlines(True)
for i in range(0, len(lines)):
last = i
if lines[i].endswith('\n'):
prefix = len(self._linepending) and ''.join(self._linepending) or ''
self.output.write(self.translate(prefix + lines[i]))
del self._linepending[0:]
last = -1
if last >= 0:
self._linepending.append(lines[ last ])
def __del__(self):
if len(self._linepending):
self.output.write(self.translate(''.join(self._linepending)))
class HTTPTranslator(LineBufferTranslator):
""" Translates output from |urllib2| HTTPHandler(debuglevel = 1) into
HTTP-compatible, readible text structures for human analysis.
"""
RE_LINE_PARSER = re.compile(r'^(?:([a-z]+):)\s*(\'?)([^\r\n]*)\2(?:[\r\n]*)$')
RE_LINE_BREAK = re.compile(r'(\r?\n|(?:\\r)?\\n)')
RE_HTTP_METHOD = re.compile(r'^(POST|GET|HEAD|DELETE|PUT|TRACE|OPTIONS)')
RE_PARAMETER_SPACER = re.compile(r'&([a-z0-9]+)=')
@classmethod
def spacer(cls, line):
return cls.RE_PARAMETER_SPACER.sub(r' &\1= ', line)
def translate(self, line):
parsed = self.RE_LINE_PARSER.match(line)
if parsed:
value = parsed.group(3)
stage = parsed.group(1)
if stage == 'send': # query string is rendered here
return '\n# HTTP Request:\n' + self.stripslashes(value)
elif stage == 'reply':
return '\n\n# HTTP Response:\n' + self.stripslashes(value)
elif stage == 'header':
return value + '\n'
else:
return value
return line
def consume(outbuffer = None): # Capture standard output
sys.stdout = HTTPTranslator(outbuffer or sys.stdout)
return sys.stdout
if __name__ == '__main__':
consume(sys.stdout).write(sys.stdin.read())
print('\n')
# vim: set nowrap tabstop=4 shiftwidth=4 softtabstop=0 expandtab textwidth=0 filetype=python foldmethod=indent foldcolumn=4

View file

@ -1,424 +0,0 @@
from future.moves.urllib.request import urlopen, build_opener, install_opener
from future.moves.urllib.request import Request, HTTPSHandler
from future.moves.urllib.error import URLError, HTTPError
from future.moves.urllib.parse import urlencode
import random
import datetime
import time
import uuid
import hashlib
import socket
def generate_uuid(basedata=None):
""" Provides a _random_ UUID with no input, or a UUID4-format MD5 checksum of any input data provided """
if basedata is None:
return str(uuid.uuid4())
elif isinstance(basedata, str):
checksum = hashlib.md5(str(basedata).encode('utf-8')).hexdigest()
return '%8s-%4s-%4s-%4s-%12s' % (
checksum[0:8], checksum[8:12], checksum[12:16], checksum[16:20], checksum[20:32])
class Time(datetime.datetime):
""" Wrappers and convenience methods for processing various time representations """
@classmethod
def from_unix(cls, seconds, milliseconds=0):
""" Produce a full |datetime.datetime| object from a Unix timestamp """
base = list(time.gmtime(seconds))[0:6]
base.append(milliseconds * 1000) # microseconds
return cls(*base)
@classmethod
def to_unix(cls, timestamp):
""" Wrapper over time module to produce Unix epoch time as a float """
if not isinstance(timestamp, datetime.datetime):
raise TypeError('Time.milliseconds expects a datetime object')
base = time.mktime(timestamp.timetuple())
return base
@classmethod
def milliseconds_offset(cls, timestamp, now=None):
""" Offset time (in milliseconds) from a |datetime.datetime| object to now """
if isinstance(timestamp, (int, float)):
base = timestamp
else:
base = cls.to_unix(timestamp)
base = base + (timestamp.microsecond / 1000000)
if now is None:
now = time.time()
return (now - base) * 1000
class HTTPRequest(object):
""" URL Construction and request handling abstraction.
This is not intended to be used outside this module.
Automates mapping of persistent state (i.e. query parameters)
onto transcient datasets for each query.
"""
endpoint = 'https://www.google-analytics.com/collect'
@staticmethod
def debug():
""" Activate debugging on urllib2 """
handler = HTTPSHandler(debuglevel=1)
opener = build_opener(handler)
install_opener(opener)
# Store properties for all requests
def __init__(self, user_agent=None, *args, **opts):
self.user_agent = user_agent or 'Analytics Pros - Universal Analytics (Python)'
@classmethod
def fixUTF8(cls, data): # Ensure proper encoding for UA's servers...
""" Convert all strings to UTF-8 """
for key in data:
if isinstance(data[key], str):
data[key] = data[key].encode('utf-8')
return data
# Apply stored properties to the given dataset & POST to the configured endpoint
def send(self, data):
request = Request(
self.endpoint + '?' + urlencode(self.fixUTF8(data)).encode('utf-8'),
headers={
'User-Agent': self.user_agent
}
)
self.open(request)
def open(self, request):
try:
return urlopen(request)
except HTTPError as e:
return False
except URLError as e:
self.cache_request(request)
return False
def cache_request(self, request):
# TODO: implement a proper caching mechanism here for re-transmitting hits
# record = (Time.now(), request.get_full_url(), request.get_data(), request.headers)
pass
class HTTPPost(HTTPRequest):
# Apply stored properties to the given dataset & POST to the configured endpoint
def send(self, data):
request = Request(
self.endpoint,
data=urlencode(self.fixUTF8(data)).encode('utf-8'),
headers={
'User-Agent': self.user_agent
}
)
self.open(request)
class Tracker(object):
""" Primary tracking interface for Universal Analytics """
params = None
parameter_alias = {}
valid_hittypes = ('pageview', 'event', 'social', 'screenview', 'transaction', 'item', 'exception', 'timing')
@classmethod
def alias(cls, typemap, base, *names):
""" Declare an alternate (humane) name for a measurement protocol parameter """
cls.parameter_alias[base] = (typemap, base)
for i in names:
cls.parameter_alias[i] = (typemap, base)
@classmethod
def coerceParameter(cls, name, value=None):
if isinstance(name, str) and name[0] == '&':
return name[1:], str(value)
elif name in cls.parameter_alias:
typecast, param_name = cls.parameter_alias.get(name)
return param_name, typecast(value)
else:
raise KeyError('Parameter "{0}" is not recognized'.format(name))
def payload(self, data):
for key, value in data.items():
try:
yield self.coerceParameter(key, value)
except KeyError:
continue
option_sequence = {
'pageview': [(str, 'dp')],
'event': [(str, 'ec'), (str, 'ea'), (str, 'el'), (int, 'ev')],
'social': [(str, 'sn'), (str, 'sa'), (str, 'st')],
'timing': [(str, 'utc'), (str, 'utv'), (str, 'utt'), (str, 'utl')]
}
@classmethod
def consume_options(cls, data, hittype, args):
""" Interpret sequential arguments related to known hittypes based on declared structures """
opt_position = 0
data['t'] = hittype # integrate hit type parameter
if hittype in cls.option_sequence:
for expected_type, optname in cls.option_sequence[hittype]:
if opt_position < len(args) and isinstance(args[opt_position], expected_type):
data[optname] = args[opt_position]
opt_position += 1
@classmethod
def hittime(cls, timestamp=None, age=None, milliseconds=None):
""" Returns an integer represeting the milliseconds offset for a given hit (relative to now) """
if isinstance(timestamp, (int, float)):
return int(Time.milliseconds_offset(Time.from_unix(timestamp, milliseconds=milliseconds)))
if isinstance(timestamp, datetime.datetime):
return int(Time.milliseconds_offset(timestamp))
if isinstance(age, (int, float)):
return int(age * 1000) + (milliseconds or 0)
@property
def account(self):
return self.params.get('tid', None)
def __init__(self, account, name=None, client_id=None, hash_client_id=False, user_id=None, user_agent=None,
use_post=True):
if use_post is False:
self.http = HTTPRequest(user_agent=user_agent)
else:
self.http = HTTPPost(user_agent=user_agent)
self.params = {'v': 1, 'tid': account}
if client_id is None:
client_id = generate_uuid()
self.params['cid'] = client_id
self.hash_client_id = hash_client_id
if user_id is not None:
self.params['uid'] = user_id
def set_timestamp(self, data):
""" Interpret time-related options, apply queue-time parameter as needed """
if 'hittime' in data: # an absolute timestamp
data['qt'] = self.hittime(timestamp=data.pop('hittime', None))
if 'hitage' in data: # a relative age (in seconds)
data['qt'] = self.hittime(age=data.pop('hitage', None))
def send(self, hittype, *args, **data):
""" Transmit HTTP requests to Google Analytics using the measurement protocol """
if hittype not in self.valid_hittypes:
raise KeyError('Unsupported Universal Analytics Hit Type: {0}'.format(repr(hittype)))
self.set_timestamp(data)
self.consume_options(data, hittype, args)
for item in args: # process dictionary-object arguments of transcient data
if isinstance(item, dict):
for key, val in self.payload(item):
data[key] = val
for k, v in self.params.items(): # update only absent parameters
if k not in data:
data[k] = v
data = dict(self.payload(data))
if self.hash_client_id:
data['cid'] = generate_uuid(data['cid'])
# Transmit the hit to Google...
self.http.send(data)
# Setting persistent attibutes of the session/hit/etc (inc. custom dimensions/metrics)
def set(self, name, value=None):
if isinstance(name, dict):
for key, value in name.items():
try:
param, value = self.coerceParameter(key, value)
self.params[param] = value
except KeyError:
pass
elif isinstance(name, str):
try:
param, value = self.coerceParameter(name, value)
self.params[param] = value
except KeyError:
pass
def __getitem__(self, name):
param, value = self.coerceParameter(name, None)
return self.params.get(param, None)
def __setitem__(self, name, value):
param, value = self.coerceParameter(name, value)
self.params[param] = value
def __delitem__(self, name):
param, value = self.coerceParameter(name, None)
if param in self.params:
del self.params[param]
def safe_unicode(obj):
""" Safe convertion to the Unicode string version of the object """
try:
return str(obj)
except UnicodeDecodeError:
return obj.decode('utf-8')
# Declaring name mappings for Measurement Protocol parameters
MAX_CUSTOM_DEFINITIONS = 200
MAX_EC_LISTS = 11 # 1-based index
MAX_EC_PRODUCTS = 11 # 1-based index
MAX_EC_PROMOTIONS = 11 # 1-based index
Tracker.alias(int, 'v', 'protocol-version')
Tracker.alias(safe_unicode, 'cid', 'client-id', 'clientId', 'clientid')
Tracker.alias(safe_unicode, 'tid', 'trackingId', 'account')
Tracker.alias(safe_unicode, 'uid', 'user-id', 'userId', 'userid')
Tracker.alias(safe_unicode, 'uip', 'user-ip', 'userIp', 'ipaddr')
Tracker.alias(safe_unicode, 'ua', 'userAgent', 'userAgentOverride', 'user-agent')
Tracker.alias(safe_unicode, 'dp', 'page', 'path')
Tracker.alias(safe_unicode, 'dt', 'title', 'pagetitle', 'pageTitle' 'page-title')
Tracker.alias(safe_unicode, 'dl', 'location')
Tracker.alias(safe_unicode, 'dh', 'hostname')
Tracker.alias(safe_unicode, 'sc', 'sessioncontrol', 'session-control', 'sessionControl')
Tracker.alias(safe_unicode, 'dr', 'referrer', 'referer')
Tracker.alias(int, 'qt', 'queueTime', 'queue-time')
Tracker.alias(safe_unicode, 't', 'hitType', 'hittype')
Tracker.alias(int, 'aip', 'anonymizeIp', 'anonIp', 'anonymize-ip')
Tracker.alias(safe_unicode, 'ds', 'dataSource', 'data-source')
# Campaign attribution
Tracker.alias(safe_unicode, 'cn', 'campaign', 'campaignName', 'campaign-name')
Tracker.alias(safe_unicode, 'cs', 'source', 'campaignSource', 'campaign-source')
Tracker.alias(safe_unicode, 'cm', 'medium', 'campaignMedium', 'campaign-medium')
Tracker.alias(safe_unicode, 'ck', 'keyword', 'campaignKeyword', 'campaign-keyword')
Tracker.alias(safe_unicode, 'cc', 'content', 'campaignContent', 'campaign-content')
Tracker.alias(safe_unicode, 'ci', 'campaignId', 'campaignID', 'campaign-id')
# Technical specs
Tracker.alias(safe_unicode, 'sr', 'screenResolution', 'screen-resolution', 'resolution')
Tracker.alias(safe_unicode, 'vp', 'viewport', 'viewportSize', 'viewport-size')
Tracker.alias(safe_unicode, 'de', 'encoding', 'documentEncoding', 'document-encoding')
Tracker.alias(int, 'sd', 'colors', 'screenColors', 'screen-colors')
Tracker.alias(safe_unicode, 'ul', 'language', 'user-language', 'userLanguage')
# Mobile app
Tracker.alias(safe_unicode, 'an', 'appName', 'app-name', 'app')
Tracker.alias(safe_unicode, 'cd', 'contentDescription', 'screenName', 'screen-name', 'content-description')
Tracker.alias(safe_unicode, 'av', 'appVersion', 'app-version', 'version')
Tracker.alias(safe_unicode, 'aid', 'appID', 'appId', 'application-id', 'app-id', 'applicationId')
Tracker.alias(safe_unicode, 'aiid', 'appInstallerId', 'app-installer-id')
# Ecommerce
Tracker.alias(safe_unicode, 'ta', 'affiliation', 'transactionAffiliation', 'transaction-affiliation')
Tracker.alias(safe_unicode, 'ti', 'transaction', 'transactionId', 'transaction-id')
Tracker.alias(float, 'tr', 'revenue', 'transactionRevenue', 'transaction-revenue')
Tracker.alias(float, 'ts', 'shipping', 'transactionShipping', 'transaction-shipping')
Tracker.alias(float, 'tt', 'tax', 'transactionTax', 'transaction-tax')
Tracker.alias(safe_unicode, 'cu', 'currency', 'transactionCurrency',
'transaction-currency') # Currency code, e.g. USD, EUR
Tracker.alias(safe_unicode, 'in', 'item-name', 'itemName')
Tracker.alias(float, 'ip', 'item-price', 'itemPrice')
Tracker.alias(float, 'iq', 'item-quantity', 'itemQuantity')
Tracker.alias(safe_unicode, 'ic', 'item-code', 'sku', 'itemCode')
Tracker.alias(safe_unicode, 'iv', 'item-variation', 'item-category', 'itemCategory', 'itemVariation')
# Events
Tracker.alias(safe_unicode, 'ec', 'event-category', 'eventCategory', 'category')
Tracker.alias(safe_unicode, 'ea', 'event-action', 'eventAction', 'action')
Tracker.alias(safe_unicode, 'el', 'event-label', 'eventLabel', 'label')
Tracker.alias(int, 'ev', 'event-value', 'eventValue', 'value')
Tracker.alias(int, 'ni', 'noninteractive', 'nonInteractive', 'noninteraction', 'nonInteraction')
# Social
Tracker.alias(safe_unicode, 'sa', 'social-action', 'socialAction')
Tracker.alias(safe_unicode, 'sn', 'social-network', 'socialNetwork')
Tracker.alias(safe_unicode, 'st', 'social-target', 'socialTarget')
# Exceptions
Tracker.alias(safe_unicode, 'exd', 'exception-description', 'exceptionDescription', 'exDescription')
Tracker.alias(int, 'exf', 'exception-fatal', 'exceptionFatal', 'exFatal')
# User Timing
Tracker.alias(safe_unicode, 'utc', 'timingCategory', 'timing-category')
Tracker.alias(safe_unicode, 'utv', 'timingVariable', 'timing-variable')
Tracker.alias(float, 'utt', 'time', 'timingTime', 'timing-time')
Tracker.alias(safe_unicode, 'utl', 'timingLabel', 'timing-label')
Tracker.alias(float, 'dns', 'timingDNS', 'timing-dns')
Tracker.alias(float, 'pdt', 'timingPageLoad', 'timing-page-load')
Tracker.alias(float, 'rrt', 'timingRedirect', 'timing-redirect')
Tracker.alias(safe_unicode, 'tcp', 'timingTCPConnect', 'timing-tcp-connect')
Tracker.alias(safe_unicode, 'srt', 'timingServerResponse', 'timing-server-response')
# Custom dimensions and metrics
for i in range(0, 200):
Tracker.alias(safe_unicode, 'cd{0}'.format(i), 'dimension{0}'.format(i))
Tracker.alias(int, 'cm{0}'.format(i), 'metric{0}'.format(i))
# Content groups
for i in range(0, 5):
Tracker.alias(safe_unicode, 'cg{0}'.format(i), 'contentGroup{0}'.format(i))
# Enhanced Ecommerce
Tracker.alias(str, 'pa') # Product action
Tracker.alias(str, 'tcc') # Coupon code
Tracker.alias(str, 'pal') # Product action list
Tracker.alias(int, 'cos') # Checkout step
Tracker.alias(str, 'col') # Checkout step option
Tracker.alias(str, 'promoa') # Promotion action
for product_index in range(1, MAX_EC_PRODUCTS):
Tracker.alias(str, 'pr{0}id'.format(product_index)) # Product SKU
Tracker.alias(str, 'pr{0}nm'.format(product_index)) # Product name
Tracker.alias(str, 'pr{0}br'.format(product_index)) # Product brand
Tracker.alias(str, 'pr{0}ca'.format(product_index)) # Product category
Tracker.alias(str, 'pr{0}va'.format(product_index)) # Product variant
Tracker.alias(str, 'pr{0}pr'.format(product_index)) # Product price
Tracker.alias(int, 'pr{0}qt'.format(product_index)) # Product quantity
Tracker.alias(str, 'pr{0}cc'.format(product_index)) # Product coupon code
Tracker.alias(int, 'pr{0}ps'.format(product_index)) # Product position
for custom_index in range(MAX_CUSTOM_DEFINITIONS):
Tracker.alias(str, 'pr{0}cd{1}'.format(product_index, custom_index)) # Product custom dimension
Tracker.alias(int, 'pr{0}cm{1}'.format(product_index, custom_index)) # Product custom metric
for list_index in range(1, MAX_EC_LISTS):
Tracker.alias(str, 'il{0}pi{1}id'.format(list_index, product_index)) # Product impression SKU
Tracker.alias(str, 'il{0}pi{1}nm'.format(list_index, product_index)) # Product impression name
Tracker.alias(str, 'il{0}pi{1}br'.format(list_index, product_index)) # Product impression brand
Tracker.alias(str, 'il{0}pi{1}ca'.format(list_index, product_index)) # Product impression category
Tracker.alias(str, 'il{0}pi{1}va'.format(list_index, product_index)) # Product impression variant
Tracker.alias(int, 'il{0}pi{1}ps'.format(list_index, product_index)) # Product impression position
Tracker.alias(int, 'il{0}pi{1}pr'.format(list_index, product_index)) # Product impression price
for custom_index in range(MAX_CUSTOM_DEFINITIONS):
Tracker.alias(str, 'il{0}pi{1}cd{2}'.format(list_index, product_index,
custom_index)) # Product impression custom dimension
Tracker.alias(int, 'il{0}pi{1}cm{2}'.format(list_index, product_index,
custom_index)) # Product impression custom metric
for list_index in range(1, MAX_EC_LISTS):
Tracker.alias(str, 'il{0}nm'.format(list_index)) # Product impression list name
for promotion_index in range(1, MAX_EC_PROMOTIONS):
Tracker.alias(str, 'promo{0}id'.format(promotion_index)) # Promotion ID
Tracker.alias(str, 'promo{0}nm'.format(promotion_index)) # Promotion name
Tracker.alias(str, 'promo{0}cr'.format(promotion_index)) # Promotion creative
Tracker.alias(str, 'promo{0}ps'.format(promotion_index)) # Promotion position
# Shortcut for creating trackers
def create(account, *args, **kwargs):
return Tracker(account, *args, **kwargs)
# vim: set nowrap tabstop=4 shiftwidth=4 softtabstop=0 expandtab textwidth=0 filetype=python foldmethod=indent foldcolumn=4

View file

@ -1 +0,0 @@
from . import Tracker

View file

@ -3,13 +3,9 @@ from __future__ import absolute_import
import sys import sys
from apscheduler.executors.base import BaseExecutor, run_job from apscheduler.executors.base import BaseExecutor, run_job
from apscheduler.executors.base_py3 import run_coroutine_job
from apscheduler.util import iscoroutinefunction_partial from apscheduler.util import iscoroutinefunction_partial
try:
from apscheduler.executors.base_py3 import run_coroutine_job
except ImportError:
run_coroutine_job = None
class AsyncIOExecutor(BaseExecutor): class AsyncIOExecutor(BaseExecutor):
""" """
@ -46,11 +42,8 @@ class AsyncIOExecutor(BaseExecutor):
self._run_job_success(job.id, events) self._run_job_success(job.id, events)
if iscoroutinefunction_partial(job.func): if iscoroutinefunction_partial(job.func):
if run_coroutine_job is not None: coro = run_coroutine_job(job, job._jobstore_alias, run_times, self._logger.name)
coro = run_coroutine_job(job, job._jobstore_alias, run_times, self._logger.name) f = self._eventloop.create_task(coro)
f = self._eventloop.create_task(coro)
else:
raise Exception('Executing coroutine based jobs is not supported with Trollius')
else: else:
f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times, f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times,
self._logger.name) self._logger.name)

View file

@ -57,7 +57,7 @@ class SQLAlchemyJobStore(BaseJobStore):
# 25 = precision that translates to an 8-byte float # 25 = precision that translates to an 8-byte float
self.jobs_t = Table( self.jobs_t = Table(
tablename, metadata, tablename, metadata,
Column('id', Unicode(191, _warn_on_bytestring=False), primary_key=True), Column('id', Unicode(191), primary_key=True),
Column('next_run_time', Float(25), index=True), Column('next_run_time', Float(25), index=True),
Column('job_state', LargeBinary, nullable=False), Column('job_state', LargeBinary, nullable=False),
schema=tableschema schema=tableschema
@ -68,20 +68,22 @@ class SQLAlchemyJobStore(BaseJobStore):
self.jobs_t.create(self.engine, True) self.jobs_t.create(self.engine, True)
def lookup_job(self, job_id): def lookup_job(self, job_id):
selectable = select([self.jobs_t.c.job_state]).where(self.jobs_t.c.id == job_id) selectable = select(self.jobs_t.c.job_state).where(self.jobs_t.c.id == job_id)
job_state = self.engine.execute(selectable).scalar() with self.engine.begin() as connection:
return self._reconstitute_job(job_state) if job_state else None job_state = connection.execute(selectable).scalar()
return self._reconstitute_job(job_state) if job_state else None
def get_due_jobs(self, now): def get_due_jobs(self, now):
timestamp = datetime_to_utc_timestamp(now) timestamp = datetime_to_utc_timestamp(now)
return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp) return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp)
def get_next_run_time(self): def get_next_run_time(self):
selectable = select([self.jobs_t.c.next_run_time]).\ selectable = select(self.jobs_t.c.next_run_time).\
where(self.jobs_t.c.next_run_time != null()).\ where(self.jobs_t.c.next_run_time != null()).\
order_by(self.jobs_t.c.next_run_time).limit(1) order_by(self.jobs_t.c.next_run_time).limit(1)
next_run_time = self.engine.execute(selectable).scalar() with self.engine.begin() as connection:
return utc_timestamp_to_datetime(next_run_time) next_run_time = connection.execute(selectable).scalar()
return utc_timestamp_to_datetime(next_run_time)
def get_all_jobs(self): def get_all_jobs(self):
jobs = self._get_jobs() jobs = self._get_jobs()
@ -94,29 +96,33 @@ class SQLAlchemyJobStore(BaseJobStore):
'next_run_time': datetime_to_utc_timestamp(job.next_run_time), 'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol) 'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
}) })
try: with self.engine.begin() as connection:
self.engine.execute(insert) try:
except IntegrityError: connection.execute(insert)
raise ConflictingIdError(job.id) except IntegrityError:
raise ConflictingIdError(job.id)
def update_job(self, job): def update_job(self, job):
update = self.jobs_t.update().values(**{ update = self.jobs_t.update().values(**{
'next_run_time': datetime_to_utc_timestamp(job.next_run_time), 'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol) 'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
}).where(self.jobs_t.c.id == job.id) }).where(self.jobs_t.c.id == job.id)
result = self.engine.execute(update) with self.engine.begin() as connection:
if result.rowcount == 0: result = connection.execute(update)
raise JobLookupError(job.id) if result.rowcount == 0:
raise JobLookupError(job.id)
def remove_job(self, job_id): def remove_job(self, job_id):
delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id) delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id)
result = self.engine.execute(delete) with self.engine.begin() as connection:
if result.rowcount == 0: result = connection.execute(delete)
raise JobLookupError(job_id) if result.rowcount == 0:
raise JobLookupError(job_id)
def remove_all_jobs(self): def remove_all_jobs(self):
delete = self.jobs_t.delete() delete = self.jobs_t.delete()
self.engine.execute(delete) with self.engine.begin() as connection:
connection.execute(delete)
def shutdown(self): def shutdown(self):
self.engine.dispose() self.engine.dispose()
@ -132,21 +138,22 @@ class SQLAlchemyJobStore(BaseJobStore):
def _get_jobs(self, *conditions): def _get_jobs(self, *conditions):
jobs = [] jobs = []
selectable = select([self.jobs_t.c.id, self.jobs_t.c.job_state]).\ selectable = select(self.jobs_t.c.id, self.jobs_t.c.job_state).\
order_by(self.jobs_t.c.next_run_time) order_by(self.jobs_t.c.next_run_time)
selectable = selectable.where(and_(*conditions)) if conditions else selectable selectable = selectable.where(and_(*conditions)) if conditions else selectable
failed_job_ids = set() failed_job_ids = set()
for row in self.engine.execute(selectable): with self.engine.begin() as connection:
try: for row in connection.execute(selectable):
jobs.append(self._reconstitute_job(row.job_state)) try:
except BaseException: jobs.append(self._reconstitute_job(row.job_state))
self._logger.exception('Unable to restore job "%s" -- removing it', row.id) except BaseException:
failed_job_ids.add(row.id) self._logger.exception('Unable to restore job "%s" -- removing it', row.id)
failed_job_ids.add(row.id)
# Remove all the jobs we failed to restore # Remove all the jobs we failed to restore
if failed_job_ids: if failed_job_ids:
delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids)) delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids))
self.engine.execute(delete) connection.execute(delete)
return jobs return jobs

View file

@ -1,18 +1,10 @@
from __future__ import absolute_import from __future__ import absolute_import
import asyncio
from functools import wraps, partial from functools import wraps, partial
from apscheduler.schedulers.base import BaseScheduler from apscheduler.schedulers.base import BaseScheduler
from apscheduler.util import maybe_ref from apscheduler.util import maybe_ref
try:
import asyncio
except ImportError: # pragma: nocover
try:
import trollius as asyncio
except ImportError:
raise ImportError(
'AsyncIOScheduler requires either Python 3.4 or the asyncio package installed')
def run_in_event_loop(func): def run_in_event_loop(func):
@wraps(func) @wraps(func)

View file

@ -33,7 +33,7 @@ class QtScheduler(BaseScheduler):
def _start_timer(self, wait_seconds): def _start_timer(self, wait_seconds):
self._stop_timer() self._stop_timer()
if wait_seconds is not None: if wait_seconds is not None:
wait_time = min(wait_seconds * 1000, 2147483647) wait_time = min(int(wait_seconds * 1000), 2147483647)
self._timer = QTimer.singleShot(wait_time, self._process_jobs) self._timer = QTimer.singleShot(wait_time, self._process_jobs)
def _stop_timer(self): def _stop_timer(self):

View file

@ -2,6 +2,7 @@
from __future__ import division from __future__ import division
from asyncio import iscoroutinefunction
from datetime import date, datetime, time, timedelta, tzinfo from datetime import date, datetime, time, timedelta, tzinfo
from calendar import timegm from calendar import timegm
from functools import partial from functools import partial
@ -22,15 +23,6 @@ try:
except ImportError: except ImportError:
TIMEOUT_MAX = 4294967 # Maximum value accepted by Event.wait() on Windows TIMEOUT_MAX = 4294967 # Maximum value accepted by Event.wait() on Windows
try:
from asyncio import iscoroutinefunction
except ImportError:
try:
from trollius import iscoroutinefunction
except ImportError:
def iscoroutinefunction(func):
return False
__all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp', __all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp',
'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name', 'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name',
'obj_to_ref', 'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args', 'obj_to_ref', 'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args',

View file

@ -11,9 +11,9 @@ from bleach.sanitizer import (
# yyyymmdd # yyyymmdd
__releasedate__ = "20220627" __releasedate__ = "20230123"
# x.y.z or x.y.z.dev0 -- semver # x.y.z or x.y.z.dev0 -- semver
__version__ = "5.0.1" __version__ = "6.0.0"
__all__ = ["clean", "linkify"] __all__ = ["clean", "linkify"]
@ -52,7 +52,7 @@ def clean(
:arg str text: the text to clean :arg str text: the text to clean
:arg list tags: allowed list of tags; defaults to :arg set tags: set of allowed tags; defaults to
``bleach.sanitizer.ALLOWED_TAGS`` ``bleach.sanitizer.ALLOWED_TAGS``
:arg dict attributes: allowed attributes; can be a callable, list or dict; :arg dict attributes: allowed attributes; can be a callable, list or dict;

View file

@ -38,6 +38,9 @@ from bleach._vendor.html5lib.filters.sanitizer import (
allowed_protocols, allowed_protocols,
allowed_css_properties, allowed_css_properties,
allowed_svg_properties, allowed_svg_properties,
attr_val_is_uri,
svg_attr_val_allows_ref,
svg_allow_local_href,
) # noqa: E402 module level import not at top of file ) # noqa: E402 module level import not at top of file
from bleach._vendor.html5lib.filters.sanitizer import ( from bleach._vendor.html5lib.filters.sanitizer import (
Filter as SanitizerFilter, Filter as SanitizerFilter,
@ -78,127 +81,129 @@ TAG_TOKEN_TYPE_PARSEERROR = constants.tokenTypes["ParseError"]
#: List of valid HTML tags, from WHATWG HTML Living Standard as of 2018-10-17 #: List of valid HTML tags, from WHATWG HTML Living Standard as of 2018-10-17
#: https://html.spec.whatwg.org/multipage/indices.html#elements-3 #: https://html.spec.whatwg.org/multipage/indices.html#elements-3
HTML_TAGS = [ HTML_TAGS = frozenset(
"a", (
"abbr", "a",
"address", "abbr",
"area", "address",
"article", "area",
"aside", "article",
"audio", "aside",
"b", "audio",
"base", "b",
"bdi", "base",
"bdo", "bdi",
"blockquote", "bdo",
"body", "blockquote",
"br", "body",
"button", "br",
"canvas", "button",
"caption", "canvas",
"cite", "caption",
"code", "cite",
"col", "code",
"colgroup", "col",
"data", "colgroup",
"datalist", "data",
"dd", "datalist",
"del", "dd",
"details", "del",
"dfn", "details",
"dialog", "dfn",
"div", "dialog",
"dl", "div",
"dt", "dl",
"em", "dt",
"embed", "em",
"fieldset", "embed",
"figcaption", "fieldset",
"figure", "figcaption",
"footer", "figure",
"form", "footer",
"h1", "form",
"h2", "h1",
"h3", "h2",
"h4", "h3",
"h5", "h4",
"h6", "h5",
"head", "h6",
"header", "head",
"hgroup", "header",
"hr", "hgroup",
"html", "hr",
"i", "html",
"iframe", "i",
"img", "iframe",
"input", "img",
"ins", "input",
"kbd", "ins",
"keygen", "kbd",
"label", "keygen",
"legend", "label",
"li", "legend",
"link", "li",
"map", "link",
"mark", "map",
"menu", "mark",
"meta", "menu",
"meter", "meta",
"nav", "meter",
"noscript", "nav",
"object", "noscript",
"ol", "object",
"optgroup", "ol",
"option", "optgroup",
"output", "option",
"p", "output",
"param", "p",
"picture", "param",
"pre", "picture",
"progress", "pre",
"q", "progress",
"rp", "q",
"rt", "rp",
"ruby", "rt",
"s", "ruby",
"samp", "s",
"script", "samp",
"section", "script",
"select", "section",
"slot", "select",
"small", "slot",
"source", "small",
"span", "source",
"strong", "span",
"style", "strong",
"sub", "style",
"summary", "sub",
"sup", "summary",
"table", "sup",
"tbody", "table",
"td", "tbody",
"template", "td",
"textarea", "template",
"tfoot", "textarea",
"th", "tfoot",
"thead", "th",
"time", "thead",
"title", "time",
"tr", "title",
"track", "tr",
"u", "track",
"ul", "u",
"var", "ul",
"video", "var",
"wbr", "video",
] "wbr",
)
)
#: List of block level HTML tags, as per https://github.com/mozilla/bleach/issues/369 #: List of block level HTML tags, as per https://github.com/mozilla/bleach/issues/369
#: from mozilla on 2019.07.11 #: from mozilla on 2019.07.11
#: https://developer.mozilla.org/en-US/docs/Web/HTML/Block-level_elements#Elements #: https://developer.mozilla.org/en-US/docs/Web/HTML/Block-level_elements#Elements
HTML_TAGS_BLOCK_LEVEL = frozenset( HTML_TAGS_BLOCK_LEVEL = frozenset(
[ (
"address", "address",
"article", "article",
"aside", "aside",
@ -232,7 +237,7 @@ HTML_TAGS_BLOCK_LEVEL = frozenset(
"section", "section",
"table", "table",
"ul", "ul",
] )
) )
@ -473,7 +478,7 @@ class BleachHTMLParser(HTMLParser):
def __init__(self, tags, strip, consume_entities, **kwargs): def __init__(self, tags, strip, consume_entities, **kwargs):
""" """
:arg tags: list of allowed tags--everything else is either stripped or :arg tags: set of allowed tags--everything else is either stripped or
escaped; if None, then this doesn't look at tags at all escaped; if None, then this doesn't look at tags at all
:arg strip: whether to strip disallowed tags (True) or escape them (False); :arg strip: whether to strip disallowed tags (True) or escape them (False);
if tags=None, then this doesn't have any effect if tags=None, then this doesn't have any effect
@ -481,7 +486,9 @@ class BleachHTMLParser(HTMLParser):
leave them as is when tokenizing (BleachHTMLTokenizer-added behavior) leave them as is when tokenizing (BleachHTMLTokenizer-added behavior)
""" """
self.tags = [tag.lower() for tag in tags] if tags is not None else None self.tags = (
frozenset((tag.lower() for tag in tags)) if tags is not None else None
)
self.strip = strip self.strip = strip
self.consume_entities = consume_entities self.consume_entities = consume_entities
super().__init__(**kwargs) super().__init__(**kwargs)
@ -691,7 +698,7 @@ class BleachHTMLSerializer(HTMLSerializer):
# Only leave entities in that are not ambiguous. If they're # Only leave entities in that are not ambiguous. If they're
# ambiguous, then we escape the ampersand. # ambiguous, then we escape the ampersand.
if entity is not None and convert_entity(entity) is not None: if entity is not None and convert_entity(entity) is not None:
yield "&" + entity + ";" yield f"&{entity};"
# Length of the entity plus 2--one for & at the beginning # Length of the entity plus 2--one for & at the beginning
# and one for ; at the end # and one for ; at the end

View file

@ -120,9 +120,10 @@ class Linker:
:arg list callbacks: list of callbacks to run when adjusting tag attributes; :arg list callbacks: list of callbacks to run when adjusting tag attributes;
defaults to ``bleach.linkifier.DEFAULT_CALLBACKS`` defaults to ``bleach.linkifier.DEFAULT_CALLBACKS``
:arg list skip_tags: list of tags that you don't want to linkify the :arg set skip_tags: set of tags that you don't want to linkify the
contents of; for example, you could set this to ``['pre']`` to skip contents of; for example, you could set this to ``{'pre'}`` to skip
linkifying contents of ``pre`` tags linkifying contents of ``pre`` tags; ``None`` means you don't
want linkify to skip any tags
:arg bool parse_email: whether or not to linkify email addresses :arg bool parse_email: whether or not to linkify email addresses
@ -130,7 +131,7 @@ class Linker:
:arg email_re: email matching regex :arg email_re: email matching regex
:arg list recognized_tags: the list of tags that linkify knows about; :arg set recognized_tags: the set of tags that linkify knows about;
everything else gets escaped everything else gets escaped
:returns: linkified text as unicode :returns: linkified text as unicode
@ -145,15 +146,18 @@ class Linker:
# Create a parser/tokenizer that allows all HTML tags and escapes # Create a parser/tokenizer that allows all HTML tags and escapes
# anything not in that list. # anything not in that list.
self.parser = html5lib_shim.BleachHTMLParser( self.parser = html5lib_shim.BleachHTMLParser(
tags=recognized_tags, tags=frozenset(recognized_tags),
strip=False, strip=False,
consume_entities=True, consume_entities=False,
namespaceHTMLElements=False, namespaceHTMLElements=False,
) )
self.walker = html5lib_shim.getTreeWalker("etree") self.walker = html5lib_shim.getTreeWalker("etree")
self.serializer = html5lib_shim.BleachHTMLSerializer( self.serializer = html5lib_shim.BleachHTMLSerializer(
quote_attr_values="always", quote_attr_values="always",
omit_optional_tags=False, omit_optional_tags=False,
# We want to leave entities as they are without escaping or
# resolving or expanding
resolve_entities=False,
# linkify does not sanitize # linkify does not sanitize
sanitize=False, sanitize=False,
# linkify preserves attr order # linkify preserves attr order
@ -218,8 +222,8 @@ class LinkifyFilter(html5lib_shim.Filter):
:arg list callbacks: list of callbacks to run when adjusting tag attributes; :arg list callbacks: list of callbacks to run when adjusting tag attributes;
defaults to ``bleach.linkifier.DEFAULT_CALLBACKS`` defaults to ``bleach.linkifier.DEFAULT_CALLBACKS``
:arg list skip_tags: list of tags that you don't want to linkify the :arg set skip_tags: set of tags that you don't want to linkify the
contents of; for example, you could set this to ``['pre']`` to skip contents of; for example, you could set this to ``{'pre'}`` to skip
linkifying contents of ``pre`` tags linkifying contents of ``pre`` tags
:arg bool parse_email: whether or not to linkify email addresses :arg bool parse_email: whether or not to linkify email addresses
@ -232,7 +236,7 @@ class LinkifyFilter(html5lib_shim.Filter):
super().__init__(source) super().__init__(source)
self.callbacks = callbacks or [] self.callbacks = callbacks or []
self.skip_tags = skip_tags or [] self.skip_tags = skip_tags or {}
self.parse_email = parse_email self.parse_email = parse_email
self.url_re = url_re self.url_re = url_re
@ -510,6 +514,62 @@ class LinkifyFilter(html5lib_shim.Filter):
yield {"type": "Characters", "data": str(new_text)} yield {"type": "Characters", "data": str(new_text)}
yield token_buffer[-1] yield token_buffer[-1]
def extract_entities(self, token):
"""Handles Characters tokens with entities
Our overridden tokenizer doesn't do anything with entities. However,
that means that the serializer will convert all ``&`` in Characters
tokens to ``&amp;``.
Since we don't want that, we extract entities here and convert them to
Entity tokens so the serializer will let them be.
:arg token: the Characters token to work on
:returns: generator of tokens
"""
data = token.get("data", "")
# If there isn't a & in the data, we can return now
if "&" not in data:
yield token
return
new_tokens = []
# For each possible entity that starts with a "&", we try to extract an
# actual entity and re-tokenize accordingly
for part in html5lib_shim.next_possible_entity(data):
if not part:
continue
if part.startswith("&"):
entity = html5lib_shim.match_entity(part)
if entity is not None:
if entity == "amp":
# LinkifyFilter can't match urls across token boundaries
# which is problematic with &amp; since that shows up in
# querystrings all the time. This special-cases &amp;
# and converts it to a & and sticks it in as a
# Characters token. It'll get merged with surrounding
# tokens in the BleachSanitizerfilter.__iter__ and
# escaped in the serializer.
new_tokens.append({"type": "Characters", "data": "&"})
else:
new_tokens.append({"type": "Entity", "name": entity})
# Length of the entity plus 2--one for & at the beginning
# and one for ; at the end
remainder = part[len(entity) + 2 :]
if remainder:
new_tokens.append({"type": "Characters", "data": remainder})
continue
new_tokens.append({"type": "Characters", "data": part})
yield from new_tokens
def __iter__(self): def __iter__(self):
in_a = False in_a = False
in_skip_tag = None in_skip_tag = None
@ -564,8 +624,8 @@ class LinkifyFilter(html5lib_shim.Filter):
new_stream = self.handle_links(new_stream) new_stream = self.handle_links(new_stream)
for token in new_stream: for new_token in new_stream:
yield token yield from self.extract_entities(new_token)
# We've already yielded this token, so continue # We've already yielded this token, so continue
continue continue

View file

@ -8,21 +8,23 @@ from bleach import html5lib_shim
from bleach import parse_shim from bleach import parse_shim
#: List of allowed tags #: Set of allowed tags
ALLOWED_TAGS = [ ALLOWED_TAGS = frozenset(
"a", (
"abbr", "a",
"acronym", "abbr",
"b", "acronym",
"blockquote", "b",
"code", "blockquote",
"em", "code",
"i", "em",
"li", "i",
"ol", "li",
"strong", "ol",
"ul", "strong",
] "ul",
)
)
#: Map of allowed attributes by tag #: Map of allowed attributes by tag
@ -33,7 +35,7 @@ ALLOWED_ATTRIBUTES = {
} }
#: List of allowed protocols #: List of allowed protocols
ALLOWED_PROTOCOLS = ["http", "https", "mailto"] ALLOWED_PROTOCOLS = frozenset(("http", "https", "mailto"))
#: Invisible characters--0 to and including 31 except 9 (tab), 10 (lf), and 13 (cr) #: Invisible characters--0 to and including 31 except 9 (tab), 10 (lf), and 13 (cr)
INVISIBLE_CHARACTERS = "".join( INVISIBLE_CHARACTERS = "".join(
@ -48,6 +50,10 @@ INVISIBLE_CHARACTERS_RE = re.compile("[" + INVISIBLE_CHARACTERS + "]", re.UNICOD
INVISIBLE_REPLACEMENT_CHAR = "?" INVISIBLE_REPLACEMENT_CHAR = "?"
class NoCssSanitizerWarning(UserWarning):
pass
class Cleaner: class Cleaner:
"""Cleaner for cleaning HTML fragments of malicious content """Cleaner for cleaning HTML fragments of malicious content
@ -89,7 +95,7 @@ class Cleaner:
): ):
"""Initializes a Cleaner """Initializes a Cleaner
:arg list tags: allowed list of tags; defaults to :arg set tags: set of allowed tags; defaults to
``bleach.sanitizer.ALLOWED_TAGS`` ``bleach.sanitizer.ALLOWED_TAGS``
:arg dict attributes: allowed attributes; can be a callable, list or dict; :arg dict attributes: allowed attributes; can be a callable, list or dict;
@ -143,6 +149,25 @@ class Cleaner:
alphabetical_attributes=False, alphabetical_attributes=False,
) )
if css_sanitizer is None:
# FIXME(willkg): this doesn't handle when attributes or an
# attributes value is a callable
attributes_values = []
if isinstance(attributes, list):
attributes_values = attributes
elif isinstance(attributes, dict):
attributes_values = []
for values in attributes.values():
if isinstance(values, (list, tuple)):
attributes_values.extend(values)
if "style" in attributes_values:
warnings.warn(
"'style' attribute specified, but css_sanitizer not set.",
category=NoCssSanitizerWarning,
)
def clean(self, text): def clean(self, text):
"""Cleans text and returns sanitized result as unicode """Cleans text and returns sanitized result as unicode
@ -155,9 +180,8 @@ class Cleaner:
""" """
if not isinstance(text, str): if not isinstance(text, str):
message = ( message = (
"argument cannot be of '{name}' type, must be of text type".format( f"argument cannot be of {text.__class__.__name__!r} type, "
name=text.__class__.__name__ + "must be of text type"
)
) )
raise TypeError(message) raise TypeError(message)
@ -167,13 +191,11 @@ class Cleaner:
dom = self.parser.parseFragment(text) dom = self.parser.parseFragment(text)
filtered = BleachSanitizerFilter( filtered = BleachSanitizerFilter(
source=self.walker(dom), source=self.walker(dom),
# Bleach-sanitizer-specific things allowed_tags=self.tags,
attributes=self.attributes, attributes=self.attributes,
strip_disallowed_elements=self.strip, strip_disallowed_tags=self.strip,
strip_html_comments=self.strip_comments, strip_html_comments=self.strip_comments,
css_sanitizer=self.css_sanitizer, css_sanitizer=self.css_sanitizer,
# html5lib-sanitizer things
allowed_elements=self.tags,
allowed_protocols=self.protocols, allowed_protocols=self.protocols,
) )
@ -237,19 +259,21 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
def __init__( def __init__(
self, self,
source, source,
allowed_elements=ALLOWED_TAGS, allowed_tags=ALLOWED_TAGS,
attributes=ALLOWED_ATTRIBUTES, attributes=ALLOWED_ATTRIBUTES,
allowed_protocols=ALLOWED_PROTOCOLS, allowed_protocols=ALLOWED_PROTOCOLS,
strip_disallowed_elements=False, attr_val_is_uri=html5lib_shim.attr_val_is_uri,
svg_attr_val_allows_ref=html5lib_shim.svg_attr_val_allows_ref,
svg_allow_local_href=html5lib_shim.svg_allow_local_href,
strip_disallowed_tags=False,
strip_html_comments=True, strip_html_comments=True,
css_sanitizer=None, css_sanitizer=None,
**kwargs,
): ):
"""Creates a BleachSanitizerFilter instance """Creates a BleachSanitizerFilter instance
:arg source: html5lib TreeWalker stream as an html5lib TreeWalker :arg source: html5lib TreeWalker stream as an html5lib TreeWalker
:arg list allowed_elements: allowed list of tags; defaults to :arg set allowed_tags: set of allowed tags; defaults to
``bleach.sanitizer.ALLOWED_TAGS`` ``bleach.sanitizer.ALLOWED_TAGS``
:arg dict attributes: allowed attributes; can be a callable, list or dict; :arg dict attributes: allowed attributes; can be a callable, list or dict;
@ -258,8 +282,16 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
:arg list allowed_protocols: allowed list of protocols for links; defaults :arg list allowed_protocols: allowed list of protocols for links; defaults
to ``bleach.sanitizer.ALLOWED_PROTOCOLS`` to ``bleach.sanitizer.ALLOWED_PROTOCOLS``
:arg bool strip_disallowed_elements: whether or not to strip disallowed :arg attr_val_is_uri: set of attributes that have URI values
elements
:arg svg_attr_val_allows_ref: set of SVG attributes that can have
references
:arg svg_allow_local_href: set of SVG elements that can have local
hrefs
:arg bool strip_disallowed_tags: whether or not to strip disallowed
tags
:arg bool strip_html_comments: whether or not to strip HTML comments :arg bool strip_html_comments: whether or not to strip HTML comments
@ -267,24 +299,24 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
sanitizing style attribute values and style text; defaults to None sanitizing style attribute values and style text; defaults to None
""" """
self.attr_filter = attribute_filter_factory(attributes) # NOTE(willkg): This is the superclass of
self.strip_disallowed_elements = strip_disallowed_elements # html5lib.filters.sanitizer.Filter. We call this directly skipping the
self.strip_html_comments = strip_html_comments # __init__ for html5lib.filters.sanitizer.Filter because that does
self.css_sanitizer = css_sanitizer # things we don't need to do and kicks up the deprecation warning for
# using Sanitizer.
html5lib_shim.Filter.__init__(self, source)
# filter out html5lib deprecation warnings to use bleach from BleachSanitizerFilter init self.allowed_tags = frozenset(allowed_tags)
warnings.filterwarnings( self.allowed_protocols = frozenset(allowed_protocols)
"ignore",
message="html5lib's sanitizer is deprecated", self.attr_filter = attribute_filter_factory(attributes)
category=DeprecationWarning, self.strip_disallowed_tags = strip_disallowed_tags
module="bleach._vendor.html5lib", self.strip_html_comments = strip_html_comments
)
return super().__init__( self.attr_val_is_uri = attr_val_is_uri
source, self.svg_attr_val_allows_ref = svg_attr_val_allows_ref
allowed_elements=allowed_elements, self.css_sanitizer = css_sanitizer
allowed_protocols=allowed_protocols, self.svg_allow_local_href = svg_allow_local_href
**kwargs,
)
def sanitize_stream(self, token_iterator): def sanitize_stream(self, token_iterator):
for token in token_iterator: for token in token_iterator:
@ -354,10 +386,10 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
""" """
token_type = token["type"] token_type = token["type"]
if token_type in ["StartTag", "EndTag", "EmptyTag"]: if token_type in ["StartTag", "EndTag", "EmptyTag"]:
if token["name"] in self.allowed_elements: if token["name"] in self.allowed_tags:
return self.allow_token(token) return self.allow_token(token)
elif self.strip_disallowed_elements: elif self.strip_disallowed_tags:
return None return None
else: else:
@ -570,7 +602,7 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
def disallowed_token(self, token): def disallowed_token(self, token):
token_type = token["type"] token_type = token["type"]
if token_type == "EndTag": if token_type == "EndTag":
token["data"] = "</%s>" % token["name"] token["data"] = f"</{token['name']}>"
elif token["data"]: elif token["data"]:
assert token_type in ("StartTag", "EmptyTag") assert token_type in ("StartTag", "EmptyTag")
@ -586,25 +618,19 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
if ns is None or ns not in html5lib_shim.prefixes: if ns is None or ns not in html5lib_shim.prefixes:
namespaced_name = name namespaced_name = name
else: else:
namespaced_name = "{}:{}".format(html5lib_shim.prefixes[ns], name) namespaced_name = f"{html5lib_shim.prefixes[ns]}:{name}"
attrs.append( # NOTE(willkg): HTMLSerializer escapes attribute values
' %s="%s"' # already, so if we do it here (like HTMLSerializer does),
% ( # then we end up double-escaping.
namespaced_name, attrs.append(f' {namespaced_name}="{v}"')
# NOTE(willkg): HTMLSerializer escapes attribute values token["data"] = f"<{token['name']}{''.join(attrs)}>"
# already, so if we do it here (like HTMLSerializer does),
# then we end up double-escaping.
v,
)
)
token["data"] = "<{}{}>".format(token["name"], "".join(attrs))
else: else:
token["data"] = "<%s>" % token["name"] token["data"] = f"<{token['name']}>"
if token.get("selfClosing"): if token.get("selfClosing"):
token["data"] = token["data"][:-1] + "/>" token["data"] = f"{token['data'][:-1]}/>"
token["type"] = "Characters" token["type"] = "Characters"

View file

@ -7,7 +7,7 @@ Beautiful Soup uses a pluggable XML or HTML parser to parse a
provides methods and Pythonic idioms that make it easy to navigate, provides methods and Pythonic idioms that make it easy to navigate,
search, and modify the parse tree. search, and modify the parse tree.
Beautiful Soup works with Python 3.5 and up. It works better if lxml Beautiful Soup works with Python 3.6 and up. It works better if lxml
and/or html5lib is installed. and/or html5lib is installed.
For more than you ever wanted to know about Beautiful Soup, see the For more than you ever wanted to know about Beautiful Soup, see the
@ -15,8 +15,8 @@ documentation: http://www.crummy.com/software/BeautifulSoup/bs4/doc/
""" """
__author__ = "Leonard Richardson (leonardr@segfault.org)" __author__ = "Leonard Richardson (leonardr@segfault.org)"
__version__ = "4.11.1" __version__ = "4.11.2"
__copyright__ = "Copyright (c) 2004-2022 Leonard Richardson" __copyright__ = "Copyright (c) 2004-2023 Leonard Richardson"
# Use of this source code is governed by the MIT license. # Use of this source code is governed by the MIT license.
__license__ = "MIT" __license__ = "MIT"
@ -211,7 +211,7 @@ class BeautifulSoup(Tag):
warnings.warn( warnings.warn(
'The "%s" argument to the BeautifulSoup constructor ' 'The "%s" argument to the BeautifulSoup constructor '
'has been renamed to "%s."' % (old_name, new_name), 'has been renamed to "%s."' % (old_name, new_name),
DeprecationWarning DeprecationWarning, stacklevel=3
) )
return kwargs.pop(old_name) return kwargs.pop(old_name)
return None return None
@ -405,7 +405,8 @@ class BeautifulSoup(Tag):
'The input looks more like a URL than markup. You may want to use' 'The input looks more like a URL than markup. You may want to use'
' an HTTP client like requests to get the document behind' ' an HTTP client like requests to get the document behind'
' the URL, and feed that document to Beautiful Soup.', ' the URL, and feed that document to Beautiful Soup.',
MarkupResemblesLocatorWarning MarkupResemblesLocatorWarning,
stacklevel=3
) )
return True return True
return False return False
@ -436,7 +437,7 @@ class BeautifulSoup(Tag):
'The input looks more like a filename than markup. You may' 'The input looks more like a filename than markup. You may'
' want to open this file and pass the filehandle into' ' want to open this file and pass the filehandle into'
' Beautiful Soup.', ' Beautiful Soup.',
MarkupResemblesLocatorWarning MarkupResemblesLocatorWarning, stacklevel=3
) )
return True return True
return False return False
@ -789,7 +790,7 @@ class BeautifulStoneSoup(BeautifulSoup):
warnings.warn( warnings.warn(
'The BeautifulStoneSoup class is deprecated. Instead of using ' 'The BeautifulStoneSoup class is deprecated. Instead of using '
'it, pass features="xml" into the BeautifulSoup constructor.', 'it, pass features="xml" into the BeautifulSoup constructor.',
DeprecationWarning DeprecationWarning, stacklevel=2
) )
super(BeautifulStoneSoup, self).__init__(*args, **kwargs) super(BeautifulStoneSoup, self).__init__(*args, **kwargs)

View file

@ -122,7 +122,7 @@ class TreeBuilder(object):
# A value for these tag/attribute combinations is a space- or # A value for these tag/attribute combinations is a space- or
# comma-separated list of CDATA, rather than a single CDATA. # comma-separated list of CDATA, rather than a single CDATA.
DEFAULT_CDATA_LIST_ATTRIBUTES = {} DEFAULT_CDATA_LIST_ATTRIBUTES = defaultdict(list)
# Whitespace should be preserved inside these tags. # Whitespace should be preserved inside these tags.
DEFAULT_PRESERVE_WHITESPACE_TAGS = set() DEFAULT_PRESERVE_WHITESPACE_TAGS = set()

View file

@ -70,7 +70,10 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
# ATM because the html5lib TreeBuilder doesn't use # ATM because the html5lib TreeBuilder doesn't use
# UnicodeDammit. # UnicodeDammit.
if exclude_encodings: if exclude_encodings:
warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.") warnings.warn(
"You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.",
stacklevel=3
)
# html5lib only parses HTML, so if it's given XML that's worth # html5lib only parses HTML, so if it's given XML that's worth
# noting. # noting.
@ -81,7 +84,10 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
# These methods are defined by Beautiful Soup. # These methods are defined by Beautiful Soup.
def feed(self, markup): def feed(self, markup):
if self.soup.parse_only is not None: if self.soup.parse_only is not None:
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.") warnings.warn(
"You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.",
stacklevel=4
)
parser = html5lib.HTMLParser(tree=self.create_treebuilder) parser = html5lib.HTMLParser(tree=self.create_treebuilder)
self.underlying_builder.parser = parser self.underlying_builder.parser = parser
extra_kwargs = dict() extra_kwargs = dict()
@ -249,9 +255,9 @@ class AttrList(object):
# If this attribute is a multi-valued attribute for this element, # If this attribute is a multi-valued attribute for this element,
# turn its value into a list. # turn its value into a list.
list_attr = self.element.cdata_list_attributes or {} list_attr = self.element.cdata_list_attributes or {}
if (name in list_attr.get('*') if (name in list_attr.get('*', [])
or (self.element.name in list_attr or (self.element.name in list_attr
and name in list_attr[self.element.name])): and name in list_attr.get(self.element.name, []))):
# A node that is being cloned may have already undergone # A node that is being cloned may have already undergone
# this procedure. # this procedure.
if not isinstance(value, list): if not isinstance(value, list):

View file

@ -10,30 +10,9 @@ __all__ = [
from html.parser import HTMLParser from html.parser import HTMLParser
try:
from html.parser import HTMLParseError
except ImportError as e:
# HTMLParseError is removed in Python 3.5. Since it can never be
# thrown in 3.5, we can just define our own class as a placeholder.
class HTMLParseError(Exception):
pass
import sys import sys
import warnings import warnings
# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
# argument, which we'd like to set to False. Unfortunately,
# http://bugs.python.org/issue13273 makes strict=True a better bet
# before Python 3.2.3.
#
# At the end of this file, we monkeypatch HTMLParser so that
# strict=True works well on Python 3.2.2.
major, minor, release = sys.version_info[:3]
CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
from bs4.element import ( from bs4.element import (
CData, CData,
Comment, Comment,
@ -90,20 +69,7 @@ class BeautifulSoupHTMLParser(HTMLParser, DetectsXMLParsedAsHTML):
self.already_closed_empty_element = [] self.already_closed_empty_element = []
self._initialize_xml_detector() self._initialize_xml_detector()
def error(self, msg):
"""In Python 3, HTMLParser subclasses must implement error(), although
this requirement doesn't appear to be documented.
In Python 2, HTMLParser implements error() by raising an exception,
which we don't want to do.
In any event, this method is called only on very strange
markup and our best strategy is to pretend it didn't happen
and keep going.
"""
warnings.warn(msg)
def handle_startendtag(self, name, attrs): def handle_startendtag(self, name, attrs):
"""Handle an incoming empty-element tag. """Handle an incoming empty-element tag.
@ -203,9 +169,10 @@ class BeautifulSoupHTMLParser(HTMLParser, DetectsXMLParsedAsHTML):
:param name: Character number, possibly in hexadecimal. :param name: Character number, possibly in hexadecimal.
""" """
# XXX workaround for a bug in HTMLParser. Remove this once # TODO: This was originally a workaround for a bug in
# it's fixed in all supported versions. # HTMLParser. (http://bugs.python.org/issue13633) The bug has
# http://bugs.python.org/issue13633 # been fixed, but removing this code still makes some
# Beautiful Soup tests fail. This needs investigation.
if name.startswith('x'): if name.startswith('x'):
real_name = int(name.lstrip('x'), 16) real_name = int(name.lstrip('x'), 16)
elif name.startswith('X'): elif name.startswith('X'):
@ -333,10 +300,7 @@ class HTMLParserTreeBuilder(HTMLTreeBuilder):
parser_args = parser_args or [] parser_args = parser_args or []
parser_kwargs = parser_kwargs or {} parser_kwargs = parser_kwargs or {}
parser_kwargs.update(extra_parser_kwargs) parser_kwargs.update(extra_parser_kwargs)
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED: parser_kwargs['convert_charrefs'] = False
parser_kwargs['strict'] = False
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
parser_kwargs['convert_charrefs'] = False
self.parser_args = (parser_args, parser_kwargs) self.parser_args = (parser_args, parser_kwargs)
def prepare_markup(self, markup, user_specified_encoding=None, def prepare_markup(self, markup, user_specified_encoding=None,
@ -395,105 +359,6 @@ class HTMLParserTreeBuilder(HTMLTreeBuilder):
args, kwargs = self.parser_args args, kwargs = self.parser_args
parser = BeautifulSoupHTMLParser(*args, **kwargs) parser = BeautifulSoupHTMLParser(*args, **kwargs)
parser.soup = self.soup parser.soup = self.soup
try: parser.feed(markup)
parser.feed(markup) parser.close()
parser.close()
except HTMLParseError as e:
warnings.warn(RuntimeWarning(
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
raise e
parser.already_closed_empty_element = [] parser.already_closed_empty_element = []
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
# string.
#
# XXX This code can be removed once most Python 3 users are on 3.2.3.
if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
import re
attrfind_tolerant = re.compile(
r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
locatestarttagend = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:\s+ # whitespace before attribute name
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
(?:\s*=\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|\"[^\"]*\" # LIT-enclosed value
|[^'\">\s]+ # bare value
)
)?
)
)*
\s* # trailing whitespace
""", re.VERBOSE)
BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
from html.parser import tagfind, attrfind
def parse_starttag(self, i):
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between i+1 and j into a tag and attrs
attrs = []
match = tagfind.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = rawdata[i+1:k].lower()
while k < endpos:
if self.strict:
m = attrfind.match(rawdata, k)
else:
m = attrfind_tolerant.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
if attrvalue:
attrvalue = self.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n")
else:
offset = offset + len(self.__starttag_text)
if self.strict:
self.error("junk characters in start tag: %r"
% (rawdata[k:endpos][:20],))
self.handle_data(rawdata[i:endpos])
return endpos
if end.endswith('/>'):
# XHTML-style empty tag: <span attr="value" />
self.handle_startendtag(tag, attrs)
else:
self.handle_starttag(tag, attrs)
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode(tag)
return endpos
def set_cdata_mode(self, elem):
self.cdata_elem = elem.lower()
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
BeautifulSoupHTMLParser.parse_starttag = parse_starttag
BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
CONSTRUCTOR_TAKES_STRICT = True

View file

@ -496,13 +496,16 @@ class PageElement(object):
def extend(self, tags): def extend(self, tags):
"""Appends the given PageElements to this one's contents. """Appends the given PageElements to this one's contents.
:param tags: A list of PageElements. :param tags: A list of PageElements. If a single Tag is
provided instead, this PageElement's contents will be extended
with that Tag's contents.
""" """
if isinstance(tags, Tag): if isinstance(tags, Tag):
# Calling self.append() on another tag's contents will change tags = tags.contents
# the list we're iterating over. Make a list that won't if isinstance(tags, list):
# change. # Moving items around the tree may change their position in
tags = list(tags.contents) # the original list. Make a list that won't change.
tags = list(tags)
for tag in tags: for tag in tags:
self.append(tag) self.append(tag)
@ -586,8 +589,9 @@ class PageElement(object):
:kwargs: A dictionary of filters on attribute values. :kwargs: A dictionary of filters on attribute values.
:return: A ResultSet containing PageElements. :return: A ResultSet containing PageElements.
""" """
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(name, attrs, string, limit, self.next_elements, return self._find_all(name, attrs, string, limit, self.next_elements,
**kwargs) _stacklevel=_stacklevel+1, **kwargs)
findAllNext = find_all_next # BS3 findAllNext = find_all_next # BS3
def find_next_sibling(self, name=None, attrs={}, string=None, **kwargs): def find_next_sibling(self, name=None, attrs={}, string=None, **kwargs):
@ -624,8 +628,11 @@ class PageElement(object):
:return: A ResultSet of PageElements. :return: A ResultSet of PageElements.
:rtype: bs4.element.ResultSet :rtype: bs4.element.ResultSet
""" """
return self._find_all(name, attrs, string, limit, _stacklevel = kwargs.pop('_stacklevel', 2)
self.next_siblings, **kwargs) return self._find_all(
name, attrs, string, limit,
self.next_siblings, _stacklevel=_stacklevel+1, **kwargs
)
findNextSiblings = find_next_siblings # BS3 findNextSiblings = find_next_siblings # BS3
fetchNextSiblings = find_next_siblings # BS2 fetchNextSiblings = find_next_siblings # BS2
@ -663,8 +670,11 @@ class PageElement(object):
:return: A ResultSet of PageElements. :return: A ResultSet of PageElements.
:rtype: bs4.element.ResultSet :rtype: bs4.element.ResultSet
""" """
return self._find_all(name, attrs, string, limit, self.previous_elements, _stacklevel = kwargs.pop('_stacklevel', 2)
**kwargs) return self._find_all(
name, attrs, string, limit, self.previous_elements,
_stacklevel=_stacklevel+1, **kwargs
)
findAllPrevious = find_all_previous # BS3 findAllPrevious = find_all_previous # BS3
fetchPrevious = find_all_previous # BS2 fetchPrevious = find_all_previous # BS2
@ -702,8 +712,11 @@ class PageElement(object):
:return: A ResultSet of PageElements. :return: A ResultSet of PageElements.
:rtype: bs4.element.ResultSet :rtype: bs4.element.ResultSet
""" """
return self._find_all(name, attrs, string, limit, _stacklevel = kwargs.pop('_stacklevel', 2)
self.previous_siblings, **kwargs) return self._find_all(
name, attrs, string, limit,
self.previous_siblings, _stacklevel=_stacklevel+1, **kwargs
)
findPreviousSiblings = find_previous_siblings # BS3 findPreviousSiblings = find_previous_siblings # BS3
fetchPreviousSiblings = find_previous_siblings # BS2 fetchPreviousSiblings = find_previous_siblings # BS2
@ -724,7 +737,7 @@ class PageElement(object):
# NOTE: We can't use _find_one because findParents takes a different # NOTE: We can't use _find_one because findParents takes a different
# set of arguments. # set of arguments.
r = None r = None
l = self.find_parents(name, attrs, 1, **kwargs) l = self.find_parents(name, attrs, 1, _stacklevel=3, **kwargs)
if l: if l:
r = l[0] r = l[0]
return r return r
@ -744,8 +757,9 @@ class PageElement(object):
:return: A PageElement. :return: A PageElement.
:rtype: bs4.element.Tag | bs4.element.NavigableString :rtype: bs4.element.Tag | bs4.element.NavigableString
""" """
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(name, attrs, None, limit, self.parents, return self._find_all(name, attrs, None, limit, self.parents,
**kwargs) _stacklevel=_stacklevel+1, **kwargs)
findParents = find_parents # BS3 findParents = find_parents # BS3
fetchParents = find_parents # BS2 fetchParents = find_parents # BS2
@ -771,19 +785,20 @@ class PageElement(object):
def _find_one(self, method, name, attrs, string, **kwargs): def _find_one(self, method, name, attrs, string, **kwargs):
r = None r = None
l = method(name, attrs, string, 1, **kwargs) l = method(name, attrs, string, 1, _stacklevel=4, **kwargs)
if l: if l:
r = l[0] r = l[0]
return r return r
def _find_all(self, name, attrs, string, limit, generator, **kwargs): def _find_all(self, name, attrs, string, limit, generator, **kwargs):
"Iterates over a generator looking for things that match." "Iterates over a generator looking for things that match."
_stacklevel = kwargs.pop('_stacklevel', 3)
if string is None and 'text' in kwargs: if string is None and 'text' in kwargs:
string = kwargs.pop('text') string = kwargs.pop('text')
warnings.warn( warnings.warn(
"The 'text' argument to find()-type methods is deprecated. Use 'string' instead.", "The 'text' argument to find()-type methods is deprecated. Use 'string' instead.",
DeprecationWarning DeprecationWarning, stacklevel=_stacklevel
) )
if isinstance(name, SoupStrainer): if isinstance(name, SoupStrainer):
@ -1306,7 +1321,8 @@ class Tag(PageElement):
sourceline=self.sourceline, sourcepos=self.sourcepos, sourceline=self.sourceline, sourcepos=self.sourcepos,
can_be_empty_element=self.can_be_empty_element, can_be_empty_element=self.can_be_empty_element,
cdata_list_attributes=self.cdata_list_attributes, cdata_list_attributes=self.cdata_list_attributes,
preserve_whitespace_tags=self.preserve_whitespace_tags preserve_whitespace_tags=self.preserve_whitespace_tags,
interesting_string_types=self.interesting_string_types
) )
for attr in ('can_be_empty_element', 'hidden'): for attr in ('can_be_empty_element', 'hidden'):
setattr(clone, attr, getattr(self, attr)) setattr(clone, attr, getattr(self, attr))
@ -1558,7 +1574,7 @@ class Tag(PageElement):
'.%(name)sTag is deprecated, use .find("%(name)s") instead. If you really were looking for a tag called %(name)sTag, use .find("%(name)sTag")' % dict( '.%(name)sTag is deprecated, use .find("%(name)s") instead. If you really were looking for a tag called %(name)sTag, use .find("%(name)sTag")' % dict(
name=tag_name name=tag_name
), ),
DeprecationWarning DeprecationWarning, stacklevel=2
) )
return self.find(tag_name) return self.find(tag_name)
# We special case contents to avoid recursion. # We special case contents to avoid recursion.
@ -1862,7 +1878,8 @@ class Tag(PageElement):
:rtype: bs4.element.Tag | bs4.element.NavigableString :rtype: bs4.element.Tag | bs4.element.NavigableString
""" """
r = None r = None
l = self.find_all(name, attrs, recursive, string, 1, **kwargs) l = self.find_all(name, attrs, recursive, string, 1, _stacklevel=3,
**kwargs)
if l: if l:
r = l[0] r = l[0]
return r return r
@ -1889,7 +1906,9 @@ class Tag(PageElement):
generator = self.descendants generator = self.descendants
if not recursive: if not recursive:
generator = self.children generator = self.children
return self._find_all(name, attrs, string, limit, generator, **kwargs) _stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(name, attrs, string, limit, generator,
_stacklevel=_stacklevel+1, **kwargs)
findAll = find_all # BS3 findAll = find_all # BS3
findChildren = find_all # BS2 findChildren = find_all # BS2
@ -1993,7 +2012,7 @@ class Tag(PageElement):
""" """
warnings.warn( warnings.warn(
'has_key is deprecated. Use has_attr(key) instead.', 'has_key is deprecated. Use has_attr(key) instead.',
DeprecationWarning DeprecationWarning, stacklevel=2
) )
return self.has_attr(key) return self.has_attr(key)
@ -2024,7 +2043,7 @@ class SoupStrainer(object):
string = kwargs.pop('text') string = kwargs.pop('text')
warnings.warn( warnings.warn(
"The 'text' argument to the SoupStrainer constructor is deprecated. Use 'string' instead.", "The 'text' argument to the SoupStrainer constructor is deprecated. Use 'string' instead.",
DeprecationWarning DeprecationWarning, stacklevel=2
) )
self.name = self._normalize_search_value(name) self.name = self._normalize_search_value(name)

View file

@ -149,14 +149,14 @@ class HTMLFormatter(Formatter):
"""A generic Formatter for HTML.""" """A generic Formatter for HTML."""
REGISTRY = {} REGISTRY = {}
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs) super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
class XMLFormatter(Formatter): class XMLFormatter(Formatter):
"""A generic Formatter for XML.""" """A generic Formatter for XML."""
REGISTRY = {} REGISTRY = {}
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs) super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
# Set up aliases for the default formatters. # Set up aliases for the default formatters.

View file

@ -29,6 +29,29 @@ from bs4.builder import (
) )
default_builder = HTMLParserTreeBuilder default_builder = HTMLParserTreeBuilder
# Some tests depend on specific third-party libraries. We use
# @pytest.mark.skipIf on the following conditionals to skip them
# if the libraries are not installed.
try:
from soupsieve import SelectorSyntaxError
SOUP_SIEVE_PRESENT = True
except ImportError:
SOUP_SIEVE_PRESENT = False
try:
import html5lib
HTML5LIB_PRESENT = True
except ImportError:
HTML5LIB_PRESENT = False
try:
import lxml.etree
LXML_PRESENT = True
LXML_VERSION = lxml.etree.LXML_VERSION
except ImportError:
LXML_PRESENT = False
LXML_VERSION = (0,)
BAD_DOCUMENT = """A bare string BAD_DOCUMENT = """A bare string
<!DOCTYPE xsl:stylesheet SYSTEM "htmlent.dtd"> <!DOCTYPE xsl:stylesheet SYSTEM "htmlent.dtd">
<!DOCTYPE xsl:stylesheet PUBLIC "htmlent.dtd"> <!DOCTYPE xsl:stylesheet PUBLIC "htmlent.dtd">
@ -258,10 +281,10 @@ class TreeBuilderSmokeTest(object):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"multi_valued_attributes", "multi_valued_attributes",
[None, dict(b=['class']), {'*': ['notclass']}] [None, {}, dict(b=['class']), {'*': ['notclass']}]
) )
def test_attribute_not_multi_valued(self, multi_valued_attributes): def test_attribute_not_multi_valued(self, multi_valued_attributes):
markup = '<a class="a b c">' markup = '<html xmlns="http://www.w3.org/1999/xhtml"><a class="a b c"></html>'
soup = self.soup(markup, multi_valued_attributes=multi_valued_attributes) soup = self.soup(markup, multi_valued_attributes=multi_valued_attributes)
assert soup.a['class'] == 'a b c' assert soup.a['class'] == 'a b c'
@ -820,26 +843,27 @@ Hello, world!
soup = self.soup(text) soup = self.soup(text)
assert soup.p.encode("utf-8") == expected assert soup.p.encode("utf-8") == expected
def test_real_iso_latin_document(self): def test_real_iso_8859_document(self):
# Smoke test of interrelated functionality, using an # Smoke test of interrelated functionality, using an
# easy-to-understand document. # easy-to-understand document.
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1. # Here it is in Unicode. Note that it claims to be in ISO-8859-1.
unicode_html = '<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>' unicode_html = '<html><head><meta content="text/html; charset=ISO-8859-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
# That's because we're going to encode it into ISO-Latin-1, and use # That's because we're going to encode it into ISO-8859-1,
# that to test. # and use that to test.
iso_latin_html = unicode_html.encode("iso-8859-1") iso_latin_html = unicode_html.encode("iso-8859-1")
# Parse the ISO-Latin-1 HTML. # Parse the ISO-8859-1 HTML.
soup = self.soup(iso_latin_html) soup = self.soup(iso_latin_html)
# Encode it to UTF-8. # Encode it to UTF-8.
result = soup.encode("utf-8") result = soup.encode("utf-8")
# What do we expect the result to look like? Well, it would # What do we expect the result to look like? Well, it would
# look like unicode_html, except that the META tag would say # look like unicode_html, except that the META tag would say
# UTF-8 instead of ISO-Latin-1. # UTF-8 instead of ISO-8859-1.
expected = unicode_html.replace("ISO-Latin-1", "utf-8") expected = unicode_html.replace("ISO-8859-1", "utf-8")
# And, of course, it would be in UTF-8, not Unicode. # And, of course, it would be in UTF-8, not Unicode.
expected = expected.encode("utf-8") expected = expected.encode("utf-8")
@ -1177,15 +1201,3 @@ class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
assert isinstance(soup.contents[0], Comment) assert isinstance(soup.contents[0], Comment)
assert soup.contents[0] == '?xml version="1.0" encoding="utf-8"?' assert soup.contents[0] == '?xml version="1.0" encoding="utf-8"?'
assert "html" == soup.contents[0].next_element.name assert "html" == soup.contents[0].next_element.name
def skipIf(condition, reason):
def nothing(test, *args, **kwargs):
return None
def decorator(test_item):
if condition:
return nothing
else:
return test_item
return decorator

View file

@ -10,22 +10,23 @@ from bs4.builder import (
TreeBuilderRegistry, TreeBuilderRegistry,
) )
try: from . import (
from bs4.builder import HTML5TreeBuilder HTML5LIB_PRESENT,
HTML5LIB_PRESENT = True LXML_PRESENT,
except ImportError: )
HTML5LIB_PRESENT = False
try: if HTML5LIB_PRESENT:
from bs4.builder import HTML5TreeBuilder
if LXML_PRESENT:
from bs4.builder import ( from bs4.builder import (
LXMLTreeBuilderForXML, LXMLTreeBuilderForXML,
LXMLTreeBuilder, LXMLTreeBuilder,
) )
LXML_PRESENT = True
except ImportError:
LXML_PRESENT = False
# TODO: Split out the lxml and html5lib tests into their own classes
# and gate with pytest.mark.skipIf.
class TestBuiltInRegistry(object): class TestBuiltInRegistry(object):
"""Test the built-in registry with the default builders registered.""" """Test the built-in registry with the default builders registered."""

View file

@ -17,26 +17,24 @@ class TestUnicodeDammit(object):
dammit = UnicodeDammit(markup) dammit = UnicodeDammit(markup)
assert dammit.unicode_markup == markup assert dammit.unicode_markup == markup
def test_smart_quotes_to_unicode(self): @pytest.mark.parametrize(
"smart_quotes_to,expect_converted",
[(None, "\u2018\u2019\u201c\u201d"),
("xml", "&#x2018;&#x2019;&#x201C;&#x201D;"),
("html", "&lsquo;&rsquo;&ldquo;&rdquo;"),
("ascii", "''" + '""'),
]
)
def test_smart_quotes_to(self, smart_quotes_to, expect_converted):
"""Verify the functionality of the smart_quotes_to argument
to the UnicodeDammit constructor."""
markup = b"<foo>\x91\x92\x93\x94</foo>" markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup) converted = UnicodeDammit(
assert dammit.unicode_markup == "<foo>\u2018\u2019\u201c\u201d</foo>" markup, known_definite_encodings=["windows-1252"],
smart_quotes_to=smart_quotes_to
def test_smart_quotes_to_xml_entities(self): ).unicode_markup
markup = b"<foo>\x91\x92\x93\x94</foo>" assert converted == "<foo>{}</foo>".format(expect_converted)
dammit = UnicodeDammit(markup, smart_quotes_to="xml")
assert dammit.unicode_markup == "<foo>&#x2018;&#x2019;&#x201C;&#x201D;</foo>"
def test_smart_quotes_to_html_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="html")
assert dammit.unicode_markup == "<foo>&lsquo;&rsquo;&ldquo;&rdquo;</foo>"
def test_smart_quotes_to_ascii(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="ascii")
assert dammit.unicode_markup == """<foo>''""</foo>"""
def test_detect_utf8(self): def test_detect_utf8(self):
utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83" utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
dammit = UnicodeDammit(utf8) dammit = UnicodeDammit(utf8)
@ -275,23 +273,24 @@ class TestEntitySubstitution(object):
def setup_method(self): def setup_method(self):
self.sub = EntitySubstitution self.sub = EntitySubstitution
def test_simple_html_substitution(self):
# Unicode characters corresponding to named HTML entites
# are substituted, and no others.
s = "foo\u2200\N{SNOWMAN}\u00f5bar"
assert self.sub.substitute_html(s) == "foo&forall;\N{SNOWMAN}&otilde;bar"
def test_smart_quote_substitution(self): @pytest.mark.parametrize(
# MS smart quotes are a common source of frustration, so we "original,substituted",
# give them a special test. [
quotes = b"\x91\x92foo\x93\x94" # Basic case. Unicode characters corresponding to named
dammit = UnicodeDammit(quotes) # HTML entites are substituted; others are not.
assert self.sub.substitute_html(dammit.markup) == "&lsquo;&rsquo;foo&ldquo;&rdquo;" ("foo\u2200\N{SNOWMAN}\u00f5bar",
"foo&forall;\N{SNOWMAN}&otilde;bar"),
# MS smart quotes are a common source of frustration, so we
# give them a special test.
('foo“”', "&lsquo;&rsquo;foo&ldquo;&rdquo;"),
]
)
def test_substitute_html(self, original, substituted):
assert self.sub.substitute_html(original) == substituted
def test_html5_entity(self): def test_html5_entity(self):
# Some HTML5 entities correspond to single- or multi-character
# Unicode sequences.
for entity, u in ( for entity, u in (
# A few spot checks of our ability to recognize # A few spot checks of our ability to recognize
# special character sequences and convert them # special character sequences and convert them

View file

@ -1,27 +1,26 @@
"""Tests to ensure that the html5lib tree builder generates good trees.""" """Tests to ensure that the html5lib tree builder generates good trees."""
import pytest
import warnings import warnings
try: from bs4 import BeautifulSoup
from bs4.builder import HTML5TreeBuilder
HTML5LIB_PRESENT = True
except ImportError as e:
HTML5LIB_PRESENT = False
from bs4.element import SoupStrainer from bs4.element import SoupStrainer
from . import ( from . import (
HTML5LIB_PRESENT,
HTML5TreeBuilderSmokeTest, HTML5TreeBuilderSmokeTest,
SoupTest, SoupTest,
skipIf,
) )
@skipIf( @pytest.mark.skipif(
not HTML5LIB_PRESENT, not HTML5LIB_PRESENT,
"html5lib seems not to be present, not testing its tree builder.") reason="html5lib seems not to be present, not testing its tree builder."
)
class TestHTML5LibBuilder(SoupTest, HTML5TreeBuilderSmokeTest): class TestHTML5LibBuilder(SoupTest, HTML5TreeBuilderSmokeTest):
"""See ``HTML5TreeBuilderSmokeTest``.""" """See ``HTML5TreeBuilderSmokeTest``."""
@property @property
def default_builder(self): def default_builder(self):
from bs4.builder import HTML5TreeBuilder
return HTML5TreeBuilder return HTML5TreeBuilder
def test_soupstrainer(self): def test_soupstrainer(self):
@ -29,10 +28,12 @@ class TestHTML5LibBuilder(SoupTest, HTML5TreeBuilderSmokeTest):
strainer = SoupStrainer("b") strainer = SoupStrainer("b")
markup = "<p>A <b>bold</b> statement.</p>" markup = "<p>A <b>bold</b> statement.</p>"
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
soup = self.soup(markup, parse_only=strainer) soup = BeautifulSoup(markup, "html5lib", parse_only=strainer)
assert soup.decode() == self.document_for(markup) assert soup.decode() == self.document_for(markup)
assert "the html5lib tree builder doesn't support parse_only" in str(w[0].message) [warning] = w
assert warning.filename == __file__
assert "the html5lib tree builder doesn't support parse_only" in str(warning.message)
def test_correctly_nested_tables(self): def test_correctly_nested_tables(self):
"""html5lib inserts <tbody> tags where other parsers don't.""" """html5lib inserts <tbody> tags where other parsers don't."""

View file

@ -122,15 +122,3 @@ class TestHTMLParserTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
with_element = div.encode(formatter="html") with_element = div.encode(formatter="html")
expect = b"<div>%s</div>" % output_element expect = b"<div>%s</div>" % output_element
assert with_element == expect assert with_element == expect
class TestHTMLParserSubclass(SoupTest):
def test_error(self):
"""Verify that our HTMLParser subclass implements error() in a way
that doesn't cause a crash.
"""
parser = BeautifulSoupHTMLParser()
with warnings.catch_warnings(record=True) as warns:
parser.error("don't crash")
[warning] = warns
assert "don't crash" == str(warning.message)

View file

@ -1,16 +1,10 @@
"""Tests to ensure that the lxml tree builder generates good trees.""" """Tests to ensure that the lxml tree builder generates good trees."""
import pickle import pickle
import pytest
import re import re
import warnings import warnings
from . import LXML_PRESENT, LXML_VERSION
try:
import lxml.etree
LXML_PRESENT = True
LXML_VERSION = lxml.etree.LXML_VERSION
except ImportError as e:
LXML_PRESENT = False
LXML_VERSION = (0,)
if LXML_PRESENT: if LXML_PRESENT:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
@ -23,13 +17,14 @@ from bs4.element import Comment, Doctype, SoupStrainer
from . import ( from . import (
HTMLTreeBuilderSmokeTest, HTMLTreeBuilderSmokeTest,
XMLTreeBuilderSmokeTest, XMLTreeBuilderSmokeTest,
SOUP_SIEVE_PRESENT,
SoupTest, SoupTest,
skipIf,
) )
@skipIf( @pytest.mark.skipif(
not LXML_PRESENT, not LXML_PRESENT,
"lxml seems not to be present, not testing its tree builder.") reason="lxml seems not to be present, not testing its tree builder."
)
class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest): class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
"""See ``HTMLTreeBuilderSmokeTest``.""" """See ``HTMLTreeBuilderSmokeTest``."""
@ -54,9 +49,10 @@ class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
# In lxml < 2.3.5, an empty doctype causes a segfault. Skip this # In lxml < 2.3.5, an empty doctype causes a segfault. Skip this
# test if an old version of lxml is installed. # test if an old version of lxml is installed.
@skipIf( @pytest.mark.skipif(
not LXML_PRESENT or LXML_VERSION < (2,3,5,0), not LXML_PRESENT or LXML_VERSION < (2,3,5,0),
"Skipping doctype test for old version of lxml to avoid segfault.") reason="Skipping doctype test for old version of lxml to avoid segfault."
)
def test_empty_doctype(self): def test_empty_doctype(self):
soup = self.soup("<!DOCTYPE>") soup = self.soup("<!DOCTYPE>")
doctype = soup.contents[0] doctype = soup.contents[0]
@ -68,7 +64,9 @@ class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
soup = BeautifulStoneSoup("<b />") soup = BeautifulStoneSoup("<b />")
assert "<b/>" == str(soup.b) assert "<b/>" == str(soup.b)
assert "BeautifulStoneSoup class is deprecated" in str(w[0].message) [warning] = w
assert warning.filename == __file__
assert "BeautifulStoneSoup class is deprecated" in str(warning.message)
def test_tracking_line_numbers(self): def test_tracking_line_numbers(self):
# The lxml TreeBuilder cannot keep track of line numbers from # The lxml TreeBuilder cannot keep track of line numbers from
@ -85,9 +83,10 @@ class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
assert "sourceline" == soup.p.sourceline.name assert "sourceline" == soup.p.sourceline.name
assert "sourcepos" == soup.p.sourcepos.name assert "sourcepos" == soup.p.sourcepos.name
@skipIf( @pytest.mark.skipif(
not LXML_PRESENT, not LXML_PRESENT,
"lxml seems not to be present, not testing its XML tree builder.") reason="lxml seems not to be present, not testing its XML tree builder."
)
class TestLXMLXMLTreeBuilder(SoupTest, XMLTreeBuilderSmokeTest): class TestLXMLXMLTreeBuilder(SoupTest, XMLTreeBuilderSmokeTest):
"""See ``HTMLTreeBuilderSmokeTest``.""" """See ``HTMLTreeBuilderSmokeTest``."""
@ -148,6 +147,9 @@ class TestLXMLXMLTreeBuilder(SoupTest, XMLTreeBuilderSmokeTest):
} }
@pytest.mark.skipif(
not SOUP_SIEVE_PRESENT, reason="Soup Sieve not installed"
)
def test_namespace_interaction_with_select_and_find(self): def test_namespace_interaction_with_select_and_find(self):
# Demonstrate how namespaces interact with select* and # Demonstrate how namespaces interact with select* and
# find* methods. # find* methods.

View file

@ -3,15 +3,18 @@ import copy
import pickle import pickle
import pytest import pytest
from soupsieve import SelectorSyntaxError
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from bs4.element import ( from bs4.element import (
Comment, Comment,
SoupStrainer, SoupStrainer,
) )
from . import SoupTest from . import (
SoupTest,
SOUP_SIEVE_PRESENT,
)
if SOUP_SIEVE_PRESENT:
from soupsieve import SelectorSyntaxError
class TestEncoding(SoupTest): class TestEncoding(SoupTest):
"""Test the ability to encode objects into strings.""" """Test the ability to encode objects into strings."""
@ -213,6 +216,7 @@ class TestFormatters(SoupTest):
assert soup.contents[0].name == 'pre' assert soup.contents[0].name == 'pre'
@pytest.mark.skipif(not SOUP_SIEVE_PRESENT, reason="Soup Sieve not installed")
class TestCSSSelectors(SoupTest): class TestCSSSelectors(SoupTest):
"""Test basic CSS selector functionality. """Test basic CSS selector functionality.
@ -694,6 +698,7 @@ class TestPersistence(SoupTest):
assert tag.can_be_empty_element == copied.can_be_empty_element assert tag.can_be_empty_element == copied.can_be_empty_element
assert tag.cdata_list_attributes == copied.cdata_list_attributes assert tag.cdata_list_attributes == copied.cdata_list_attributes
assert tag.preserve_whitespace_tags == copied.preserve_whitespace_tags assert tag.preserve_whitespace_tags == copied.preserve_whitespace_tags
assert tag.interesting_string_types == copied.interesting_string_types
def test_unicode_pickle(self): def test_unicode_pickle(self):
# A tree containing Unicode characters can be pickled. # A tree containing Unicode characters can be pickled.

View file

@ -30,19 +30,11 @@ from bs4.element import (
from . import ( from . import (
default_builder, default_builder,
LXML_PRESENT,
SoupTest, SoupTest,
skipIf,
) )
import warnings import warnings
try:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
LXML_PRESENT = True
except ImportError as e:
LXML_PRESENT = False
PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2))
class TestConstructor(SoupTest): class TestConstructor(SoupTest):
def test_short_unicode_input(self): def test_short_unicode_input(self):
@ -139,7 +131,7 @@ class TestConstructor(SoupTest):
assert " an id " == a['id'] assert " an id " == a['id']
assert ["a", "class"] == a['class'] assert ["a", "class"] == a['class']
# TreeBuilder takes an argument called 'mutli_valued_attributes' which lets # TreeBuilder takes an argument called 'multi_valued_attributes' which lets
# you customize or disable this. As always, you can customize the TreeBuilder # you customize or disable this. As always, you can customize the TreeBuilder
# by passing in a keyword argument to the BeautifulSoup constructor. # by passing in a keyword argument to the BeautifulSoup constructor.
soup = self.soup(markup, builder=default_builder, multi_valued_attributes=None) soup = self.soup(markup, builder=default_builder, multi_valued_attributes=None)
@ -219,10 +211,17 @@ class TestConstructor(SoupTest):
class TestWarnings(SoupTest): class TestWarnings(SoupTest):
# Note that some of the tests in this class create BeautifulSoup
# objects directly rather than using self.soup(). That's
# because SoupTest.soup is defined in a different file,
# which will throw off the assertion in _assert_warning
# that the code that triggered the warning is in the same
# file as the test.
def _assert_warning(self, warnings, cls): def _assert_warning(self, warnings, cls):
for w in warnings: for w in warnings:
if isinstance(w.message, cls): if isinstance(w.message, cls):
assert w.filename == __file__
return w return w
raise Exception("%s warning not found in %r" % (cls, warnings)) raise Exception("%s warning not found in %r" % (cls, warnings))
@ -243,13 +242,17 @@ class TestWarnings(SoupTest):
def test_no_warning_if_explicit_parser_specified(self): def test_no_warning_if_explicit_parser_specified(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
soup = BeautifulSoup("<a><b></b></a>", "html.parser") soup = self.soup("<a><b></b></a>")
assert [] == w assert [] == w
def test_parseOnlyThese_renamed_to_parse_only(self): def test_parseOnlyThese_renamed_to_parse_only(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", parseOnlyThese=SoupStrainer("b")) soup = BeautifulSoup(
msg = str(w[0].message) "<a><b></b></a>", "html.parser",
parseOnlyThese=SoupStrainer("b"),
)
warning = self._assert_warning(w, DeprecationWarning)
msg = str(warning.message)
assert "parseOnlyThese" in msg assert "parseOnlyThese" in msg
assert "parse_only" in msg assert "parse_only" in msg
assert b"<b></b>" == soup.encode() assert b"<b></b>" == soup.encode()
@ -257,8 +260,11 @@ class TestWarnings(SoupTest):
def test_fromEncoding_renamed_to_from_encoding(self): def test_fromEncoding_renamed_to_from_encoding(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
utf8 = b"\xc3\xa9" utf8 = b"\xc3\xa9"
soup = self.soup(utf8, fromEncoding="utf8") soup = BeautifulSoup(
msg = str(w[0].message) utf8, "html.parser", fromEncoding="utf8"
)
warning = self._assert_warning(w, DeprecationWarning)
msg = str(warning.message)
assert "fromEncoding" in msg assert "fromEncoding" in msg
assert "from_encoding" in msg assert "from_encoding" in msg
assert "utf8" == soup.original_encoding assert "utf8" == soup.original_encoding
@ -276,7 +282,7 @@ class TestWarnings(SoupTest):
# A warning is issued if the "markup" looks like the name of # A warning is issued if the "markup" looks like the name of
# an HTML or text file, or a full path to a file on disk. # an HTML or text file, or a full path to a file on disk.
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
soup = self.soup("markup" + extension) soup = BeautifulSoup("markup" + extension, "html.parser")
warning = self._assert_warning(w, MarkupResemblesLocatorWarning) warning = self._assert_warning(w, MarkupResemblesLocatorWarning)
assert "looks more like a filename" in str(warning.message) assert "looks more like a filename" in str(warning.message)
@ -291,11 +297,11 @@ class TestWarnings(SoupTest):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
soup = self.soup("markup" + extension) soup = self.soup("markup" + extension)
assert [] == w assert [] == w
def test_url_warning_with_bytes_url(self): def test_url_warning_with_bytes_url(self):
url = b"http://www.crummybytes.com/" url = b"http://www.crummybytes.com/"
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(url) soup = BeautifulSoup(url, "html.parser")
warning = self._assert_warning( warning = self._assert_warning(
warning_list, MarkupResemblesLocatorWarning warning_list, MarkupResemblesLocatorWarning
) )
@ -307,7 +313,7 @@ class TestWarnings(SoupTest):
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
# note - this url must differ from the bytes one otherwise # note - this url must differ from the bytes one otherwise
# python's warnings system swallows the second warning # python's warnings system swallows the second warning
soup = self.soup(url) soup = BeautifulSoup(url, "html.parser")
warning = self._assert_warning( warning = self._assert_warning(
warning_list, MarkupResemblesLocatorWarning warning_list, MarkupResemblesLocatorWarning
) )
@ -347,18 +353,22 @@ class TestNewTag(SoupTest):
assert "foo" == new_tag.name assert "foo" == new_tag.name
assert dict(bar="baz", name="a name") == new_tag.attrs assert dict(bar="baz", name="a name") == new_tag.attrs
assert None == new_tag.parent assert None == new_tag.parent
@pytest.mark.skipif(
not LXML_PRESENT,
reason="lxml not installed, cannot parse XML document"
)
def test_xml_tag_inherits_self_closing_rules_from_builder(self):
xml_soup = BeautifulSoup("", "xml")
xml_br = xml_soup.new_tag("br")
xml_p = xml_soup.new_tag("p")
# Both the <br> and <p> tag are empty-element, just because
# they have no contents.
assert b"<br/>" == xml_br.encode()
assert b"<p/>" == xml_p.encode()
def test_tag_inherits_self_closing_rules_from_builder(self): def test_tag_inherits_self_closing_rules_from_builder(self):
if LXML_PRESENT:
xml_soup = BeautifulSoup("", "lxml-xml")
xml_br = xml_soup.new_tag("br")
xml_p = xml_soup.new_tag("p")
# Both the <br> and <p> tag are empty-element, just because
# they have no contents.
assert b"<br/>" == xml_br.encode()
assert b"<p/>" == xml_p.encode()
html_soup = BeautifulSoup("", "html.parser") html_soup = BeautifulSoup("", "html.parser")
html_br = html_soup.new_tag("br") html_br = html_soup.new_tag("br")
html_p = html_soup.new_tag("p") html_p = html_soup.new_tag("p")
@ -450,13 +460,3 @@ class TestEncodingConversion(SoupTest):
# The internal data structures can be encoded as UTF-8. # The internal data structures can be encoded as UTF-8.
soup_from_unicode = self.soup(self.unicode_data) soup_from_unicode = self.soup(self.unicode_data)
assert soup_from_unicode.encode('utf-8') == self.utf8_data assert soup_from_unicode.encode('utf-8') == self.utf8_data
@skipIf(
PYTHON_3_PRE_3_2,
"Bad HTMLParser detected; skipping test of non-ASCII characters in attribute name.")
def test_attribute_name_containing_unicode_characters(self):
markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
assert self.soup(markup).div.encode("utf8") == markup.encode("utf8")

View file

@ -33,7 +33,6 @@ from bs4.element import (
) )
from . import ( from . import (
SoupTest, SoupTest,
skipIf,
) )
class TestFind(SoupTest): class TestFind(SoupTest):
@ -910,12 +909,16 @@ class TestTreeModification(SoupTest):
soup.a.extend(l) soup.a.extend(l)
assert "<a><g></g><f></f><e></e><d></d><c></c><b></b></a>" == soup.decode() assert "<a><g></g><f></f><e></e><d></d><c></c><b></b></a>" == soup.decode()
def test_extend_with_another_tags_contents(self): @pytest.mark.parametrize(
"get_tags", [lambda tag: tag, lambda tag: tag.contents]
)
def test_extend_with_another_tags_contents(self, get_tags):
data = '<body><div id="d1"><a>1</a><a>2</a><a>3</a><a>4</a></div><div id="d2"></div></body>' data = '<body><div id="d1"><a>1</a><a>2</a><a>3</a><a>4</a></div><div id="d2"></div></body>'
soup = self.soup(data) soup = self.soup(data)
d1 = soup.find('div', id='d1') d1 = soup.find('div', id='d1')
d2 = soup.find('div', id='d2') d2 = soup.find('div', id='d2')
d2.extend(d1) tags = get_tags(d1)
d2.extend(tags)
assert '<div id="d1"></div>' == d1.decode() assert '<div id="d1"></div>' == d1.decode()
assert '<div id="d2"><a>1</a><a>2</a><a>3</a><a>4</a></div>' == d2.decode() assert '<div id="d2"><a>1</a><a>2</a><a>3</a><a>4</a></div>' == d2.decode()
@ -1272,19 +1275,30 @@ class TestTreeModification(SoupTest):
class TestDeprecatedArguments(SoupTest): class TestDeprecatedArguments(SoupTest):
def test_find_type_method_string(self): @pytest.mark.parametrize(
"method_name", [
"find", "find_all", "find_parent", "find_parents",
"find_next", "find_all_next", "find_previous",
"find_all_previous", "find_next_sibling", "find_next_siblings",
"find_previous_sibling", "find_previous_siblings",
]
)
def test_find_type_method_string(self, method_name):
soup = self.soup("<a>some</a><b>markup</b>") soup = self.soup("<a>some</a><b>markup</b>")
method = getattr(soup.b, method_name)
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
[result] = soup.find_all(text='markup') method(text='markup')
assert result == 'markup' [warning] = w
assert result.parent.name == 'b' assert warning.filename == __file__
msg = str(w[0].message) msg = str(warning.message)
assert msg == "The 'text' argument to find()-type methods is deprecated. Use 'string' instead." assert msg == "The 'text' argument to find()-type methods is deprecated. Use 'string' instead."
def test_soupstrainer_constructor_string(self): def test_soupstrainer_constructor_string(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
strainer = SoupStrainer(text="text") strainer = SoupStrainer(text="text")
assert strainer.text == 'text' assert strainer.text == 'text'
msg = str(w[0].message) [warning] = w
msg = str(warning.message)
assert warning.filename == __file__
assert msg == "The 'text' argument to the SoupStrainer constructor is deprecated. Use 'string' instead." assert msg == "The 'text' argument to the SoupStrainer constructor is deprecated. Use 'string' instead."

View file

@ -21,14 +21,8 @@ at <https://github.com/Ousret/charset_normalizer>.
""" """
import logging import logging
from .api import from_bytes, from_fp, from_path, normalize from .api import from_bytes, from_fp, from_path
from .legacy import ( from .legacy import detect
CharsetDetector,
CharsetDoctor,
CharsetNormalizerMatch,
CharsetNormalizerMatches,
detect,
)
from .models import CharsetMatch, CharsetMatches from .models import CharsetMatch, CharsetMatches
from .utils import set_logging_handler from .utils import set_logging_handler
from .version import VERSION, __version__ from .version import VERSION, __version__
@ -37,14 +31,9 @@ __all__ = (
"from_fp", "from_fp",
"from_path", "from_path",
"from_bytes", "from_bytes",
"normalize",
"detect", "detect",
"CharsetMatch", "CharsetMatch",
"CharsetMatches", "CharsetMatches",
"CharsetNormalizerMatch",
"CharsetNormalizerMatches",
"CharsetDetector",
"CharsetDoctor",
"__version__", "__version__",
"VERSION", "VERSION",
"set_logging_handler", "set_logging_handler",

View file

@ -1,7 +1,5 @@
import logging import logging
import warnings
from os import PathLike from os import PathLike
from os.path import basename, splitext
from typing import Any, BinaryIO, List, Optional, Set from typing import Any, BinaryIO, List, Optional, Set
from .cd import ( from .cd import (
@ -41,11 +39,12 @@ def from_bytes(
cp_exclusion: Optional[List[str]] = None, cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True, preemptive_behaviour: bool = True,
explain: bool = False, explain: bool = False,
language_threshold: float = 0.1,
) -> CharsetMatches: ) -> CharsetMatches:
""" """
Given a raw bytes sequence, return the best possibles charset usable to render str objects. Given a raw bytes sequence, return the best possibles charset usable to render str objects.
If there is no results, it is a strong indicator that the source is binary/not text. If there is no results, it is a strong indicator that the source is binary/not text.
By default, the process will extract 5 blocs of 512o each to assess the mess and coherence of a given sequence. By default, the process will extract 5 blocks of 512o each to assess the mess and coherence of a given sequence.
And will give up a particular code page after 20% of measured mess. Those criteria are customizable at will. And will give up a particular code page after 20% of measured mess. Those criteria are customizable at will.
The preemptive behavior DOES NOT replace the traditional detection workflow, it prioritize a particular code page The preemptive behavior DOES NOT replace the traditional detection workflow, it prioritize a particular code page
@ -176,7 +175,6 @@ def from_bytes(
prioritized_encodings.append("utf_8") prioritized_encodings.append("utf_8")
for encoding_iana in prioritized_encodings + IANA_SUPPORTED: for encoding_iana in prioritized_encodings + IANA_SUPPORTED:
if cp_isolation and encoding_iana not in cp_isolation: if cp_isolation and encoding_iana not in cp_isolation:
continue continue
@ -197,7 +195,14 @@ def from_bytes(
if encoding_iana in {"utf_16", "utf_32"} and not bom_or_sig_available: if encoding_iana in {"utf_16", "utf_32"} and not bom_or_sig_available:
logger.log( logger.log(
TRACE, TRACE,
"Encoding %s wont be tested as-is because it require a BOM. Will try some sub-encoder LE/BE.", "Encoding %s won't be tested as-is because it require a BOM. Will try some sub-encoder LE/BE.",
encoding_iana,
)
continue
if encoding_iana in {"utf_7"} and not bom_or_sig_available:
logger.log(
TRACE,
"Encoding %s won't be tested as-is because detection is unreliable without BOM/SIG.",
encoding_iana, encoding_iana,
) )
continue continue
@ -297,7 +302,13 @@ def from_bytes(
): ):
md_chunks.append(chunk) md_chunks.append(chunk)
md_ratios.append(mess_ratio(chunk, threshold)) md_ratios.append(
mess_ratio(
chunk,
threshold,
explain is True and 1 <= len(cp_isolation) <= 2,
)
)
if md_ratios[-1] >= threshold: if md_ratios[-1] >= threshold:
early_stop_count += 1 early_stop_count += 1
@ -306,7 +317,9 @@ def from_bytes(
bom_or_sig_available and strip_sig_or_bom is False bom_or_sig_available and strip_sig_or_bom is False
): ):
break break
except UnicodeDecodeError as e: # Lazy str loading may have missed something there except (
UnicodeDecodeError
) as e: # Lazy str loading may have missed something there
logger.log( logger.log(
TRACE, TRACE,
"LazyStr Loading: After MD chunk decode, code page %s does not fit given bytes sequence at ALL. %s", "LazyStr Loading: After MD chunk decode, code page %s does not fit given bytes sequence at ALL. %s",
@ -389,7 +402,9 @@ def from_bytes(
if encoding_iana != "ascii": if encoding_iana != "ascii":
for chunk in md_chunks: for chunk in md_chunks:
chunk_languages = coherence_ratio( chunk_languages = coherence_ratio(
chunk, 0.1, ",".join(target_languages) if target_languages else None chunk,
language_threshold,
",".join(target_languages) if target_languages else None,
) )
cd_ratios.append(chunk_languages) cd_ratios.append(chunk_languages)
@ -491,6 +506,7 @@ def from_fp(
cp_exclusion: Optional[List[str]] = None, cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True, preemptive_behaviour: bool = True,
explain: bool = False, explain: bool = False,
language_threshold: float = 0.1,
) -> CharsetMatches: ) -> CharsetMatches:
""" """
Same thing than the function from_bytes but using a file pointer that is already ready. Same thing than the function from_bytes but using a file pointer that is already ready.
@ -505,6 +521,7 @@ def from_fp(
cp_exclusion, cp_exclusion,
preemptive_behaviour, preemptive_behaviour,
explain, explain,
language_threshold,
) )
@ -517,6 +534,7 @@ def from_path(
cp_exclusion: Optional[List[str]] = None, cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True, preemptive_behaviour: bool = True,
explain: bool = False, explain: bool = False,
language_threshold: float = 0.1,
) -> CharsetMatches: ) -> CharsetMatches:
""" """
Same thing than the function from_bytes but with one extra step. Opening and reading given file path in binary mode. Same thing than the function from_bytes but with one extra step. Opening and reading given file path in binary mode.
@ -532,53 +550,5 @@ def from_path(
cp_exclusion, cp_exclusion,
preemptive_behaviour, preemptive_behaviour,
explain, explain,
language_threshold,
) )
def normalize(
path: "PathLike[Any]",
steps: int = 5,
chunk_size: int = 512,
threshold: float = 0.20,
cp_isolation: Optional[List[str]] = None,
cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True,
) -> CharsetMatch:
"""
Take a (text-based) file path and try to create another file next to it, this time using UTF-8.
"""
warnings.warn(
"normalize is deprecated and will be removed in 3.0",
DeprecationWarning,
)
results = from_path(
path,
steps,
chunk_size,
threshold,
cp_isolation,
cp_exclusion,
preemptive_behaviour,
)
filename = basename(path)
target_extensions = list(splitext(filename))
if len(results) == 0:
raise IOError(
'Unable to normalize "{}", no encoding charset seems to fit.'.format(
filename
)
)
result = results.best()
target_extensions[0] += "-" + result.encoding # type: ignore
with open(
"{}".format(str(path).replace(filename, "".join(target_extensions))), "wb"
) as fp:
fp.write(result.output()) # type: ignore
return result # type: ignore

View file

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Dict, List from typing import Dict, List
# Language label that contain the em dash "—"
# character are to be considered alternative seq to origin
FREQUENCIES: Dict[str, List[str]] = { FREQUENCIES: Dict[str, List[str]] = {
"English": [ "English": [
"e", "e",
@ -30,6 +32,34 @@ FREQUENCIES: Dict[str, List[str]] = {
"z", "z",
"q", "q",
], ],
"English—": [
"e",
"a",
"t",
"i",
"o",
"n",
"s",
"r",
"h",
"l",
"d",
"c",
"m",
"u",
"f",
"p",
"g",
"w",
"b",
"y",
"v",
"k",
"j",
"x",
"z",
"q",
],
"German": [ "German": [
"e", "e",
"n", "n",
@ -226,33 +256,303 @@ FREQUENCIES: Dict[str, List[str]] = {
"ж", "ж",
"ц", "ц",
], ],
# Jap-Kanji
"Japanese": [ "Japanese": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"丿",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"广",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
# Jap-Katakana
"Japanese—": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
# Jap-Hiragana
"Japanese——": [
"", "",
"", "",
"", "",
"", "",
"",
"",
"", "",
"",
"", "",
"",
"", "",
"", "",
"", "",
"", "",
"",
"",
"",
"", "",
"", "",
"",
"",
"",
"", "",
"", "",
"",
"", "",
"", "",
"", "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
], ],
"Portuguese": [ "Portuguese": [
"a", "a",
@ -340,6 +640,77 @@ FREQUENCIES: Dict[str, List[str]] = {
"", "",
"", "",
"", "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
], ],
"Ukrainian": [ "Ukrainian": [
"о", "о",
@ -956,34 +1327,6 @@ FREQUENCIES: Dict[str, List[str]] = {
"ö", "ö",
"y", "y",
], ],
"Simple English": [
"e",
"a",
"t",
"i",
"o",
"n",
"s",
"r",
"h",
"l",
"d",
"c",
"m",
"u",
"f",
"p",
"g",
"w",
"b",
"y",
"v",
"k",
"j",
"x",
"z",
"q",
],
"Thai": [ "Thai": [
"", "",
"", "",
@ -1066,31 +1409,6 @@ FREQUENCIES: Dict[str, List[str]] = {
"", "",
"", "",
], ],
"Classical Chinese": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
"Kazakh": [ "Kazakh": [
"а", "а",
"ы", "ы",

View file

@ -105,7 +105,7 @@ def mb_encoding_languages(iana_name: str) -> List[str]:
): ):
return ["Japanese"] return ["Japanese"]
if iana_name.startswith("gb") or iana_name in ZH_NAMES: if iana_name.startswith("gb") or iana_name in ZH_NAMES:
return ["Chinese", "Classical Chinese"] return ["Chinese"]
if iana_name.startswith("iso2022_kr") or iana_name in KO_NAMES: if iana_name.startswith("iso2022_kr") or iana_name in KO_NAMES:
return ["Korean"] return ["Korean"]
@ -140,7 +140,6 @@ def alphabet_languages(
source_have_accents = any(is_accentuated(character) for character in characters) source_have_accents = any(is_accentuated(character) for character in characters)
for language, language_characters in FREQUENCIES.items(): for language, language_characters in FREQUENCIES.items():
target_have_accents, target_pure_latin = get_target_features(language) target_have_accents, target_pure_latin = get_target_features(language)
if ignore_non_latin and target_pure_latin is False: if ignore_non_latin and target_pure_latin is False:
@ -179,22 +178,45 @@ def characters_popularity_compare(
character_approved_count: int = 0 character_approved_count: int = 0
FREQUENCIES_language_set = set(FREQUENCIES[language]) FREQUENCIES_language_set = set(FREQUENCIES[language])
for character in ordered_characters: ordered_characters_count: int = len(ordered_characters)
target_language_characters_count: int = len(FREQUENCIES[language])
large_alphabet: bool = target_language_characters_count > 26
for character, character_rank in zip(
ordered_characters, range(0, ordered_characters_count)
):
if character not in FREQUENCIES_language_set: if character not in FREQUENCIES_language_set:
continue continue
character_rank_in_language: int = FREQUENCIES[language].index(character)
expected_projection_ratio: float = (
target_language_characters_count / ordered_characters_count
)
character_rank_projection: int = int(character_rank * expected_projection_ratio)
if (
large_alphabet is False
and abs(character_rank_projection - character_rank_in_language) > 4
):
continue
if (
large_alphabet is True
and abs(character_rank_projection - character_rank_in_language)
< target_language_characters_count / 3
):
character_approved_count += 1
continue
characters_before_source: List[str] = FREQUENCIES[language][ characters_before_source: List[str] = FREQUENCIES[language][
0 : FREQUENCIES[language].index(character) 0:character_rank_in_language
] ]
characters_after_source: List[str] = FREQUENCIES[language][ characters_after_source: List[str] = FREQUENCIES[language][
FREQUENCIES[language].index(character) : character_rank_in_language:
]
characters_before: List[str] = ordered_characters[
0 : ordered_characters.index(character)
]
characters_after: List[str] = ordered_characters[
ordered_characters.index(character) :
] ]
characters_before: List[str] = ordered_characters[0:character_rank]
characters_after: List[str] = ordered_characters[character_rank:]
before_match_count: int = len( before_match_count: int = len(
set(characters_before) & set(characters_before_source) set(characters_before) & set(characters_before_source)
@ -289,6 +311,33 @@ def merge_coherence_ratios(results: List[CoherenceMatches]) -> CoherenceMatches:
return sorted(merge, key=lambda x: x[1], reverse=True) return sorted(merge, key=lambda x: x[1], reverse=True)
def filter_alt_coherence_matches(results: CoherenceMatches) -> CoherenceMatches:
"""
We shall NOT return "English—" in CoherenceMatches because it is an alternative
of "English". This function only keeps the best match and remove the em-dash in it.
"""
index_results: Dict[str, List[float]] = dict()
for result in results:
language, ratio = result
no_em_name: str = language.replace("", "")
if no_em_name not in index_results:
index_results[no_em_name] = []
index_results[no_em_name].append(ratio)
if any(len(index_results[e]) > 1 for e in index_results):
filtered_results: CoherenceMatches = []
for language in index_results:
filtered_results.append((language, max(index_results[language])))
return filtered_results
return results
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def coherence_ratio( def coherence_ratio(
decoded_sequence: str, threshold: float = 0.1, lg_inclusion: Optional[str] = None decoded_sequence: str, threshold: float = 0.1, lg_inclusion: Optional[str] = None
@ -336,4 +385,6 @@ def coherence_ratio(
if sufficient_match_count >= 3: if sufficient_match_count >= 3:
break break
return sorted(results, key=lambda x: x[1], reverse=True) return sorted(
filter_alt_coherence_matches(results), key=lambda x: x[1], reverse=True
)

View file

@ -1,15 +1,12 @@
import argparse import argparse
import sys import sys
from json import dumps from json import dumps
from os.path import abspath from os.path import abspath, basename, dirname, join, realpath
from platform import python_version from platform import python_version
from typing import List, Optional from typing import List, Optional
from unicodedata import unidata_version
try: import charset_normalizer.md as md_module
from unicodedata2 import unidata_version
except ImportError:
from unicodedata import unidata_version
from charset_normalizer import from_fp from charset_normalizer import from_fp
from charset_normalizer.models import CliDetectionResult from charset_normalizer.models import CliDetectionResult
from charset_normalizer.version import __version__ from charset_normalizer.version import __version__
@ -124,8 +121,11 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
parser.add_argument( parser.add_argument(
"--version", "--version",
action="version", action="version",
version="Charset-Normalizer {} - Python {} - Unicode {}".format( version="Charset-Normalizer {} - Python {} - Unicode {} - SpeedUp {}".format(
__version__, python_version(), unidata_version __version__,
python_version(),
unidata_version,
"OFF" if md_module.__file__.lower().endswith(".py") else "ON",
), ),
help="Show version information and exit.", help="Show version information and exit.",
) )
@ -147,7 +147,6 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
x_ = [] x_ = []
for my_file in args.files: for my_file in args.files:
matches = from_fp(my_file, threshold=args.threshold, explain=args.verbose) matches = from_fp(my_file, threshold=args.threshold, explain=args.verbose)
best_guess = matches.best() best_guess = matches.best()
@ -222,7 +221,6 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
) )
if args.normalize is True: if args.normalize is True:
if best_guess.encoding.startswith("utf") is True: if best_guess.encoding.startswith("utf") is True:
print( print(
'"{}" file does not need to be normalized, as it already came from unicode.'.format( '"{}" file does not need to be normalized, as it already came from unicode.'.format(
@ -234,7 +232,10 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
my_file.close() my_file.close()
continue continue
o_: List[str] = my_file.name.split(".") dir_path = dirname(realpath(my_file.name))
file_name = basename(realpath(my_file.name))
o_: List[str] = file_name.split(".")
if args.replace is False: if args.replace is False:
o_.insert(-1, best_guess.encoding) o_.insert(-1, best_guess.encoding)
@ -255,7 +256,7 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
continue continue
try: try:
x_[0].unicode_path = abspath("./{}".format(".".join(o_))) x_[0].unicode_path = join(dir_path, ".".join(o_))
with open(x_[0].unicode_path, "w", encoding="utf-8") as fp: with open(x_[0].unicode_path, "w", encoding="utf-8") as fp:
fp.write(str(best_guess)) fp.write(str(best_guess))

View file

@ -489,9 +489,7 @@ COMMON_SAFE_ASCII_CHARACTERS: Set[str] = {
KO_NAMES: Set[str] = {"johab", "cp949", "euc_kr"} KO_NAMES: Set[str] = {"johab", "cp949", "euc_kr"}
ZH_NAMES: Set[str] = {"big5", "cp950", "big5hkscs", "hz"} ZH_NAMES: Set[str] = {"big5", "cp950", "big5hkscs", "hz"}
NOT_PRINTABLE_PATTERN = re_compile(r"[0-9\W\n\r\t]+")
LANGUAGE_SUPPORTED_COUNT: int = len(FREQUENCIES) LANGUAGE_SUPPORTED_COUNT: int = len(FREQUENCIES)
# Logging LEVEL bellow DEBUG # Logging LEVEL below DEBUG
TRACE: int = 5 TRACE: int = 5

View file

@ -1,12 +1,13 @@
import warnings from typing import Any, Dict, Optional, Union
from typing import Dict, Optional, Union from warnings import warn
from .api import from_bytes, from_fp, from_path, normalize from .api import from_bytes
from .constant import CHARDET_CORRESPONDENCE from .constant import CHARDET_CORRESPONDENCE
from .models import CharsetMatch, CharsetMatches
def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]: def detect(
byte_str: bytes, should_rename_legacy: bool = False, **kwargs: Any
) -> Dict[str, Optional[Union[str, float]]]:
""" """
chardet legacy method chardet legacy method
Detect the encoding of the given byte string. It should be mostly backward-compatible. Detect the encoding of the given byte string. It should be mostly backward-compatible.
@ -15,7 +16,14 @@ def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
further information. Not planned for removal. further information. Not planned for removal.
:param byte_str: The byte sequence to examine. :param byte_str: The byte sequence to examine.
:param should_rename_legacy: Should we rename legacy encodings
to their more modern equivalents?
""" """
if len(kwargs):
warn(
f"charset-normalizer disregard arguments '{','.join(list(kwargs.keys()))}' in legacy function detect()"
)
if not isinstance(byte_str, (bytearray, bytes)): if not isinstance(byte_str, (bytearray, bytes)):
raise TypeError( # pragma: nocover raise TypeError( # pragma: nocover
"Expected object of type bytes or bytearray, got: " "Expected object of type bytes or bytearray, got: "
@ -36,60 +44,11 @@ def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
if r is not None and encoding == "utf_8" and r.bom: if r is not None and encoding == "utf_8" and r.bom:
encoding += "_sig" encoding += "_sig"
if should_rename_legacy is False and encoding in CHARDET_CORRESPONDENCE:
encoding = CHARDET_CORRESPONDENCE[encoding]
return { return {
"encoding": encoding "encoding": encoding,
if encoding not in CHARDET_CORRESPONDENCE
else CHARDET_CORRESPONDENCE[encoding],
"language": language, "language": language,
"confidence": confidence, "confidence": confidence,
} }
class CharsetNormalizerMatch(CharsetMatch):
pass
class CharsetNormalizerMatches(CharsetMatches):
@staticmethod
def from_fp(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return from_fp(*args, **kwargs) # pragma: nocover
@staticmethod
def from_bytes(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return from_bytes(*args, **kwargs) # pragma: nocover
@staticmethod
def from_path(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return from_path(*args, **kwargs) # pragma: nocover
@staticmethod
def normalize(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return normalize(*args, **kwargs) # pragma: nocover
class CharsetDetector(CharsetNormalizerMatches):
pass
class CharsetDoctor(CharsetNormalizerMatches):
pass

View file

@ -1,7 +1,12 @@
from functools import lru_cache from functools import lru_cache
from logging import getLogger
from typing import List, Optional from typing import List, Optional
from .constant import COMMON_SAFE_ASCII_CHARACTERS, UNICODE_SECONDARY_RANGE_KEYWORD from .constant import (
COMMON_SAFE_ASCII_CHARACTERS,
TRACE,
UNICODE_SECONDARY_RANGE_KEYWORD,
)
from .utils import ( from .utils import (
is_accentuated, is_accentuated,
is_ascii, is_ascii,
@ -123,7 +128,7 @@ class TooManyAccentuatedPlugin(MessDetectorPlugin):
@property @property
def ratio(self) -> float: def ratio(self) -> float:
if self._character_count == 0: if self._character_count == 0 or self._character_count < 8:
return 0.0 return 0.0
ratio_of_accentuation: float = self._accentuated_count / self._character_count ratio_of_accentuation: float = self._accentuated_count / self._character_count
return ratio_of_accentuation if ratio_of_accentuation >= 0.35 else 0.0 return ratio_of_accentuation if ratio_of_accentuation >= 0.35 else 0.0
@ -547,7 +552,20 @@ def mess_ratio(
break break
if debug: if debug:
logger = getLogger("charset_normalizer")
logger.log(
TRACE,
"Mess-detector extended-analysis start. "
f"intermediary_mean_mess_ratio_calc={intermediary_mean_mess_ratio_calc} mean_mess_ratio={mean_mess_ratio} "
f"maximum_threshold={maximum_threshold}",
)
if len(decoded_sequence) > 16:
logger.log(TRACE, f"Starting with: {decoded_sequence[:16]}")
logger.log(TRACE, f"Ending with: {decoded_sequence[-16::]}")
for dt in detectors: # pragma: nocover for dt in detectors: # pragma: nocover
print(dt.__class__, dt.ratio) logger.log(TRACE, f"{dt.__class__}: {dt.ratio}")
return round(mean_mess_ratio, 3) return round(mean_mess_ratio, 3)

View file

@ -1,22 +1,9 @@
import warnings
from collections import Counter
from encodings.aliases import aliases from encodings.aliases import aliases
from hashlib import sha256 from hashlib import sha256
from json import dumps from json import dumps
from re import sub from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Counter as TypeCounter,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
from .constant import NOT_PRINTABLE_PATTERN, TOO_BIG_SEQUENCE from .constant import TOO_BIG_SEQUENCE
from .md import mess_ratio
from .utils import iana_name, is_multi_byte_encoding, unicode_range from .utils import iana_name, is_multi_byte_encoding, unicode_range
@ -65,7 +52,7 @@ class CharsetMatch:
chaos_difference: float = abs(self.chaos - other.chaos) chaos_difference: float = abs(self.chaos - other.chaos)
coherence_difference: float = abs(self.coherence - other.coherence) coherence_difference: float = abs(self.coherence - other.coherence)
# Bellow 1% difference --> Use Coherence # Below 1% difference --> Use Coherence
if chaos_difference < 0.01 and coherence_difference > 0.02: if chaos_difference < 0.01 and coherence_difference > 0.02:
# When having a tough decision, use the result that decoded as many multi-byte as possible. # When having a tough decision, use the result that decoded as many multi-byte as possible.
if chaos_difference == 0.0 and self.coherence == other.coherence: if chaos_difference == 0.0 and self.coherence == other.coherence:
@ -78,45 +65,6 @@ class CharsetMatch:
def multi_byte_usage(self) -> float: def multi_byte_usage(self) -> float:
return 1.0 - len(str(self)) / len(self.raw) return 1.0 - len(str(self)) / len(self.raw)
@property
def chaos_secondary_pass(self) -> float:
"""
Check once again chaos in decoded text, except this time, with full content.
Use with caution, this can be very slow.
Notice: Will be removed in 3.0
"""
warnings.warn(
"chaos_secondary_pass is deprecated and will be removed in 3.0",
DeprecationWarning,
)
return mess_ratio(str(self), 1.0)
@property
def coherence_non_latin(self) -> float:
"""
Coherence ratio on the first non-latin language detected if ANY.
Notice: Will be removed in 3.0
"""
warnings.warn(
"coherence_non_latin is deprecated and will be removed in 3.0",
DeprecationWarning,
)
return 0.0
@property
def w_counter(self) -> TypeCounter[str]:
"""
Word counter instance on decoded text.
Notice: Will be removed in 3.0
"""
warnings.warn(
"w_counter is deprecated and will be removed in 3.0", DeprecationWarning
)
string_printable_only = sub(NOT_PRINTABLE_PATTERN, " ", str(self).lower())
return Counter(string_printable_only.split())
def __str__(self) -> str: def __str__(self) -> str:
# Lazy Str Loading # Lazy Str Loading
if self._string is None: if self._string is None:
@ -252,18 +200,6 @@ class CharsetMatch:
""" """
return [self._encoding] + [m.encoding for m in self._leaves] return [self._encoding] + [m.encoding for m in self._leaves]
def first(self) -> "CharsetMatch":
"""
Kept for BC reasons. Will be removed in 3.0.
"""
return self
def best(self) -> "CharsetMatch":
"""
Kept for BC reasons. Will be removed in 3.0.
"""
return self
def output(self, encoding: str = "utf_8") -> bytes: def output(self, encoding: str = "utf_8") -> bytes:
""" """
Method to get re-encoded bytes payload using given target encoding. Default to UTF-8. Method to get re-encoded bytes payload using given target encoding. Default to UTF-8.

View file

@ -1,12 +1,6 @@
try:
# WARNING: unicodedata2 support is going to be removed in 3.0
# Python is quickly catching up.
import unicodedata2 as unicodedata
except ImportError:
import unicodedata # type: ignore[no-redef]
import importlib import importlib
import logging import logging
import unicodedata
from codecs import IncrementalDecoder from codecs import IncrementalDecoder
from encodings.aliases import aliases from encodings.aliases import aliases
from functools import lru_cache from functools import lru_cache
@ -317,7 +311,6 @@ def range_scan(decoded_sequence: str) -> List[str]:
def cp_similarity(iana_name_a: str, iana_name_b: str) -> float: def cp_similarity(iana_name_a: str, iana_name_b: str) -> float:
if is_multi_byte_encoding(iana_name_a) or is_multi_byte_encoding(iana_name_b): if is_multi_byte_encoding(iana_name_a) or is_multi_byte_encoding(iana_name_b):
return 0.0 return 0.0
@ -357,7 +350,6 @@ def set_logging_handler(
level: int = logging.INFO, level: int = logging.INFO,
format_string: str = "%(asctime)s | %(levelname)s | %(message)s", format_string: str = "%(asctime)s | %(levelname)s | %(message)s",
) -> None: ) -> None:
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
@ -377,7 +369,6 @@ def cut_sequence_chunks(
is_multi_byte_decoder: bool, is_multi_byte_decoder: bool,
decoded_payload: Optional[str] = None, decoded_payload: Optional[str] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if decoded_payload and is_multi_byte_decoder is False: if decoded_payload and is_multi_byte_decoder is False:
for i in offsets: for i in offsets:
chunk = decoded_payload[i : i + chunk_size] chunk = decoded_payload[i : i + chunk_size]
@ -402,8 +393,7 @@ def cut_sequence_chunks(
# multi-byte bad cutting detector and adjustment # multi-byte bad cutting detector and adjustment
# not the cleanest way to perform that fix but clever enough for now. # not the cleanest way to perform that fix but clever enough for now.
if is_multi_byte_decoder and i > 0 and sequences[i] >= 0x80: if is_multi_byte_decoder and i > 0:
chunk_partial_size_chk: int = min(chunk_size, 16) chunk_partial_size_chk: int = min(chunk_size, 16)
if ( if (

View file

@ -2,5 +2,5 @@
Expose version Expose version
""" """
__version__ = "2.1.1" __version__ = "3.1.0"
VERSION = __version__.split(".") VERSION = __version__.split(".")

View file

@ -38,7 +38,7 @@ CL_BLANK = "
URI_SCHEME = "cloudinary" URI_SCHEME = "cloudinary"
API_VERSION = "v1_1" API_VERSION = "v1_1"
VERSION = "1.30.0" VERSION = "1.32.0"
_USER_PLATFORM_DETAILS = "; ".join((platform(), "Python {}".format(python_version()))) _USER_PLATFORM_DETAILS = "; ".join((platform(), "Python {}".format(python_version())))
@ -234,6 +234,7 @@ _http_client = HttpClient()
# FIXME: circular import issue # FIXME: circular import issue
from cloudinary.search import Search from cloudinary.search import Search
from cloudinary.search_folders import SearchFolders
@python_2_unicode_compatible @python_2_unicode_compatible

View file

@ -188,9 +188,9 @@ def _prepare_asset_details_params(**options):
:internal :internal
""" """
return only(options, "exif", "faces", "colors", "image_metadata", "cinemagraph_analysis", return only(options, "exif", "faces", "colors", "image_metadata", "media_metadata", "cinemagraph_analysis",
"pages", "phash", "coordinates", "max_results", "quality_analysis", "derived_next_cursor", "pages", "phash", "coordinates", "max_results", "quality_analysis", "derived_next_cursor",
"accessibility_analysis", "versions") "accessibility_analysis", "versions", "related", "related_next_cursor")
def update(public_id, **options): def update(public_id, **options):
@ -223,6 +223,8 @@ def update(public_id, **options):
params["display_name"] = options.get("display_name") params["display_name"] = options.get("display_name")
if "unique_display_name" in options: if "unique_display_name" in options:
params["unique_display_name"] = options.get("unique_display_name") params["unique_display_name"] = options.get("unique_display_name")
if "clear_invalid" in options:
params["clear_invalid"] = options.get("clear_invalid")
return call_api("post", uri, params, **options) return call_api("post", uri, params, **options)
@ -293,6 +295,50 @@ def delete_derived_by_transformation(public_ids, transformations,
return call_api("delete", uri, params, **options) return call_api("delete", uri, params, **options)
def add_related_assets(public_id, assets_to_relate, resource_type="image", type="upload", **options):
"""
Relates an asset to other assets by public IDs.
:param public_id: The public ID of the asset to update.
:type public_id: str
:param assets_to_relate: The array of up to 10 fully_qualified_public_ids given as resource_type/type/public_id.
:type assets_to_relate: list[str]
:param type: The upload type. Defaults to "upload".
:type type: str
:param resource_type: The type of the resource. Defaults to "image".
:type resource_type: str
:param options: Additional options.
:type options: dict, optional
:return: The result of the command.
:rtype: dict
"""
uri = ["resources", "related_assets", resource_type, type, public_id]
params = {"assets_to_relate": utils.build_array(assets_to_relate)}
return call_json_api("post", uri, params, **options)
def delete_related_assets(public_id, assets_to_unrelate, resource_type="image", type="upload", **options):
"""
Unrelates an asset from other assets by public IDs.
:param public_id: The public ID of the asset to update.
:type public_id: str
:param assets_to_unrelate: The array of up to 10 fully_qualified_public_ids given as resource_type/type/public_id.
:type assets_to_unrelate: list[str]
:param type: The upload type.
:type type: str
:param resource_type: The type of the resource: defaults to "image".
:type resource_type: str
:param options: Additional options.
:type options: dict, optional
:return: The result of the command.
:rtype: dict
"""
uri = ["resources", "related_assets", resource_type, type, public_id]
params = {"assets_to_unrelate": utils.build_array(assets_to_unrelate)}
return call_json_api("delete", uri, params, **options)
def tags(**options): def tags(**options):
resource_type = options.pop("resource_type", "image") resource_type = options.pop("resource_type", "image")
uri = ["tags", resource_type] uri = ["tags", resource_type]

View file

@ -68,9 +68,9 @@ def execute_request(http_connector, method, params, headers, auth, api_url, **op
response = http_connector.request(method.upper(), api_url, processed_params, req_headers, **kw) response = http_connector.request(method.upper(), api_url, processed_params, req_headers, **kw)
body = response.data body = response.data
except HTTPError as e: except HTTPError as e:
raise GeneralError("Unexpected error {0}", e.message) raise GeneralError("Unexpected error %s" % str(e))
except socket.error as e: except socket.error as e:
raise GeneralError("Socket Error: %s" % (str(e))) raise GeneralError("Socket Error: %s" % str(e))
try: try:
result = json.loads(body.decode('utf-8')) result = json.loads(body.decode('utf-8'))

View file

@ -4,7 +4,11 @@ from cloudinary.api_client.call_api import call_json_api
from cloudinary.utils import unique from cloudinary.utils import unique
class Search: class Search(object):
ASSETS = 'resources'
_endpoint = ASSETS
_KEYS_WITH_UNIQUE_VALUES = { _KEYS_WITH_UNIQUE_VALUES = {
'sort_by': lambda x: next(iter(x)), 'sort_by': lambda x: next(iter(x)),
'aggregate': None, 'aggregate': None,
@ -53,7 +57,7 @@ class Search:
def execute(self, **options): def execute(self, **options):
"""Execute the search and return results.""" """Execute the search and return results."""
options["content_type"] = 'application/json' options["content_type"] = 'application/json'
uri = ['resources', 'search'] uri = [self._endpoint, 'search']
return call_json_api('post', uri, self.as_dict(), **options) return call_json_api('post', uri, self.as_dict(), **options)
def _add(self, name, value): def _add(self, name, value):
@ -72,3 +76,7 @@ class Search:
to_return[key] = value to_return[key] = value
return to_return return to_return
def endpoint(self, endpoint):
self._endpoint = endpoint
return self

View file

@ -0,0 +1,10 @@
from cloudinary import Search
class SearchFolders(Search):
FOLDERS = 'folders'
def __init__(self):
super(SearchFolders, self).__init__()
self.endpoint(self.FOLDERS)

View file

@ -168,7 +168,8 @@ def update_metadata(metadata, public_ids, **options):
"timestamp": utils.now(), "timestamp": utils.now(),
"metadata": utils.encode_context(metadata), "metadata": utils.encode_context(metadata),
"public_ids": utils.build_array(public_ids), "public_ids": utils.build_array(public_ids),
"type": options.get("type") "type": options.get("type"),
"clear_invalid": options.get("clear_invalid")
} }
return call_api("metadata", params, **options) return call_api("metadata", params, **options)

View file

@ -78,6 +78,7 @@ __SIMPLE_UPLOAD_PARAMS = [
"backup", "backup",
"faces", "faces",
"image_metadata", "image_metadata",
"media_metadata",
"exif", "exif",
"colors", "colors",
"use_filename", "use_filename",
@ -1052,7 +1053,8 @@ def build_custom_headers(headers):
def build_upload_params(**options): def build_upload_params(**options):
params = {param_name: options.get(param_name) for param_name in __SIMPLE_UPLOAD_PARAMS} params = {param_name: options.get(param_name) for param_name in __SIMPLE_UPLOAD_PARAMS if param_name in options}
params["upload_preset"] = params.pop("upload_preset", cloudinary.config().upload_preset)
serialized_params = { serialized_params = {
"timestamp": now(), "timestamp": now(),
@ -1577,3 +1579,19 @@ def unique(collection, key=None):
to_return[key(element)] = element to_return[key(element)] = element
return list(to_return.values()) return list(to_return.values())
def fq_public_id(public_id, resource_type="image", type="upload"):
"""
Returns the fully qualified public id of form resource_type/type/public_id.
:param public_id: The public ID of the asset.
:type public_id: str
:param resource_type: The type of the asset. Defaults to "image".
:type resource_type: str
:param type: The upload type. Defaults to "upload".
:type type: str
:return:
"""
return "{resource_type}/{type}/{public_id}".format(resource_type=resource_type, type=type, public_id=public_id)

View file

@ -18,49 +18,52 @@
"""dnspython DNS toolkit""" """dnspython DNS toolkit"""
__all__ = [ __all__ = [
'asyncbackend', "asyncbackend",
'asyncquery', "asyncquery",
'asyncresolver', "asyncresolver",
'dnssec', "dnssec",
'e164', "dnssectypes",
'edns', "e164",
'entropy', "edns",
'exception', "entropy",
'flags', "exception",
'immutable', "flags",
'inet', "immutable",
'ipv4', "inet",
'ipv6', "ipv4",
'message', "ipv6",
'name', "message",
'namedict', "name",
'node', "namedict",
'opcode', "node",
'query', "opcode",
'rcode', "query",
'rdata', "quic",
'rdataclass', "rcode",
'rdataset', "rdata",
'rdatatype', "rdataclass",
'renderer', "rdataset",
'resolver', "rdatatype",
'reversename', "renderer",
'rrset', "resolver",
'serial', "reversename",
'set', "rrset",
'tokenizer', "serial",
'transaction', "set",
'tsig', "tokenizer",
'tsigkeyring', "transaction",
'ttl', "tsig",
'rdtypes', "tsigkeyring",
'update', "ttl",
'version', "rdtypes",
'versioned', "update",
'wire', "version",
'xfr', "versioned",
'zone', "wire",
'zonefile', "xfr",
"zone",
"zonetypes",
"zonefile",
] ]
from dns.version import version as __version__ # noqa from dns.version import version as __version__ # noqa

View file

@ -3,6 +3,7 @@
# This is a nullcontext for both sync and async. 3.7 has a nullcontext, # This is a nullcontext for both sync and async. 3.7 has a nullcontext,
# but it is only for sync use. # but it is only for sync use.
class NullContext: class NullContext:
def __init__(self, enter_result=None): def __init__(self, enter_result=None):
self.enter_result = enter_result self.enter_result = enter_result
@ -23,6 +24,7 @@ class NullContext:
# These are declared here so backends can import them without creating # These are declared here so backends can import them without creating
# circular dependencies with dns.asyncbackend. # circular dependencies with dns.asyncbackend.
class Socket: # pragma: no cover class Socket: # pragma: no cover
async def close(self): async def close(self):
pass pass
@ -41,6 +43,9 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
raise NotImplementedError raise NotImplementedError
@ -56,14 +61,25 @@ class StreamSocket(Socket): # pragma: no cover
raise NotImplementedError raise NotImplementedError
class Backend: # pragma: no cover class Backend: # pragma: no cover
def name(self): def name(self):
return 'unknown' return "unknown"
async def make_socket(self, af, socktype, proto=0, async def make_socket(
source=None, destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
raise NotImplementedError raise NotImplementedError
def datagram_connection_required(self): def datagram_connection_required(self):
return False return False
async def sleep(self, interval):
raise NotImplementedError

View file

@ -10,7 +10,8 @@ import dns._asyncbackend
import dns.exception import dns.exception
_is_win32 = sys.platform == 'win32' _is_win32 = sys.platform == "win32"
def _get_running_loop(): def _get_running_loop():
try: try:
@ -30,7 +31,6 @@ class _DatagramProtocol:
def datagram_received(self, data, addr): def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done(): if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr)) self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done(): if self.recvfrom and not self.recvfrom.done():
@ -56,30 +56,34 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol): def __init__(self, family, transport, protocol):
self.family = family super().__init__(family)
self.transport = transport self.transport = transport
self.protocol = protocol self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto # no timeout for asyncio sendto
self.transport.sendto(what, destination) self.transport.sendto(what, destination)
return len(what)
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it # ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future() done = _get_running_loop().create_future()
assert self.protocol.recvfrom is None try:
self.protocol.recvfrom = done assert self.protocol.recvfrom is None
await _maybe_wait_for(done, timeout) self.protocol.recvfrom = done
return done.result() await _maybe_wait_for(done, timeout)
return done.result()
finally:
self.protocol.recvfrom = None
async def close(self): async def close(self):
self.protocol.close() self.protocol.close()
async def getpeername(self): async def getpeername(self):
return self.transport.get_extra_info('peername') return self.transport.get_extra_info("peername")
async def getsockname(self): async def getsockname(self):
return self.transport.get_extra_info('sockname') return self.transport.get_extra_info("sockname")
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
@ -93,8 +97,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
return await _maybe_wait_for(self.writer.drain(), timeout) return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout): async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size), return await _maybe_wait_for(self.reader.read(size), timeout)
timeout)
async def close(self): async def close(self):
self.writer.close() self.writer.close()
@ -104,43 +107,64 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
pass pass
async def getpeername(self): async def getpeername(self):
return self.writer.get_extra_info('peername') return self.writer.get_extra_info("peername")
async def getsockname(self): async def getsockname(self):
return self.writer.get_extra_info('sockname') return self.writer.get_extra_info("sockname")
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
return 'asyncio' return "asyncio"
async def make_socket(self, af, socktype, proto=0, async def make_socket(
source=None, destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
if destination is None and socktype == socket.SOCK_DGRAM and \ socktype,
_is_win32: proto=0,
raise NotImplementedError('destinationless datagram sockets ' source=None,
'are not supported by asyncio ' destination=None,
'on Windows') timeout=None,
ssl_context=None,
server_hostname=None,
):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop() loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM: if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af, _DatagramProtocol,
proto=proto, remote_addr=destination) source,
family=af,
proto=proto,
remote_addr=destination,
)
return DatagramSocket(af, transport, protocol) return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM: elif socktype == socket.SOCK_STREAM:
if destination is None:
# This shouldn't happen, but we check to make code analysis software
# happier.
raise ValueError("destination required for stream sockets")
(r, w) = await _maybe_wait_for( (r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0], asyncio.open_connection(
destination[1], destination[0],
ssl=ssl_context, destination[1],
family=af, ssl=ssl_context,
proto=proto, family=af,
local_addr=source, proto=proto,
server_hostname=server_hostname), local_addr=source,
timeout) server_hostname=server_hostname,
),
timeout,
)
return StreamSocket(af, r, w) return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' + raise NotImplementedError(
f'type {socktype}') # pragma: no cover "unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
await asyncio.sleep(interval) await asyncio.sleep(interval)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination) return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size) return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
await self.socket.close() await self.socket.close()
@ -57,12 +59,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout): async def sendall(self, what, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.sendall(what) return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout): async def recv(self, size, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.recv(size) return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
await self.socket.close() await self.socket.close()
@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
return 'curio' return "curio"
async def make_socket(self, af, socktype, proto=0, async def make_socket(
source=None, destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if socktype == socket.SOCK_DGRAM: if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto) s = curio.socket.socket(af, socktype, proto)
try: try:
@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend):
else: else:
source_addr = None source_addr = None
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
s = await curio.open_connection(destination[0], destination[1], s = await curio.open_connection(
ssl=ssl_context, destination[0],
source_addr=source_addr, destination[1],
server_hostname=server_hostname) ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname,
)
return StreamSocket(s) return StreamSocket(s)
raise NotImplementedError('unsupported socket ' + raise NotImplementedError(
f'type {socktype}') # pragma: no cover "unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
await curio.sleep(interval) await curio.sleep(interval)

View file

@ -1,84 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator is for python 3.6,
# which doesn't have Context Variables. This implementation is somewhat
# costly for classes with slots, as it adds a __dict__ to them.
import inspect
class _Immutable:
"""Immutable mixin class"""
# Note we MUST NOT have __slots__ as that causes
#
# TypeError: multiple bases have instance lay-out conflict
#
# when we get mixed in with another class with slots. When we
# get mixed into something with slots, it effectively adds __dict__ to
# the slots of the other class, which allows attribute setting to work,
# albeit at the cost of the dictionary.
def __setattr__(self, name, value):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
try:
# Are we already initializing an immutable class?
previous = args[0]._immutable_init
except AttributeError:
# We are the first!
previous = None
object.__setattr__(args[0], '_immutable_init', args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
if not previous:
# If we started the initialization, establish immutability
# by removing the attribute that allows mutation
object.__delattr__(args[0], '_immutable_init')
nf.__signature__ = inspect.signature(f)
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View file

@ -8,7 +8,7 @@ import contextvars
import inspect import inspect
_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False) _in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable: class _Immutable:
@ -41,6 +41,7 @@ def _immutable_init(f):
f(*args, **kwargs) f(*args, **kwargs)
finally: finally:
_in__init__.reset(previous) _in__init__.reset(previous)
nf.__signature__ = inspect.signature(f) nf.__signature__ = inspect.signature(f)
return nf return nf
@ -50,7 +51,7 @@ def immutable(cls):
# Some ancestor already has the mixin, so just make sure we keep # Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol. # following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__) cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'): if hasattr(cls, "__setstate__"):
cls.__setstate__ = _immutable_init(cls.__setstate__) cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls ncls = cls
else: else:
@ -63,7 +64,8 @@ def immutable(cls):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'): if hasattr(cls, "__setstate__"):
@_immutable_init @_immutable_init
def __setstate__(self, *args, **kwargs): def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs) super().__setstate__(*args, **kwargs)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination) return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.socket.recvfrom(size) return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
self.socket.close() self.socket.close()
@ -58,12 +60,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout): async def sendall(self, what, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.stream.send_all(what) return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout): async def recv(self, size, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.stream.receive_some(size) return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
await self.stream.aclose() await self.stream.aclose()
@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
return 'trio' return "trio"
async def make_socket(self, af, socktype, proto=0, source=None, async def make_socket(
destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto) s = trio.socket.socket(af, socktype, proto)
stream = None stream = None
try: try:
@ -103,19 +113,20 @@ class Backend(dns._asyncbackend.Backend):
return DatagramSocket(s) return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM: elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s) stream = trio.SocketStream(s)
s = None
tls = False tls = False
if ssl_context: if ssl_context:
tls = True tls = True
try: try:
stream = trio.SSLStream(stream, ssl_context, stream = trio.SSLStream(
server_hostname=server_hostname) stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover except Exception: # pragma: no cover
await stream.aclose() await stream.aclose()
raise raise
return StreamSocket(af, stream, tls) return StreamSocket(af, stream, tls)
raise NotImplementedError('unsupported socket ' + raise NotImplementedError(
f'type {socktype}') # pragma: no cover "unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
await trio.sleep(interval) await trio.sleep(interval)

View file

@ -1,26 +1,33 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Dict
import dns.exception import dns.exception
# pylint: disable=unused-import # pylint: disable=unused-import
from dns._asyncbackend import Socket, DatagramSocket, \ from dns._asyncbackend import (
StreamSocket, Backend # noqa: Socket,
DatagramSocket,
StreamSocket,
Backend,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import # pylint: enable=unused-import
_default_backend = None _default_backend = None
_backends = {} _backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes # Allow sniffio import to be disabled for testing purposes
_no_sniffio = False _no_sniffio = False
class AsyncLibraryNotFoundError(dns.exception.DNSException): class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass pass
def get_backend(name): def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend. """Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio", *name*, a ``str``, the name of the backend. Currently the "trio",
@ -32,22 +39,25 @@ def get_backend(name):
backend = _backends.get(name) backend = _backends.get(name)
if backend: if backend:
return backend return backend
if name == 'trio': if name == "trio":
import dns._trio_backend import dns._trio_backend
backend = dns._trio_backend.Backend() backend = dns._trio_backend.Backend()
elif name == 'curio': elif name == "curio":
import dns._curio_backend import dns._curio_backend
backend = dns._curio_backend.Backend() backend = dns._curio_backend.Backend()
elif name == 'asyncio': elif name == "asyncio":
import dns._asyncio_backend import dns._asyncio_backend
backend = dns._asyncio_backend.Backend() backend = dns._asyncio_backend.Backend()
else: else:
raise NotImplementedError(f'unimplemented async backend {name}') raise NotImplementedError(f"unimplemented async backend {name}")
_backends[name] = backend _backends[name] = backend
return backend return backend
def sniff(): def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using """Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available. the ``sniffio`` module if it is available.
@ -59,35 +69,32 @@ def sniff():
if _no_sniffio: if _no_sniffio:
raise ImportError raise ImportError
import sniffio import sniffio
try: try:
return sniffio.current_async_library() return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError: except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError('sniffio cannot determine ' + raise AsyncLibraryNotFoundError(
'async library') "sniffio cannot determine " + "async library"
)
except ImportError: except ImportError:
import asyncio import asyncio
try: try:
asyncio.get_running_loop() asyncio.get_running_loop()
return 'asyncio' return "asyncio"
except RuntimeError: except RuntimeError:
raise AsyncLibraryNotFoundError('no async library detected') raise AsyncLibraryNotFoundError("no async library detected")
except AttributeError: # pragma: no cover
# we have to check current_task on 3.6
if not asyncio.Task.current_task():
raise AsyncLibraryNotFoundError('no async library detected')
return 'asyncio'
def get_default_backend(): def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary. """Get the default backend, initializing it if necessary."""
"""
if _default_backend: if _default_backend:
return _default_backend return _default_backend
return set_default_backend(sniff()) return set_default_backend(sniff())
def set_default_backend(name): def set_default_backend(name: str) -> Backend:
"""Set the default backend. """Set the default backend.
It's not normally necessary to call this method, as It's not normally necessary to call this method, as

View file

@ -1,13 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
class Backend:
...
def get_backend(name: str) -> Backend:
...
def sniff() -> str:
...
def get_default_backend() -> Backend:
...
def set_default_backend(name: str) -> Backend:
...

View file

@ -17,7 +17,10 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64 import base64
import contextlib
import socket import socket
import struct import struct
import time import time
@ -27,12 +30,24 @@ import dns.exception
import dns.inet import dns.inet
import dns.name import dns.name
import dns.message import dns.message
import dns.quic
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.transaction
from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ from dns._asyncbackend import NullContext
UDPMode, _have_httpx, _have_http2, NoDOH from dns.query import (
_compute_times,
_matches_destination,
BadResponse,
ssl,
UDPMode,
_have_httpx,
_have_http2,
NoDOH,
NoDOQ,
)
if _have_httpx: if _have_httpx:
import httpx import httpx
@ -47,11 +62,11 @@ def _source_tuple(af, address, port):
if address or port: if address or port:
if address is None: if address is None:
if af == socket.AF_INET: if af == socket.AF_INET:
address = '0.0.0.0' address = "0.0.0.0"
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
address = '::' address = "::"
else: else:
raise NotImplementedError(f'unknown address family {af}') raise NotImplementedError(f"unknown address family {af}")
return (address, port) return (address, port)
else: else:
return None return None
@ -66,7 +81,12 @@ def _timeout(expiration, now=None):
return None return None
async def send_udp(sock, what, destination, expiration=None): async def send_udp(
sock: dns.asyncbackend.DatagramSocket,
what: Union[dns.message.Message, bytes],
destination: Any,
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket. """Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``. *sock*, a ``dns.asyncbackend.DatagramSocket``.
@ -78,7 +98,8 @@ async def send_udp(sock, what, destination, expiration=None):
*expiration*, a ``float`` or ``None``, the absolute time at which *expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will a timeout exception should be raised. If ``None``, no timeout will
occur. occur. The expiration value is meaningless for the asyncio backend, as
asyncio's transport sendto() never blocks.
Returns an ``(int, float)`` tuple of bytes sent and the sent time. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
""" """
@ -90,35 +111,61 @@ async def send_udp(sock, what, destination, expiration=None):
return (n, sent_time) return (n, sent_time)
async def receive_udp(sock, destination=None, expiration=None, async def receive_udp(
ignore_unexpected=False, one_rr_per_rrset=False, sock: dns.asyncbackend.DatagramSocket,
keyring=None, request_mac=b'', ignore_trailing=False, destination: Optional[Any] = None,
raise_on_truncation=False): expiration: Optional[float] = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
) -> Any:
"""Read a DNS message from a UDP socket. """Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``. *sock*, a ``dns.asyncbackend.DatagramSocket``.
See :py:func:`dns.query.receive_udp()` for the documentation of the other See :py:func:`dns.query.receive_udp()` for the documentation of the other
parameters, exceptions, and return type of this method. parameters, and exceptions.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
""" """
wire = b'' wire = b""
while 1: while 1:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if _matches_destination(sock.family, from_address, destination, if _matches_destination(
ignore_unexpected): sock.family, from_address, destination, ignore_unexpected
):
break break
received_time = time.time() received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, r = dns.message.from_wire(
one_rr_per_rrset=one_rr_per_rrset, wire,
ignore_trailing=ignore_trailing, keyring=keyring,
raise_on_truncation=raise_on_truncation) request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
return (r, received_time, from_address) return (r, received_time, from_address)
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, async def udp(
ignore_trailing=False, raise_on_truncation=False, sock=None, q: dns.message.Message,
backend=None): where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP. """Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
@ -134,42 +181,52 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
""" """
wire = q.to_wire() wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
s = None af = dns.inet.af_for_address(where)
# After 3.6 is no longer supported, this can use an AsyncExitStack. destination = _lltuple((where, port), af)
try: if sock:
af = dns.inet.af_for_address(where) cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
destination = _lltuple((where, port), af) else:
if sock: if not backend:
s = sock backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port)
if backend.datagram_connection_required():
dtuple = (where, port)
else: else:
if not backend: dtuple = None
backend = dns.asyncbackend.get_default_backend() cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
stuple = _source_tuple(af, source, source_port) async with cm as s:
if backend.datagram_connection_required():
dtuple = (where, port)
else:
dtuple = None
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
dtuple)
await send_udp(s, wire, destination, expiration) await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration, (r, received_time, _) = await receive_udp(
ignore_unexpected, s,
one_rr_per_rrset, destination,
q.keyring, q.mac, expiration,
ignore_trailing, ignore_unexpected,
raise_on_truncation) one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
raise_on_truncation,
)
r.time = received_time - begin_time r.time = received_time - begin_time
if not q.is_response(r): if not q.is_response(r):
raise BadResponse raise BadResponse
return r return r
finally:
if not sock and s:
await s.close()
async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
source_port=0, ignore_unexpected=False, async def udp_with_fallback(
one_rr_per_rrset=False, ignore_trailing=False, q: dns.message.Message,
udp_sock=None, tcp_sock=None, backend=None): where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back """Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response. to TCP if UDP results in a truncated response.
@ -191,18 +248,42 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
method. method.
""" """
try: try:
response = await udp(q, where, timeout, port, source, source_port, response = await udp(
ignore_unexpected, one_rr_per_rrset, q,
ignore_trailing, True, udp_sock, backend) where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
backend,
)
return (response, False) return (response, False)
except dns.message.Truncated: except dns.message.Truncated:
response = await tcp(q, where, timeout, port, source, source_port, response = await tcp(
one_rr_per_rrset, ignore_trailing, tcp_sock, q,
backend) where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
backend,
)
return (response, True) return (response, True)
async def send_tcp(sock, what, expiration=None): async def send_tcp(
sock: dns.asyncbackend.StreamSocket,
what: Union[dns.message.Message, bytes],
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket. """Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``. *sock*, a ``dns.asyncbackend.StreamSocket``.
@ -212,12 +293,14 @@ async def send_tcp(sock, what, expiration=None):
""" """
if isinstance(what, dns.message.Message): if isinstance(what, dns.message.Message):
what = what.to_wire() wire = what.to_wire()
l = len(what) else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us # copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed # avoid writev() or doing a short write that would get pushed
# onto the net # onto the net
tcpmsg = struct.pack("!H", l) + what tcpmsg = struct.pack("!H", l) + wire
sent_time = time.time() sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time) return (len(tcpmsg), sent_time)
@ -227,18 +310,24 @@ async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we """Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF. either get the desired amount, or we hit EOF.
""" """
s = b'' s = b""
while count > 0: while count > 0:
n = await sock.recv(count, _timeout(expiration)) n = await sock.recv(count, _timeout(expiration))
if n == b'': if n == b"":
raise EOFError raise EOFError
count = count - len(n) count = count - len(n)
s = s + n s = s + n
return s return s
async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, async def receive_tcp(
keyring=None, request_mac=b'', ignore_trailing=False): sock: dns.asyncbackend.StreamSocket,
expiration: Optional[float] = None,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket. """Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``. *sock*, a ``dns.asyncbackend.StreamSocket``.
@ -251,15 +340,28 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
(l,) = struct.unpack("!H", ldata) (l,) = struct.unpack("!H", ldata)
wire = await _read_exactly(sock, l, expiration) wire = await _read_exactly(sock, l, expiration)
received_time = time.time() received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, r = dns.message.from_wire(
one_rr_per_rrset=one_rr_per_rrset, wire,
ignore_trailing=ignore_trailing) keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return (r, received_time) return (r, received_time)
async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, async def tcp(
one_rr_per_rrset=False, ignore_trailing=False, sock=None, q: dns.message.Message,
backend=None): where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP. """Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
@ -276,41 +378,48 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
wire = q.to_wire() wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
s = None if sock:
# After 3.6 is no longer supported, this can use an AsyncExitStack. # Verify that the socket is connected, as if it's not connected,
try: # it's not writable, and the polling in send_tcp() will time out or
if sock: # hang forever.
# Verify that the socket is connected, as if it's not connected, await sock.getpeername()
# it's not writable, and the polling in send_tcp() will time out or cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
# hang forever. else:
await sock.getpeername() # These are simple (address, port) pairs, not family-dependent tuples
s = sock # you pass to low-level socket code.
else: af = dns.inet.af_for_address(where)
# These are simple (address, port) pairs, not stuple = _source_tuple(af, source, source_port)
# family-dependent tuples you pass to lowlevel socket dtuple = (where, port)
# code. if not backend:
af = dns.inet.af_for_address(where) backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port) cm = await backend.make_socket(
dtuple = (where, port) af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
if not backend: )
backend = dns.asyncbackend.get_default_backend() async with cm as s:
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
dtuple, timeout)
await send_tcp(s, wire, expiration) await send_tcp(s, wire, expiration)
(r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, (r, received_time) = await receive_tcp(
q.keyring, q.mac, s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
ignore_trailing) )
r.time = received_time - begin_time r.time = received_time - begin_time
if not q.is_response(r): if not q.is_response(r):
raise BadResponse raise BadResponse
return r return r
finally:
if not sock and s:
await s.close()
async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, sock=None, async def tls(
backend=None, ssl_context=None, server_hostname=None): q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS. """Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
@ -326,11 +435,14 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
See :py:func:`dns.query.tls()` for the documentation of the other See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method. parameters, exceptions, and return type of this method.
""" """
# After 3.6 is no longer supported, this can use an AsyncExitStack.
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
if not sock: if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None: if ssl_context is None:
ssl_context = ssl.create_default_context() # See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None: if server_hostname is None:
ssl_context.check_hostname = False ssl_context.check_hostname = False
else: else:
@ -341,25 +453,49 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
dtuple = (where, port) dtuple = (where, port)
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, cm = await backend.make_socket(
dtuple, timeout, ssl_context, af,
server_hostname) socket.SOCK_STREAM,
else: 0,
s = sock stuple,
try: dtuple,
timeout,
ssl_context,
server_hostname,
)
async with cm as s:
timeout = _timeout(expiration) timeout = _timeout(expiration)
response = await tcp(q, where, timeout, port, source, source_port, response = await tcp(
one_rr_per_rrset, ignore_trailing, s, backend) q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
s,
backend,
)
end_time = time.time() end_time = time.time()
response.time = end_time - begin_time response.time = end_time - begin_time
return response return response
finally:
if not sock and s:
await s.close()
async def https(q, where, timeout=None, port=443, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, client=None, async def https(
path='/dns-query', post=True, verify=True): q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 443,
source: Optional[str] = None,
source_port: int = 0, # pylint: disable=W0613
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
client: Optional["httpx.AsyncClient"] = None,
path: str = "/dns-query",
post: bool = True,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for *client*, a ``httpx.AsyncClient``. If provided, the client to use for
@ -373,7 +509,7 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
""" """
if not _have_httpx: if not _have_httpx:
raise NoDOH('httpx is not available.') # pragma: no cover raise NoDOH("httpx is not available.") # pragma: no cover
wire = q.to_wire() wire = q.to_wire()
try: try:
@ -381,65 +517,78 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
except ValueError: except ValueError:
af = None af = None
transport = None transport = None
headers = { headers = {"accept": "application/dns-message"}
"accept": "application/dns-message"
}
if af is not None: if af is not None:
if af == socket.AF_INET: if af == socket.AF_INET:
url = 'https://{}:{}{}'.format(where, port, path) url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path) url = "https://[{}]:{}{}".format(where, port, path)
else: else:
url = where url = where
if source is not None: if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0]) transport = httpx.AsyncHTTPTransport(local_address=source[0])
# After 3.6 is no longer supported, this can use an AsyncExitStack if client:
client_to_close = None cm: contextlib.AbstractAsyncContextManager = NullContext(client)
try: else:
if not client: cm = httpx.AsyncClient(
client = httpx.AsyncClient(http1=True, http2=_have_http2, http1=True, http2=_have_http2, verify=verify, transport=transport
verify=verify, transport=transport) )
client_to_close = client
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
if post: if post:
headers.update({ headers.update(
"content-type": "application/dns-message", {
"content-length": str(len(wire)) "content-type": "application/dns-message",
}) "content-length": str(len(wire)),
response = await client.post(url, headers=headers, content=wire, }
timeout=timeout) )
response = await the_client.post(
url, headers=headers, content=wire, timeout=timeout
)
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
wire = wire.decode() # httpx does a repr() if we give it bytes twire = wire.decode() # httpx does a repr() if we give it bytes
response = await client.get(url, headers=headers, timeout=timeout, response = await the_client.get(
params={"dns": wire}) url, headers=headers, timeout=timeout, params={"dns": twire}
finally: )
if client_to_close:
await client.aclose()
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
if response.status_code < 200 or response.status_code > 299: if response.status_code < 200 or response.status_code > 299:
raise ValueError('{} responded with status code {}' raise ValueError(
'\nResponse body: {}'.format(where, "{} responded with status code {}"
response.status_code, "\nResponse body: {!r}".format(
response.content)) where, response.status_code, response.content
r = dns.message.from_wire(response.content, )
keyring=q.keyring, )
request_mac=q.request_mac, r = dns.message.from_wire(
one_rr_per_rrset=one_rr_per_rrset, response.content,
ignore_trailing=ignore_trailing) keyring=q.keyring,
r.time = response.elapsed request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = response.elapsed.total_seconds()
if not q.is_response(r): if not q.is_response(r):
raise BadResponse raise BadResponse
return r return r
async def inbound_xfr(where, txn_manager, query=None,
port=53, timeout=None, lifetime=None, source=None, async def inbound_xfr(
source_port=0, udp_mode=UDPMode.NEVER, backend=None): where: str,
txn_manager: dns.transaction.TransactionManager,
query: Optional[dns.message.Message] = None,
port: int = 53,
timeout: Optional[float] = None,
lifetime: Optional[float] = None,
source: Optional[str] = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the """Conduct an inbound transfer and apply it via a transaction from the
txn_manager. txn_manager.
@ -472,42 +621,48 @@ async def inbound_xfr(where, txn_manager, query=None,
is_udp = False is_udp = False
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, sock_type, 0, stuple, dtuple, s = await backend.make_socket(
_timeout(expiration)) af, sock_type, 0, stuple, dtuple, _timeout(expiration)
)
async with s: async with s:
if is_udp: if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration)) await s.sendto(wire, dtuple, _timeout(expiration))
else: else:
tcpmsg = struct.pack("!H", len(wire)) + wire tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration) await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
is_udp) as inbound:
done = False done = False
tsig_ctx = None tsig_ctx = None
while not done: while not done:
(_, mexpiration) = _compute_times(timeout) (_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \ if mexpiration is None or (
(expiration is not None and mexpiration > expiration): expiration is not None and mexpiration > expiration
):
mexpiration = expiration mexpiration = expiration
if is_udp: if is_udp:
destination = _lltuple((where, port), af) destination = _lltuple((where, port), af)
while True: while True:
timeout = _timeout(mexpiration) timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535, (rwire, from_address) = await s.recvfrom(65535, timeout)
timeout) if _matches_destination(
if _matches_destination(af, from_address, af, from_address, destination, True
destination, True): ):
break break
else: else:
ldata = await _read_exactly(s, 2, mexpiration) ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata) (l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration) rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = (rdtype == dns.rdatatype.IXFR) is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(rwire, keyring=query.keyring, r = dns.message.from_wire(
request_mac=query.mac, xfr=True, rwire,
origin=origin, tsig_ctx=tsig_ctx, keyring=query.keyring,
multi=(not is_udp), request_mac=query.mac,
one_rr_per_rrset=is_ixfr) xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try: try:
done = inbound.process_message(r) done = inbound.process_message(r)
except dns.xfr.UseTCP: except dns.xfr.UseTCP:
@ -521,3 +676,62 @@ async def inbound_xfr(where, txn_manager, query=None,
tsig_ctx = r.tsig_ctx tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig: if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG") raise dns.exception.FormError("missing TSIG")
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(context, verify_mode=verify) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
start = time.time()
stream = await the_connection.make_stream()
async with stream:
await stream.send(wire, True)
wire = await stream.receive(timeout)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r

View file

@ -1,43 +0,0 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
async def udp(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.DatagramSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tls(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

View file

@ -17,13 +17,18 @@
"""Asynchronous DNS stub resolver.""" """Asynchronous DNS stub resolver."""
from typing import Any, Dict, Optional, Union
import time import time
import dns.asyncbackend import dns.asyncbackend
import dns.asyncquery import dns.asyncquery
import dns.exception import dns.exception
import dns.name
import dns.query import dns.query
import dns.resolver import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
# import some resolver symbols for brevity # import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
@ -37,11 +42,19 @@ _tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver): class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver.""" """Asynchronous DNS stub resolver."""
async def resolve(self, qname, rdtype=dns.rdatatype.A, async def resolve(
rdclass=dns.rdataclass.IN, self,
tcp=False, source=None, raise_on_no_answer=True, qname: Union[dns.name.Name, str],
source_port=0, lifetime=None, search=None, rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
backend=None): rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question. """Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
@ -52,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver):
type of this method. type of this method.
""" """
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, resolution = dns.resolver._Resolution(
raise_on_no_answer, search) self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
)
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
start = time.time() start = time.time()
@ -66,30 +80,40 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None: if answer is not None:
# cache hit! # cache hit!
return answer return answer
assert request is not None # needed for type checking
done = False done = False
while not done: while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver() (nameserver, port, tcp, backoff) = resolution.next_nameserver()
if backoff: if backoff:
await backend.sleep(backoff) await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, timeout = self._compute_timeout(start, lifetime, resolution.errors)
resolution.errors)
try: try:
if dns.inet.is_address(nameserver): if dns.inet.is_address(nameserver):
if tcp: if tcp:
response = await _tcp(request, nameserver, response = await _tcp(
timeout, port, request,
source, source_port, nameserver,
backend=backend) timeout,
port,
source,
source_port,
backend=backend,
)
else: else:
response = await _udp(request, nameserver, response = await _udp(
timeout, port, request,
source, source_port, nameserver,
raise_on_truncation=True, timeout,
backend=backend) port,
source,
source_port,
raise_on_truncation=True,
backend=backend,
)
else: else:
response = await dns.asyncquery.https(request, response = await dns.asyncquery.https(
nameserver, request, nameserver, timeout=timeout
timeout=timeout) )
except Exception as ex: except Exception as ex:
(_, done) = resolution.query_result(None, ex) (_, done) = resolution.query_result(None, ex)
continue continue
@ -101,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None: if answer is not None:
return answer return answer
async def resolve_address(self, ipaddr, *args, **kwargs): async def resolve_address(
self, ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR """Use an asynchronous resolver to run a reverse query for PTR
records. records.
@ -116,15 +142,20 @@ class Resolver(dns.resolver.BaseResolver):
function. function.
""" """
# We make a modified kwargs for type checking happiness, as otherwise
return await self.resolve(dns.reversename.from_address(ipaddr), # we get a legit warning about possibly having rdtype and rdclass
rdtype=dns.rdatatype.PTR, # in the kwargs more than once.
rdclass=dns.rdataclass.IN, modified_kwargs: Dict[str, Any] = {}
*args, **kwargs) modified_kwargs.update(kwargs)
modified_kwargs["rdtype"] = dns.rdatatype.PTR
modified_kwargs["rdclass"] = dns.rdataclass.IN
return await self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
)
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
async def canonical_name(self, name): async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries The canonical name is the name the resolver uses for queries
@ -149,14 +180,15 @@ class Resolver(dns.resolver.BaseResolver):
default_resolver = None default_resolver = None
def get_default_resolver(): def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary.""" """Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None: if default_resolver is None:
reset_default_resolver() reset_default_resolver()
assert default_resolver is not None
return default_resolver return default_resolver
def reset_default_resolver(): def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver. """Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
@ -167,9 +199,18 @@ def reset_default_resolver():
default_resolver = Resolver() default_resolver = Resolver()
async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, async def resolve(
tcp=False, source=None, raise_on_no_answer=True, qname: Union[dns.name.Name, str],
source_port=0, lifetime=None, search=None, backend=None): rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question. """Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver This is a convenience function that uses the default resolver
@ -179,13 +220,23 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
information on the parameters. information on the parameters.
""" """
return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, return await get_default_resolver().resolve(
source, raise_on_no_answer, qname,
source_port, lifetime, search, rdtype,
backend) rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)
async def resolve_address(ipaddr, *args, **kwargs): async def resolve_address(
ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records. """Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
@ -194,7 +245,8 @@ async def resolve_address(ipaddr, *args, **kwargs):
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def canonical_name(name):
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more See :py:func:`dns.resolver.Resolver.canonical_name` for more
@ -203,8 +255,14 @@ async def canonical_name(name):
return await get_default_resolver().canonical_name(name) return await get_default_resolver().canonical_name(name)
async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
resolver=None, backend=None): async def zone_for_name(
name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
tcp: bool = False,
resolver: Optional[Resolver] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.name.Name:
"""Find the name of the zone which contains the specified name. """Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more See :py:func:`dns.resolver.Resolver.zone_for_name` for more
@ -219,8 +277,10 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
raise NotAbsolute(name) raise NotAbsolute(name)
while True: while True:
try: try:
answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass, answer = await resolver.resolve(
tcp, backend=backend) name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
)
assert answer.rrset is not None
if answer.rrset.name == name: if answer.rrset.name == name:
return name return name
# otherwise we were CNAMEd or DNAMEd and need to look higher # otherwise we were CNAMEd or DNAMEd and need to look higher

View file

@ -1,26 +0,0 @@
from typing import Union, Optional, List, Any, Dict
from . import exception, rdataclass, name, rdatatype, asyncbackend
async def resolve(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...
async def resolve_address(self, ipaddr: str,
*args: Any, **kwargs: Optional[Dict]):
...
class Resolver:
def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
configure : Optional[bool] = True):
self.nameservers : List[str]
async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
rdclass : Union[int,str] = rdataclass.IN,
tcp : bool = False, source : Optional[str] = None,
raise_on_no_answer=True, source_port : int = 0,
lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...

File diff suppressed because it is too large Load diff

View file

@ -1,21 +0,0 @@
from typing import Union, Dict, Tuple, Optional
from . import rdataset, rrset, exception, name, rdtypes, rdata, node
import dns.rdtypes.ANY.DS as DS
import dns.rdtypes.ANY.DNSKEY as DNSKEY
_have_pyca : bool
def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
...
def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
...
class ValidationFailure(exception.DNSException):
...
def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
...
def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
...

71
lib/dns/dnssectypes.py Normal file
View file

@ -0,0 +1,71 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNSSEC-related types."""
# This is a separate file to avoid import circularity between dns.dnssec and
# the implementations of the DS and DNSKEY types.
import dns.enum
class Algorithm(dns.enum.IntEnum):
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
ED25519 = 15
ED448 = 16
INDIRECT = 252
PRIVATEDNS = 253
PRIVATEOID = 254
@classmethod
def _maximum(cls):
return 255
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
@classmethod
def _maximum(cls):
return 255
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
SHA1 = 1
@classmethod
def _maximum(cls):
return 255

View file

@ -17,15 +17,19 @@
"""DNS E.164 helpers.""" """DNS E.164 helpers."""
from typing import Iterable, Optional, Union
import dns.exception import dns.exception
import dns.name import dns.name
import dns.resolver import dns.resolver
#: The public E.164 domain. #: The public E.164 domain.
public_enum_domain = dns.name.from_text('e164.arpa.') public_enum_domain = dns.name.from_text("e164.arpa.")
def from_e164(text, origin=public_enum_domain): def from_e164(
text: str, origin: Optional[dns.name.Name] = public_enum_domain
) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose """Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number. value is the ENUM domain name for that number.
@ -42,10 +46,14 @@ def from_e164(text, origin=public_enum_domain):
parts = [d for d in text if d.isdigit()] parts = [d for d in text if d.isdigit()]
parts.reverse() parts.reverse()
return dns.name.from_text('.'.join(parts), origin=origin) return dns.name.from_text(".".join(parts), origin=origin)
def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): def to_e164(
name: dns.name.Name,
origin: Optional[dns.name.Name] = public_enum_domain,
want_plus_prefix: bool = True,
) -> str:
"""Convert an ENUM domain name into an E.164 number. """Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred Note that dnspython does not have any information about preferred
@ -69,15 +77,19 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
name = name.relativize(origin) name = name.relativize(origin)
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1] dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
if len(dlabels) != len(name.labels): if len(dlabels) != len(name.labels):
raise dns.exception.SyntaxError('non-digit labels in ENUM domain name') raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
dlabels.reverse() dlabels.reverse()
text = b''.join(dlabels) text = b"".join(dlabels)
if want_plus_prefix: if want_plus_prefix:
text = b'+' + text text = b"+" + text
return text.decode() return text.decode()
def query(number, domains, resolver=None): def query(
number: str,
domains: Iterable[Union[dns.name.Name, str]],
resolver: Optional[dns.resolver.Resolver] = None,
) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains. """Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
@ -98,7 +110,7 @@ def query(number, domains, resolver=None):
domain = dns.name.from_text(domain) domain = dns.name.from_text(domain)
qname = dns.e164.from_e164(number, domain) qname = dns.e164.from_e164(number, domain)
try: try:
return resolver.resolve(qname, 'NAPTR') return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e: except dns.resolver.NXDOMAIN as e:
e_nx += e e_nx += e
raise e_nx raise e_nx

View file

@ -1,10 +0,0 @@
from typing import Optional, Iterable
from . import name, resolver
def from_e164(text : str, origin=name.Name(".")) -> name.Name:
...
def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
...
def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
...

View file

@ -17,6 +17,8 @@
"""EDNS Options""" """EDNS Options"""
from typing import Any, Dict, Optional, Union
import math import math
import socket import socket
import struct import struct
@ -24,6 +26,7 @@ import struct
import dns.enum import dns.enum
import dns.inet import dns.inet
import dns.rdata import dns.rdata
import dns.wire
class OptionType(dns.enum.IntEnum): class OptionType(dns.enum.IntEnum):
@ -59,14 +62,14 @@ class Option:
"""Base class for all EDNS option types.""" """Base class for all EDNS option types."""
def __init__(self, otype): def __init__(self, otype: Union[OptionType, str]):
"""Initialize an option. """Initialize an option.
*otype*, an ``int``, is the option type. *otype*, a ``dns.edns.OptionType``, is the option type.
""" """
self.otype = OptionType.make(otype) self.otype = OptionType.make(otype)
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
"""Convert an option to wire format. """Convert an option to wire format.
Returns a ``bytes`` or ``None``. Returns a ``bytes`` or ``None``.
@ -75,10 +78,10 @@ class Option:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type. *otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be *parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length. restructed to the option length.
@ -115,26 +118,22 @@ class Option:
return self._cmp(other) != 0 return self._cmp(other) != 0
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) < 0 return self._cmp(other) < 0
def __le__(self, other): def __le__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) <= 0 return self._cmp(other) <= 0
def __ge__(self, other): def __ge__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) >= 0 return self._cmp(other) >= 0
def __gt__(self, other): def __gt__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) > 0 return self._cmp(other) > 0
@ -142,7 +141,7 @@ class Option:
return self.to_text() return self.to_text()
class GenericOption(Option): class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class """Generic Option Class
@ -150,28 +149,31 @@ class GenericOption(Option):
implementation. implementation.
""" """
def __init__(self, otype, data): def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
super().__init__(otype) super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True) self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
if file: if file:
file.write(self.data) file.write(self.data)
return None
else: else:
return self.data return self.data
def to_text(self): def to_text(self) -> str:
return "Generic %d" % self.otype return "Generic %d" % self.otype
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
return cls(otype, parser.get_remaining()) return cls(otype, parser.get_remaining())
class ECSOption(Option): class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)""" """EDNS Client Subnet (ECS, RFC7871)"""
def __init__(self, address, srclen=None, scopelen=0): def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information. """*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the *srclen*, an ``int``, the source prefix length, which is the
@ -200,8 +202,9 @@ class ECSOption(Option):
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32) srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32) scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen) else: # pragma: no cover (this will never happen)
raise ValueError('Bad address family') raise ValueError("Bad address family")
assert srclen is not None
self.address = address self.address = address
self.srclen = srclen self.srclen = srclen
self.scopelen = scopelen self.scopelen = scopelen
@ -214,16 +217,14 @@ class ECSOption(Option):
self.addrdata = addrdata[:nbytes] self.addrdata = addrdata[:nbytes]
nbits = srclen % 8 nbits = srclen % 8
if nbits != 0: if nbits != 0:
last = struct.pack('B', last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last self.addrdata = self.addrdata[:-1] + last
def to_text(self): def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen, return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
self.scopelen)
@staticmethod @staticmethod
def from_text(text): def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption` """Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option. *text*, a `str`, the text form of the option.
@ -246,7 +247,7 @@ class ECSOption(Option):
>>> # it understands results from `dns.edns.ECSOption.to_text()` >>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') >>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
""" """
optional_prefix = 'ECS' optional_prefix = "ECS"
tokens = text.split() tokens = text.split()
ecs_text = None ecs_text = None
if len(tokens) == 1: if len(tokens) == 1:
@ -257,47 +258,53 @@ class ECSOption(Option):
ecs_text = tokens[1] ecs_text = tokens[1]
else: else:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError('could not parse ECS from "{}"'.format(text))
n_slashes = ecs_text.count('/') n_slashes = ecs_text.count("/")
if n_slashes == 1: if n_slashes == 1:
address, srclen = ecs_text.split('/') address, tsrclen = ecs_text.split("/")
scope = 0 tscope = "0"
elif n_slashes == 2: elif n_slashes == 2:
address, srclen, scope = ecs_text.split('/') address, tsrclen, tscope = ecs_text.split("/")
else: else:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError('could not parse ECS from "{}"'.format(text))
try: try:
scope = int(scope) scope = int(tscope)
except ValueError: except ValueError:
raise ValueError('invalid scope ' + raise ValueError(
'"{}": scope must be an integer'.format(scope)) "invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
try: try:
srclen = int(srclen) srclen = int(tsrclen)
except ValueError: except ValueError:
raise ValueError('invalid srclen ' + raise ValueError(
'"{}": srclen must be an integer'.format(srclen)) "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
)
return ECSOption(address, srclen, scope) return ECSOption(address, srclen, scope)
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + value = (
self.addrdata) struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
)
if file: if file:
file.write(value) file.write(value)
return None
else: else:
return value return value
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(
family, src, scope = parser.get_struct('!HBB') cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
family, src, scope = parser.get_struct("!HBB")
addrlen = int(math.ceil(src / 8.0)) addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen) prefix = parser.get_bytes(addrlen)
if family == 1: if family == 1:
pad = 4 - addrlen pad = 4 - addrlen
addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad) addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
elif family == 2: elif family == 2:
pad = 16 - addrlen pad = 16 - addrlen
addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad) addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
else: else:
raise ValueError('unsupported family') raise ValueError("unsupported family")
return cls(addr, src, scope) return cls(addr, src, scope)
@ -334,10 +341,10 @@ class EDECode(dns.enum.IntEnum):
return 65535 return 65535
class EDEOption(Option): class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)""" """Extended DNS Error (EDE, RFC8914)"""
def __init__(self, code, text=None): def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error. extended error.
@ -349,49 +356,50 @@ class EDEOption(Option):
self.code = EDECode.make(code) self.code = EDECode.make(code)
if text is not None and not isinstance(text, str): if text is not None and not isinstance(text, str):
raise ValueError('text must be string or None') raise ValueError("text must be string or None")
self.code = code
self.text = text self.text = text
def to_text(self): def to_text(self) -> str:
output = f'EDE {self.code}' output = f"EDE {self.code}"
if self.text is not None: if self.text is not None:
output += f': {self.text}' output += f": {self.text}"
return output return output
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = struct.pack('!H', self.code) value = struct.pack("!H", self.code)
if self.text is not None: if self.text is not None:
value += self.text.encode('utf8') value += self.text.encode("utf8")
if file: if file:
file.write(value) file.write(value)
return None
else: else:
return value return value
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(
code = parser.get_uint16() cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
the_code = EDECode.make(parser.get_uint16())
text = parser.get_remaining() text = parser.get_remaining()
if text: if text:
if text[-1] == 0: # text MAY be null-terminated if text[-1] == 0: # text MAY be null-terminated
text = text[:-1] text = text[:-1]
text = text.decode('utf8') btext = text.decode("utf8")
else: else:
text = None btext = None
return cls(code, text) return cls(the_code, btext)
_type_to_class = { _type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption, OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption, OptionType.EDE: EDEOption,
} }
def get_option_class(otype): def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type. """Return the class for the specified option type.
The GenericOption class is used if a more specific class is not The GenericOption class is used if a more specific class is not
@ -404,7 +412,9 @@ def get_option_class(otype):
return cls return cls
def option_from_wire_parser(otype, parser): def option_from_wire_parser(
otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type. *otype*, an ``int``, is the option type.
@ -414,12 +424,14 @@ def option_from_wire_parser(otype, parser):
Returns an instance of a subclass of ``dns.edns.Option``. Returns an instance of a subclass of ``dns.edns.Option``.
""" """
cls = get_option_class(otype) the_otype = OptionType.make(otype)
otype = OptionType.make(otype) cls = get_option_class(the_otype)
return cls.from_wire_parser(otype, parser) return cls.from_wire_parser(otype, parser)
def option_from_wire(otype, wire, current, olen): def option_from_wire(
otype: Union[OptionType, str], wire: bytes, current: int, olen: int
) -> Option:
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type. *otype*, an ``int``, is the option type.
@ -437,7 +449,8 @@ def option_from_wire(otype, wire, current, olen):
with parser.restrict_to(olen): with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser) return option_from_wire_parser(otype, parser)
def register_type(implementation, otype):
def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type. """Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``. *implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
@ -447,6 +460,7 @@ def register_type(implementation, otype):
_type_to_class[otype] = implementation _type_to_class[otype] = implementation
### BEGIN generated OptionType constants ### BEGIN generated OptionType constants
NSID = OptionType.NSID NSID = OptionType.NSID

View file

@ -15,14 +15,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from typing import Any, Optional
import os import os
import hashlib import hashlib
import random import random
import threading
import time import time
try:
import threading as _threading
except ImportError: # pragma: no cover
import dummy_threading as _threading # type: ignore
class EntropyPool: class EntropyPool:
@ -32,51 +31,51 @@ class EntropyPool:
# leaving this code doesn't hurt anything as the library code # leaving this code doesn't hurt anything as the library code
# is used if present. # is used if present.
def __init__(self, seed=None): def __init__(self, seed: Optional[bytes] = None):
self.pool_index = 0 self.pool_index = 0
self.digest = None self.digest: Optional[bytearray] = None
self.next_byte = 0 self.next_byte = 0
self.lock = _threading.Lock() self.lock = threading.Lock()
self.hash = hashlib.sha1() self.hash = hashlib.sha1()
self.hash_len = 20 self.hash_len = 20
self.pool = bytearray(b'\0' * self.hash_len) self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None: if seed is not None:
self._stir(bytearray(seed)) self._stir(seed)
self.seeded = True self.seeded = True
self.seed_pid = os.getpid() self.seed_pid = os.getpid()
else: else:
self.seeded = False self.seeded = False
self.seed_pid = 0 self.seed_pid = 0
def _stir(self, entropy): def _stir(self, entropy: bytes) -> None:
for c in entropy: for c in entropy:
if self.pool_index == self.hash_len: if self.pool_index == self.hash_len:
self.pool_index = 0 self.pool_index = 0
b = c & 0xff b = c & 0xFF
self.pool[self.pool_index] ^= b self.pool[self.pool_index] ^= b
self.pool_index += 1 self.pool_index += 1
def stir(self, entropy): def stir(self, entropy: bytes) -> None:
with self.lock: with self.lock:
self._stir(entropy) self._stir(entropy)
def _maybe_seed(self): def _maybe_seed(self) -> None:
if not self.seeded or self.seed_pid != os.getpid(): if not self.seeded or self.seed_pid != os.getpid():
try: try:
seed = os.urandom(16) seed = os.urandom(16)
except Exception: # pragma: no cover except Exception: # pragma: no cover
try: try:
with open('/dev/urandom', 'rb', 0) as r: with open("/dev/urandom", "rb", 0) as r:
seed = r.read(16) seed = r.read(16)
except Exception: except Exception:
seed = str(time.time()) seed = str(time.time()).encode()
self.seeded = True self.seeded = True
self.seed_pid = os.getpid() self.seed_pid = os.getpid()
self.digest = None self.digest = None
seed = bytearray(seed) seed = bytearray(seed)
self._stir(seed) self._stir(seed)
def random_8(self): def random_8(self) -> int:
with self.lock: with self.lock:
self._maybe_seed() self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len: if self.digest is None or self.next_byte == self.hash_len:
@ -88,16 +87,16 @@ class EntropyPool:
self.next_byte += 1 self.next_byte += 1
return value return value
def random_16(self): def random_16(self) -> int:
return self.random_8() * 256 + self.random_8() return self.random_8() * 256 + self.random_8()
def random_32(self): def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16() return self.random_16() * 65536 + self.random_16()
def random_between(self, first, last): def random_between(self, first: int, last: int) -> int:
size = last - first + 1 size = last - first + 1
if size > 4294967296: if size > 4294967296:
raise ValueError('too big') raise ValueError("too big")
if size > 65536: if size > 65536:
rand = self.random_32 rand = self.random_32
max = 4294967295 max = 4294967295
@ -109,20 +108,24 @@ class EntropyPool:
max = 255 max = 255
return first + size * rand() // (max + 1) return first + size * rand() // (max + 1)
pool = EntropyPool() pool = EntropyPool()
system_random: Optional[Any]
try: try:
system_random = random.SystemRandom() system_random = random.SystemRandom()
except Exception: # pragma: no cover except Exception: # pragma: no cover
system_random = None system_random = None
def random_16():
def random_16() -> int:
if system_random is not None: if system_random is not None:
return system_random.randrange(0, 65536) return system_random.randrange(0, 65536)
else: else:
return pool.random_16() return pool.random_16()
def between(first, last):
def between(first: int, last: int) -> int:
if system_random is not None: if system_random is not None:
return system_random.randrange(first, last + 1) return system_random.randrange(first, last + 1)
else: else:

View file

@ -1,10 +0,0 @@
from typing import Optional
from random import SystemRandom
system_random : Optional[SystemRandom]
def random_16() -> int:
pass
def between(first: int, last: int) -> int:
pass

View file

@ -17,6 +17,7 @@
import enum import enum
class IntEnum(enum.IntEnum): class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _check_value(cls, value): def _check_value(cls, value):
@ -32,9 +33,12 @@ class IntEnum(enum.IntEnum):
return cls[text] return cls[text]
except KeyError: except KeyError:
pass pass
value = cls._extra_from_text(text)
if value:
return value
prefix = cls._prefix() prefix = cls._prefix()
if text.startswith(prefix) and text[len(prefix):].isdigit(): if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix):]) value = int(text[len(prefix) :])
cls._check_value(value) cls._check_value(value)
try: try:
return cls(value) return cls(value)
@ -46,9 +50,13 @@ class IntEnum(enum.IntEnum):
def to_text(cls, value): def to_text(cls, value):
cls._check_value(value) cls._check_value(value)
try: try:
return cls(value).name text = cls(value).name
except ValueError: except ValueError:
return f"{cls._prefix()}{value}" text = None
text = cls._extra_to_text(value, text)
if text is None:
text = f"{cls._prefix()}{value}"
return text
@classmethod @classmethod
def make(cls, value): def make(cls, value):
@ -83,7 +91,15 @@ class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _prefix(cls): def _prefix(cls):
return '' return ""
@classmethod
def _extra_from_text(cls, text): # pylint: disable=W0613
return None
@classmethod
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
return current_text
@classmethod @classmethod
def _unknown_exception_class(cls): def _unknown_exception_class(cls):

View file

@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``. always be subclasses of ``DNSException``.
""" """
from typing import Optional, Set
class DNSException(Exception): class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions. """Abstract base class shared by all dnspython exceptions.
@ -44,14 +48,15 @@ class DNSException(Exception):
and ``fmt`` class variables to get nice parametrized messages. and ``fmt`` class variables to get nice parametrized messages.
""" """
msg = None # non-parametrized message msg: Optional[str] = None # non-parametrized message
supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check) supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt = None # message parametrized with results from _fmt_kwargs fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs) self._check_params(*args, **kwargs)
if kwargs: if kwargs:
self.kwargs = self._check_kwargs(**kwargs) # This call to a virtual method from __init__ is ok in our usage
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self) self.msg = str(self)
else: else:
self.kwargs = dict() # defined but empty for old mode exceptions self.kwargs = dict() # defined but empty for old mode exceptions
@ -68,14 +73,15 @@ class DNSException(Exception):
For sanity we do not allow to mix old and new behavior.""" For sanity we do not allow to mix old and new behavior."""
if args or kwargs: if args or kwargs:
assert bool(args) != bool(kwargs), \ assert bool(args) != bool(
'keyword arguments are mutually exclusive with positional args' kwargs
), "keyword arguments are mutually exclusive with positional args"
def _check_kwargs(self, **kwargs): def _check_kwargs(self, **kwargs):
if kwargs: if kwargs:
assert set(kwargs.keys()) == self.supp_kwargs, \ assert (
'following set of keyword args is required: %s' % ( set(kwargs.keys()) == self.supp_kwargs
self.supp_kwargs) ), "following set of keyword args is required: %s" % (self.supp_kwargs)
return kwargs return kwargs
def _fmt_kwargs(self, **kwargs): def _fmt_kwargs(self, **kwargs):
@ -124,9 +130,15 @@ class TooBig(DNSException):
class Timeout(DNSException): class Timeout(DNSException):
"""The DNS operation timed out.""" """The DNS operation timed out."""
supp_kwargs = {'timeout'}
supp_kwargs = {"timeout"}
fmt = "The DNS operation timed out after {timeout:.3f} seconds" fmt = "The DNS operation timed out after {timeout:.3f} seconds"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ExceptionWrapper: class ExceptionWrapper:
def __init__(self, exception_class): def __init__(self, exception_class):
@ -136,7 +148,6 @@ class ExceptionWrapper:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val, if exc_type is not None and not isinstance(exc_val, self.exception_class):
self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val raise self.exception_class(str(exc_val)) from exc_val
return False return False

View file

@ -1,12 +0,0 @@
from typing import Set, Optional, Dict
class DNSException(Exception):
supp_kwargs : Set[str]
kwargs : Optional[Dict]
fmt : Optional[str]
class SyntaxError(DNSException): ...
class FormError(DNSException): ...
class Timeout(DNSException): ...
class TooBig(DNSException): ...
class UnexpectedEnd(SyntaxError): ...

View file

@ -17,10 +17,13 @@
"""DNS Message Flags.""" """DNS Message Flags."""
from typing import Any
import enum import enum
# Standard DNS flags # Standard DNS flags
class Flag(enum.IntFlag): class Flag(enum.IntFlag):
#: Query Response #: Query Response
QR = 0x8000 QR = 0x8000
@ -40,12 +43,13 @@ class Flag(enum.IntFlag):
# EDNS flags # EDNS flags
class EDNSFlag(enum.IntFlag): class EDNSFlag(enum.IntFlag):
#: DNSSEC answer OK #: DNSSEC answer OK
DO = 0x8000 DO = 0x8000
def _from_text(text, enum_class): def _from_text(text: str, enum_class: Any) -> int:
flags = 0 flags = 0
tokens = text.split() tokens = text.split()
for t in tokens: for t in tokens:
@ -53,15 +57,15 @@ def _from_text(text, enum_class):
return flags return flags
def _to_text(flags, enum_class): def _to_text(flags: int, enum_class: Any) -> str:
text_flags = [] text_flags = []
for k, v in enum_class.__members__.items(): for k, v in enum_class.__members__.items():
if flags & v != 0: if flags & v != 0:
text_flags.append(k) text_flags.append(k)
return ' '.join(text_flags) return " ".join(text_flags)
def from_text(text): def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags """Convert a space-separated list of flag text values into a flags
value. value.
@ -71,7 +75,7 @@ def from_text(text):
return _from_text(text, Flag) return _from_text(text, Flag)
def to_text(flags): def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text """Convert a flags value into a space-separated list of flag text
values. values.
@ -81,7 +85,7 @@ def to_text(flags):
return _to_text(flags, Flag) return _to_text(flags, Flag)
def edns_from_text(text): def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS """Convert a space-separated list of EDNS flag text values into a EDNS
flags value. flags value.
@ -91,7 +95,7 @@ def edns_from_text(text):
return _from_text(text, EDNSFlag) return _from_text(text, EDNSFlag)
def edns_to_text(flags): def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag """Convert an EDNS flags value into a space-separated list of EDNS flag
text values. text values.
@ -100,6 +104,7 @@ def edns_to_text(flags):
return _to_text(flags, EDNSFlag) return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants ### BEGIN generated Flag constants
QR = Flag.QR QR = Flag.QR

Some files were not shown because too many files have changed in this diff Show more