Merge branch 'nightly' into concurrent_stream_graph

This commit is contained in:
herby2212 2023-07-11 01:24:59 +02:00
commit 713f8f18e6
335 changed files with 15376 additions and 10025 deletions

4
.github/codeql-config.yml vendored Normal file
View file

@ -0,0 +1,4 @@
name: CodeQL Config
paths-ignore:
- lib

38
.github/workflows/codeql.yml vendored Normal file
View file

@ -0,0 +1,38 @@
name: CodeQL
on:
push:
branches: [nightly]
pull_request:
branches: [nightly]
schedule:
- cron: '05 10 * * 1'
jobs:
codeql-analysis:
name: CodeQL Analysis
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: ['javascript', 'python']
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
config-file: ./.github/codeql-config.yml
languages: ${{ matrix.language }}
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{matrix.language}}"

View file

@ -13,7 +13,7 @@ jobs:
if: ${{ !contains(github.event.head_commit.message, '[skip ci]') }}
steps:
- name: Checkout Code
uses: actions/checkout@v3.2.0
uses: actions/checkout@v3
- name: Prepare
id: prepare
@ -47,7 +47,7 @@ jobs:
version: latest
- name: Cache Docker Layers
uses: actions/cache@v3.2.0
uses: actions/cache@v3
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
@ -70,7 +70,7 @@ jobs:
password: ${{ secrets.GHCR_TOKEN }}
- name: Docker Build and Push
uses: docker/build-push-action@v3
uses: docker/build-push-action@v4
if: success()
with:
context: .

View file

@ -24,7 +24,7 @@ jobs:
steps:
- name: Checkout Code
uses: actions/checkout@v3.2.0
uses: actions/checkout@v3
- name: Set Release Version
id: get_version
@ -52,7 +52,7 @@ jobs:
echo $GITHUB_SHA > version.txt
- name: Set Up Python
uses: actions/setup-python@v4.4.0
uses: actions/setup-python@v4
with:
python-version: '3.9'
cache: pip
@ -119,7 +119,10 @@ jobs:
run: |
CHANGELOG="$( sed -n '/^## /{p; :loop n; p; /^## /q; b loop}' CHANGELOG.md \
| sed '$d' | sed '$d' | sed '$d' )"
echo "CHANGELOG=${CHANGELOG}" >> $GITHUB_OUTPUT
EOF=$(dd if=/dev/urandom bs=15 count=1 status=none | base64)
echo "CHANGELOG<<$EOF" >> $GITHUB_OUTPUT
echo "$CHANGELOG" >> $GITHUB_OUTPUT
echo "$EOF" >> $GITHUB_OUTPUT
- name: Create Release
uses: actions/create-release@v1

View file

@ -20,7 +20,7 @@ jobs:
- armhf
steps:
- name: Checkout Code
uses: actions/checkout@v3.2.0
uses: actions/checkout@v3
- name: Prepare
id: prepare

View file

@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3.2.0
uses: actions/checkout@v3
- name: Comment on Pull Request
uses: mshick/add-pr-comment@v2

View file

@ -9,6 +9,7 @@ jobs:
winget:
name: Submit Winget Package
runs-on: windows-latest
if: ${{ !github.event.release.prerelease }}
steps:
- name: Submit package to Windows Package Manager Community Repository
run: |

3
.gitignore vendored
View file

@ -53,6 +53,9 @@ Thumbs.db
#Ignore files generated by PyCharm
*.idea/*
#Ignore files generated by VSCode
*.vscode/*
#Ignore files generated by vi
*.swp

View file

@ -1,5 +1,78 @@
# Changelog
## v2.12.4 (2023-05-23)
* History:
* Fix: Set view offset equal to duration if a stream is stopped within the last 10 sec.
* Other:
* Fix: Database import may fail for some older databases.
* Fix: Double-quoted strings for newer versions of SQLite. (#2015, #2057)
* API:
* Change: Return the ID for async API calls (export_metadata, notify, notify_newsletter).
## v2.12.3 (2023-04-14)
* Activity:
* Fix: Incorrect subtitle decision shown when subtitles are transcoded.
* History:
* Fix: Incorrect order when sorting by the duration column in the history tables.
* Notifications:
* Fix: Logging error when running scripts that use PlexAPI.
* UI:
* Fix: Calculate file sizes setting causing the media info table to fail to load.
* Fix: Incorrect artwork and thumbnail shown for Live TV on the Most Active Libraries statistics card.
* API:
* Change: Renamed duration to play_duration in the get_history API response. (Note: duration kept for backwards compatibility.)
## v2.12.2 (2023-03-16)
* Other:
* Fix: Tautulli not starting on FreeBSD jails.
## v2.12.1 (2023-03-14)
* Activity:
* Fix: Stop checking for deprecated sync items sessions.
* Change: Do not show audio language on activity cards for music.
* Other:
* Fix: Tautulli not starting on macOS.
## v2.12.0 (2023-03-13)
* Notifications:
* New: Added support for Telegram group topics. (#1980)
* New: Added anidb_id and anidb_url notification parameters. (#1973)
* New: Added notification triggers for Intro Marker, Commercial Marker, and Credits Marker.
* New: Added various intro, commercial, and credits marker notification parameters.
* New: Allow setting a custom Pushover notification sound. (#2005)
* Change: Notification images are now uploaded directly to Discord without the need for a 3rd party image hosting service.
* Change: Automatically strip whitespace from notification condition values.
* Change: Trigger watched notifications based on the video watched completion behaviour setting.
* Exporter:
* Fix: Unable to run exporter when using the Snap package. (#2007)
* New: Added credits marker, and audio/subtitle settings to export fields.
* UI:
* Fix: Incorrect styling and missing content for collection media info pages.
* New: Added edition details field on movie media info pages. (#1957) (Thanks @herby2212)
* New: Added setting to change the video watched completion behaviour.
* New: Added watch time and user statistics to collection and playlist media info pages. (#1982, #2012) (Thanks @herby2212)
* New: Added history table to collection and playlist media info pages.
* New: Dynamically change watched status in the UI based on video watched completion behaviour setting.
* New: Added hidden setting to override server name.
* Change: Move track artist to a details field instead of in the title on track media info pages.
* API:
* New: Added section_id and user_id parameters to get_home_stats API command. (#1944)
* New: Added marker info to get_metadata API command results.
* New: Added media_type parameter to get_item_watch_time_stats and get_item_user_stats API commands. (#1982) (Thanks @herby2212)
* New: Added last_refreshed timestamp to get_library_media_info API command response.
* Other:
* Change: Migrate analytics to Google Analytics 4.
## v2.11.1 (2022-12-22)
* Activity:

View file

@ -57,24 +57,24 @@ Read the [Installation Guides][Installation] for instructions on how to install
[badge-release-nightly-last-commit]: https://img.shields.io/github/last-commit/Tautulli/Tautulli/nightly?style=flat-square&color=blue
[badge-release-nightly-commits]: https://img.shields.io/github/commits-since/Tautulli/Tautulli/latest/nightly?style=flat-square&color=blue
[badge-docker-master]: https://img.shields.io/badge/docker-latest-blue?style=flat-square
[badge-docker-master-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Docker/master?style=flat-square
[badge-docker-master-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-docker.yml?style=flat-square&branch=master
[badge-docker-beta]: https://img.shields.io/badge/docker-beta-blue?style=flat-square
[badge-docker-beta-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Docker/beta?style=flat-square
[badge-docker-beta-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-docker.yml?style=flat-square&branch=beta
[badge-docker-nightly]: https://img.shields.io/badge/docker-nightly-blue?style=flat-square
[badge-docker-nightly-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Docker/nightly?style=flat-square
[badge-docker-nightly-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-docker.yml?style=flat-square&branch=nightly
[badge-snap-master]: https://img.shields.io/badge/snap-stable-blue?style=flat-square
[badge-snap-master-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Snap/master?style=flat-square
[badge-snap-master-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-snap.yml?style=flat-square&branch=master
[badge-snap-beta]: https://img.shields.io/badge/snap-beta-blue?style=flat-square
[badge-snap-beta-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Snap/beta?style=flat-square
[badge-snap-beta-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-snap.yml?style=flat-square&branch=beta
[badge-snap-nightly]: https://img.shields.io/badge/snap-edge-blue?style=flat-square
[badge-snap-nightly-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Snap/nightly?style=flat-square
[badge-snap-nightly-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-snap.yml?style=flat-square&branch=nightly
[badge-installer-master-win]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=windows&style=flat-square
[badge-installer-master-macos]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=macos&style=flat-square
[badge-installer-master-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Installers/master?style=flat-square
[badge-installer-master-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-installers.yml?style=flat-square&branch=master
[badge-installer-beta-win]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=windows&include_prereleases&style=flat-square
[badge-installer-beta-macos]: https://img.shields.io/github/v/release/Tautulli/Tautulli?label=macos&include_prereleases&style=flat-square
[badge-installer-beta-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Installers/beta?style=flat-square
[badge-installer-nightly-ci]: https://img.shields.io/github/workflow/status/Tautulli/Tautulli/Publish%20Installers/nightly?style=flat-square
[badge-installer-beta-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-installers.yml?style=flat-square&branch=beta
[badge-installer-nightly-ci]: https://img.shields.io/github/actions/workflow/status/Tautulli/Tautulli/.github/workflows/publish-installers.yml?style=flat-square&branch=nightly
## Support

File diff suppressed because one or more lines are too long

View file

@ -79,7 +79,6 @@ select.form-control {
color: #eee !important;
border: 0px solid #444 !important;
background: #555 !important;
padding: 1px 2px;
transition: background-color .3s;
}
.selectize-control.form-control .selectize-input {
@ -87,7 +86,6 @@ select.form-control {
align-items: center;
flex-wrap: wrap;
margin-bottom: 4px;
padding-left: 5px;
}
.selectize-control.form-control.selectize-pms-ip .selectize-input {
padding-left: 12px !important;
@ -2916,7 +2914,7 @@ a .home-platforms-list-cover-face:hover
margin-bottom: -20px;
width: 100%;
max-width: 1750px;
overflow: hidden;
display: flow-root;
}
.table-card-back td {
font-size: 12px;

View file

@ -265,12 +265,15 @@ DOCUMENTATION :: END
<div class="sub-heading">Audio</div>
<div class="sub-value" id="audio_decision-${sk}">
% if data['stream_audio_decision']:
<%
audio_language = (data['audio_language'] or 'Unknown') + ' - ' if data['media_type'] != 'track' else ''
%>
% if data['stream_audio_decision'] == 'transcode':
Transcode (${data['audio_language'] or 'Unknown'} - ${AUDIO_CODEC_OVERRIDES.get(data['audio_codec'], data['audio_codec'].upper())} ${data['audio_channel_layout'].split('(')[0].capitalize()} <i class="fa fa-long-arrow-right"></i> ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
Transcode (${audio_language}${AUDIO_CODEC_OVERRIDES.get(data['audio_codec'], data['audio_codec'].upper())} ${data['audio_channel_layout'].split('(')[0].capitalize()} <i class="fa fa-long-arrow-right"></i> ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
% elif data['stream_audio_decision'] == 'copy':
Direct Stream (${data['audio_language'] or 'Unknown'} - ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
Direct Stream (${audio_language}${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
% else:
Direct Play (${data['audio_language'] or 'Unknown'} - ${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
Direct Play (${audio_language}${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())} ${data['stream_audio_channel_layout'].split('(')[0].capitalize()})
% endif
% endif
</div>

View file

@ -1,6 +1,7 @@
<%inherit file="base.html"/>
<%def name="headIncludes()">
<link rel="stylesheet" href="${http_root}css/bootstrap-select.min.css">
<link rel="stylesheet" href="${http_root}css/dataTables.bootstrap.min.css">
<link rel="stylesheet" href="${http_root}css/tautulli-dataTables.css">
</%def>
@ -14,9 +15,7 @@
<div class="button-bar">
<div class="btn-group" id="user-selection">
<label>
<select name="graph-user" id="graph-user" class="btn" style="color: inherit;">
<option value="">All Users</option>
<option disabled>&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;</option>
<select name="graph-user" id="graph-user" multiple>
</select>
</label>
</div>
@ -239,6 +238,7 @@
</%def>
<%def name="javascriptIncludes()">
<script src="${http_root}js/bootstrap-select.min.js"></script>
<script src="${http_root}js/highcharts.min.js"></script>
<script src="${http_root}js/jquery.dataTables.min.js"></script>
<script src="${http_root}js/dataTables.bootstrap.min.js"></script>
@ -379,8 +379,8 @@
//$(current_tab).addClass('active');
$('.days').html(current_day_range);
$('.months').html(current_month_range);
$('.days').text(current_day_range);
$('.months').text(current_month_range);
// Load user ids and names (for the selector)
$.ajax({
@ -388,14 +388,35 @@
type: 'get',
dataType: "json",
success: function (data) {
var select = $('#graph-user');
let select = $('#graph-user');
let by_id = {};
data.sort(function(a, b) {
return a.friendly_name.localeCompare(b.friendly_name);
});
data.forEach(function(item) {
select.append('<option value="' + item.user_id + '">' +
item.friendly_name + '</option>');
by_id[item.user_id] = item.friendly_name;
});
select.selectpicker({
countSelectedText: function(sel, total) {
if (sel === 0 || sel === total) {
return 'All users';
} else if (sel > 1) {
return sel + ' users';
} else {
return select.val().map(function(id) {
return by_id[id];
}).join(', ');
}
},
style: 'btn-dark',
actionsBox: true,
selectedTextFormat: 'count',
noneSelectedText: 'All users'
});
select.selectpicker('render');
select.selectpicker('selectAll');
}
});
@ -644,11 +665,6 @@
$('#nav-tabs-total').tab('show');
}
// Set initial state
if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); }
if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); }
if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); }
// Tab1 opened
$('#nav-tabs-plays').on('shown.bs.tab', function (e) {
e.preventDefault();
@ -681,7 +697,7 @@
setLocalStorage('graph_days', current_day_range);
if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); }
if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); }
$('.days').html(current_day_range);
$('.days').text(current_day_range);
});
// Month range changed
@ -691,12 +707,23 @@
current_month_range = $(this).val();
setLocalStorage('graph_months', current_month_range);
if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); }
$('.months').html(current_month_range);
$('.months').text(current_month_range);
});
let graph_user_last_id = undefined;
// User changed
$('#graph-user').on('change', function() {
selected_user_id = $(this).val() || null;
let val = $(this).val();
if (val.length === 0 || val.length === $(this).children().length) {
selected_user_id = null; // if all users are selected, just send an empty list
} else {
selected_user_id = val.join(",");
}
if (selected_user_id === graph_user_last_id) {
return;
}
graph_user_last_id = selected_user_id;
if (current_tab === '#tabs-plays') { loadGraphsTab1(current_day_range, yaxis); }
if (current_tab === '#tabs-stream') { loadGraphsTab2(current_day_range, yaxis); }
if (current_tab === '#tabs-total') { loadGraphsTab3(current_month_range, yaxis); }

View file

@ -1,6 +1,7 @@
<%inherit file="base.html"/>
<%def name="headIncludes()">
<link rel="stylesheet" href="${http_root}css/bootstrap-select.min.css">
<link rel="stylesheet" href="${http_root}css/dataTables.bootstrap.min.css">
<link rel="stylesheet" href="${http_root}css/dataTables.colVis.css">
<link rel="stylesheet" href="${http_root}css/tautulli-dataTables.css">
@ -31,9 +32,7 @@
% if _session['user_group'] == 'admin':
<div class="btn-group" id="user-selection">
<label>
<select name="history-user" id="history-user" class="btn" style="color: inherit;">
<option value="">All Users</option>
<option disabled>&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;&#9472;</option>
<select name="history-user" id="history-user" multiple>
</select>
</label>
</div>
@ -84,7 +83,7 @@
<th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th>
<th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th>
</tr>
</thead>
@ -121,6 +120,7 @@
</%def>
<%def name="javascriptIncludes()">
<script src="${http_root}js/bootstrap-select.min.js"></script>
<script src="${http_root}js/jquery.dataTables.min.js"></script>
<script src="${http_root}js/dataTables.colVis.js"></script>
<script src="${http_root}js/dataTables.bootstrap.min.js"></script>
@ -134,17 +134,40 @@
type: 'GET',
dataType: 'json',
success: function (data) {
var select = $('#history-user');
let select = $('#history-user');
let by_id = {};
data.sort(function (a, b) {
return a.friendly_name.localeCompare(b.friendly_name);
});
data.forEach(function (item) {
select.append('<option value="' + item.user_id + '">' +
item.friendly_name + '</option>');
by_id[item.user_id] = item.friendly_name;
});
select.selectpicker({
countSelectedText: function(sel, total) {
if (sel === 0 || sel === total) {
return 'All users';
} else if (sel > 1) {
return sel + ' users';
} else {
return select.val().map(function(id) {
return by_id[id];
}).join(', ');
}
},
style: 'btn-dark',
actionsBox: true,
selectedTextFormat: 'count',
noneSelectedText: 'All users'
});
select.selectpicker('render');
select.selectpicker('selectAll');
}
});
let history_user_last_id = undefined;
function loadHistoryTable(media_type, transcode_decision, selected_user_id) {
history_table_options.ajax = {
url: 'get_history',
@ -187,7 +210,16 @@
});
$('#history-user').on('change', function () {
selected_user_id = $(this).val() || null;
let val = $(this).val();
if (val.length === 0 || val.length === $(this).children().length) {
selected_user_id = null; // if all users are selected, just send an empty list
} else {
selected_user_id = val.join(",");
}
if (selected_user_id === history_user_last_id) {
return;
}
history_user_last_id = selected_user_id;
history_table.draw();
});
}

View file

@ -32,7 +32,7 @@
<th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th>
<th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th>
</tr>
</thead>

View file

@ -77,7 +77,8 @@ DOCUMENTATION :: END
<% fallback = 'art-live' if row0['live'] else 'art' %>
<div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'], row0['rating_key'], 500, 280, 40, '282828', 3, fallback=fallback)});">
% elif stat_id == 'top_libraries':
<div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'] or row0['library_art'], None, 500, 280, 40, '282828', 3, fallback=row0['library_art'])});" data-library_art="${row0['library_art']}">
<% fallback = 'art-live' if row0['live'] else row0['library_art'] %>
<div id="stats-background-${stat_id}" class="dashboard-stats-background" style="background-image: url(${page('pms_image_proxy', row0['art'] or row0['library_art'], None, 500, 280, 40, '282828', 3, fallback=fallback)});" data-library_art="${row0['library_art']}">
% elif stat_id == 'top_users':
<div id="stats-background-${stat_id}" class="dashboard-stats-background" data-blurhash="${page('pms_image_proxy', row0['user_thumb'] or 'interfaces/default/images/gravatar-default.png', None, 100, 100, 40, '282828', 0, fallback='user')}">
% elif stat_id == 'top_platforms':
@ -109,8 +110,8 @@ DOCUMENTATION :: END
</a>
</div>
% elif stat_id == 'top_libraries':
% if row0['thumb'].startswith('http'):
<div id="stats-thumb-${stat_id}" class="dashboard-stats-flat hidden-xs" style="background-image: url(${page('pms_image_proxy', row0['thumb'], None, 80, 80)});"></div>
% if row0['library_thumb'].startswith('http'):
<div id="stats-thumb-${stat_id}" class="dashboard-stats-flat hidden-xs" style="background-image: url(${page('pms_image_proxy', row0['library_thumb'], None, 100, 100, fallback='cover')});"></div>
% else:
<div id="stats-thumb-${stat_id}" class="dashboard-stats-flat svg-icon library-${row0['section_type']} hidden-xs"></div>
% endif
@ -147,7 +148,8 @@ DOCUMENTATION :: END
data-rating_key="${row.get('rating_key')}" data-grandparent_rating_key="${row.get('grandparent_rating_key')}" data-guid="${row.get('guid')}" data-title="${row.get('title')}"
data-art="${row.get('art')}" data-thumb="${row.get('thumb')}" data-platform="${row.get('platform_name')}" data-library-type="${row.get('section_type')}"
data-user_id="${row.get('user_id')}" data-user="${row.get('user')}" data-friendly_name="${row.get('friendly_name')}" data-user_thumb="${row.get('user_thumb')}"
data-last_watch="${row.get('last_watch')}" data-started="${row.get('started')}" data-live="${row.get('live')}" data-library_art="${row.get('library_art', '')}">
data-last_watch="${row.get('last_watch')}" data-started="${row.get('started')}" data-live="${row.get('live')}"
data-library_art="${row.get('library_art', '')}" data-library_thumb="${row.get('library_thumb', '')}">
<div class="sub-list">${loop.index + 1}</div>
<div class="sub-value">
% if stat_id in ('top_movies', 'popular_movies', 'top_tv', 'popular_tv', 'top_music', 'popular_music', 'last_watched'):

View file

@ -523,14 +523,15 @@
var audio_decision = '';
if (['movie', 'episode', 'clip', 'track'].indexOf(s.media_type) > -1 && s.stream_audio_decision) {
var audio_language = (s.media_type !== 'track') ? (s.audio_language || 'Unknown') + ' - ' : '';
var a_codec = (s.audio_codec === 'truehd') ? 'TrueHD' : s.audio_codec.toUpperCase();
var sa_codec = (s.stream_audio_codec === 'truehd') ? 'TrueHD' : s.stream_audio_codec.toUpperCase();
if (s.stream_audio_decision === 'transcode') {
audio_decision = 'Transcode ('+ (s.audio_language || 'Unknown')+ ' - ' + a_codec + ' ' + capitalizeFirstLetter(s.audio_channel_layout.split('(')[0]) + ' <i class="fa fa-long-arrow-right"></i> ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
audio_decision = 'Transcode (' + audio_language + a_codec + ' ' + capitalizeFirstLetter(s.audio_channel_layout.split('(')[0]) + ' <i class="fa fa-long-arrow-right"></i> ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
} else if (s.stream_audio_decision === 'copy') {
audio_decision = 'Direct Stream ('+ (s.audio_language || 'Unknown')+ ' - ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
audio_decision = 'Direct Stream (' + audio_language + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
} else {
audio_decision = 'Direct Play ('+ (s.audio_language || 'Unknown')+ ' - ' + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
audio_decision = 'Direct Play (' + audio_language + sa_codec + ' ' + capitalizeFirstLetter(s.stream_audio_channel_layout.split('(')[0]) + ')';
}
}
$('#audio_decision-' + key).html(audio_decision);
@ -797,6 +798,7 @@
var guid = $(elem).data('guid');
var live = $(elem).data('live');
var library_art = $(elem).data('library_art');
var library_thumb = $(elem).data('library_thumb');
var [height, fallback_poster, fallback_art] = [450, 'poster', 'art'];
if ($.inArray(stat_id, ['top_music', 'popular_music']) > -1) {
[height, fallback_poster, fallback_art] = [300, 'cover', 'art'];
@ -808,11 +810,11 @@
if (stat_id === 'most_concurrent') {
return
} else if (stat_id === 'top_libraries') {
$('#stats-background-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', art || library_art, null, 500, 280, 40, '282828', 3, library_art || fallback_art) + ')');
$('#stats-background-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', art || library_art, null, 500, 280, 40, '282828', 3, fallback_art) + ')');
$('#stats-thumb-' + stat_id).removeClass(function (index, className) {
return (className.match (/(^|\s)svg-icon library-\S+/g) || []).join(' ')});
if (thumb.startsWith('http')) {
$('#stats-thumb-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', thumb, null, 300, 300, null, null, null, 'cover') + ')');
if (library_thumb.startsWith('http')) {
$('#stats-thumb-' + stat_id).css('background-image', 'url(' + page('pms_image_proxy', library_thumb, null, 100, 100, null, null, null, 'cover') + ')');
} else {
$('#stats-thumb-' + stat_id).css('background-image', '')
.addClass('svg-icon library-' + library_type);

View file

@ -12,8 +12,10 @@ data :: Usable parameters (if not applicable for media type, blank value will be
== Global keys ==
rating_key Returns the unique identifier for the media item.
media_type Returns the type of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'.
sub_media_type Returns the subtype of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'.
art Returns the location of the item's artwork
title Returns the name of the movie, show, episode, artist, album, or track.
edition_title Returns the edition title of a movie.
duration Returns the standard runtime of the media.
content_rating Returns the age rating for the media.
summary Returns a brief description of the media plot.
@ -212,7 +214,7 @@ DOCUMENTATION :: END
% if _session['user_group'] == 'admin':
<span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span>
% endif
% elif data['media_type'] in ('artist', 'album', 'track', 'playlist', 'photo_album', 'photo', 'clip'):
% elif data['media_type'] in ('artist', 'album', 'track', 'playlist', 'photo_album', 'photo', 'clip') or data['sub_media_type'] in ('artist', 'album', 'track'):
<div class="summary-poster-face-track" style="background-image: url(${page('pms_image_proxy', data['thumb'], data['rating_key'], 300, 300, fallback='cover')});">
<div class="summary-poster-face-overlay">
<span></span>
@ -266,7 +268,7 @@ DOCUMENTATION :: END
<h1><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a></h1>
<h2>${data['title']}</h2>
% elif data['media_type'] == 'track':
<h1><a href="${page('info', data['grandparent_rating_key'])}">${data['original_title'] or data['grandparent_title']}</a></h1>
<h1><a href="${page('info', data['grandparent_rating_key'])}">${data['grandparent_title']}</a></h1>
<h2><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a> - ${data['title']}</h2>
<h3 class="hidden-xs">T${data['media_index']}</h3>
% elif data['media_type'] in ('photo', 'clip'):
@ -282,14 +284,14 @@ DOCUMENTATION :: END
padding_height = ''
if data['media_type'] == 'movie' or data['live']:
padding_height = 'height: 305px;'
elif data['media_type'] in ('show', 'season', 'collection'):
padding_height = 'height: 270px;'
elif data['media_type'] == 'episode':
padding_height = 'height: 70px;'
elif data['media_type'] in ('artist', 'album', 'playlist', 'photo_album', 'photo'):
elif data['media_type'] in ('artist', 'album', 'playlist', 'photo_album', 'photo') or data['sub_media_type'] in ('artist', 'album', 'track'):
padding_height = 'height: 150px;'
elif data['media_type'] in ('track', 'clip'):
padding_height = 'height: 180px;'
elif data['media_type'] == 'episode':
padding_height = 'height: 70px;'
elif data['media_type'] in ('show', 'season', 'collection'):
padding_height = 'height: 270px;'
%>
<div class="summary-content-padding hidden-xs hidden-sm" style="${padding_height}">
% if data['media_type'] in ('movie', 'episode', 'track', 'clip'):
@ -368,6 +370,11 @@ DOCUMENTATION :: END
Studio <strong> ${data['studio']}</strong>
% endif
</div>
<div class="summary-content-details-tag">
% if data['media_type'] == 'track' and data['original_title']:
Track Artists <strong> ${data['original_title']}</strong>
% endif
</div>
<div class="summary-content-details-tag">
% if data['media_type'] == 'movie':
Year <strong> ${data['year']}</strong>
@ -390,6 +397,11 @@ DOCUMENTATION :: END
Runtime <strong> <span id="runtime">${data['duration']}</span></strong>
% endif
</div>
% if data['edition_title']:
<div class="summary-content-details-tag">
Edition <strong> ${data['edition_title']} </strong>
</div>
% endif
<div class="summary-content-details-tag">
% if data['content_rating']:
Rated <strong> ${data['content_rating']} </strong>
@ -542,7 +554,7 @@ DOCUMENTATION :: END
</div>
</div>
% endif
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'):
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
<div class="col-md-12">
<div class="table-card-header">
<div class="header-bar">
@ -571,7 +583,7 @@ DOCUMENTATION :: END
</div>
% endif
<%
history_type = data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track')
history_type = data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist')
history_active = 'active' if history_type else ''
export_active = 'active' if not history_type else ''
%>
@ -634,7 +646,7 @@ DOCUMENTATION :: END
<div class="col-md-12">
<div class="table-card-header">
<div class="header-bar">
% if data['media_type'] in ('artist', 'album', 'track'):
% if data['media_type'] in ('artist', 'album', 'track', 'playlist'):
<span>Play History for <strong>${data['title']}</strong></span>
% else:
<span>Watch History for <strong>${data['title']}</strong></span>
@ -680,7 +692,7 @@ DOCUMENTATION :: END
<th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th>
<th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th>
</tr>
</thead>
@ -806,7 +818,7 @@ DOCUMENTATION :: END
% elif data['media_type'] == 'album':
${data['parent_title']}<br />${data['title']}
% elif data['media_type'] == 'track':
${data['original_title'] or data['grandparent_title']}<br />${data['title']}<br />${data['parent_title']}
${data['grandparent_title']}<br />${data['title']}<br />${data['parent_title']}
% endif
</strong>
</p>
@ -853,7 +865,7 @@ DOCUMENTATION :: END
%>
<script src="${http_root}js/tables/history_table.js${cache_param}"></script>
<script src="${http_root}js/tables/export_table.js${cache_param}"></script>
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'):
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
<script>
function loadHistoryTable(transcode_decision) {
// Build watch history table
@ -873,6 +885,9 @@ DOCUMENTATION :: END
parent_rating_key: "${data['rating_key']}"
% elif data['media_type'] in ('movie', 'episode', 'track'):
rating_key: "${data['rating_key']}"
% elif data['media_type'] in ('collection', 'playlist'):
media_type: "${data['media_type']}",
rating_key: "${data['rating_key']}"
% endif
};
}
@ -925,13 +940,16 @@ DOCUMENTATION :: END
});
</script>
% endif
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'):
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
<script>
// Populate watch time stats
$.ajax({
url: 'item_watch_time_stats',
async: true,
data: { rating_key: "${data['rating_key']}" },
data: {
rating_key: "${data['rating_key']}",
media_type: "${data['media_type']}"
},
complete: function(xhr, status) {
$("#watch-time-stats").html(xhr.responseText);
}
@ -940,7 +958,10 @@ DOCUMENTATION :: END
$.ajax({
url: 'item_user_stats',
async: true,
data: { rating_key: "${data['rating_key']}" },
data: {
rating_key: "${data['rating_key']}",
media_type: "${data['media_type']}"
},
complete: function(xhr, status) {
$("#user-stats").html(xhr.responseText);
}

View file

@ -160,6 +160,16 @@ DOCUMENTATION :: END
% endif
</div>
</a>
<div class="item-children-instance-text-wrapper poster-item">
<h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3>
% if media_type == 'collection':
<h3 class="text-muted">
<a class="text-muted" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>
</h3>
% endif
</div>
% elif child['media_type'] == 'episode':
<a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}">
<div class="item-children-poster">
@ -179,6 +189,29 @@ DOCUMENTATION :: END
<h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3>
% if media_type == 'collection':
<h3 class="text-muted">
<a href="${page('info', child['grandparent_rating_key'])}" title="${child['grandparent_title']}">${child['grandparent_title']}</a>
</h3>
<h3 class="text-muted">
<a href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${short_season(child['parent_title'])}</a>
&middot; <a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}">E${child['media_index']}</a>
</h3>
% endif
</div>
% elif child['media_type'] == 'artist':
<a href="${page('info', child['rating_key'])}" title="${child['title']}">
<div class="item-children-poster">
<div class="item-children-poster-face cover-item" style="background-image: url(${page('pms_image_proxy', child['thumb'], child['rating_key'], 300, 300, fallback='cover')});"></div>
% if _session['user_group'] == 'admin':
<span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span>
% endif
</div>
</a>
<div class="item-children-instance-text-wrapper cover-item">
<h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3>
</div>
% elif child['media_type'] == 'album':
<a href="${page('info', child['rating_key'])}" title="${child['title']}">
@ -193,6 +226,11 @@ DOCUMENTATION :: END
<h3>
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
</h3>
% if media_type == 'collection':
<h3 class="text-muted">
<a class="text-muted" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>
</h3>
% endif
</div>
% elif child['media_type'] == 'track':
<% e = 'even' if loop.index % 2 == 0 else 'odd' %>
@ -205,7 +243,15 @@ DOCUMENTATION :: END
${child['title']}
</span>
</a>
% if child['original_title']:
% if media_type == 'collection':
-
<a href="${page('info', child['grandparent_rating_key'])}" title="${child['grandparent_title']}">
<span class="thumb-tooltip" data-toggle="popover" data-img="${page('pms_image_proxy', child['grandparent_thumb'], child['grandparent_rating_key'], 300, 300, fallback='cover')}" data-height="80" data-width="80">
${child['grandparent_title']}
</span>
</a>
<span class="text-muted"> (<a class="no-highlight" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>)</span>
% elif child['original_title']:
<span class="text-muted"> - ${child['original_title']}</span>
% endif
</span>

File diff suppressed because one or more lines are too long

View file

@ -32,7 +32,12 @@ collections_table_options = {
if (rowData['smart']) {
smart = '<span class="media-type-tooltip" data-toggle="tooltip" title="Smart Collection"><i class="fa fa-cog fa-fw"></i></span>&nbsp;'
}
console.log(rowData['subtype'])
if (rowData['subtype'] === 'artist' || rowData['subtype'] === 'album' || rowData['subtype'] === 'track') {
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 300, null, null, null, 'cover') + '" data-height="80" data-width="80">' + rowData['title'] + '</span>';
} else {
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 450, null, null, null, 'poster') + '" data-height="120" data-width="80">' + rowData['title'] + '</span>';
}
$(td).html(smart + '<a href="' + page('info', rowData['ratingKey']) + '">' + thumb_popover + '</a>');
}
},

View file

@ -247,7 +247,7 @@ history_table_options = {
},
{
"targets": [11],
"data": "duration",
"data": "play_duration",
"render": function (data, type, full) {
if (data !== null) {
return Math.round(moment.duration(data, 'seconds').as('minutes')) + ' mins';
@ -529,7 +529,7 @@ function childTableFormat(rowData) {
'<th align="left" id="started">Started</th>' +
'<th align="left" id="paused_counter">Paused</th>' +
'<th align="left" id="stopped">Stopped</th>' +
'<th align="left" id="duration">Duration</th>' +
'<th align="left" id="play_duration">Duration</th>' +
'<th align="left" id="percent_complete"></th>' +
'</tr>' +
'</thead>' +

View file

@ -248,7 +248,7 @@ DOCUMENTATION :: END
<th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th>
<th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th>
</tr>
</thead>

View file

@ -453,12 +453,12 @@
$("#download-tautullilog").click(function () {
var logfile = $(".tab-pane.active").data('logfile');
window.location.href = "download_log?logfile=" + logfile;
window.location.href = "download_log?logfile=" + window.encodeURIComponent(logfile);
});
$("#download-plexserverlog").click(function () {
var logfile = $("option:selected", "#plex-log-files").val();
window.location.href = "download_plex_log?logfile=" + logfile;
window.location.href = "download_plex_log?logfile=" + window.encodeURIComponent(logfile);
});
$("#clear-notify-logs").click(function () {

View file

@ -70,7 +70,7 @@ DOCUMENTATION :: END
function checkQRAddress(url) {
var parser = document.createElement('a');
parser.href = url;
parser.setAttribute('href', url);
var hostname = parser.hostname;
var protocol = parser.protocol;

View file

@ -142,8 +142,10 @@
<div class="row">
<div class="col-md-12">
<select class="form-control" id="${item['name']}" name="${item['name']}">
% if item['select_all']:
<option value="select-all">Select All</option>
<option value="remove-all">Remove All</option>
% endif
% if isinstance(item['select_options'], dict):
% for section, options in item['select_options'].items():
<optgroup label="${section}">
@ -153,7 +155,9 @@
</optgroup>
% endfor
% else:
% if item['select_all']:
<option value="border-all"></option>
% endif
% for option in sorted(item['select_options'], key=lambda x: x['text'].lower()):
<option value="${option['value']}">${option['text']}</option>
% endfor

View file

@ -134,8 +134,10 @@
<div class="row">
<div class="col-md-12">
<select class="form-control" id="${item['name']}" name="${item['name']}">
% if item['select_all']:
<option value="select-all">Select All</option>
<option value="remove-all">Remove All</option>
% endif
% if isinstance(item['select_options'], dict):
% for section, options in item['select_options'].items():
<optgroup label="${section}">
@ -145,7 +147,9 @@
</optgroup>
% endfor
% else:
% if item['select_all']:
<option value="border-all"></option>
% endif
% for option in sorted(item['select_options'], key=lambda x: x['text'].lower()):
<option value="${option['value']}">${option['text']}</option>
% endfor
@ -719,6 +723,12 @@
pushoverPriority();
});
var $pushover_sound = $('#pushover_sound').selectize({
create: true
});
var pushover_sound = $pushover_sound[0].selectize;
pushover_sound.setValue(${json.dumps(next((c['value'] for c in notifier['config_options'] if c['name'] == 'pushover_sound'), [])) | n});
% elif notifier['agent_name'] == 'plexmobileapp':
var $plexmobileapp_user_ids = $('#plexmobileapp_user_ids').selectize({
plugins: ['remove_button'],

View file

@ -132,12 +132,6 @@
</label>
<p class="help-block">Change the "<em>Play by day of week</em>" graph to start on Monday. Default is start on Sunday.</p>
</div>
<div class="checkbox advanced-setting">
<label>
<input type="checkbox" id="group_history_tables" name="group_history_tables" value="1" ${config['group_history_tables']}> Group Play History
</label>
<p class="help-block">Group play history for the same item and user as a single entry when progress is less than the watched percent.</p>
</div>
<div class="checkbox advanced-setting">
<label>
<input type="checkbox" id="history_table_activity" name="history_table_activity" value="1" ${config['history_table_activity']}> Current Activity in History Tables
@ -213,6 +207,39 @@
</div>
<p class="help-block">Set the percentage for a music track to be considered as listened. Minimum 50, Maximum 95.</p>
</div>
<div class="form-group">
<label for="music_watched_percent">Video Watched Completion Behaviour</label>
<div class="row">
<div class="col-md-7">
<select class="form-control" id="watched_marker" name="watched_marker">
<option value="0" ${'selected' if config['watched_marker'] == 0 else ''}>At selected threshold percentage</option>
<option value="1" ${'selected' if config['watched_marker'] == 1 else ''}>At final credits marker position</option>
<option value="2" ${'selected' if config['watched_marker'] == 2 else ''}>At first credits marker position</option>
<option value="3" ${'selected' if config['watched_marker'] == 3 else ''}>Earliest between threshold percent and first credits marker</option>
</select>
</div>
</div>
<p class="help-block">Decide whether to use end credits markers to determine the 'watched' state of video items. When markers are not available the selected threshold percentage will be used.</p>
</div>
<div class="checkbox advanced-setting">
<label>
<input type="checkbox" id="group_history_tables" name="group_history_tables" value="1" ${config['group_history_tables']}> Group Play History
</label>
<p class="help-block">Group play history for the same item and user as a single entry when progress is less than the watched percent.</p>
</div>
<div class="form-group advanced-setting">
<label>Regroup Play History</label>
<p class="help-block">
Fix grouping of play history in the database.<br />
</p>
<div class="row">
<div class="col-md-4">
<div class="btn-group">
<button class="btn btn-form" type="button" id="regroup_history">Regroup</button>
</div>
</div>
</div>
</div>
<div class="form-group advanced-setting">
<label>Flush Temporary Sessions</label>
<p class="help-block">
@ -2470,6 +2497,12 @@ $(document).ready(function() {
confirmAjaxCall(url, msg);
});
$("#regroup_history").click(function () {
var msg = 'Are you sure you want to regroup play history in the database?<br /><br /><strong>This make take a long time for large databases.<br />Regrouping will continue in the background.</strong>';
var url = 'regroup_history';
confirmAjaxCall(url, msg);
});
$("#delete_temp_sessions").click(function () {
var msg = 'Are you sure you want to flush the temporary sessions?<br /><br /><strong>This will reset all currently active sessions.</strong>';
var url = 'delete_temp_sessions';

View file

@ -212,7 +212,7 @@ DOCUMENTATION :: END
<th align="left" id="started">Started</th>
<th align="left" id="paused_counter">Paused</th>
<th align="left" id="stopped">Stopped</th>
<th align="left" id="duration">Duration</th>
<th align="left" id="play_duration">Duration</th>
<th align="left" id="percent_complete"></th>
</tr>
</thead>

View file

@ -1,121 +0,0 @@
#!/usr/bin/python
###############################################################################
# Formatting filter for urllib2's HTTPHandler(debuglevel=1) output
# Copyright (c) 2013, Analytics Pros
#
# This project is free software, distributed under the BSD license.
# Analytics Pros offers consulting and integration services if your firm needs
# assistance in strategy, implementation, or auditing existing work.
###############################################################################
import sys, re, os
from io import StringIO
class BufferTranslator(object):
""" Provides a buffer-compatible interface for filtering buffer content.
"""
parsers = []
def __init__(self, output):
self.output = output
self.encoding = getattr(output, 'encoding', None)
def write(self, content):
content = self.translate(content)
self.output.write(content)
@staticmethod
def stripslashes(content):
return content.decode('string_escape')
@staticmethod
def addslashes(content):
return content.encode('string_escape')
def translate(self, line):
for pattern, method in self.parsers:
match = pattern.match(line)
if match:
return method(match)
return line
class LineBufferTranslator(BufferTranslator):
""" Line buffer implementation supports translation of line-format input
even when input is not already line-buffered. Caches input until newlines
occur, and then dispatches translated input to output buffer.
"""
def __init__(self, *a, **kw):
self._linepending = []
super(LineBufferTranslator, self).__init__(*a, **kw)
def write(self, _input):
lines = _input.splitlines(True)
for i in range(0, len(lines)):
last = i
if lines[i].endswith('\n'):
prefix = len(self._linepending) and ''.join(self._linepending) or ''
self.output.write(self.translate(prefix + lines[i]))
del self._linepending[0:]
last = -1
if last >= 0:
self._linepending.append(lines[ last ])
def __del__(self):
if len(self._linepending):
self.output.write(self.translate(''.join(self._linepending)))
class HTTPTranslator(LineBufferTranslator):
""" Translates output from |urllib2| HTTPHandler(debuglevel = 1) into
HTTP-compatible, readible text structures for human analysis.
"""
RE_LINE_PARSER = re.compile(r'^(?:([a-z]+):)\s*(\'?)([^\r\n]*)\2(?:[\r\n]*)$')
RE_LINE_BREAK = re.compile(r'(\r?\n|(?:\\r)?\\n)')
RE_HTTP_METHOD = re.compile(r'^(POST|GET|HEAD|DELETE|PUT|TRACE|OPTIONS)')
RE_PARAMETER_SPACER = re.compile(r'&([a-z0-9]+)=')
@classmethod
def spacer(cls, line):
return cls.RE_PARAMETER_SPACER.sub(r' &\1= ', line)
def translate(self, line):
parsed = self.RE_LINE_PARSER.match(line)
if parsed:
value = parsed.group(3)
stage = parsed.group(1)
if stage == 'send': # query string is rendered here
return '\n# HTTP Request:\n' + self.stripslashes(value)
elif stage == 'reply':
return '\n\n# HTTP Response:\n' + self.stripslashes(value)
elif stage == 'header':
return value + '\n'
else:
return value
return line
def consume(outbuffer = None): # Capture standard output
sys.stdout = HTTPTranslator(outbuffer or sys.stdout)
return sys.stdout
if __name__ == '__main__':
consume(sys.stdout).write(sys.stdin.read())
print('\n')
# vim: set nowrap tabstop=4 shiftwidth=4 softtabstop=0 expandtab textwidth=0 filetype=python foldmethod=indent foldcolumn=4

View file

@ -1,424 +0,0 @@
from future.moves.urllib.request import urlopen, build_opener, install_opener
from future.moves.urllib.request import Request, HTTPSHandler
from future.moves.urllib.error import URLError, HTTPError
from future.moves.urllib.parse import urlencode
import random
import datetime
import time
import uuid
import hashlib
import socket
def generate_uuid(basedata=None):
""" Provides a _random_ UUID with no input, or a UUID4-format MD5 checksum of any input data provided """
if basedata is None:
return str(uuid.uuid4())
elif isinstance(basedata, str):
checksum = hashlib.md5(str(basedata).encode('utf-8')).hexdigest()
return '%8s-%4s-%4s-%4s-%12s' % (
checksum[0:8], checksum[8:12], checksum[12:16], checksum[16:20], checksum[20:32])
class Time(datetime.datetime):
""" Wrappers and convenience methods for processing various time representations """
@classmethod
def from_unix(cls, seconds, milliseconds=0):
""" Produce a full |datetime.datetime| object from a Unix timestamp """
base = list(time.gmtime(seconds))[0:6]
base.append(milliseconds * 1000) # microseconds
return cls(*base)
@classmethod
def to_unix(cls, timestamp):
""" Wrapper over time module to produce Unix epoch time as a float """
if not isinstance(timestamp, datetime.datetime):
raise TypeError('Time.milliseconds expects a datetime object')
base = time.mktime(timestamp.timetuple())
return base
@classmethod
def milliseconds_offset(cls, timestamp, now=None):
""" Offset time (in milliseconds) from a |datetime.datetime| object to now """
if isinstance(timestamp, (int, float)):
base = timestamp
else:
base = cls.to_unix(timestamp)
base = base + (timestamp.microsecond / 1000000)
if now is None:
now = time.time()
return (now - base) * 1000
class HTTPRequest(object):
""" URL Construction and request handling abstraction.
This is not intended to be used outside this module.
Automates mapping of persistent state (i.e. query parameters)
onto transcient datasets for each query.
"""
endpoint = 'https://www.google-analytics.com/collect'
@staticmethod
def debug():
""" Activate debugging on urllib2 """
handler = HTTPSHandler(debuglevel=1)
opener = build_opener(handler)
install_opener(opener)
# Store properties for all requests
def __init__(self, user_agent=None, *args, **opts):
self.user_agent = user_agent or 'Analytics Pros - Universal Analytics (Python)'
@classmethod
def fixUTF8(cls, data): # Ensure proper encoding for UA's servers...
""" Convert all strings to UTF-8 """
for key in data:
if isinstance(data[key], str):
data[key] = data[key].encode('utf-8')
return data
# Apply stored properties to the given dataset & POST to the configured endpoint
def send(self, data):
request = Request(
self.endpoint + '?' + urlencode(self.fixUTF8(data)).encode('utf-8'),
headers={
'User-Agent': self.user_agent
}
)
self.open(request)
def open(self, request):
try:
return urlopen(request)
except HTTPError as e:
return False
except URLError as e:
self.cache_request(request)
return False
def cache_request(self, request):
# TODO: implement a proper caching mechanism here for re-transmitting hits
# record = (Time.now(), request.get_full_url(), request.get_data(), request.headers)
pass
class HTTPPost(HTTPRequest):
# Apply stored properties to the given dataset & POST to the configured endpoint
def send(self, data):
request = Request(
self.endpoint,
data=urlencode(self.fixUTF8(data)).encode('utf-8'),
headers={
'User-Agent': self.user_agent
}
)
self.open(request)
class Tracker(object):
""" Primary tracking interface for Universal Analytics """
params = None
parameter_alias = {}
valid_hittypes = ('pageview', 'event', 'social', 'screenview', 'transaction', 'item', 'exception', 'timing')
@classmethod
def alias(cls, typemap, base, *names):
""" Declare an alternate (humane) name for a measurement protocol parameter """
cls.parameter_alias[base] = (typemap, base)
for i in names:
cls.parameter_alias[i] = (typemap, base)
@classmethod
def coerceParameter(cls, name, value=None):
if isinstance(name, str) and name[0] == '&':
return name[1:], str(value)
elif name in cls.parameter_alias:
typecast, param_name = cls.parameter_alias.get(name)
return param_name, typecast(value)
else:
raise KeyError('Parameter "{0}" is not recognized'.format(name))
def payload(self, data):
for key, value in data.items():
try:
yield self.coerceParameter(key, value)
except KeyError:
continue
option_sequence = {
'pageview': [(str, 'dp')],
'event': [(str, 'ec'), (str, 'ea'), (str, 'el'), (int, 'ev')],
'social': [(str, 'sn'), (str, 'sa'), (str, 'st')],
'timing': [(str, 'utc'), (str, 'utv'), (str, 'utt'), (str, 'utl')]
}
@classmethod
def consume_options(cls, data, hittype, args):
""" Interpret sequential arguments related to known hittypes based on declared structures """
opt_position = 0
data['t'] = hittype # integrate hit type parameter
if hittype in cls.option_sequence:
for expected_type, optname in cls.option_sequence[hittype]:
if opt_position < len(args) and isinstance(args[opt_position], expected_type):
data[optname] = args[opt_position]
opt_position += 1
@classmethod
def hittime(cls, timestamp=None, age=None, milliseconds=None):
""" Returns an integer represeting the milliseconds offset for a given hit (relative to now) """
if isinstance(timestamp, (int, float)):
return int(Time.milliseconds_offset(Time.from_unix(timestamp, milliseconds=milliseconds)))
if isinstance(timestamp, datetime.datetime):
return int(Time.milliseconds_offset(timestamp))
if isinstance(age, (int, float)):
return int(age * 1000) + (milliseconds or 0)
@property
def account(self):
return self.params.get('tid', None)
def __init__(self, account, name=None, client_id=None, hash_client_id=False, user_id=None, user_agent=None,
use_post=True):
if use_post is False:
self.http = HTTPRequest(user_agent=user_agent)
else:
self.http = HTTPPost(user_agent=user_agent)
self.params = {'v': 1, 'tid': account}
if client_id is None:
client_id = generate_uuid()
self.params['cid'] = client_id
self.hash_client_id = hash_client_id
if user_id is not None:
self.params['uid'] = user_id
def set_timestamp(self, data):
""" Interpret time-related options, apply queue-time parameter as needed """
if 'hittime' in data: # an absolute timestamp
data['qt'] = self.hittime(timestamp=data.pop('hittime', None))
if 'hitage' in data: # a relative age (in seconds)
data['qt'] = self.hittime(age=data.pop('hitage', None))
def send(self, hittype, *args, **data):
""" Transmit HTTP requests to Google Analytics using the measurement protocol """
if hittype not in self.valid_hittypes:
raise KeyError('Unsupported Universal Analytics Hit Type: {0}'.format(repr(hittype)))
self.set_timestamp(data)
self.consume_options(data, hittype, args)
for item in args: # process dictionary-object arguments of transcient data
if isinstance(item, dict):
for key, val in self.payload(item):
data[key] = val
for k, v in self.params.items(): # update only absent parameters
if k not in data:
data[k] = v
data = dict(self.payload(data))
if self.hash_client_id:
data['cid'] = generate_uuid(data['cid'])
# Transmit the hit to Google...
self.http.send(data)
# Setting persistent attibutes of the session/hit/etc (inc. custom dimensions/metrics)
def set(self, name, value=None):
if isinstance(name, dict):
for key, value in name.items():
try:
param, value = self.coerceParameter(key, value)
self.params[param] = value
except KeyError:
pass
elif isinstance(name, str):
try:
param, value = self.coerceParameter(name, value)
self.params[param] = value
except KeyError:
pass
def __getitem__(self, name):
param, value = self.coerceParameter(name, None)
return self.params.get(param, None)
def __setitem__(self, name, value):
param, value = self.coerceParameter(name, value)
self.params[param] = value
def __delitem__(self, name):
param, value = self.coerceParameter(name, None)
if param in self.params:
del self.params[param]
def safe_unicode(obj):
""" Safe convertion to the Unicode string version of the object """
try:
return str(obj)
except UnicodeDecodeError:
return obj.decode('utf-8')
# Declaring name mappings for Measurement Protocol parameters
MAX_CUSTOM_DEFINITIONS = 200
MAX_EC_LISTS = 11 # 1-based index
MAX_EC_PRODUCTS = 11 # 1-based index
MAX_EC_PROMOTIONS = 11 # 1-based index
Tracker.alias(int, 'v', 'protocol-version')
Tracker.alias(safe_unicode, 'cid', 'client-id', 'clientId', 'clientid')
Tracker.alias(safe_unicode, 'tid', 'trackingId', 'account')
Tracker.alias(safe_unicode, 'uid', 'user-id', 'userId', 'userid')
Tracker.alias(safe_unicode, 'uip', 'user-ip', 'userIp', 'ipaddr')
Tracker.alias(safe_unicode, 'ua', 'userAgent', 'userAgentOverride', 'user-agent')
Tracker.alias(safe_unicode, 'dp', 'page', 'path')
Tracker.alias(safe_unicode, 'dt', 'title', 'pagetitle', 'pageTitle' 'page-title')
Tracker.alias(safe_unicode, 'dl', 'location')
Tracker.alias(safe_unicode, 'dh', 'hostname')
Tracker.alias(safe_unicode, 'sc', 'sessioncontrol', 'session-control', 'sessionControl')
Tracker.alias(safe_unicode, 'dr', 'referrer', 'referer')
Tracker.alias(int, 'qt', 'queueTime', 'queue-time')
Tracker.alias(safe_unicode, 't', 'hitType', 'hittype')
Tracker.alias(int, 'aip', 'anonymizeIp', 'anonIp', 'anonymize-ip')
Tracker.alias(safe_unicode, 'ds', 'dataSource', 'data-source')
# Campaign attribution
Tracker.alias(safe_unicode, 'cn', 'campaign', 'campaignName', 'campaign-name')
Tracker.alias(safe_unicode, 'cs', 'source', 'campaignSource', 'campaign-source')
Tracker.alias(safe_unicode, 'cm', 'medium', 'campaignMedium', 'campaign-medium')
Tracker.alias(safe_unicode, 'ck', 'keyword', 'campaignKeyword', 'campaign-keyword')
Tracker.alias(safe_unicode, 'cc', 'content', 'campaignContent', 'campaign-content')
Tracker.alias(safe_unicode, 'ci', 'campaignId', 'campaignID', 'campaign-id')
# Technical specs
Tracker.alias(safe_unicode, 'sr', 'screenResolution', 'screen-resolution', 'resolution')
Tracker.alias(safe_unicode, 'vp', 'viewport', 'viewportSize', 'viewport-size')
Tracker.alias(safe_unicode, 'de', 'encoding', 'documentEncoding', 'document-encoding')
Tracker.alias(int, 'sd', 'colors', 'screenColors', 'screen-colors')
Tracker.alias(safe_unicode, 'ul', 'language', 'user-language', 'userLanguage')
# Mobile app
Tracker.alias(safe_unicode, 'an', 'appName', 'app-name', 'app')
Tracker.alias(safe_unicode, 'cd', 'contentDescription', 'screenName', 'screen-name', 'content-description')
Tracker.alias(safe_unicode, 'av', 'appVersion', 'app-version', 'version')
Tracker.alias(safe_unicode, 'aid', 'appID', 'appId', 'application-id', 'app-id', 'applicationId')
Tracker.alias(safe_unicode, 'aiid', 'appInstallerId', 'app-installer-id')
# Ecommerce
Tracker.alias(safe_unicode, 'ta', 'affiliation', 'transactionAffiliation', 'transaction-affiliation')
Tracker.alias(safe_unicode, 'ti', 'transaction', 'transactionId', 'transaction-id')
Tracker.alias(float, 'tr', 'revenue', 'transactionRevenue', 'transaction-revenue')
Tracker.alias(float, 'ts', 'shipping', 'transactionShipping', 'transaction-shipping')
Tracker.alias(float, 'tt', 'tax', 'transactionTax', 'transaction-tax')
Tracker.alias(safe_unicode, 'cu', 'currency', 'transactionCurrency',
'transaction-currency') # Currency code, e.g. USD, EUR
Tracker.alias(safe_unicode, 'in', 'item-name', 'itemName')
Tracker.alias(float, 'ip', 'item-price', 'itemPrice')
Tracker.alias(float, 'iq', 'item-quantity', 'itemQuantity')
Tracker.alias(safe_unicode, 'ic', 'item-code', 'sku', 'itemCode')
Tracker.alias(safe_unicode, 'iv', 'item-variation', 'item-category', 'itemCategory', 'itemVariation')
# Events
Tracker.alias(safe_unicode, 'ec', 'event-category', 'eventCategory', 'category')
Tracker.alias(safe_unicode, 'ea', 'event-action', 'eventAction', 'action')
Tracker.alias(safe_unicode, 'el', 'event-label', 'eventLabel', 'label')
Tracker.alias(int, 'ev', 'event-value', 'eventValue', 'value')
Tracker.alias(int, 'ni', 'noninteractive', 'nonInteractive', 'noninteraction', 'nonInteraction')
# Social
Tracker.alias(safe_unicode, 'sa', 'social-action', 'socialAction')
Tracker.alias(safe_unicode, 'sn', 'social-network', 'socialNetwork')
Tracker.alias(safe_unicode, 'st', 'social-target', 'socialTarget')
# Exceptions
Tracker.alias(safe_unicode, 'exd', 'exception-description', 'exceptionDescription', 'exDescription')
Tracker.alias(int, 'exf', 'exception-fatal', 'exceptionFatal', 'exFatal')
# User Timing
Tracker.alias(safe_unicode, 'utc', 'timingCategory', 'timing-category')
Tracker.alias(safe_unicode, 'utv', 'timingVariable', 'timing-variable')
Tracker.alias(float, 'utt', 'time', 'timingTime', 'timing-time')
Tracker.alias(safe_unicode, 'utl', 'timingLabel', 'timing-label')
Tracker.alias(float, 'dns', 'timingDNS', 'timing-dns')
Tracker.alias(float, 'pdt', 'timingPageLoad', 'timing-page-load')
Tracker.alias(float, 'rrt', 'timingRedirect', 'timing-redirect')
Tracker.alias(safe_unicode, 'tcp', 'timingTCPConnect', 'timing-tcp-connect')
Tracker.alias(safe_unicode, 'srt', 'timingServerResponse', 'timing-server-response')
# Custom dimensions and metrics
for i in range(0, 200):
Tracker.alias(safe_unicode, 'cd{0}'.format(i), 'dimension{0}'.format(i))
Tracker.alias(int, 'cm{0}'.format(i), 'metric{0}'.format(i))
# Content groups
for i in range(0, 5):
Tracker.alias(safe_unicode, 'cg{0}'.format(i), 'contentGroup{0}'.format(i))
# Enhanced Ecommerce
Tracker.alias(str, 'pa') # Product action
Tracker.alias(str, 'tcc') # Coupon code
Tracker.alias(str, 'pal') # Product action list
Tracker.alias(int, 'cos') # Checkout step
Tracker.alias(str, 'col') # Checkout step option
Tracker.alias(str, 'promoa') # Promotion action
for product_index in range(1, MAX_EC_PRODUCTS):
Tracker.alias(str, 'pr{0}id'.format(product_index)) # Product SKU
Tracker.alias(str, 'pr{0}nm'.format(product_index)) # Product name
Tracker.alias(str, 'pr{0}br'.format(product_index)) # Product brand
Tracker.alias(str, 'pr{0}ca'.format(product_index)) # Product category
Tracker.alias(str, 'pr{0}va'.format(product_index)) # Product variant
Tracker.alias(str, 'pr{0}pr'.format(product_index)) # Product price
Tracker.alias(int, 'pr{0}qt'.format(product_index)) # Product quantity
Tracker.alias(str, 'pr{0}cc'.format(product_index)) # Product coupon code
Tracker.alias(int, 'pr{0}ps'.format(product_index)) # Product position
for custom_index in range(MAX_CUSTOM_DEFINITIONS):
Tracker.alias(str, 'pr{0}cd{1}'.format(product_index, custom_index)) # Product custom dimension
Tracker.alias(int, 'pr{0}cm{1}'.format(product_index, custom_index)) # Product custom metric
for list_index in range(1, MAX_EC_LISTS):
Tracker.alias(str, 'il{0}pi{1}id'.format(list_index, product_index)) # Product impression SKU
Tracker.alias(str, 'il{0}pi{1}nm'.format(list_index, product_index)) # Product impression name
Tracker.alias(str, 'il{0}pi{1}br'.format(list_index, product_index)) # Product impression brand
Tracker.alias(str, 'il{0}pi{1}ca'.format(list_index, product_index)) # Product impression category
Tracker.alias(str, 'il{0}pi{1}va'.format(list_index, product_index)) # Product impression variant
Tracker.alias(int, 'il{0}pi{1}ps'.format(list_index, product_index)) # Product impression position
Tracker.alias(int, 'il{0}pi{1}pr'.format(list_index, product_index)) # Product impression price
for custom_index in range(MAX_CUSTOM_DEFINITIONS):
Tracker.alias(str, 'il{0}pi{1}cd{2}'.format(list_index, product_index,
custom_index)) # Product impression custom dimension
Tracker.alias(int, 'il{0}pi{1}cm{2}'.format(list_index, product_index,
custom_index)) # Product impression custom metric
for list_index in range(1, MAX_EC_LISTS):
Tracker.alias(str, 'il{0}nm'.format(list_index)) # Product impression list name
for promotion_index in range(1, MAX_EC_PROMOTIONS):
Tracker.alias(str, 'promo{0}id'.format(promotion_index)) # Promotion ID
Tracker.alias(str, 'promo{0}nm'.format(promotion_index)) # Promotion name
Tracker.alias(str, 'promo{0}cr'.format(promotion_index)) # Promotion creative
Tracker.alias(str, 'promo{0}ps'.format(promotion_index)) # Promotion position
# Shortcut for creating trackers
def create(account, *args, **kwargs):
return Tracker(account, *args, **kwargs)
# vim: set nowrap tabstop=4 shiftwidth=4 softtabstop=0 expandtab textwidth=0 filetype=python foldmethod=indent foldcolumn=4

View file

@ -1 +0,0 @@
from . import Tracker

View file

@ -3,13 +3,9 @@ from __future__ import absolute_import
import sys
from apscheduler.executors.base import BaseExecutor, run_job
from apscheduler.executors.base_py3 import run_coroutine_job
from apscheduler.util import iscoroutinefunction_partial
try:
from apscheduler.executors.base_py3 import run_coroutine_job
except ImportError:
run_coroutine_job = None
class AsyncIOExecutor(BaseExecutor):
"""
@ -46,11 +42,8 @@ class AsyncIOExecutor(BaseExecutor):
self._run_job_success(job.id, events)
if iscoroutinefunction_partial(job.func):
if run_coroutine_job is not None:
coro = run_coroutine_job(job, job._jobstore_alias, run_times, self._logger.name)
f = self._eventloop.create_task(coro)
else:
raise Exception('Executing coroutine based jobs is not supported with Trollius')
else:
f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times,
self._logger.name)

View file

@ -57,7 +57,7 @@ class SQLAlchemyJobStore(BaseJobStore):
# 25 = precision that translates to an 8-byte float
self.jobs_t = Table(
tablename, metadata,
Column('id', Unicode(191, _warn_on_bytestring=False), primary_key=True),
Column('id', Unicode(191), primary_key=True),
Column('next_run_time', Float(25), index=True),
Column('job_state', LargeBinary, nullable=False),
schema=tableschema
@ -68,8 +68,9 @@ class SQLAlchemyJobStore(BaseJobStore):
self.jobs_t.create(self.engine, True)
def lookup_job(self, job_id):
selectable = select([self.jobs_t.c.job_state]).where(self.jobs_t.c.id == job_id)
job_state = self.engine.execute(selectable).scalar()
selectable = select(self.jobs_t.c.job_state).where(self.jobs_t.c.id == job_id)
with self.engine.begin() as connection:
job_state = connection.execute(selectable).scalar()
return self._reconstitute_job(job_state) if job_state else None
def get_due_jobs(self, now):
@ -77,10 +78,11 @@ class SQLAlchemyJobStore(BaseJobStore):
return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp)
def get_next_run_time(self):
selectable = select([self.jobs_t.c.next_run_time]).\
selectable = select(self.jobs_t.c.next_run_time).\
where(self.jobs_t.c.next_run_time != null()).\
order_by(self.jobs_t.c.next_run_time).limit(1)
next_run_time = self.engine.execute(selectable).scalar()
with self.engine.begin() as connection:
next_run_time = connection.execute(selectable).scalar()
return utc_timestamp_to_datetime(next_run_time)
def get_all_jobs(self):
@ -94,8 +96,9 @@ class SQLAlchemyJobStore(BaseJobStore):
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
})
with self.engine.begin() as connection:
try:
self.engine.execute(insert)
connection.execute(insert)
except IntegrityError:
raise ConflictingIdError(job.id)
@ -104,19 +107,22 @@ class SQLAlchemyJobStore(BaseJobStore):
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
}).where(self.jobs_t.c.id == job.id)
result = self.engine.execute(update)
with self.engine.begin() as connection:
result = connection.execute(update)
if result.rowcount == 0:
raise JobLookupError(job.id)
def remove_job(self, job_id):
delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id)
result = self.engine.execute(delete)
with self.engine.begin() as connection:
result = connection.execute(delete)
if result.rowcount == 0:
raise JobLookupError(job_id)
def remove_all_jobs(self):
delete = self.jobs_t.delete()
self.engine.execute(delete)
with self.engine.begin() as connection:
connection.execute(delete)
def shutdown(self):
self.engine.dispose()
@ -132,11 +138,12 @@ class SQLAlchemyJobStore(BaseJobStore):
def _get_jobs(self, *conditions):
jobs = []
selectable = select([self.jobs_t.c.id, self.jobs_t.c.job_state]).\
selectable = select(self.jobs_t.c.id, self.jobs_t.c.job_state).\
order_by(self.jobs_t.c.next_run_time)
selectable = selectable.where(and_(*conditions)) if conditions else selectable
failed_job_ids = set()
for row in self.engine.execute(selectable):
with self.engine.begin() as connection:
for row in connection.execute(selectable):
try:
jobs.append(self._reconstitute_job(row.job_state))
except BaseException:
@ -146,7 +153,7 @@ class SQLAlchemyJobStore(BaseJobStore):
# Remove all the jobs we failed to restore
if failed_job_ids:
delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids))
self.engine.execute(delete)
connection.execute(delete)
return jobs

View file

@ -1,18 +1,10 @@
from __future__ import absolute_import
import asyncio
from functools import wraps, partial
from apscheduler.schedulers.base import BaseScheduler
from apscheduler.util import maybe_ref
try:
import asyncio
except ImportError: # pragma: nocover
try:
import trollius as asyncio
except ImportError:
raise ImportError(
'AsyncIOScheduler requires either Python 3.4 or the asyncio package installed')
def run_in_event_loop(func):
@wraps(func)

View file

@ -33,7 +33,7 @@ class QtScheduler(BaseScheduler):
def _start_timer(self, wait_seconds):
self._stop_timer()
if wait_seconds is not None:
wait_time = min(wait_seconds * 1000, 2147483647)
wait_time = min(int(wait_seconds * 1000), 2147483647)
self._timer = QTimer.singleShot(wait_time, self._process_jobs)
def _stop_timer(self):

View file

@ -2,6 +2,7 @@
from __future__ import division
from asyncio import iscoroutinefunction
from datetime import date, datetime, time, timedelta, tzinfo
from calendar import timegm
from functools import partial
@ -22,15 +23,6 @@ try:
except ImportError:
TIMEOUT_MAX = 4294967 # Maximum value accepted by Event.wait() on Windows
try:
from asyncio import iscoroutinefunction
except ImportError:
try:
from trollius import iscoroutinefunction
except ImportError:
def iscoroutinefunction(func):
return False
__all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp',
'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name',
'obj_to_ref', 'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args',

View file

@ -11,9 +11,9 @@ from bleach.sanitizer import (
# yyyymmdd
__releasedate__ = "20220627"
__releasedate__ = "20230123"
# x.y.z or x.y.z.dev0 -- semver
__version__ = "5.0.1"
__version__ = "6.0.0"
__all__ = ["clean", "linkify"]
@ -52,7 +52,7 @@ def clean(
:arg str text: the text to clean
:arg list tags: allowed list of tags; defaults to
:arg set tags: set of allowed tags; defaults to
``bleach.sanitizer.ALLOWED_TAGS``
:arg dict attributes: allowed attributes; can be a callable, list or dict;

View file

@ -38,6 +38,9 @@ from bleach._vendor.html5lib.filters.sanitizer import (
allowed_protocols,
allowed_css_properties,
allowed_svg_properties,
attr_val_is_uri,
svg_attr_val_allows_ref,
svg_allow_local_href,
) # noqa: E402 module level import not at top of file
from bleach._vendor.html5lib.filters.sanitizer import (
Filter as SanitizerFilter,
@ -78,7 +81,8 @@ TAG_TOKEN_TYPE_PARSEERROR = constants.tokenTypes["ParseError"]
#: List of valid HTML tags, from WHATWG HTML Living Standard as of 2018-10-17
#: https://html.spec.whatwg.org/multipage/indices.html#elements-3
HTML_TAGS = [
HTML_TAGS = frozenset(
(
"a",
"abbr",
"address",
@ -191,14 +195,15 @@ HTML_TAGS = [
"var",
"video",
"wbr",
]
)
)
#: List of block level HTML tags, as per https://github.com/mozilla/bleach/issues/369
#: from mozilla on 2019.07.11
#: https://developer.mozilla.org/en-US/docs/Web/HTML/Block-level_elements#Elements
HTML_TAGS_BLOCK_LEVEL = frozenset(
[
(
"address",
"article",
"aside",
@ -232,7 +237,7 @@ HTML_TAGS_BLOCK_LEVEL = frozenset(
"section",
"table",
"ul",
]
)
)
@ -473,7 +478,7 @@ class BleachHTMLParser(HTMLParser):
def __init__(self, tags, strip, consume_entities, **kwargs):
"""
:arg tags: list of allowed tags--everything else is either stripped or
:arg tags: set of allowed tags--everything else is either stripped or
escaped; if None, then this doesn't look at tags at all
:arg strip: whether to strip disallowed tags (True) or escape them (False);
if tags=None, then this doesn't have any effect
@ -481,7 +486,9 @@ class BleachHTMLParser(HTMLParser):
leave them as is when tokenizing (BleachHTMLTokenizer-added behavior)
"""
self.tags = [tag.lower() for tag in tags] if tags is not None else None
self.tags = (
frozenset((tag.lower() for tag in tags)) if tags is not None else None
)
self.strip = strip
self.consume_entities = consume_entities
super().__init__(**kwargs)
@ -691,7 +698,7 @@ class BleachHTMLSerializer(HTMLSerializer):
# Only leave entities in that are not ambiguous. If they're
# ambiguous, then we escape the ampersand.
if entity is not None and convert_entity(entity) is not None:
yield "&" + entity + ";"
yield f"&{entity};"
# Length of the entity plus 2--one for & at the beginning
# and one for ; at the end

View file

@ -120,9 +120,10 @@ class Linker:
:arg list callbacks: list of callbacks to run when adjusting tag attributes;
defaults to ``bleach.linkifier.DEFAULT_CALLBACKS``
:arg list skip_tags: list of tags that you don't want to linkify the
contents of; for example, you could set this to ``['pre']`` to skip
linkifying contents of ``pre`` tags
:arg set skip_tags: set of tags that you don't want to linkify the
contents of; for example, you could set this to ``{'pre'}`` to skip
linkifying contents of ``pre`` tags; ``None`` means you don't
want linkify to skip any tags
:arg bool parse_email: whether or not to linkify email addresses
@ -130,7 +131,7 @@ class Linker:
:arg email_re: email matching regex
:arg list recognized_tags: the list of tags that linkify knows about;
:arg set recognized_tags: the set of tags that linkify knows about;
everything else gets escaped
:returns: linkified text as unicode
@ -145,15 +146,18 @@ class Linker:
# Create a parser/tokenizer that allows all HTML tags and escapes
# anything not in that list.
self.parser = html5lib_shim.BleachHTMLParser(
tags=recognized_tags,
tags=frozenset(recognized_tags),
strip=False,
consume_entities=True,
consume_entities=False,
namespaceHTMLElements=False,
)
self.walker = html5lib_shim.getTreeWalker("etree")
self.serializer = html5lib_shim.BleachHTMLSerializer(
quote_attr_values="always",
omit_optional_tags=False,
# We want to leave entities as they are without escaping or
# resolving or expanding
resolve_entities=False,
# linkify does not sanitize
sanitize=False,
# linkify preserves attr order
@ -218,8 +222,8 @@ class LinkifyFilter(html5lib_shim.Filter):
:arg list callbacks: list of callbacks to run when adjusting tag attributes;
defaults to ``bleach.linkifier.DEFAULT_CALLBACKS``
:arg list skip_tags: list of tags that you don't want to linkify the
contents of; for example, you could set this to ``['pre']`` to skip
:arg set skip_tags: set of tags that you don't want to linkify the
contents of; for example, you could set this to ``{'pre'}`` to skip
linkifying contents of ``pre`` tags
:arg bool parse_email: whether or not to linkify email addresses
@ -232,7 +236,7 @@ class LinkifyFilter(html5lib_shim.Filter):
super().__init__(source)
self.callbacks = callbacks or []
self.skip_tags = skip_tags or []
self.skip_tags = skip_tags or {}
self.parse_email = parse_email
self.url_re = url_re
@ -510,6 +514,62 @@ class LinkifyFilter(html5lib_shim.Filter):
yield {"type": "Characters", "data": str(new_text)}
yield token_buffer[-1]
def extract_entities(self, token):
"""Handles Characters tokens with entities
Our overridden tokenizer doesn't do anything with entities. However,
that means that the serializer will convert all ``&`` in Characters
tokens to ``&amp;``.
Since we don't want that, we extract entities here and convert them to
Entity tokens so the serializer will let them be.
:arg token: the Characters token to work on
:returns: generator of tokens
"""
data = token.get("data", "")
# If there isn't a & in the data, we can return now
if "&" not in data:
yield token
return
new_tokens = []
# For each possible entity that starts with a "&", we try to extract an
# actual entity and re-tokenize accordingly
for part in html5lib_shim.next_possible_entity(data):
if not part:
continue
if part.startswith("&"):
entity = html5lib_shim.match_entity(part)
if entity is not None:
if entity == "amp":
# LinkifyFilter can't match urls across token boundaries
# which is problematic with &amp; since that shows up in
# querystrings all the time. This special-cases &amp;
# and converts it to a & and sticks it in as a
# Characters token. It'll get merged with surrounding
# tokens in the BleachSanitizerfilter.__iter__ and
# escaped in the serializer.
new_tokens.append({"type": "Characters", "data": "&"})
else:
new_tokens.append({"type": "Entity", "name": entity})
# Length of the entity plus 2--one for & at the beginning
# and one for ; at the end
remainder = part[len(entity) + 2 :]
if remainder:
new_tokens.append({"type": "Characters", "data": remainder})
continue
new_tokens.append({"type": "Characters", "data": part})
yield from new_tokens
def __iter__(self):
in_a = False
in_skip_tag = None
@ -564,8 +624,8 @@ class LinkifyFilter(html5lib_shim.Filter):
new_stream = self.handle_links(new_stream)
for token in new_stream:
yield token
for new_token in new_stream:
yield from self.extract_entities(new_token)
# We've already yielded this token, so continue
continue

View file

@ -8,8 +8,9 @@ from bleach import html5lib_shim
from bleach import parse_shim
#: List of allowed tags
ALLOWED_TAGS = [
#: Set of allowed tags
ALLOWED_TAGS = frozenset(
(
"a",
"abbr",
"acronym",
@ -22,7 +23,8 @@ ALLOWED_TAGS = [
"ol",
"strong",
"ul",
]
)
)
#: Map of allowed attributes by tag
@ -33,7 +35,7 @@ ALLOWED_ATTRIBUTES = {
}
#: List of allowed protocols
ALLOWED_PROTOCOLS = ["http", "https", "mailto"]
ALLOWED_PROTOCOLS = frozenset(("http", "https", "mailto"))
#: Invisible characters--0 to and including 31 except 9 (tab), 10 (lf), and 13 (cr)
INVISIBLE_CHARACTERS = "".join(
@ -48,6 +50,10 @@ INVISIBLE_CHARACTERS_RE = re.compile("[" + INVISIBLE_CHARACTERS + "]", re.UNICOD
INVISIBLE_REPLACEMENT_CHAR = "?"
class NoCssSanitizerWarning(UserWarning):
pass
class Cleaner:
"""Cleaner for cleaning HTML fragments of malicious content
@ -89,7 +95,7 @@ class Cleaner:
):
"""Initializes a Cleaner
:arg list tags: allowed list of tags; defaults to
:arg set tags: set of allowed tags; defaults to
``bleach.sanitizer.ALLOWED_TAGS``
:arg dict attributes: allowed attributes; can be a callable, list or dict;
@ -143,6 +149,25 @@ class Cleaner:
alphabetical_attributes=False,
)
if css_sanitizer is None:
# FIXME(willkg): this doesn't handle when attributes or an
# attributes value is a callable
attributes_values = []
if isinstance(attributes, list):
attributes_values = attributes
elif isinstance(attributes, dict):
attributes_values = []
for values in attributes.values():
if isinstance(values, (list, tuple)):
attributes_values.extend(values)
if "style" in attributes_values:
warnings.warn(
"'style' attribute specified, but css_sanitizer not set.",
category=NoCssSanitizerWarning,
)
def clean(self, text):
"""Cleans text and returns sanitized result as unicode
@ -155,9 +180,8 @@ class Cleaner:
"""
if not isinstance(text, str):
message = (
"argument cannot be of '{name}' type, must be of text type".format(
name=text.__class__.__name__
)
f"argument cannot be of {text.__class__.__name__!r} type, "
+ "must be of text type"
)
raise TypeError(message)
@ -167,13 +191,11 @@ class Cleaner:
dom = self.parser.parseFragment(text)
filtered = BleachSanitizerFilter(
source=self.walker(dom),
# Bleach-sanitizer-specific things
allowed_tags=self.tags,
attributes=self.attributes,
strip_disallowed_elements=self.strip,
strip_disallowed_tags=self.strip,
strip_html_comments=self.strip_comments,
css_sanitizer=self.css_sanitizer,
# html5lib-sanitizer things
allowed_elements=self.tags,
allowed_protocols=self.protocols,
)
@ -237,19 +259,21 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
def __init__(
self,
source,
allowed_elements=ALLOWED_TAGS,
allowed_tags=ALLOWED_TAGS,
attributes=ALLOWED_ATTRIBUTES,
allowed_protocols=ALLOWED_PROTOCOLS,
strip_disallowed_elements=False,
attr_val_is_uri=html5lib_shim.attr_val_is_uri,
svg_attr_val_allows_ref=html5lib_shim.svg_attr_val_allows_ref,
svg_allow_local_href=html5lib_shim.svg_allow_local_href,
strip_disallowed_tags=False,
strip_html_comments=True,
css_sanitizer=None,
**kwargs,
):
"""Creates a BleachSanitizerFilter instance
:arg source: html5lib TreeWalker stream as an html5lib TreeWalker
:arg list allowed_elements: allowed list of tags; defaults to
:arg set allowed_tags: set of allowed tags; defaults to
``bleach.sanitizer.ALLOWED_TAGS``
:arg dict attributes: allowed attributes; can be a callable, list or dict;
@ -258,8 +282,16 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
:arg list allowed_protocols: allowed list of protocols for links; defaults
to ``bleach.sanitizer.ALLOWED_PROTOCOLS``
:arg bool strip_disallowed_elements: whether or not to strip disallowed
elements
:arg attr_val_is_uri: set of attributes that have URI values
:arg svg_attr_val_allows_ref: set of SVG attributes that can have
references
:arg svg_allow_local_href: set of SVG elements that can have local
hrefs
:arg bool strip_disallowed_tags: whether or not to strip disallowed
tags
:arg bool strip_html_comments: whether or not to strip HTML comments
@ -267,24 +299,24 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
sanitizing style attribute values and style text; defaults to None
"""
self.attr_filter = attribute_filter_factory(attributes)
self.strip_disallowed_elements = strip_disallowed_elements
self.strip_html_comments = strip_html_comments
self.css_sanitizer = css_sanitizer
# NOTE(willkg): This is the superclass of
# html5lib.filters.sanitizer.Filter. We call this directly skipping the
# __init__ for html5lib.filters.sanitizer.Filter because that does
# things we don't need to do and kicks up the deprecation warning for
# using Sanitizer.
html5lib_shim.Filter.__init__(self, source)
# filter out html5lib deprecation warnings to use bleach from BleachSanitizerFilter init
warnings.filterwarnings(
"ignore",
message="html5lib's sanitizer is deprecated",
category=DeprecationWarning,
module="bleach._vendor.html5lib",
)
return super().__init__(
source,
allowed_elements=allowed_elements,
allowed_protocols=allowed_protocols,
**kwargs,
)
self.allowed_tags = frozenset(allowed_tags)
self.allowed_protocols = frozenset(allowed_protocols)
self.attr_filter = attribute_filter_factory(attributes)
self.strip_disallowed_tags = strip_disallowed_tags
self.strip_html_comments = strip_html_comments
self.attr_val_is_uri = attr_val_is_uri
self.svg_attr_val_allows_ref = svg_attr_val_allows_ref
self.css_sanitizer = css_sanitizer
self.svg_allow_local_href = svg_allow_local_href
def sanitize_stream(self, token_iterator):
for token in token_iterator:
@ -354,10 +386,10 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
"""
token_type = token["type"]
if token_type in ["StartTag", "EndTag", "EmptyTag"]:
if token["name"] in self.allowed_elements:
if token["name"] in self.allowed_tags:
return self.allow_token(token)
elif self.strip_disallowed_elements:
elif self.strip_disallowed_tags:
return None
else:
@ -570,7 +602,7 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
def disallowed_token(self, token):
token_type = token["type"]
if token_type == "EndTag":
token["data"] = "</%s>" % token["name"]
token["data"] = f"</{token['name']}>"
elif token["data"]:
assert token_type in ("StartTag", "EmptyTag")
@ -586,25 +618,19 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
if ns is None or ns not in html5lib_shim.prefixes:
namespaced_name = name
else:
namespaced_name = "{}:{}".format(html5lib_shim.prefixes[ns], name)
namespaced_name = f"{html5lib_shim.prefixes[ns]}:{name}"
attrs.append(
' %s="%s"'
% (
namespaced_name,
# NOTE(willkg): HTMLSerializer escapes attribute values
# already, so if we do it here (like HTMLSerializer does),
# then we end up double-escaping.
v,
)
)
token["data"] = "<{}{}>".format(token["name"], "".join(attrs))
attrs.append(f' {namespaced_name}="{v}"')
token["data"] = f"<{token['name']}{''.join(attrs)}>"
else:
token["data"] = "<%s>" % token["name"]
token["data"] = f"<{token['name']}>"
if token.get("selfClosing"):
token["data"] = token["data"][:-1] + "/>"
token["data"] = f"{token['data'][:-1]}/>"
token["type"] = "Characters"

View file

@ -7,7 +7,7 @@ Beautiful Soup uses a pluggable XML or HTML parser to parse a
provides methods and Pythonic idioms that make it easy to navigate,
search, and modify the parse tree.
Beautiful Soup works with Python 3.5 and up. It works better if lxml
Beautiful Soup works with Python 3.6 and up. It works better if lxml
and/or html5lib is installed.
For more than you ever wanted to know about Beautiful Soup, see the
@ -15,8 +15,8 @@ documentation: http://www.crummy.com/software/BeautifulSoup/bs4/doc/
"""
__author__ = "Leonard Richardson (leonardr@segfault.org)"
__version__ = "4.11.1"
__copyright__ = "Copyright (c) 2004-2022 Leonard Richardson"
__version__ = "4.11.2"
__copyright__ = "Copyright (c) 2004-2023 Leonard Richardson"
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
@ -211,7 +211,7 @@ class BeautifulSoup(Tag):
warnings.warn(
'The "%s" argument to the BeautifulSoup constructor '
'has been renamed to "%s."' % (old_name, new_name),
DeprecationWarning
DeprecationWarning, stacklevel=3
)
return kwargs.pop(old_name)
return None
@ -405,7 +405,8 @@ class BeautifulSoup(Tag):
'The input looks more like a URL than markup. You may want to use'
' an HTTP client like requests to get the document behind'
' the URL, and feed that document to Beautiful Soup.',
MarkupResemblesLocatorWarning
MarkupResemblesLocatorWarning,
stacklevel=3
)
return True
return False
@ -436,7 +437,7 @@ class BeautifulSoup(Tag):
'The input looks more like a filename than markup. You may'
' want to open this file and pass the filehandle into'
' Beautiful Soup.',
MarkupResemblesLocatorWarning
MarkupResemblesLocatorWarning, stacklevel=3
)
return True
return False
@ -789,7 +790,7 @@ class BeautifulStoneSoup(BeautifulSoup):
warnings.warn(
'The BeautifulStoneSoup class is deprecated. Instead of using '
'it, pass features="xml" into the BeautifulSoup constructor.',
DeprecationWarning
DeprecationWarning, stacklevel=2
)
super(BeautifulStoneSoup, self).__init__(*args, **kwargs)

View file

@ -122,7 +122,7 @@ class TreeBuilder(object):
# A value for these tag/attribute combinations is a space- or
# comma-separated list of CDATA, rather than a single CDATA.
DEFAULT_CDATA_LIST_ATTRIBUTES = {}
DEFAULT_CDATA_LIST_ATTRIBUTES = defaultdict(list)
# Whitespace should be preserved inside these tags.
DEFAULT_PRESERVE_WHITESPACE_TAGS = set()

View file

@ -70,7 +70,10 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
# ATM because the html5lib TreeBuilder doesn't use
# UnicodeDammit.
if exclude_encodings:
warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.")
warnings.warn(
"You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.",
stacklevel=3
)
# html5lib only parses HTML, so if it's given XML that's worth
# noting.
@ -81,7 +84,10 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
# These methods are defined by Beautiful Soup.
def feed(self, markup):
if self.soup.parse_only is not None:
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
warnings.warn(
"You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.",
stacklevel=4
)
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
self.underlying_builder.parser = parser
extra_kwargs = dict()
@ -249,9 +255,9 @@ class AttrList(object):
# If this attribute is a multi-valued attribute for this element,
# turn its value into a list.
list_attr = self.element.cdata_list_attributes or {}
if (name in list_attr.get('*')
if (name in list_attr.get('*', [])
or (self.element.name in list_attr
and name in list_attr[self.element.name])):
and name in list_attr.get(self.element.name, []))):
# A node that is being cloned may have already undergone
# this procedure.
if not isinstance(value, list):

View file

@ -10,30 +10,9 @@ __all__ = [
from html.parser import HTMLParser
try:
from html.parser import HTMLParseError
except ImportError as e:
# HTMLParseError is removed in Python 3.5. Since it can never be
# thrown in 3.5, we can just define our own class as a placeholder.
class HTMLParseError(Exception):
pass
import sys
import warnings
# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
# argument, which we'd like to set to False. Unfortunately,
# http://bugs.python.org/issue13273 makes strict=True a better bet
# before Python 3.2.3.
#
# At the end of this file, we monkeypatch HTMLParser so that
# strict=True works well on Python 3.2.2.
major, minor, release = sys.version_info[:3]
CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
from bs4.element import (
CData,
Comment,
@ -91,19 +70,6 @@ class BeautifulSoupHTMLParser(HTMLParser, DetectsXMLParsedAsHTML):
self._initialize_xml_detector()
def error(self, msg):
"""In Python 3, HTMLParser subclasses must implement error(), although
this requirement doesn't appear to be documented.
In Python 2, HTMLParser implements error() by raising an exception,
which we don't want to do.
In any event, this method is called only on very strange
markup and our best strategy is to pretend it didn't happen
and keep going.
"""
warnings.warn(msg)
def handle_startendtag(self, name, attrs):
"""Handle an incoming empty-element tag.
@ -203,9 +169,10 @@ class BeautifulSoupHTMLParser(HTMLParser, DetectsXMLParsedAsHTML):
:param name: Character number, possibly in hexadecimal.
"""
# XXX workaround for a bug in HTMLParser. Remove this once
# it's fixed in all supported versions.
# http://bugs.python.org/issue13633
# TODO: This was originally a workaround for a bug in
# HTMLParser. (http://bugs.python.org/issue13633) The bug has
# been fixed, but removing this code still makes some
# Beautiful Soup tests fail. This needs investigation.
if name.startswith('x'):
real_name = int(name.lstrip('x'), 16)
elif name.startswith('X'):
@ -333,9 +300,6 @@ class HTMLParserTreeBuilder(HTMLTreeBuilder):
parser_args = parser_args or []
parser_kwargs = parser_kwargs or {}
parser_kwargs.update(extra_parser_kwargs)
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
parser_kwargs['strict'] = False
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
parser_kwargs['convert_charrefs'] = False
self.parser_args = (parser_args, parser_kwargs)
@ -395,105 +359,6 @@ class HTMLParserTreeBuilder(HTMLTreeBuilder):
args, kwargs = self.parser_args
parser = BeautifulSoupHTMLParser(*args, **kwargs)
parser.soup = self.soup
try:
parser.feed(markup)
parser.close()
except HTMLParseError as e:
warnings.warn(RuntimeWarning(
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
raise e
parser.already_closed_empty_element = []
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
# string.
#
# XXX This code can be removed once most Python 3 users are on 3.2.3.
if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
import re
attrfind_tolerant = re.compile(
r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
locatestarttagend = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:\s+ # whitespace before attribute name
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
(?:\s*=\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|\"[^\"]*\" # LIT-enclosed value
|[^'\">\s]+ # bare value
)
)?
)
)*
\s* # trailing whitespace
""", re.VERBOSE)
BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
from html.parser import tagfind, attrfind
def parse_starttag(self, i):
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between i+1 and j into a tag and attrs
attrs = []
match = tagfind.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = rawdata[i+1:k].lower()
while k < endpos:
if self.strict:
m = attrfind.match(rawdata, k)
else:
m = attrfind_tolerant.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
if attrvalue:
attrvalue = self.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n")
else:
offset = offset + len(self.__starttag_text)
if self.strict:
self.error("junk characters in start tag: %r"
% (rawdata[k:endpos][:20],))
self.handle_data(rawdata[i:endpos])
return endpos
if end.endswith('/>'):
# XHTML-style empty tag: <span attr="value" />
self.handle_startendtag(tag, attrs)
else:
self.handle_starttag(tag, attrs)
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode(tag)
return endpos
def set_cdata_mode(self, elem):
self.cdata_elem = elem.lower()
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
BeautifulSoupHTMLParser.parse_starttag = parse_starttag
BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
CONSTRUCTOR_TAKES_STRICT = True

View file

@ -496,13 +496,16 @@ class PageElement(object):
def extend(self, tags):
"""Appends the given PageElements to this one's contents.
:param tags: A list of PageElements.
:param tags: A list of PageElements. If a single Tag is
provided instead, this PageElement's contents will be extended
with that Tag's contents.
"""
if isinstance(tags, Tag):
# Calling self.append() on another tag's contents will change
# the list we're iterating over. Make a list that won't
# change.
tags = list(tags.contents)
tags = tags.contents
if isinstance(tags, list):
# Moving items around the tree may change their position in
# the original list. Make a list that won't change.
tags = list(tags)
for tag in tags:
self.append(tag)
@ -586,8 +589,9 @@ class PageElement(object):
:kwargs: A dictionary of filters on attribute values.
:return: A ResultSet containing PageElements.
"""
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(name, attrs, string, limit, self.next_elements,
**kwargs)
_stacklevel=_stacklevel+1, **kwargs)
findAllNext = find_all_next # BS3
def find_next_sibling(self, name=None, attrs={}, string=None, **kwargs):
@ -624,8 +628,11 @@ class PageElement(object):
:return: A ResultSet of PageElements.
:rtype: bs4.element.ResultSet
"""
return self._find_all(name, attrs, string, limit,
self.next_siblings, **kwargs)
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(
name, attrs, string, limit,
self.next_siblings, _stacklevel=_stacklevel+1, **kwargs
)
findNextSiblings = find_next_siblings # BS3
fetchNextSiblings = find_next_siblings # BS2
@ -663,8 +670,11 @@ class PageElement(object):
:return: A ResultSet of PageElements.
:rtype: bs4.element.ResultSet
"""
return self._find_all(name, attrs, string, limit, self.previous_elements,
**kwargs)
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(
name, attrs, string, limit, self.previous_elements,
_stacklevel=_stacklevel+1, **kwargs
)
findAllPrevious = find_all_previous # BS3
fetchPrevious = find_all_previous # BS2
@ -702,8 +712,11 @@ class PageElement(object):
:return: A ResultSet of PageElements.
:rtype: bs4.element.ResultSet
"""
return self._find_all(name, attrs, string, limit,
self.previous_siblings, **kwargs)
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(
name, attrs, string, limit,
self.previous_siblings, _stacklevel=_stacklevel+1, **kwargs
)
findPreviousSiblings = find_previous_siblings # BS3
fetchPreviousSiblings = find_previous_siblings # BS2
@ -724,7 +737,7 @@ class PageElement(object):
# NOTE: We can't use _find_one because findParents takes a different
# set of arguments.
r = None
l = self.find_parents(name, attrs, 1, **kwargs)
l = self.find_parents(name, attrs, 1, _stacklevel=3, **kwargs)
if l:
r = l[0]
return r
@ -744,8 +757,9 @@ class PageElement(object):
:return: A PageElement.
:rtype: bs4.element.Tag | bs4.element.NavigableString
"""
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(name, attrs, None, limit, self.parents,
**kwargs)
_stacklevel=_stacklevel+1, **kwargs)
findParents = find_parents # BS3
fetchParents = find_parents # BS2
@ -771,19 +785,20 @@ class PageElement(object):
def _find_one(self, method, name, attrs, string, **kwargs):
r = None
l = method(name, attrs, string, 1, **kwargs)
l = method(name, attrs, string, 1, _stacklevel=4, **kwargs)
if l:
r = l[0]
return r
def _find_all(self, name, attrs, string, limit, generator, **kwargs):
"Iterates over a generator looking for things that match."
_stacklevel = kwargs.pop('_stacklevel', 3)
if string is None and 'text' in kwargs:
string = kwargs.pop('text')
warnings.warn(
"The 'text' argument to find()-type methods is deprecated. Use 'string' instead.",
DeprecationWarning
DeprecationWarning, stacklevel=_stacklevel
)
if isinstance(name, SoupStrainer):
@ -1306,7 +1321,8 @@ class Tag(PageElement):
sourceline=self.sourceline, sourcepos=self.sourcepos,
can_be_empty_element=self.can_be_empty_element,
cdata_list_attributes=self.cdata_list_attributes,
preserve_whitespace_tags=self.preserve_whitespace_tags
preserve_whitespace_tags=self.preserve_whitespace_tags,
interesting_string_types=self.interesting_string_types
)
for attr in ('can_be_empty_element', 'hidden'):
setattr(clone, attr, getattr(self, attr))
@ -1558,7 +1574,7 @@ class Tag(PageElement):
'.%(name)sTag is deprecated, use .find("%(name)s") instead. If you really were looking for a tag called %(name)sTag, use .find("%(name)sTag")' % dict(
name=tag_name
),
DeprecationWarning
DeprecationWarning, stacklevel=2
)
return self.find(tag_name)
# We special case contents to avoid recursion.
@ -1862,7 +1878,8 @@ class Tag(PageElement):
:rtype: bs4.element.Tag | bs4.element.NavigableString
"""
r = None
l = self.find_all(name, attrs, recursive, string, 1, **kwargs)
l = self.find_all(name, attrs, recursive, string, 1, _stacklevel=3,
**kwargs)
if l:
r = l[0]
return r
@ -1889,7 +1906,9 @@ class Tag(PageElement):
generator = self.descendants
if not recursive:
generator = self.children
return self._find_all(name, attrs, string, limit, generator, **kwargs)
_stacklevel = kwargs.pop('_stacklevel', 2)
return self._find_all(name, attrs, string, limit, generator,
_stacklevel=_stacklevel+1, **kwargs)
findAll = find_all # BS3
findChildren = find_all # BS2
@ -1993,7 +2012,7 @@ class Tag(PageElement):
"""
warnings.warn(
'has_key is deprecated. Use has_attr(key) instead.',
DeprecationWarning
DeprecationWarning, stacklevel=2
)
return self.has_attr(key)
@ -2024,7 +2043,7 @@ class SoupStrainer(object):
string = kwargs.pop('text')
warnings.warn(
"The 'text' argument to the SoupStrainer constructor is deprecated. Use 'string' instead.",
DeprecationWarning
DeprecationWarning, stacklevel=2
)
self.name = self._normalize_search_value(name)

View file

@ -149,14 +149,14 @@ class HTMLFormatter(Formatter):
"""A generic Formatter for HTML."""
REGISTRY = {}
def __init__(self, *args, **kwargs):
return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
class XMLFormatter(Formatter):
"""A generic Formatter for XML."""
REGISTRY = {}
def __init__(self, *args, **kwargs):
return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
# Set up aliases for the default formatters.

View file

@ -29,6 +29,29 @@ from bs4.builder import (
)
default_builder = HTMLParserTreeBuilder
# Some tests depend on specific third-party libraries. We use
# @pytest.mark.skipIf on the following conditionals to skip them
# if the libraries are not installed.
try:
from soupsieve import SelectorSyntaxError
SOUP_SIEVE_PRESENT = True
except ImportError:
SOUP_SIEVE_PRESENT = False
try:
import html5lib
HTML5LIB_PRESENT = True
except ImportError:
HTML5LIB_PRESENT = False
try:
import lxml.etree
LXML_PRESENT = True
LXML_VERSION = lxml.etree.LXML_VERSION
except ImportError:
LXML_PRESENT = False
LXML_VERSION = (0,)
BAD_DOCUMENT = """A bare string
<!DOCTYPE xsl:stylesheet SYSTEM "htmlent.dtd">
<!DOCTYPE xsl:stylesheet PUBLIC "htmlent.dtd">
@ -258,10 +281,10 @@ class TreeBuilderSmokeTest(object):
@pytest.mark.parametrize(
"multi_valued_attributes",
[None, dict(b=['class']), {'*': ['notclass']}]
[None, {}, dict(b=['class']), {'*': ['notclass']}]
)
def test_attribute_not_multi_valued(self, multi_valued_attributes):
markup = '<a class="a b c">'
markup = '<html xmlns="http://www.w3.org/1999/xhtml"><a class="a b c"></html>'
soup = self.soup(markup, multi_valued_attributes=multi_valued_attributes)
assert soup.a['class'] == 'a b c'
@ -820,26 +843,27 @@ Hello, world!
soup = self.soup(text)
assert soup.p.encode("utf-8") == expected
def test_real_iso_latin_document(self):
def test_real_iso_8859_document(self):
# Smoke test of interrelated functionality, using an
# easy-to-understand document.
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
unicode_html = '<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
# Here it is in Unicode. Note that it claims to be in ISO-8859-1.
unicode_html = '<html><head><meta content="text/html; charset=ISO-8859-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
# That's because we're going to encode it into ISO-Latin-1, and use
# that to test.
# That's because we're going to encode it into ISO-8859-1,
# and use that to test.
iso_latin_html = unicode_html.encode("iso-8859-1")
# Parse the ISO-Latin-1 HTML.
# Parse the ISO-8859-1 HTML.
soup = self.soup(iso_latin_html)
# Encode it to UTF-8.
result = soup.encode("utf-8")
# What do we expect the result to look like? Well, it would
# look like unicode_html, except that the META tag would say
# UTF-8 instead of ISO-Latin-1.
expected = unicode_html.replace("ISO-Latin-1", "utf-8")
# UTF-8 instead of ISO-8859-1.
expected = unicode_html.replace("ISO-8859-1", "utf-8")
# And, of course, it would be in UTF-8, not Unicode.
expected = expected.encode("utf-8")
@ -1177,15 +1201,3 @@ class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
assert isinstance(soup.contents[0], Comment)
assert soup.contents[0] == '?xml version="1.0" encoding="utf-8"?'
assert "html" == soup.contents[0].next_element.name
def skipIf(condition, reason):
def nothing(test, *args, **kwargs):
return None
def decorator(test_item):
if condition:
return nothing
else:
return test_item
return decorator

View file

@ -10,22 +10,23 @@ from bs4.builder import (
TreeBuilderRegistry,
)
try:
from bs4.builder import HTML5TreeBuilder
HTML5LIB_PRESENT = True
except ImportError:
HTML5LIB_PRESENT = False
from . import (
HTML5LIB_PRESENT,
LXML_PRESENT,
)
try:
if HTML5LIB_PRESENT:
from bs4.builder import HTML5TreeBuilder
if LXML_PRESENT:
from bs4.builder import (
LXMLTreeBuilderForXML,
LXMLTreeBuilder,
)
LXML_PRESENT = True
except ImportError:
LXML_PRESENT = False
# TODO: Split out the lxml and html5lib tests into their own classes
# and gate with pytest.mark.skipIf.
class TestBuiltInRegistry(object):
"""Test the built-in registry with the default builders registered."""

View file

@ -17,25 +17,23 @@ class TestUnicodeDammit(object):
dammit = UnicodeDammit(markup)
assert dammit.unicode_markup == markup
def test_smart_quotes_to_unicode(self):
@pytest.mark.parametrize(
"smart_quotes_to,expect_converted",
[(None, "\u2018\u2019\u201c\u201d"),
("xml", "&#x2018;&#x2019;&#x201C;&#x201D;"),
("html", "&lsquo;&rsquo;&ldquo;&rdquo;"),
("ascii", "''" + '""'),
]
)
def test_smart_quotes_to(self, smart_quotes_to, expect_converted):
"""Verify the functionality of the smart_quotes_to argument
to the UnicodeDammit constructor."""
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup)
assert dammit.unicode_markup == "<foo>\u2018\u2019\u201c\u201d</foo>"
def test_smart_quotes_to_xml_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="xml")
assert dammit.unicode_markup == "<foo>&#x2018;&#x2019;&#x201C;&#x201D;</foo>"
def test_smart_quotes_to_html_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="html")
assert dammit.unicode_markup == "<foo>&lsquo;&rsquo;&ldquo;&rdquo;</foo>"
def test_smart_quotes_to_ascii(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="ascii")
assert dammit.unicode_markup == """<foo>''""</foo>"""
converted = UnicodeDammit(
markup, known_definite_encodings=["windows-1252"],
smart_quotes_to=smart_quotes_to
).unicode_markup
assert converted == "<foo>{}</foo>".format(expect_converted)
def test_detect_utf8(self):
utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
@ -275,23 +273,24 @@ class TestEntitySubstitution(object):
def setup_method(self):
self.sub = EntitySubstitution
def test_simple_html_substitution(self):
# Unicode characters corresponding to named HTML entites
# are substituted, and no others.
s = "foo\u2200\N{SNOWMAN}\u00f5bar"
assert self.sub.substitute_html(s) == "foo&forall;\N{SNOWMAN}&otilde;bar"
def test_smart_quote_substitution(self):
@pytest.mark.parametrize(
"original,substituted",
[
# Basic case. Unicode characters corresponding to named
# HTML entites are substituted; others are not.
("foo\u2200\N{SNOWMAN}\u00f5bar",
"foo&forall;\N{SNOWMAN}&otilde;bar"),
# MS smart quotes are a common source of frustration, so we
# give them a special test.
quotes = b"\x91\x92foo\x93\x94"
dammit = UnicodeDammit(quotes)
assert self.sub.substitute_html(dammit.markup) == "&lsquo;&rsquo;foo&ldquo;&rdquo;"
('foo“”', "&lsquo;&rsquo;foo&ldquo;&rdquo;"),
]
)
def test_substitute_html(self, original, substituted):
assert self.sub.substitute_html(original) == substituted
def test_html5_entity(self):
# Some HTML5 entities correspond to single- or multi-character
# Unicode sequences.
for entity, u in (
# A few spot checks of our ability to recognize
# special character sequences and convert them

View file

@ -1,27 +1,26 @@
"""Tests to ensure that the html5lib tree builder generates good trees."""
import pytest
import warnings
try:
from bs4.builder import HTML5TreeBuilder
HTML5LIB_PRESENT = True
except ImportError as e:
HTML5LIB_PRESENT = False
from bs4 import BeautifulSoup
from bs4.element import SoupStrainer
from . import (
HTML5LIB_PRESENT,
HTML5TreeBuilderSmokeTest,
SoupTest,
skipIf,
)
@skipIf(
@pytest.mark.skipif(
not HTML5LIB_PRESENT,
"html5lib seems not to be present, not testing its tree builder.")
reason="html5lib seems not to be present, not testing its tree builder."
)
class TestHTML5LibBuilder(SoupTest, HTML5TreeBuilderSmokeTest):
"""See ``HTML5TreeBuilderSmokeTest``."""
@property
def default_builder(self):
from bs4.builder import HTML5TreeBuilder
return HTML5TreeBuilder
def test_soupstrainer(self):
@ -29,10 +28,12 @@ class TestHTML5LibBuilder(SoupTest, HTML5TreeBuilderSmokeTest):
strainer = SoupStrainer("b")
markup = "<p>A <b>bold</b> statement.</p>"
with warnings.catch_warnings(record=True) as w:
soup = self.soup(markup, parse_only=strainer)
soup = BeautifulSoup(markup, "html5lib", parse_only=strainer)
assert soup.decode() == self.document_for(markup)
assert "the html5lib tree builder doesn't support parse_only" in str(w[0].message)
[warning] = w
assert warning.filename == __file__
assert "the html5lib tree builder doesn't support parse_only" in str(warning.message)
def test_correctly_nested_tables(self):
"""html5lib inserts <tbody> tags where other parsers don't."""

View file

@ -122,15 +122,3 @@ class TestHTMLParserTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
with_element = div.encode(formatter="html")
expect = b"<div>%s</div>" % output_element
assert with_element == expect
class TestHTMLParserSubclass(SoupTest):
def test_error(self):
"""Verify that our HTMLParser subclass implements error() in a way
that doesn't cause a crash.
"""
parser = BeautifulSoupHTMLParser()
with warnings.catch_warnings(record=True) as warns:
parser.error("don't crash")
[warning] = warns
assert "don't crash" == str(warning.message)

View file

@ -1,16 +1,10 @@
"""Tests to ensure that the lxml tree builder generates good trees."""
import pickle
import pytest
import re
import warnings
try:
import lxml.etree
LXML_PRESENT = True
LXML_VERSION = lxml.etree.LXML_VERSION
except ImportError as e:
LXML_PRESENT = False
LXML_VERSION = (0,)
from . import LXML_PRESENT, LXML_VERSION
if LXML_PRESENT:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
@ -23,13 +17,14 @@ from bs4.element import Comment, Doctype, SoupStrainer
from . import (
HTMLTreeBuilderSmokeTest,
XMLTreeBuilderSmokeTest,
SOUP_SIEVE_PRESENT,
SoupTest,
skipIf,
)
@skipIf(
@pytest.mark.skipif(
not LXML_PRESENT,
"lxml seems not to be present, not testing its tree builder.")
reason="lxml seems not to be present, not testing its tree builder."
)
class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
"""See ``HTMLTreeBuilderSmokeTest``."""
@ -54,9 +49,10 @@ class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
# In lxml < 2.3.5, an empty doctype causes a segfault. Skip this
# test if an old version of lxml is installed.
@skipIf(
@pytest.mark.skipif(
not LXML_PRESENT or LXML_VERSION < (2,3,5,0),
"Skipping doctype test for old version of lxml to avoid segfault.")
reason="Skipping doctype test for old version of lxml to avoid segfault."
)
def test_empty_doctype(self):
soup = self.soup("<!DOCTYPE>")
doctype = soup.contents[0]
@ -68,7 +64,9 @@ class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
with warnings.catch_warnings(record=True) as w:
soup = BeautifulStoneSoup("<b />")
assert "<b/>" == str(soup.b)
assert "BeautifulStoneSoup class is deprecated" in str(w[0].message)
[warning] = w
assert warning.filename == __file__
assert "BeautifulStoneSoup class is deprecated" in str(warning.message)
def test_tracking_line_numbers(self):
# The lxml TreeBuilder cannot keep track of line numbers from
@ -85,9 +83,10 @@ class TestLXMLTreeBuilder(SoupTest, HTMLTreeBuilderSmokeTest):
assert "sourceline" == soup.p.sourceline.name
assert "sourcepos" == soup.p.sourcepos.name
@skipIf(
@pytest.mark.skipif(
not LXML_PRESENT,
"lxml seems not to be present, not testing its XML tree builder.")
reason="lxml seems not to be present, not testing its XML tree builder."
)
class TestLXMLXMLTreeBuilder(SoupTest, XMLTreeBuilderSmokeTest):
"""See ``HTMLTreeBuilderSmokeTest``."""
@ -148,6 +147,9 @@ class TestLXMLXMLTreeBuilder(SoupTest, XMLTreeBuilderSmokeTest):
}
@pytest.mark.skipif(
not SOUP_SIEVE_PRESENT, reason="Soup Sieve not installed"
)
def test_namespace_interaction_with_select_and_find(self):
# Demonstrate how namespaces interact with select* and
# find* methods.

View file

@ -3,15 +3,18 @@ import copy
import pickle
import pytest
from soupsieve import SelectorSyntaxError
from bs4 import BeautifulSoup
from bs4.element import (
Comment,
SoupStrainer,
)
from . import SoupTest
from . import (
SoupTest,
SOUP_SIEVE_PRESENT,
)
if SOUP_SIEVE_PRESENT:
from soupsieve import SelectorSyntaxError
class TestEncoding(SoupTest):
"""Test the ability to encode objects into strings."""
@ -213,6 +216,7 @@ class TestFormatters(SoupTest):
assert soup.contents[0].name == 'pre'
@pytest.mark.skipif(not SOUP_SIEVE_PRESENT, reason="Soup Sieve not installed")
class TestCSSSelectors(SoupTest):
"""Test basic CSS selector functionality.
@ -694,6 +698,7 @@ class TestPersistence(SoupTest):
assert tag.can_be_empty_element == copied.can_be_empty_element
assert tag.cdata_list_attributes == copied.cdata_list_attributes
assert tag.preserve_whitespace_tags == copied.preserve_whitespace_tags
assert tag.interesting_string_types == copied.interesting_string_types
def test_unicode_pickle(self):
# A tree containing Unicode characters can be pickled.

View file

@ -30,19 +30,11 @@ from bs4.element import (
from . import (
default_builder,
LXML_PRESENT,
SoupTest,
skipIf,
)
import warnings
try:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
LXML_PRESENT = True
except ImportError as e:
LXML_PRESENT = False
PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2))
class TestConstructor(SoupTest):
def test_short_unicode_input(self):
@ -139,7 +131,7 @@ class TestConstructor(SoupTest):
assert " an id " == a['id']
assert ["a", "class"] == a['class']
# TreeBuilder takes an argument called 'mutli_valued_attributes' which lets
# TreeBuilder takes an argument called 'multi_valued_attributes' which lets
# you customize or disable this. As always, you can customize the TreeBuilder
# by passing in a keyword argument to the BeautifulSoup constructor.
soup = self.soup(markup, builder=default_builder, multi_valued_attributes=None)
@ -219,10 +211,17 @@ class TestConstructor(SoupTest):
class TestWarnings(SoupTest):
# Note that some of the tests in this class create BeautifulSoup
# objects directly rather than using self.soup(). That's
# because SoupTest.soup is defined in a different file,
# which will throw off the assertion in _assert_warning
# that the code that triggered the warning is in the same
# file as the test.
def _assert_warning(self, warnings, cls):
for w in warnings:
if isinstance(w.message, cls):
assert w.filename == __file__
return w
raise Exception("%s warning not found in %r" % (cls, warnings))
@ -243,13 +242,17 @@ class TestWarnings(SoupTest):
def test_no_warning_if_explicit_parser_specified(self):
with warnings.catch_warnings(record=True) as w:
soup = BeautifulSoup("<a><b></b></a>", "html.parser")
soup = self.soup("<a><b></b></a>")
assert [] == w
def test_parseOnlyThese_renamed_to_parse_only(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", parseOnlyThese=SoupStrainer("b"))
msg = str(w[0].message)
soup = BeautifulSoup(
"<a><b></b></a>", "html.parser",
parseOnlyThese=SoupStrainer("b"),
)
warning = self._assert_warning(w, DeprecationWarning)
msg = str(warning.message)
assert "parseOnlyThese" in msg
assert "parse_only" in msg
assert b"<b></b>" == soup.encode()
@ -257,8 +260,11 @@ class TestWarnings(SoupTest):
def test_fromEncoding_renamed_to_from_encoding(self):
with warnings.catch_warnings(record=True) as w:
utf8 = b"\xc3\xa9"
soup = self.soup(utf8, fromEncoding="utf8")
msg = str(w[0].message)
soup = BeautifulSoup(
utf8, "html.parser", fromEncoding="utf8"
)
warning = self._assert_warning(w, DeprecationWarning)
msg = str(warning.message)
assert "fromEncoding" in msg
assert "from_encoding" in msg
assert "utf8" == soup.original_encoding
@ -276,7 +282,7 @@ class TestWarnings(SoupTest):
# A warning is issued if the "markup" looks like the name of
# an HTML or text file, or a full path to a file on disk.
with warnings.catch_warnings(record=True) as w:
soup = self.soup("markup" + extension)
soup = BeautifulSoup("markup" + extension, "html.parser")
warning = self._assert_warning(w, MarkupResemblesLocatorWarning)
assert "looks more like a filename" in str(warning.message)
@ -295,7 +301,7 @@ class TestWarnings(SoupTest):
def test_url_warning_with_bytes_url(self):
url = b"http://www.crummybytes.com/"
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(url)
soup = BeautifulSoup(url, "html.parser")
warning = self._assert_warning(
warning_list, MarkupResemblesLocatorWarning
)
@ -307,7 +313,7 @@ class TestWarnings(SoupTest):
with warnings.catch_warnings(record=True) as warning_list:
# note - this url must differ from the bytes one otherwise
# python's warnings system swallows the second warning
soup = self.soup(url)
soup = BeautifulSoup(url, "html.parser")
warning = self._assert_warning(
warning_list, MarkupResemblesLocatorWarning
)
@ -348,9 +354,12 @@ class TestNewTag(SoupTest):
assert dict(bar="baz", name="a name") == new_tag.attrs
assert None == new_tag.parent
def test_tag_inherits_self_closing_rules_from_builder(self):
if LXML_PRESENT:
xml_soup = BeautifulSoup("", "lxml-xml")
@pytest.mark.skipif(
not LXML_PRESENT,
reason="lxml not installed, cannot parse XML document"
)
def test_xml_tag_inherits_self_closing_rules_from_builder(self):
xml_soup = BeautifulSoup("", "xml")
xml_br = xml_soup.new_tag("br")
xml_p = xml_soup.new_tag("p")
@ -359,6 +368,7 @@ class TestNewTag(SoupTest):
assert b"<br/>" == xml_br.encode()
assert b"<p/>" == xml_p.encode()
def test_tag_inherits_self_closing_rules_from_builder(self):
html_soup = BeautifulSoup("", "html.parser")
html_br = html_soup.new_tag("br")
html_p = html_soup.new_tag("p")
@ -450,13 +460,3 @@ class TestEncodingConversion(SoupTest):
# The internal data structures can be encoded as UTF-8.
soup_from_unicode = self.soup(self.unicode_data)
assert soup_from_unicode.encode('utf-8') == self.utf8_data
@skipIf(
PYTHON_3_PRE_3_2,
"Bad HTMLParser detected; skipping test of non-ASCII characters in attribute name.")
def test_attribute_name_containing_unicode_characters(self):
markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
assert self.soup(markup).div.encode("utf8") == markup.encode("utf8")

View file

@ -33,7 +33,6 @@ from bs4.element import (
)
from . import (
SoupTest,
skipIf,
)
class TestFind(SoupTest):
@ -910,12 +909,16 @@ class TestTreeModification(SoupTest):
soup.a.extend(l)
assert "<a><g></g><f></f><e></e><d></d><c></c><b></b></a>" == soup.decode()
def test_extend_with_another_tags_contents(self):
@pytest.mark.parametrize(
"get_tags", [lambda tag: tag, lambda tag: tag.contents]
)
def test_extend_with_another_tags_contents(self, get_tags):
data = '<body><div id="d1"><a>1</a><a>2</a><a>3</a><a>4</a></div><div id="d2"></div></body>'
soup = self.soup(data)
d1 = soup.find('div', id='d1')
d2 = soup.find('div', id='d2')
d2.extend(d1)
tags = get_tags(d1)
d2.extend(tags)
assert '<div id="d1"></div>' == d1.decode()
assert '<div id="d2"><a>1</a><a>2</a><a>3</a><a>4</a></div>' == d2.decode()
@ -1272,19 +1275,30 @@ class TestTreeModification(SoupTest):
class TestDeprecatedArguments(SoupTest):
def test_find_type_method_string(self):
@pytest.mark.parametrize(
"method_name", [
"find", "find_all", "find_parent", "find_parents",
"find_next", "find_all_next", "find_previous",
"find_all_previous", "find_next_sibling", "find_next_siblings",
"find_previous_sibling", "find_previous_siblings",
]
)
def test_find_type_method_string(self, method_name):
soup = self.soup("<a>some</a><b>markup</b>")
method = getattr(soup.b, method_name)
with warnings.catch_warnings(record=True) as w:
[result] = soup.find_all(text='markup')
assert result == 'markup'
assert result.parent.name == 'b'
msg = str(w[0].message)
method(text='markup')
[warning] = w
assert warning.filename == __file__
msg = str(warning.message)
assert msg == "The 'text' argument to find()-type methods is deprecated. Use 'string' instead."
def test_soupstrainer_constructor_string(self):
with warnings.catch_warnings(record=True) as w:
strainer = SoupStrainer(text="text")
assert strainer.text == 'text'
msg = str(w[0].message)
[warning] = w
msg = str(warning.message)
assert warning.filename == __file__
assert msg == "The 'text' argument to the SoupStrainer constructor is deprecated. Use 'string' instead."

View file

@ -21,14 +21,8 @@ at <https://github.com/Ousret/charset_normalizer>.
"""
import logging
from .api import from_bytes, from_fp, from_path, normalize
from .legacy import (
CharsetDetector,
CharsetDoctor,
CharsetNormalizerMatch,
CharsetNormalizerMatches,
detect,
)
from .api import from_bytes, from_fp, from_path
from .legacy import detect
from .models import CharsetMatch, CharsetMatches
from .utils import set_logging_handler
from .version import VERSION, __version__
@ -37,14 +31,9 @@ __all__ = (
"from_fp",
"from_path",
"from_bytes",
"normalize",
"detect",
"CharsetMatch",
"CharsetMatches",
"CharsetNormalizerMatch",
"CharsetNormalizerMatches",
"CharsetDetector",
"CharsetDoctor",
"__version__",
"VERSION",
"set_logging_handler",

View file

@ -1,7 +1,5 @@
import logging
import warnings
from os import PathLike
from os.path import basename, splitext
from typing import Any, BinaryIO, List, Optional, Set
from .cd import (
@ -41,11 +39,12 @@ def from_bytes(
cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True,
explain: bool = False,
language_threshold: float = 0.1,
) -> CharsetMatches:
"""
Given a raw bytes sequence, return the best possibles charset usable to render str objects.
If there is no results, it is a strong indicator that the source is binary/not text.
By default, the process will extract 5 blocs of 512o each to assess the mess and coherence of a given sequence.
By default, the process will extract 5 blocks of 512o each to assess the mess and coherence of a given sequence.
And will give up a particular code page after 20% of measured mess. Those criteria are customizable at will.
The preemptive behavior DOES NOT replace the traditional detection workflow, it prioritize a particular code page
@ -176,7 +175,6 @@ def from_bytes(
prioritized_encodings.append("utf_8")
for encoding_iana in prioritized_encodings + IANA_SUPPORTED:
if cp_isolation and encoding_iana not in cp_isolation:
continue
@ -197,7 +195,14 @@ def from_bytes(
if encoding_iana in {"utf_16", "utf_32"} and not bom_or_sig_available:
logger.log(
TRACE,
"Encoding %s wont be tested as-is because it require a BOM. Will try some sub-encoder LE/BE.",
"Encoding %s won't be tested as-is because it require a BOM. Will try some sub-encoder LE/BE.",
encoding_iana,
)
continue
if encoding_iana in {"utf_7"} and not bom_or_sig_available:
logger.log(
TRACE,
"Encoding %s won't be tested as-is because detection is unreliable without BOM/SIG.",
encoding_iana,
)
continue
@ -297,7 +302,13 @@ def from_bytes(
):
md_chunks.append(chunk)
md_ratios.append(mess_ratio(chunk, threshold))
md_ratios.append(
mess_ratio(
chunk,
threshold,
explain is True and 1 <= len(cp_isolation) <= 2,
)
)
if md_ratios[-1] >= threshold:
early_stop_count += 1
@ -306,7 +317,9 @@ def from_bytes(
bom_or_sig_available and strip_sig_or_bom is False
):
break
except UnicodeDecodeError as e: # Lazy str loading may have missed something there
except (
UnicodeDecodeError
) as e: # Lazy str loading may have missed something there
logger.log(
TRACE,
"LazyStr Loading: After MD chunk decode, code page %s does not fit given bytes sequence at ALL. %s",
@ -389,7 +402,9 @@ def from_bytes(
if encoding_iana != "ascii":
for chunk in md_chunks:
chunk_languages = coherence_ratio(
chunk, 0.1, ",".join(target_languages) if target_languages else None
chunk,
language_threshold,
",".join(target_languages) if target_languages else None,
)
cd_ratios.append(chunk_languages)
@ -491,6 +506,7 @@ def from_fp(
cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True,
explain: bool = False,
language_threshold: float = 0.1,
) -> CharsetMatches:
"""
Same thing than the function from_bytes but using a file pointer that is already ready.
@ -505,6 +521,7 @@ def from_fp(
cp_exclusion,
preemptive_behaviour,
explain,
language_threshold,
)
@ -517,6 +534,7 @@ def from_path(
cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True,
explain: bool = False,
language_threshold: float = 0.1,
) -> CharsetMatches:
"""
Same thing than the function from_bytes but with one extra step. Opening and reading given file path in binary mode.
@ -532,53 +550,5 @@ def from_path(
cp_exclusion,
preemptive_behaviour,
explain,
language_threshold,
)
def normalize(
path: "PathLike[Any]",
steps: int = 5,
chunk_size: int = 512,
threshold: float = 0.20,
cp_isolation: Optional[List[str]] = None,
cp_exclusion: Optional[List[str]] = None,
preemptive_behaviour: bool = True,
) -> CharsetMatch:
"""
Take a (text-based) file path and try to create another file next to it, this time using UTF-8.
"""
warnings.warn(
"normalize is deprecated and will be removed in 3.0",
DeprecationWarning,
)
results = from_path(
path,
steps,
chunk_size,
threshold,
cp_isolation,
cp_exclusion,
preemptive_behaviour,
)
filename = basename(path)
target_extensions = list(splitext(filename))
if len(results) == 0:
raise IOError(
'Unable to normalize "{}", no encoding charset seems to fit.'.format(
filename
)
)
result = results.best()
target_extensions[0] += "-" + result.encoding # type: ignore
with open(
"{}".format(str(path).replace(filename, "".join(target_extensions))), "wb"
) as fp:
fp.write(result.output()) # type: ignore
return result # type: ignore

View file

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from typing import Dict, List
# Language label that contain the em dash "—"
# character are to be considered alternative seq to origin
FREQUENCIES: Dict[str, List[str]] = {
"English": [
"e",
@ -30,6 +32,34 @@ FREQUENCIES: Dict[str, List[str]] = {
"z",
"q",
],
"English—": [
"e",
"a",
"t",
"i",
"o",
"n",
"s",
"r",
"h",
"l",
"d",
"c",
"m",
"u",
"f",
"p",
"g",
"w",
"b",
"y",
"v",
"k",
"j",
"x",
"z",
"q",
],
"German": [
"e",
"n",
@ -226,33 +256,303 @@ FREQUENCIES: Dict[str, List[str]] = {
"ж",
"ц",
],
# Jap-Kanji
"Japanese": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"丿",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"广",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
# Jap-Katakana
"Japanese—": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
# Jap-Hiragana
"Japanese——": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
"Portuguese": [
"a",
@ -340,6 +640,77 @@ FREQUENCIES: Dict[str, List[str]] = {
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
"Ukrainian": [
"о",
@ -956,34 +1327,6 @@ FREQUENCIES: Dict[str, List[str]] = {
"ö",
"y",
],
"Simple English": [
"e",
"a",
"t",
"i",
"o",
"n",
"s",
"r",
"h",
"l",
"d",
"c",
"m",
"u",
"f",
"p",
"g",
"w",
"b",
"y",
"v",
"k",
"j",
"x",
"z",
"q",
],
"Thai": [
"",
"",
@ -1066,31 +1409,6 @@ FREQUENCIES: Dict[str, List[str]] = {
"",
"",
],
"Classical Chinese": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
],
"Kazakh": [
"а",
"ы",

View file

@ -105,7 +105,7 @@ def mb_encoding_languages(iana_name: str) -> List[str]:
):
return ["Japanese"]
if iana_name.startswith("gb") or iana_name in ZH_NAMES:
return ["Chinese", "Classical Chinese"]
return ["Chinese"]
if iana_name.startswith("iso2022_kr") or iana_name in KO_NAMES:
return ["Korean"]
@ -140,7 +140,6 @@ def alphabet_languages(
source_have_accents = any(is_accentuated(character) for character in characters)
for language, language_characters in FREQUENCIES.items():
target_have_accents, target_pure_latin = get_target_features(language)
if ignore_non_latin and target_pure_latin is False:
@ -179,22 +178,45 @@ def characters_popularity_compare(
character_approved_count: int = 0
FREQUENCIES_language_set = set(FREQUENCIES[language])
for character in ordered_characters:
ordered_characters_count: int = len(ordered_characters)
target_language_characters_count: int = len(FREQUENCIES[language])
large_alphabet: bool = target_language_characters_count > 26
for character, character_rank in zip(
ordered_characters, range(0, ordered_characters_count)
):
if character not in FREQUENCIES_language_set:
continue
character_rank_in_language: int = FREQUENCIES[language].index(character)
expected_projection_ratio: float = (
target_language_characters_count / ordered_characters_count
)
character_rank_projection: int = int(character_rank * expected_projection_ratio)
if (
large_alphabet is False
and abs(character_rank_projection - character_rank_in_language) > 4
):
continue
if (
large_alphabet is True
and abs(character_rank_projection - character_rank_in_language)
< target_language_characters_count / 3
):
character_approved_count += 1
continue
characters_before_source: List[str] = FREQUENCIES[language][
0 : FREQUENCIES[language].index(character)
0:character_rank_in_language
]
characters_after_source: List[str] = FREQUENCIES[language][
FREQUENCIES[language].index(character) :
]
characters_before: List[str] = ordered_characters[
0 : ordered_characters.index(character)
]
characters_after: List[str] = ordered_characters[
ordered_characters.index(character) :
character_rank_in_language:
]
characters_before: List[str] = ordered_characters[0:character_rank]
characters_after: List[str] = ordered_characters[character_rank:]
before_match_count: int = len(
set(characters_before) & set(characters_before_source)
@ -289,6 +311,33 @@ def merge_coherence_ratios(results: List[CoherenceMatches]) -> CoherenceMatches:
return sorted(merge, key=lambda x: x[1], reverse=True)
def filter_alt_coherence_matches(results: CoherenceMatches) -> CoherenceMatches:
"""
We shall NOT return "English—" in CoherenceMatches because it is an alternative
of "English". This function only keeps the best match and remove the em-dash in it.
"""
index_results: Dict[str, List[float]] = dict()
for result in results:
language, ratio = result
no_em_name: str = language.replace("", "")
if no_em_name not in index_results:
index_results[no_em_name] = []
index_results[no_em_name].append(ratio)
if any(len(index_results[e]) > 1 for e in index_results):
filtered_results: CoherenceMatches = []
for language in index_results:
filtered_results.append((language, max(index_results[language])))
return filtered_results
return results
@lru_cache(maxsize=2048)
def coherence_ratio(
decoded_sequence: str, threshold: float = 0.1, lg_inclusion: Optional[str] = None
@ -336,4 +385,6 @@ def coherence_ratio(
if sufficient_match_count >= 3:
break
return sorted(results, key=lambda x: x[1], reverse=True)
return sorted(
filter_alt_coherence_matches(results), key=lambda x: x[1], reverse=True
)

View file

@ -1,15 +1,12 @@
import argparse
import sys
from json import dumps
from os.path import abspath
from os.path import abspath, basename, dirname, join, realpath
from platform import python_version
from typing import List, Optional
from unicodedata import unidata_version
try:
from unicodedata2 import unidata_version
except ImportError:
from unicodedata import unidata_version
import charset_normalizer.md as md_module
from charset_normalizer import from_fp
from charset_normalizer.models import CliDetectionResult
from charset_normalizer.version import __version__
@ -124,8 +121,11 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
parser.add_argument(
"--version",
action="version",
version="Charset-Normalizer {} - Python {} - Unicode {}".format(
__version__, python_version(), unidata_version
version="Charset-Normalizer {} - Python {} - Unicode {} - SpeedUp {}".format(
__version__,
python_version(),
unidata_version,
"OFF" if md_module.__file__.lower().endswith(".py") else "ON",
),
help="Show version information and exit.",
)
@ -147,7 +147,6 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
x_ = []
for my_file in args.files:
matches = from_fp(my_file, threshold=args.threshold, explain=args.verbose)
best_guess = matches.best()
@ -222,7 +221,6 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
)
if args.normalize is True:
if best_guess.encoding.startswith("utf") is True:
print(
'"{}" file does not need to be normalized, as it already came from unicode.'.format(
@ -234,7 +232,10 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
my_file.close()
continue
o_: List[str] = my_file.name.split(".")
dir_path = dirname(realpath(my_file.name))
file_name = basename(realpath(my_file.name))
o_: List[str] = file_name.split(".")
if args.replace is False:
o_.insert(-1, best_guess.encoding)
@ -255,7 +256,7 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
continue
try:
x_[0].unicode_path = abspath("./{}".format(".".join(o_)))
x_[0].unicode_path = join(dir_path, ".".join(o_))
with open(x_[0].unicode_path, "w", encoding="utf-8") as fp:
fp.write(str(best_guess))

View file

@ -489,9 +489,7 @@ COMMON_SAFE_ASCII_CHARACTERS: Set[str] = {
KO_NAMES: Set[str] = {"johab", "cp949", "euc_kr"}
ZH_NAMES: Set[str] = {"big5", "cp950", "big5hkscs", "hz"}
NOT_PRINTABLE_PATTERN = re_compile(r"[0-9\W\n\r\t]+")
LANGUAGE_SUPPORTED_COUNT: int = len(FREQUENCIES)
# Logging LEVEL bellow DEBUG
# Logging LEVEL below DEBUG
TRACE: int = 5

View file

@ -1,12 +1,13 @@
import warnings
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
from warnings import warn
from .api import from_bytes, from_fp, from_path, normalize
from .api import from_bytes
from .constant import CHARDET_CORRESPONDENCE
from .models import CharsetMatch, CharsetMatches
def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
def detect(
byte_str: bytes, should_rename_legacy: bool = False, **kwargs: Any
) -> Dict[str, Optional[Union[str, float]]]:
"""
chardet legacy method
Detect the encoding of the given byte string. It should be mostly backward-compatible.
@ -15,7 +16,14 @@ def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
further information. Not planned for removal.
:param byte_str: The byte sequence to examine.
:param should_rename_legacy: Should we rename legacy encodings
to their more modern equivalents?
"""
if len(kwargs):
warn(
f"charset-normalizer disregard arguments '{','.join(list(kwargs.keys()))}' in legacy function detect()"
)
if not isinstance(byte_str, (bytearray, bytes)):
raise TypeError( # pragma: nocover
"Expected object of type bytes or bytearray, got: "
@ -36,60 +44,11 @@ def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
if r is not None and encoding == "utf_8" and r.bom:
encoding += "_sig"
if should_rename_legacy is False and encoding in CHARDET_CORRESPONDENCE:
encoding = CHARDET_CORRESPONDENCE[encoding]
return {
"encoding": encoding
if encoding not in CHARDET_CORRESPONDENCE
else CHARDET_CORRESPONDENCE[encoding],
"encoding": encoding,
"language": language,
"confidence": confidence,
}
class CharsetNormalizerMatch(CharsetMatch):
pass
class CharsetNormalizerMatches(CharsetMatches):
@staticmethod
def from_fp(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return from_fp(*args, **kwargs) # pragma: nocover
@staticmethod
def from_bytes(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return from_bytes(*args, **kwargs) # pragma: nocover
@staticmethod
def from_path(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return from_path(*args, **kwargs) # pragma: nocover
@staticmethod
def normalize(*args, **kwargs): # type: ignore
warnings.warn( # pragma: nocover
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
"and scheduled to be removed in 3.0",
DeprecationWarning,
)
return normalize(*args, **kwargs) # pragma: nocover
class CharsetDetector(CharsetNormalizerMatches):
pass
class CharsetDoctor(CharsetNormalizerMatches):
pass

View file

@ -1,7 +1,12 @@
from functools import lru_cache
from logging import getLogger
from typing import List, Optional
from .constant import COMMON_SAFE_ASCII_CHARACTERS, UNICODE_SECONDARY_RANGE_KEYWORD
from .constant import (
COMMON_SAFE_ASCII_CHARACTERS,
TRACE,
UNICODE_SECONDARY_RANGE_KEYWORD,
)
from .utils import (
is_accentuated,
is_ascii,
@ -123,7 +128,7 @@ class TooManyAccentuatedPlugin(MessDetectorPlugin):
@property
def ratio(self) -> float:
if self._character_count == 0:
if self._character_count == 0 or self._character_count < 8:
return 0.0
ratio_of_accentuation: float = self._accentuated_count / self._character_count
return ratio_of_accentuation if ratio_of_accentuation >= 0.35 else 0.0
@ -547,7 +552,20 @@ def mess_ratio(
break
if debug:
logger = getLogger("charset_normalizer")
logger.log(
TRACE,
"Mess-detector extended-analysis start. "
f"intermediary_mean_mess_ratio_calc={intermediary_mean_mess_ratio_calc} mean_mess_ratio={mean_mess_ratio} "
f"maximum_threshold={maximum_threshold}",
)
if len(decoded_sequence) > 16:
logger.log(TRACE, f"Starting with: {decoded_sequence[:16]}")
logger.log(TRACE, f"Ending with: {decoded_sequence[-16::]}")
for dt in detectors: # pragma: nocover
print(dt.__class__, dt.ratio)
logger.log(TRACE, f"{dt.__class__}: {dt.ratio}")
return round(mean_mess_ratio, 3)

View file

@ -1,22 +1,9 @@
import warnings
from collections import Counter
from encodings.aliases import aliases
from hashlib import sha256
from json import dumps
from re import sub
from typing import (
Any,
Counter as TypeCounter,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from .constant import NOT_PRINTABLE_PATTERN, TOO_BIG_SEQUENCE
from .md import mess_ratio
from .constant import TOO_BIG_SEQUENCE
from .utils import iana_name, is_multi_byte_encoding, unicode_range
@ -65,7 +52,7 @@ class CharsetMatch:
chaos_difference: float = abs(self.chaos - other.chaos)
coherence_difference: float = abs(self.coherence - other.coherence)
# Bellow 1% difference --> Use Coherence
# Below 1% difference --> Use Coherence
if chaos_difference < 0.01 and coherence_difference > 0.02:
# When having a tough decision, use the result that decoded as many multi-byte as possible.
if chaos_difference == 0.0 and self.coherence == other.coherence:
@ -78,45 +65,6 @@ class CharsetMatch:
def multi_byte_usage(self) -> float:
return 1.0 - len(str(self)) / len(self.raw)
@property
def chaos_secondary_pass(self) -> float:
"""
Check once again chaos in decoded text, except this time, with full content.
Use with caution, this can be very slow.
Notice: Will be removed in 3.0
"""
warnings.warn(
"chaos_secondary_pass is deprecated and will be removed in 3.0",
DeprecationWarning,
)
return mess_ratio(str(self), 1.0)
@property
def coherence_non_latin(self) -> float:
"""
Coherence ratio on the first non-latin language detected if ANY.
Notice: Will be removed in 3.0
"""
warnings.warn(
"coherence_non_latin is deprecated and will be removed in 3.0",
DeprecationWarning,
)
return 0.0
@property
def w_counter(self) -> TypeCounter[str]:
"""
Word counter instance on decoded text.
Notice: Will be removed in 3.0
"""
warnings.warn(
"w_counter is deprecated and will be removed in 3.0", DeprecationWarning
)
string_printable_only = sub(NOT_PRINTABLE_PATTERN, " ", str(self).lower())
return Counter(string_printable_only.split())
def __str__(self) -> str:
# Lazy Str Loading
if self._string is None:
@ -252,18 +200,6 @@ class CharsetMatch:
"""
return [self._encoding] + [m.encoding for m in self._leaves]
def first(self) -> "CharsetMatch":
"""
Kept for BC reasons. Will be removed in 3.0.
"""
return self
def best(self) -> "CharsetMatch":
"""
Kept for BC reasons. Will be removed in 3.0.
"""
return self
def output(self, encoding: str = "utf_8") -> bytes:
"""
Method to get re-encoded bytes payload using given target encoding. Default to UTF-8.

View file

@ -1,12 +1,6 @@
try:
# WARNING: unicodedata2 support is going to be removed in 3.0
# Python is quickly catching up.
import unicodedata2 as unicodedata
except ImportError:
import unicodedata # type: ignore[no-redef]
import importlib
import logging
import unicodedata
from codecs import IncrementalDecoder
from encodings.aliases import aliases
from functools import lru_cache
@ -317,7 +311,6 @@ def range_scan(decoded_sequence: str) -> List[str]:
def cp_similarity(iana_name_a: str, iana_name_b: str) -> float:
if is_multi_byte_encoding(iana_name_a) or is_multi_byte_encoding(iana_name_b):
return 0.0
@ -357,7 +350,6 @@ def set_logging_handler(
level: int = logging.INFO,
format_string: str = "%(asctime)s | %(levelname)s | %(message)s",
) -> None:
logger = logging.getLogger(name)
logger.setLevel(level)
@ -377,7 +369,6 @@ def cut_sequence_chunks(
is_multi_byte_decoder: bool,
decoded_payload: Optional[str] = None,
) -> Generator[str, None, None]:
if decoded_payload and is_multi_byte_decoder is False:
for i in offsets:
chunk = decoded_payload[i : i + chunk_size]
@ -402,8 +393,7 @@ def cut_sequence_chunks(
# multi-byte bad cutting detector and adjustment
# not the cleanest way to perform that fix but clever enough for now.
if is_multi_byte_decoder and i > 0 and sequences[i] >= 0x80:
if is_multi_byte_decoder and i > 0:
chunk_partial_size_chk: int = min(chunk_size, 16)
if (

View file

@ -2,5 +2,5 @@
Expose version
"""
__version__ = "2.1.1"
__version__ = "3.1.0"
VERSION = __version__.split(".")

View file

@ -38,7 +38,7 @@ CL_BLANK = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAA
URI_SCHEME = "cloudinary"
API_VERSION = "v1_1"
VERSION = "1.30.0"
VERSION = "1.32.0"
_USER_PLATFORM_DETAILS = "; ".join((platform(), "Python {}".format(python_version())))
@ -234,6 +234,7 @@ _http_client = HttpClient()
# FIXME: circular import issue
from cloudinary.search import Search
from cloudinary.search_folders import SearchFolders
@python_2_unicode_compatible

View file

@ -188,9 +188,9 @@ def _prepare_asset_details_params(**options):
:internal
"""
return only(options, "exif", "faces", "colors", "image_metadata", "cinemagraph_analysis",
return only(options, "exif", "faces", "colors", "image_metadata", "media_metadata", "cinemagraph_analysis",
"pages", "phash", "coordinates", "max_results", "quality_analysis", "derived_next_cursor",
"accessibility_analysis", "versions")
"accessibility_analysis", "versions", "related", "related_next_cursor")
def update(public_id, **options):
@ -223,6 +223,8 @@ def update(public_id, **options):
params["display_name"] = options.get("display_name")
if "unique_display_name" in options:
params["unique_display_name"] = options.get("unique_display_name")
if "clear_invalid" in options:
params["clear_invalid"] = options.get("clear_invalid")
return call_api("post", uri, params, **options)
@ -293,6 +295,50 @@ def delete_derived_by_transformation(public_ids, transformations,
return call_api("delete", uri, params, **options)
def add_related_assets(public_id, assets_to_relate, resource_type="image", type="upload", **options):
"""
Relates an asset to other assets by public IDs.
:param public_id: The public ID of the asset to update.
:type public_id: str
:param assets_to_relate: The array of up to 10 fully_qualified_public_ids given as resource_type/type/public_id.
:type assets_to_relate: list[str]
:param type: The upload type. Defaults to "upload".
:type type: str
:param resource_type: The type of the resource. Defaults to "image".
:type resource_type: str
:param options: Additional options.
:type options: dict, optional
:return: The result of the command.
:rtype: dict
"""
uri = ["resources", "related_assets", resource_type, type, public_id]
params = {"assets_to_relate": utils.build_array(assets_to_relate)}
return call_json_api("post", uri, params, **options)
def delete_related_assets(public_id, assets_to_unrelate, resource_type="image", type="upload", **options):
"""
Unrelates an asset from other assets by public IDs.
:param public_id: The public ID of the asset to update.
:type public_id: str
:param assets_to_unrelate: The array of up to 10 fully_qualified_public_ids given as resource_type/type/public_id.
:type assets_to_unrelate: list[str]
:param type: The upload type.
:type type: str
:param resource_type: The type of the resource: defaults to "image".
:type resource_type: str
:param options: Additional options.
:type options: dict, optional
:return: The result of the command.
:rtype: dict
"""
uri = ["resources", "related_assets", resource_type, type, public_id]
params = {"assets_to_unrelate": utils.build_array(assets_to_unrelate)}
return call_json_api("delete", uri, params, **options)
def tags(**options):
resource_type = options.pop("resource_type", "image")
uri = ["tags", resource_type]

View file

@ -68,9 +68,9 @@ def execute_request(http_connector, method, params, headers, auth, api_url, **op
response = http_connector.request(method.upper(), api_url, processed_params, req_headers, **kw)
body = response.data
except HTTPError as e:
raise GeneralError("Unexpected error {0}", e.message)
raise GeneralError("Unexpected error %s" % str(e))
except socket.error as e:
raise GeneralError("Socket Error: %s" % (str(e)))
raise GeneralError("Socket Error: %s" % str(e))
try:
result = json.loads(body.decode('utf-8'))

View file

@ -4,7 +4,11 @@ from cloudinary.api_client.call_api import call_json_api
from cloudinary.utils import unique
class Search:
class Search(object):
ASSETS = 'resources'
_endpoint = ASSETS
_KEYS_WITH_UNIQUE_VALUES = {
'sort_by': lambda x: next(iter(x)),
'aggregate': None,
@ -53,7 +57,7 @@ class Search:
def execute(self, **options):
"""Execute the search and return results."""
options["content_type"] = 'application/json'
uri = ['resources', 'search']
uri = [self._endpoint, 'search']
return call_json_api('post', uri, self.as_dict(), **options)
def _add(self, name, value):
@ -72,3 +76,7 @@ class Search:
to_return[key] = value
return to_return
def endpoint(self, endpoint):
self._endpoint = endpoint
return self

View file

@ -0,0 +1,10 @@
from cloudinary import Search
class SearchFolders(Search):
FOLDERS = 'folders'
def __init__(self):
super(SearchFolders, self).__init__()
self.endpoint(self.FOLDERS)

View file

@ -168,7 +168,8 @@ def update_metadata(metadata, public_ids, **options):
"timestamp": utils.now(),
"metadata": utils.encode_context(metadata),
"public_ids": utils.build_array(public_ids),
"type": options.get("type")
"type": options.get("type"),
"clear_invalid": options.get("clear_invalid")
}
return call_api("metadata", params, **options)

View file

@ -78,6 +78,7 @@ __SIMPLE_UPLOAD_PARAMS = [
"backup",
"faces",
"image_metadata",
"media_metadata",
"exif",
"colors",
"use_filename",
@ -1052,7 +1053,8 @@ def build_custom_headers(headers):
def build_upload_params(**options):
params = {param_name: options.get(param_name) for param_name in __SIMPLE_UPLOAD_PARAMS}
params = {param_name: options.get(param_name) for param_name in __SIMPLE_UPLOAD_PARAMS if param_name in options}
params["upload_preset"] = params.pop("upload_preset", cloudinary.config().upload_preset)
serialized_params = {
"timestamp": now(),
@ -1577,3 +1579,19 @@ def unique(collection, key=None):
to_return[key(element)] = element
return list(to_return.values())
def fq_public_id(public_id, resource_type="image", type="upload"):
"""
Returns the fully qualified public id of form resource_type/type/public_id.
:param public_id: The public ID of the asset.
:type public_id: str
:param resource_type: The type of the asset. Defaults to "image".
:type resource_type: str
:param type: The upload type. Defaults to "upload".
:type type: str
:return:
"""
return "{resource_type}/{type}/{public_id}".format(resource_type=resource_type, type=type, public_id=public_id)

View file

@ -18,49 +18,52 @@
"""dnspython DNS toolkit"""
__all__ = [
'asyncbackend',
'asyncquery',
'asyncresolver',
'dnssec',
'e164',
'edns',
'entropy',
'exception',
'flags',
'immutable',
'inet',
'ipv4',
'ipv6',
'message',
'name',
'namedict',
'node',
'opcode',
'query',
'rcode',
'rdata',
'rdataclass',
'rdataset',
'rdatatype',
'renderer',
'resolver',
'reversename',
'rrset',
'serial',
'set',
'tokenizer',
'transaction',
'tsig',
'tsigkeyring',
'ttl',
'rdtypes',
'update',
'version',
'versioned',
'wire',
'xfr',
'zone',
'zonefile',
"asyncbackend",
"asyncquery",
"asyncresolver",
"dnssec",
"dnssectypes",
"e164",
"edns",
"entropy",
"exception",
"flags",
"immutable",
"inet",
"ipv4",
"ipv6",
"message",
"name",
"namedict",
"node",
"opcode",
"query",
"quic",
"rcode",
"rdata",
"rdataclass",
"rdataset",
"rdatatype",
"renderer",
"resolver",
"reversename",
"rrset",
"serial",
"set",
"tokenizer",
"transaction",
"tsig",
"tsigkeyring",
"ttl",
"rdtypes",
"update",
"version",
"versioned",
"wire",
"xfr",
"zone",
"zonetypes",
"zonefile",
]
from dns.version import version as __version__ # noqa

View file

@ -3,6 +3,7 @@
# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
# but it is only for sync use.
class NullContext:
def __init__(self, enter_result=None):
self.enter_result = enter_result
@ -23,6 +24,7 @@ class NullContext:
# These are declared here so backends can import them without creating
# circular dependencies with dns.asyncbackend.
class Socket: # pragma: no cover
async def close(self):
pass
@ -41,6 +43,9 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout):
raise NotImplementedError
@ -58,12 +63,23 @@ class StreamSocket(Socket): # pragma: no cover
class Backend: # pragma: no cover
def name(self):
return 'unknown'
return "unknown"
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
raise NotImplementedError
def datagram_connection_required(self):
return False
async def sleep(self, interval):
raise NotImplementedError

View file

@ -10,7 +10,8 @@ import dns._asyncbackend
import dns.exception
_is_win32 = sys.platform == 'win32'
_is_win32 = sys.platform == "win32"
def _get_running_loop():
try:
@ -30,7 +31,6 @@ class _DatagramProtocol:
def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
@ -56,30 +56,34 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
self.family = family
super().__init__(family)
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
return len(what)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
try:
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
finally:
self.protocol.recvfrom = None
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info('peername')
return self.transport.get_extra_info("peername")
async def getsockname(self):
return self.transport.get_extra_info('sockname')
return self.transport.get_extra_info("sockname")
class StreamSocket(dns._asyncbackend.StreamSocket):
@ -93,8 +97,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size),
timeout)
return await _maybe_wait_for(self.reader.read(size), timeout)
async def close(self):
self.writer.close()
@ -104,43 +107,64 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
pass
async def getpeername(self):
return self.writer.get_extra_info('peername')
return self.writer.get_extra_info("peername")
async def getsockname(self):
return self.writer.get_extra_info('sockname')
return self.writer.get_extra_info("sockname")
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'asyncio'
return "asyncio"
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
if destination is None and socktype == socket.SOCK_DGRAM and \
_is_win32:
raise NotImplementedError('destinationless datagram sockets '
'are not supported by asyncio '
'on Windows')
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af,
proto=proto, remote_addr=destination)
_DatagramProtocol,
source,
family=af,
proto=proto,
remote_addr=destination,
)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
if destination is None:
# This shouldn't happen, but we check to make code analysis software
# happier.
raise ValueError("destination required for stream sockets")
(r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0],
asyncio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname),
timeout)
server_hostname=server_hostname,
),
timeout,
)
return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
@ -57,12 +59,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'curio'
return "curio"
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto)
try:
@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend):
else:
source_addr = None
async with _maybe_timeout(timeout):
s = await curio.open_connection(destination[0], destination[1],
s = await curio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname)
server_hostname=server_hostname,
)
return StreamSocket(s)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await curio.sleep(interval)

View file

@ -1,84 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator is for python 3.6,
# which doesn't have Context Variables. This implementation is somewhat
# costly for classes with slots, as it adds a __dict__ to them.
import inspect
class _Immutable:
"""Immutable mixin class"""
# Note we MUST NOT have __slots__ as that causes
#
# TypeError: multiple bases have instance lay-out conflict
#
# when we get mixed in with another class with slots. When we
# get mixed into something with slots, it effectively adds __dict__ to
# the slots of the other class, which allows attribute setting to work,
# albeit at the cost of the dictionary.
def __setattr__(self, name, value):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
try:
# Are we already initializing an immutable class?
previous = args[0]._immutable_init
except AttributeError:
# We are the first!
previous = None
object.__setattr__(args[0], '_immutable_init', args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
if not previous:
# If we started the initialization, establish immutability
# by removing the attribute that allows mutation
object.__delattr__(args[0], '_immutable_init')
nf.__signature__ = inspect.signature(f)
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View file

@ -8,7 +8,7 @@ import contextvars
import inspect
_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable:
@ -41,6 +41,7 @@ def _immutable_init(f):
f(*args, **kwargs)
finally:
_in__init__.reset(previous)
nf.__signature__ = inspect.signature(f)
return nf
@ -50,7 +51,7 @@ def immutable(cls):
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
if hasattr(cls, "__setstate__"):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
@ -63,7 +64,8 @@ def immutable(cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
if hasattr(cls, "__setstate__"):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
self.socket.close()
@ -58,12 +60,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout):
with _maybe_timeout(timeout):
return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
with _maybe_timeout(timeout):
return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.stream.aclose()
@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'trio'
return "trio"
async def make_socket(self, af, socktype, proto=0, source=None,
destination=None, timeout=None,
ssl_context=None, server_hostname=None):
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto)
stream = None
try:
@ -103,19 +113,20 @@ class Backend(dns._asyncbackend.Backend):
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s)
s = None
tls = False
if ssl_context:
tls = True
try:
stream = trio.SSLStream(stream, ssl_context,
server_hostname=server_hostname)
stream = trio.SSLStream(
stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover
await stream.aclose()
raise
return StreamSocket(af, stream, tls)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await trio.sleep(interval)

View file

@ -1,26 +1,33 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Dict
import dns.exception
# pylint: disable=unused-import
from dns._asyncbackend import Socket, DatagramSocket, \
StreamSocket, Backend # noqa:
from dns._asyncbackend import (
Socket,
DatagramSocket,
StreamSocket,
Backend,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import
_default_backend = None
_backends = {}
_backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes
_no_sniffio = False
class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass
def get_backend(name):
def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio",
@ -32,22 +39,25 @@ def get_backend(name):
backend = _backends.get(name)
if backend:
return backend
if name == 'trio':
if name == "trio":
import dns._trio_backend
backend = dns._trio_backend.Backend()
elif name == 'curio':
elif name == "curio":
import dns._curio_backend
backend = dns._curio_backend.Backend()
elif name == 'asyncio':
elif name == "asyncio":
import dns._asyncio_backend
backend = dns._asyncio_backend.Backend()
else:
raise NotImplementedError(f'unimplemented async backend {name}')
raise NotImplementedError(f"unimplemented async backend {name}")
_backends[name] = backend
return backend
def sniff():
def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available.
@ -59,35 +69,32 @@ def sniff():
if _no_sniffio:
raise ImportError
import sniffio
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
'async library')
raise AsyncLibraryNotFoundError(
"sniffio cannot determine " + "async library"
)
except ImportError:
import asyncio
try:
asyncio.get_running_loop()
return 'asyncio'
return "asyncio"
except RuntimeError:
raise AsyncLibraryNotFoundError('no async library detected')
except AttributeError: # pragma: no cover
# we have to check current_task on 3.6
if not asyncio.Task.current_task():
raise AsyncLibraryNotFoundError('no async library detected')
return 'asyncio'
raise AsyncLibraryNotFoundError("no async library detected")
def get_default_backend():
"""Get the default backend, initializing it if necessary.
"""
def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary."""
if _default_backend:
return _default_backend
return set_default_backend(sniff())
def set_default_backend(name):
def set_default_backend(name: str) -> Backend:
"""Set the default backend.
It's not normally necessary to call this method, as

View file

@ -1,13 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
class Backend:
...
def get_backend(name: str) -> Backend:
...
def sniff() -> str:
...
def get_default_backend() -> Backend:
...
def set_default_backend(name: str) -> Backend:
...

View file

@ -17,7 +17,10 @@
"""Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64
import contextlib
import socket
import struct
import time
@ -27,12 +30,24 @@ import dns.exception
import dns.inet
import dns.name
import dns.message
import dns.quic
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.transaction
from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
UDPMode, _have_httpx, _have_http2, NoDOH
from dns._asyncbackend import NullContext
from dns.query import (
_compute_times,
_matches_destination,
BadResponse,
ssl,
UDPMode,
_have_httpx,
_have_http2,
NoDOH,
NoDOQ,
)
if _have_httpx:
import httpx
@ -47,11 +62,11 @@ def _source_tuple(af, address, port):
if address or port:
if address is None:
if af == socket.AF_INET:
address = '0.0.0.0'
address = "0.0.0.0"
elif af == socket.AF_INET6:
address = '::'
address = "::"
else:
raise NotImplementedError(f'unknown address family {af}')
raise NotImplementedError(f"unknown address family {af}")
return (address, port)
else:
return None
@ -66,7 +81,12 @@ def _timeout(expiration, now=None):
return None
async def send_udp(sock, what, destination, expiration=None):
async def send_udp(
sock: dns.asyncbackend.DatagramSocket,
what: Union[dns.message.Message, bytes],
destination: Any,
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
@ -78,7 +98,8 @@ async def send_udp(sock, what, destination, expiration=None):
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
occur. The expiration value is meaningless for the asyncio backend, as
asyncio's transport sendto() never blocks.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
@ -90,35 +111,61 @@ async def send_udp(sock, what, destination, expiration=None):
return (n, sent_time)
async def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
async def receive_udp(
sock: dns.asyncbackend.DatagramSocket,
destination: Optional[Any] = None,
expiration: Optional[float] = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
See :py:func:`dns.query.receive_udp()` for the documentation of the other
parameters, exceptions, and return type of this method.
parameters, and exceptions.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
"""
wire = b''
wire = b""
while 1:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if _matches_destination(sock.family, from_address, destination,
ignore_unexpected):
if _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
break
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
raise_on_truncation=raise_on_truncation,
)
return (r, received_time, from_address)
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
ignore_trailing=False, raise_on_truncation=False, sock=None,
backend=None):
async def udp(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
@ -134,13 +181,10 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
s = None
# After 3.6 is no longer supported, this can use an AsyncExitStack.
try:
af = dns.inet.af_for_address(where)
destination = _lltuple((where, port), af)
if sock:
s = sock
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if not backend:
backend = dns.asyncbackend.get_default_backend()
@ -149,27 +193,40 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
dtuple = (where, port)
else:
dtuple = None
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
dtuple)
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
async with cm as s:
await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration,
(r, received_time, _) = await receive_udp(
s,
destination,
expiration,
ignore_unexpected,
one_rr_per_rrset,
q.keyring, q.mac,
q.keyring,
q.mac,
ignore_trailing,
raise_on_truncation)
raise_on_truncation,
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
finally:
if not sock and s:
await s.close()
async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
source_port=0, ignore_unexpected=False,
one_rr_per_rrset=False, ignore_trailing=False,
udp_sock=None, tcp_sock=None, backend=None):
async def udp_with_fallback(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
@ -191,18 +248,42 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
method.
"""
try:
response = await udp(q, where, timeout, port, source, source_port,
ignore_unexpected, one_rr_per_rrset,
ignore_trailing, True, udp_sock, backend)
response = await udp(
q,
where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
backend,
)
return (response, False)
except dns.message.Truncated:
response = await tcp(q, where, timeout, port, source, source_port,
one_rr_per_rrset, ignore_trailing, tcp_sock,
backend)
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
backend,
)
return (response, True)
async def send_tcp(sock, what, expiration=None):
async def send_tcp(
sock: dns.asyncbackend.StreamSocket,
what: Union[dns.message.Message, bytes],
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
@ -212,12 +293,14 @@ async def send_tcp(sock, what, expiration=None):
"""
if isinstance(what, dns.message.Message):
what = what.to_wire()
l = len(what)
wire = what.to_wire()
else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + what
tcpmsg = struct.pack("!H", l) + wire
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
@ -227,18 +310,24 @@ async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF.
"""
s = b''
s = b""
while count > 0:
n = await sock.recv(count, _timeout(expiration))
if n == b'':
if n == b"":
raise EOFError
count = count - len(n)
s = s + n
return s
async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False):
async def receive_tcp(
sock: dns.asyncbackend.StreamSocket,
expiration: Optional[float] = None,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
@ -251,15 +340,28 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
(l,) = struct.unpack("!H", ldata)
wire = await _read_exactly(sock, l, expiration)
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
ignore_trailing=ignore_trailing,
)
return (r, received_time)
async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
backend=None):
async def tcp(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
@ -276,41 +378,48 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
s = None
# After 3.6 is no longer supported, this can use an AsyncExitStack.
try:
if sock:
# Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or
# hang forever.
await sock.getpeername()
s = sock
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
# These are simple (address, port) pairs, not
# family-dependent tuples you pass to lowlevel socket
# code.
# These are simple (address, port) pairs, not family-dependent tuples
# you pass to low-level socket code.
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
dtuple, timeout)
cm = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
await send_tcp(s, wire, expiration)
(r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac,
ignore_trailing)
(r, received_time) = await receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
finally:
if not sock and s:
await s.close()
async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
backend=None, ssl_context=None, server_hostname=None):
async def tls(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
@ -326,11 +435,14 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
# After 3.6 is no longer supported, this can use an AsyncExitStack.
(begin_time, expiration) = _compute_times(timeout)
if not sock:
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
ssl_context = ssl.create_default_context()
# See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
else:
@ -341,25 +453,49 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
dtuple, timeout, ssl_context,
server_hostname)
else:
s = sock
try:
cm = await backend.make_socket(
af,
socket.SOCK_STREAM,
0,
stuple,
dtuple,
timeout,
ssl_context,
server_hostname,
)
async with cm as s:
timeout = _timeout(expiration)
response = await tcp(q, where, timeout, port, source, source_port,
one_rr_per_rrset, ignore_trailing, s, backend)
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
s,
backend,
)
end_time = time.time()
response.time = end_time - begin_time
return response
finally:
if not sock and s:
await s.close()
async def https(q, where, timeout=None, port=443, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, client=None,
path='/dns-query', post=True, verify=True):
async def https(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 443,
source: Optional[str] = None,
source_port: int = 0, # pylint: disable=W0613
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
client: Optional["httpx.AsyncClient"] = None,
path: str = "/dns-query",
post: bool = True,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
@ -373,7 +509,7 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
"""
if not _have_httpx:
raise NoDOH('httpx is not available.') # pragma: no cover
raise NoDOH("httpx is not available.") # pragma: no cover
wire = q.to_wire()
try:
@ -381,65 +517,78 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
except ValueError:
af = None
transport = None
headers = {
"accept": "application/dns-message"
}
headers = {"accept": "application/dns-message"}
if af is not None:
if af == socket.AF_INET:
url = 'https://{}:{}{}'.format(where, port, path)
url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path)
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0])
# After 3.6 is no longer supported, this can use an AsyncExitStack
client_to_close = None
try:
if not client:
client = httpx.AsyncClient(http1=True, http2=_have_http2,
verify=verify, transport=transport)
client_to_close = client
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=_have_http2, verify=verify, transport=transport
)
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
headers.update({
headers.update(
{
"content-type": "application/dns-message",
"content-length": str(len(wire))
})
response = await client.post(url, headers=headers, content=wire,
timeout=timeout)
"content-length": str(len(wire)),
}
)
response = await the_client.post(
url, headers=headers, content=wire, timeout=timeout
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
wire = wire.decode() # httpx does a repr() if we give it bytes
response = await client.get(url, headers=headers, timeout=timeout,
params={"dns": wire})
finally:
if client_to_close:
await client.aclose()
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await the_client.get(
url, headers=headers, timeout=timeout, params={"dns": twire}
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError('{} responded with status code {}'
'\nResponse body: {}'.format(where,
response.status_code,
response.content))
r = dns.message.from_wire(response.content,
raise ValueError(
"{} responded with status code {}"
"\nResponse body: {!r}".format(
where, response.status_code, response.content
)
)
r = dns.message.from_wire(
response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
r.time = response.elapsed
ignore_trailing=ignore_trailing,
)
r.time = response.elapsed.total_seconds()
if not q.is_response(r):
raise BadResponse
return r
async def inbound_xfr(where, txn_manager, query=None,
port=53, timeout=None, lifetime=None, source=None,
source_port=0, udp_mode=UDPMode.NEVER, backend=None):
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
query: Optional[dns.message.Message] = None,
port: int = 53,
timeout: Optional[float] = None,
lifetime: Optional[float] = None,
source: Optional[str] = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
@ -472,42 +621,48 @@ async def inbound_xfr(where, txn_manager, query=None,
is_udp = False
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, sock_type, 0, stuple, dtuple,
_timeout(expiration))
s = await backend.make_socket(
af, sock_type, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration))
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial,
is_udp) as inbound:
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \
(expiration is not None and mexpiration > expiration):
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
destination = _lltuple((where, port), af)
while True:
timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535,
timeout)
if _matches_destination(af, from_address,
destination, True):
(rwire, from_address) = await s.recvfrom(65535, timeout)
if _matches_destination(
af, from_address, destination, True
):
break
else:
ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = (rdtype == dns.rdatatype.IXFR)
r = dns.message.from_wire(rwire, keyring=query.keyring,
request_mac=query.mac, xfr=True,
origin=origin, tsig_ctx=tsig_ctx,
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr)
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
@ -521,3 +676,62 @@ async def inbound_xfr(where, txn_manager, query=None,
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(context, verify_mode=verify) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
start = time.time()
stream = await the_connection.make_stream()
async with stream:
await stream.send(wire, True)
wire = await stream.receive(timeout)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r

View file

@ -1,43 +0,0 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
async def udp(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.DatagramSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tls(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

View file

@ -17,13 +17,18 @@
"""Asynchronous DNS stub resolver."""
from typing import Any, Dict, Optional, Union
import time
import dns.asyncbackend
import dns.asyncquery
import dns.exception
import dns.name
import dns.query
import dns.resolver
import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
@ -37,11 +42,19 @@ _tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
async def resolve(self, qname, rdtype=dns.rdatatype.A,
rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime=None, search=None,
backend=None):
async def resolve(
self,
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
@ -52,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver):
type of this method.
"""
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
raise_on_no_answer, search)
resolution = dns.resolver._Resolution(
self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
)
if not backend:
backend = dns.asyncbackend.get_default_backend()
start = time.time()
@ -66,30 +80,40 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None:
# cache hit!
return answer
assert request is not None # needed for type checking
done = False
while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
if backoff:
await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime,
resolution.errors)
timeout = self._compute_timeout(start, lifetime, resolution.errors)
try:
if dns.inet.is_address(nameserver):
if tcp:
response = await _tcp(request, nameserver,
timeout, port,
source, source_port,
backend=backend)
else:
response = await _udp(request, nameserver,
timeout, port,
source, source_port,
raise_on_truncation=True,
backend=backend)
else:
response = await dns.asyncquery.https(request,
response = await _tcp(
request,
nameserver,
timeout=timeout)
timeout,
port,
source,
source_port,
backend=backend,
)
else:
response = await _udp(
request,
nameserver,
timeout,
port,
source,
source_port,
raise_on_truncation=True,
backend=backend,
)
else:
response = await dns.asyncquery.https(
request, nameserver, timeout=timeout
)
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
@ -101,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None:
return answer
async def resolve_address(self, ipaddr, *args, **kwargs):
async def resolve_address(
self, ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR
records.
@ -116,15 +142,20 @@ class Resolver(dns.resolver.BaseResolver):
function.
"""
return await self.resolve(dns.reversename.from_address(ipaddr),
rdtype=dns.rdatatype.PTR,
rdclass=dns.rdataclass.IN,
*args, **kwargs)
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs["rdtype"] = dns.rdatatype.PTR
modified_kwargs["rdclass"] = dns.rdataclass.IN
return await self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
)
# pylint: disable=redefined-outer-name
async def canonical_name(self, name):
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
@ -149,14 +180,15 @@ class Resolver(dns.resolver.BaseResolver):
default_resolver = None
def get_default_resolver():
def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None:
reset_default_resolver()
assert default_resolver is not None
return default_resolver
def reset_default_resolver():
def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
@ -167,9 +199,18 @@ def reset_default_resolver():
default_resolver = Resolver()
async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime=None, search=None, backend=None):
async def resolve(
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver
@ -179,13 +220,23 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
information on the parameters.
"""
return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp,
source, raise_on_no_answer,
source_port, lifetime, search,
backend)
return await get_default_resolver().resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)
async def resolve_address(ipaddr, *args, **kwargs):
async def resolve_address(
ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
@ -194,7 +245,8 @@ async def resolve_address(ipaddr, *args, **kwargs):
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def canonical_name(name):
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
@ -203,8 +255,14 @@ async def canonical_name(name):
return await get_default_resolver().canonical_name(name)
async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
resolver=None, backend=None):
async def zone_for_name(
name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
tcp: bool = False,
resolver: Optional[Resolver] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
@ -219,8 +277,10 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
raise NotAbsolute(name)
while True:
try:
answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass,
tcp, backend=backend)
answer = await resolver.resolve(
name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
)
assert answer.rrset is not None
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher

View file

@ -1,26 +0,0 @@
from typing import Union, Optional, List, Any, Dict
from . import exception, rdataclass, name, rdatatype, asyncbackend
async def resolve(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...
async def resolve_address(self, ipaddr: str,
*args: Any, **kwargs: Optional[Dict]):
...
class Resolver:
def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
configure : Optional[bool] = True):
self.nameservers : List[str]
async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
rdclass : Union[int,str] = rdataclass.IN,
tcp : bool = False, source : Optional[str] = None,
raise_on_no_answer=True, source_port : int = 0,
lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...

File diff suppressed because it is too large Load diff

View file

@ -1,21 +0,0 @@
from typing import Union, Dict, Tuple, Optional
from . import rdataset, rrset, exception, name, rdtypes, rdata, node
import dns.rdtypes.ANY.DS as DS
import dns.rdtypes.ANY.DNSKEY as DNSKEY
_have_pyca : bool
def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
...
def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
...
class ValidationFailure(exception.DNSException):
...
def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
...
def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
...

71
lib/dns/dnssectypes.py Normal file
View file

@ -0,0 +1,71 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNSSEC-related types."""
# This is a separate file to avoid import circularity between dns.dnssec and
# the implementations of the DS and DNSKEY types.
import dns.enum
class Algorithm(dns.enum.IntEnum):
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
ED25519 = 15
ED448 = 16
INDIRECT = 252
PRIVATEDNS = 253
PRIVATEOID = 254
@classmethod
def _maximum(cls):
return 255
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
@classmethod
def _maximum(cls):
return 255
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
SHA1 = 1
@classmethod
def _maximum(cls):
return 255

View file

@ -17,15 +17,19 @@
"""DNS E.164 helpers."""
from typing import Iterable, Optional, Union
import dns.exception
import dns.name
import dns.resolver
#: The public E.164 domain.
public_enum_domain = dns.name.from_text('e164.arpa.')
public_enum_domain = dns.name.from_text("e164.arpa.")
def from_e164(text, origin=public_enum_domain):
def from_e164(
text: str, origin: Optional[dns.name.Name] = public_enum_domain
) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number.
@ -42,10 +46,14 @@ def from_e164(text, origin=public_enum_domain):
parts = [d for d in text if d.isdigit()]
parts.reverse()
return dns.name.from_text('.'.join(parts), origin=origin)
return dns.name.from_text(".".join(parts), origin=origin)
def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
def to_e164(
name: dns.name.Name,
origin: Optional[dns.name.Name] = public_enum_domain,
want_plus_prefix: bool = True,
) -> str:
"""Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred
@ -69,15 +77,19 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
name = name.relativize(origin)
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
if len(dlabels) != len(name.labels):
raise dns.exception.SyntaxError('non-digit labels in ENUM domain name')
raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
dlabels.reverse()
text = b''.join(dlabels)
text = b"".join(dlabels)
if want_plus_prefix:
text = b'+' + text
text = b"+" + text
return text.decode()
def query(number, domains, resolver=None):
def query(
number: str,
domains: Iterable[Union[dns.name.Name, str]],
resolver: Optional[dns.resolver.Resolver] = None,
) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
@ -98,7 +110,7 @@ def query(number, domains, resolver=None):
domain = dns.name.from_text(domain)
qname = dns.e164.from_e164(number, domain)
try:
return resolver.resolve(qname, 'NAPTR')
return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e:
e_nx += e
raise e_nx

View file

@ -1,10 +0,0 @@
from typing import Optional, Iterable
from . import name, resolver
def from_e164(text : str, origin=name.Name(".")) -> name.Name:
...
def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
...
def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
...

View file

@ -17,6 +17,8 @@
"""EDNS Options"""
from typing import Any, Dict, Optional, Union
import math
import socket
import struct
@ -24,6 +26,7 @@ import struct
import dns.enum
import dns.inet
import dns.rdata
import dns.wire
class OptionType(dns.enum.IntEnum):
@ -59,14 +62,14 @@ class Option:
"""Base class for all EDNS option types."""
def __init__(self, otype):
def __init__(self, otype: Union[OptionType, str]):
"""Initialize an option.
*otype*, an ``int``, is the option type.
*otype*, a ``dns.edns.OptionType``, is the option type.
"""
self.otype = OptionType.make(otype)
def to_wire(self, file=None):
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
"""Convert an option to wire format.
Returns a ``bytes`` or ``None``.
@ -75,10 +78,10 @@ class Option:
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, otype, parser):
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
*otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length.
@ -115,26 +118,22 @@ class Option:
return self._cmp(other) != 0
def __lt__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) > 0
@ -142,7 +141,7 @@ class Option:
return self.to_text()
class GenericOption(Option):
class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class
@ -150,28 +149,31 @@ class GenericOption(Option):
implementation.
"""
def __init__(self, otype, data):
def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file=None):
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
if file:
file.write(self.data)
return None
else:
return self.data
def to_text(self):
def to_text(self) -> str:
return "Generic %d" % self.otype
@classmethod
def from_wire_parser(cls, otype, parser):
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
return cls(otype, parser.get_remaining())
class ECSOption(Option):
class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)"""
def __init__(self, address, srclen=None, scopelen=0):
def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the
@ -200,8 +202,9 @@ class ECSOption(Option):
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen)
raise ValueError('Bad address family')
raise ValueError("Bad address family")
assert srclen is not None
self.address = address
self.srclen = srclen
self.scopelen = scopelen
@ -214,16 +217,14 @@ class ECSOption(Option):
self.addrdata = addrdata[:nbytes]
nbits = srclen % 8
if nbits != 0:
last = struct.pack('B',
ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last
def to_text(self):
return "ECS {}/{} scope/{}".format(self.address, self.srclen,
self.scopelen)
def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
@staticmethod
def from_text(text):
def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option.
@ -246,7 +247,7 @@ class ECSOption(Option):
>>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
"""
optional_prefix = 'ECS'
optional_prefix = "ECS"
tokens = text.split()
ecs_text = None
if len(tokens) == 1:
@ -257,47 +258,53 @@ class ECSOption(Option):
ecs_text = tokens[1]
else:
raise ValueError('could not parse ECS from "{}"'.format(text))
n_slashes = ecs_text.count('/')
n_slashes = ecs_text.count("/")
if n_slashes == 1:
address, srclen = ecs_text.split('/')
scope = 0
address, tsrclen = ecs_text.split("/")
tscope = "0"
elif n_slashes == 2:
address, srclen, scope = ecs_text.split('/')
address, tsrclen, tscope = ecs_text.split("/")
else:
raise ValueError('could not parse ECS from "{}"'.format(text))
try:
scope = int(scope)
scope = int(tscope)
except ValueError:
raise ValueError('invalid scope ' +
'"{}": scope must be an integer'.format(scope))
raise ValueError(
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
try:
srclen = int(srclen)
srclen = int(tsrclen)
except ValueError:
raise ValueError('invalid srclen ' +
'"{}": srclen must be an integer'.format(srclen))
raise ValueError(
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
)
return ECSOption(address, srclen, scope)
def to_wire(self, file=None):
value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) +
self.addrdata)
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = (
struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
)
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(cls, otype, parser):
family, src, scope = parser.get_struct('!HBB')
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
family, src, scope = parser.get_struct("!HBB")
addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen)
if family == 1:
pad = 4 - addrlen
addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad)
addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
elif family == 2:
pad = 16 - addrlen
addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad)
addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
else:
raise ValueError('unsupported family')
raise ValueError("unsupported family")
return cls(addr, src, scope)
@ -334,10 +341,10 @@ class EDECode(dns.enum.IntEnum):
return 65535
class EDEOption(Option):
class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
def __init__(self, code, text=None):
def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
@ -349,49 +356,50 @@ class EDEOption(Option):
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
raise ValueError('text must be string or None')
self.code = code
raise ValueError("text must be string or None")
self.text = text
def to_text(self):
output = f'EDE {self.code}'
def to_text(self) -> str:
output = f"EDE {self.code}"
if self.text is not None:
output += f': {self.text}'
output += f": {self.text}"
return output
def to_wire(self, file=None):
value = struct.pack('!H', self.code)
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = struct.pack("!H", self.code)
if self.text is not None:
value += self.text.encode('utf8')
value += self.text.encode("utf8")
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(cls, otype, parser):
code = parser.get_uint16()
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
the_code = EDECode.make(parser.get_uint16())
text = parser.get_remaining()
if text:
if text[-1] == 0: # text MAY be null-terminated
text = text[:-1]
text = text.decode('utf8')
btext = text.decode("utf8")
else:
text = None
btext = None
return cls(code, text)
return cls(the_code, btext)
_type_to_class = {
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
}
def get_option_class(otype):
def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type.
The GenericOption class is used if a more specific class is not
@ -404,7 +412,9 @@ def get_option_class(otype):
return cls
def option_from_wire_parser(otype, parser):
def option_from_wire_parser(
otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
@ -414,12 +424,14 @@ def option_from_wire_parser(otype, parser):
Returns an instance of a subclass of ``dns.edns.Option``.
"""
cls = get_option_class(otype)
otype = OptionType.make(otype)
the_otype = OptionType.make(otype)
cls = get_option_class(the_otype)
return cls.from_wire_parser(otype, parser)
def option_from_wire(otype, wire, current, olen):
def option_from_wire(
otype: Union[OptionType, str], wire: bytes, current: int, olen: int
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
@ -437,7 +449,8 @@ def option_from_wire(otype, wire, current, olen):
with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser)
def register_type(implementation, otype):
def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
@ -447,6 +460,7 @@ def register_type(implementation, otype):
_type_to_class[otype] = implementation
### BEGIN generated OptionType constants
NSID = OptionType.NSID

View file

@ -15,14 +15,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from typing import Any, Optional
import os
import hashlib
import random
import threading
import time
try:
import threading as _threading
except ImportError: # pragma: no cover
import dummy_threading as _threading # type: ignore
class EntropyPool:
@ -32,51 +31,51 @@ class EntropyPool:
# leaving this code doesn't hurt anything as the library code
# is used if present.
def __init__(self, seed=None):
def __init__(self, seed: Optional[bytes] = None):
self.pool_index = 0
self.digest = None
self.digest: Optional[bytearray] = None
self.next_byte = 0
self.lock = _threading.Lock()
self.lock = threading.Lock()
self.hash = hashlib.sha1()
self.hash_len = 20
self.pool = bytearray(b'\0' * self.hash_len)
self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None:
self._stir(bytearray(seed))
self._stir(seed)
self.seeded = True
self.seed_pid = os.getpid()
else:
self.seeded = False
self.seed_pid = 0
def _stir(self, entropy):
def _stir(self, entropy: bytes) -> None:
for c in entropy:
if self.pool_index == self.hash_len:
self.pool_index = 0
b = c & 0xff
b = c & 0xFF
self.pool[self.pool_index] ^= b
self.pool_index += 1
def stir(self, entropy):
def stir(self, entropy: bytes) -> None:
with self.lock:
self._stir(entropy)
def _maybe_seed(self):
def _maybe_seed(self) -> None:
if not self.seeded or self.seed_pid != os.getpid():
try:
seed = os.urandom(16)
except Exception: # pragma: no cover
try:
with open('/dev/urandom', 'rb', 0) as r:
with open("/dev/urandom", "rb", 0) as r:
seed = r.read(16)
except Exception:
seed = str(time.time())
seed = str(time.time()).encode()
self.seeded = True
self.seed_pid = os.getpid()
self.digest = None
seed = bytearray(seed)
self._stir(seed)
def random_8(self):
def random_8(self) -> int:
with self.lock:
self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len:
@ -88,16 +87,16 @@ class EntropyPool:
self.next_byte += 1
return value
def random_16(self):
def random_16(self) -> int:
return self.random_8() * 256 + self.random_8()
def random_32(self):
def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16()
def random_between(self, first, last):
def random_between(self, first: int, last: int) -> int:
size = last - first + 1
if size > 4294967296:
raise ValueError('too big')
raise ValueError("too big")
if size > 65536:
rand = self.random_32
max = 4294967295
@ -109,20 +108,24 @@ class EntropyPool:
max = 255
return first + size * rand() // (max + 1)
pool = EntropyPool()
system_random: Optional[Any]
try:
system_random = random.SystemRandom()
except Exception: # pragma: no cover
system_random = None
def random_16():
def random_16() -> int:
if system_random is not None:
return system_random.randrange(0, 65536)
else:
return pool.random_16()
def between(first, last):
def between(first: int, last: int) -> int:
if system_random is not None:
return system_random.randrange(first, last + 1)
else:

View file

@ -1,10 +0,0 @@
from typing import Optional
from random import SystemRandom
system_random : Optional[SystemRandom]
def random_16() -> int:
pass
def between(first: int, last: int) -> int:
pass

View file

@ -17,6 +17,7 @@
import enum
class IntEnum(enum.IntEnum):
@classmethod
def _check_value(cls, value):
@ -32,9 +33,12 @@ class IntEnum(enum.IntEnum):
return cls[text]
except KeyError:
pass
value = cls._extra_from_text(text)
if value:
return value
prefix = cls._prefix()
if text.startswith(prefix) and text[len(prefix):].isdigit():
value = int(text[len(prefix):])
if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix) :])
cls._check_value(value)
try:
return cls(value)
@ -46,9 +50,13 @@ class IntEnum(enum.IntEnum):
def to_text(cls, value):
cls._check_value(value)
try:
return cls(value).name
text = cls(value).name
except ValueError:
return f"{cls._prefix()}{value}"
text = None
text = cls._extra_to_text(value, text)
if text is None:
text = f"{cls._prefix()}{value}"
return text
@classmethod
def make(cls, value):
@ -83,7 +91,15 @@ class IntEnum(enum.IntEnum):
@classmethod
def _prefix(cls):
return ''
return ""
@classmethod
def _extra_from_text(cls, text): # pylint: disable=W0613
return None
@classmethod
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
return current_text
@classmethod
def _unknown_exception_class(cls):

View file

@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``.
"""
from typing import Optional, Set
class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions.
@ -44,14 +48,15 @@ class DNSException(Exception):
and ``fmt`` class variables to get nice parametrized messages.
"""
msg = None # non-parametrized message
supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt = None # message parametrized with results from _fmt_kwargs
msg: Optional[str] = None # non-parametrized message
supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs)
if kwargs:
self.kwargs = self._check_kwargs(**kwargs)
# This call to a virtual method from __init__ is ok in our usage
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self)
else:
self.kwargs = dict() # defined but empty for old mode exceptions
@ -68,14 +73,15 @@ class DNSException(Exception):
For sanity we do not allow to mix old and new behavior."""
if args or kwargs:
assert bool(args) != bool(kwargs), \
'keyword arguments are mutually exclusive with positional args'
assert bool(args) != bool(
kwargs
), "keyword arguments are mutually exclusive with positional args"
def _check_kwargs(self, **kwargs):
if kwargs:
assert set(kwargs.keys()) == self.supp_kwargs, \
'following set of keyword args is required: %s' % (
self.supp_kwargs)
assert (
set(kwargs.keys()) == self.supp_kwargs
), "following set of keyword args is required: %s" % (self.supp_kwargs)
return kwargs
def _fmt_kwargs(self, **kwargs):
@ -124,9 +130,15 @@ class TooBig(DNSException):
class Timeout(DNSException):
"""The DNS operation timed out."""
supp_kwargs = {'timeout'}
supp_kwargs = {"timeout"}
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ExceptionWrapper:
def __init__(self, exception_class):
@ -136,7 +148,6 @@ class ExceptionWrapper:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val,
self.exception_class):
if exc_type is not None and not isinstance(exc_val, self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val
return False

View file

@ -1,12 +0,0 @@
from typing import Set, Optional, Dict
class DNSException(Exception):
supp_kwargs : Set[str]
kwargs : Optional[Dict]
fmt : Optional[str]
class SyntaxError(DNSException): ...
class FormError(DNSException): ...
class Timeout(DNSException): ...
class TooBig(DNSException): ...
class UnexpectedEnd(SyntaxError): ...

View file

@ -17,10 +17,13 @@
"""DNS Message Flags."""
from typing import Any
import enum
# Standard DNS flags
class Flag(enum.IntFlag):
#: Query Response
QR = 0x8000
@ -40,12 +43,13 @@ class Flag(enum.IntFlag):
# EDNS flags
class EDNSFlag(enum.IntFlag):
#: DNSSEC answer OK
DO = 0x8000
def _from_text(text, enum_class):
def _from_text(text: str, enum_class: Any) -> int:
flags = 0
tokens = text.split()
for t in tokens:
@ -53,15 +57,15 @@ def _from_text(text, enum_class):
return flags
def _to_text(flags, enum_class):
def _to_text(flags: int, enum_class: Any) -> str:
text_flags = []
for k, v in enum_class.__members__.items():
if flags & v != 0:
text_flags.append(k)
return ' '.join(text_flags)
return " ".join(text_flags)
def from_text(text):
def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags
value.
@ -71,7 +75,7 @@ def from_text(text):
return _from_text(text, Flag)
def to_text(flags):
def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text
values.
@ -81,7 +85,7 @@ def to_text(flags):
return _to_text(flags, Flag)
def edns_from_text(text):
def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
@ -91,7 +95,7 @@ def edns_from_text(text):
return _from_text(text, EDNSFlag)
def edns_to_text(flags):
def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
@ -100,6 +104,7 @@ def edns_to_text(flags):
return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants
QR = Flag.QR

Some files were not shown because too many files have changed in this diff Show more