Compare commits

..

No commits in common. "master" and "v2.14.5" have entirely different histories.

343 changed files with 17999 additions and 19510 deletions

View file

@ -100,24 +100,6 @@ jobs:
name: Tautulli-${{ matrix.os }}-installer
path: Tautulli-${{ matrix.os }}-${{ steps.get_version.outputs.RELEASE_VERSION }}-${{ matrix.arch }}.${{ matrix.ext }}
virus-total:
name: VirusTotal Scan
needs: build-installer
if: needs.build-installer.result == 'success' && !contains(github.event.head_commit.message, '[skip ci]')
runs-on: ubuntu-latest
steps:
- name: Download Installers
if: needs.build-installer.result == 'success'
uses: actions/download-artifact@v4
- name: Upload to VirusTotal
uses: crazy-max/ghaction-virustotal@v4
with:
vt_api_key: ${{ secrets.VT_API_KEY }}
files: |
Tautulli-windows-installer/Tautulli-windows-*-x64.exe
Tautulli-macos-installer/Tautulli-macos-*-universal.pkg
release:
name: Release Installers
needs: build-installer

View file

@ -23,17 +23,3 @@ jobs:
# getting latest wingetcreate file
iwr https://aka.ms/wingetcreate/latest -OutFile wingetcreate.exe
.\wingetcreate.exe update $wingetPackage -s -v $version -u $installerUrl -t $gitToken
virus-total:
name: VirusTotal Scan
runs-on: ubuntu-latest
steps:
- name: Upload to VirusTotal
uses: crazy-max/ghaction-virustotal@v4
with:
vt_api_key: ${{ secrets.VT_API_KEY }}
github_token: ${{ secrets.GHACTIONS_TOKEN }}
update_release_body: true
files: |
.exe$
.pkg$

View file

@ -1,76 +1,5 @@
# Changelog
## v2.15.2 (2025-04-12)
* Activity:
* New: Added link to library by clicking media type icon.
* New: Added stream count to tab title on homepage. (#2517)
* History:
* Fix: Check stream watched status before stream stopped status. (#2506)
* Notifications:
* Fix: ntfy notifications failing to send if provider link is blank.
* Fix: Check Pushover notification attachment is under 5MB limit. (#2396)
* Fix: Track URLs redirecting to the correct media page. (#2513)
* New: Added audio profile notification parameters.
* New: Added PATCH method for Webhook notifications.
* Graphs:
* New: Added Total line to daily streams graph. (Thanks @zdimension) (#2497)
* UI:
* Fix: Do not redirect API requests to the login page. (#2490)
* Change: Swap source and stream columns in stream info modal.
* Other:
* Fix: Various typos. (Thanks @luzpaz) (#2520)
* Fix: CherryPy CORS response header not being set correctly. (#2279)
## v2.15.1 (2025-01-11)
* Activity:
* Fix: Detection of HDR transcodes. (Thanks @cdecker08) (#2412, #2466)
* Newsletters:
* Fix: Disable basic authentication for /newsletter and /image endpoints. (#2472)
* Exporter:
* New: Added logos to season and episode exports.
* Other:
* Fix: Docker container https health check.
## v2.15.0 (2024-11-24)
* Notes:
* Support for Python 3.8 has been dropped. The minimum Python version is now 3.9.
* Notifications:
* New: Allow Telegram blockquote and tg-emoji HTML tags. (Thanks @MythodeaLoL) (#2427)
* New: Added Plex slug and Plex Watch URL notification parameters. (#2420)
* Change: Update OneSignal API calls to use the new API endpoint for Tautulli Remote App notifications.
* Newsletters:
* Fix: Dumping custom dates in raw newsletter json.
* History:
* Fix: Unable to fix match for artists. (#2429)
* Exporter:
* New: Added movie and episode hasVoiceActivity attribute to exporter fields.
* New: Added subtitle canAutoSync attribute to exporter fields.
* New: Added logos to the exporter fields.
* UI:
* New: Add friendly name to the top bar of config modals. (Thanks @peagravel) (#2432)
* API:
* New: Added plex slugs to metadata in the get_metadata API command.
* Other:
* Fix: Tautulli failing to start with Python 3.13. (#2426)
## v2.14.6 (2024-10-12)
* Newsletters:
* Fix: Allow formatting newsletter date parameters.
* Change: Support apscheduler compatible cron expressions.
* UI:
* Fix: Round runtime before converting to human duration.
* Fix: Make recently added/watched rows touch scrollable.
* Other:
* Fix: Auto-updater not running.
## v2.14.5 (2024-09-20)
* Activity:

View file

@ -25,4 +25,4 @@ CMD [ "python", "Tautulli.py", "--datadir", "/config" ]
ENTRYPOINT [ "./start.sh" ]
EXPOSE 8181
HEALTHCHECK --start-period=90s CMD curl -ILfks https://localhost:8181/status > /dev/null || curl -ILfs http://localhost:8181/status > /dev/null || exit 1
HEALTHCHECK --start-period=90s CMD curl -ILfSs http://localhost:8181/status > /dev/null || curl -ILfkSs https://localhost:8181/status > /dev/null || exit 1

View file

@ -36,7 +36,7 @@ and [PlexWatchWeb](https://github.com/ecleese/plexWatchWeb).
[![Docker Stars][badge-docker-stars]][DockerHub]
[![Downloads][badge-downloads]][Releases Latest]
[badge-python]: https://img.shields.io/badge/python->=3.9-blue?style=flat-square
[badge-python]: https://img.shields.io/badge/python->=3.8-blue?style=flat-square
[badge-docker-pulls]: https://img.shields.io/docker/pulls/tautulli/tautulli?style=flat-square
[badge-docker-stars]: https://img.shields.io/docker/stars/tautulli/tautulli?style=flat-square
[badge-downloads]: https://img.shields.io/github/downloads/Tautulli/Tautulli/total?style=flat-square
@ -129,7 +129,7 @@ This is free software under the GPL v3 open source license. Feel free to do with
but any modification must be open sourced. A copy of the license is included.
This software includes Highsoft software libraries which you may freely distribute for
non-commercial use. Commercial users must licence this software, for more information visit
non-commercial use. Commerical users must licence this software, for more information visit
https://shop.highsoft.com/faq/non-commercial#non-commercial-redistribution.

View file

@ -129,7 +129,7 @@ def main():
if args.quiet:
plexpy.QUIET = True
# Do an initial setup of the logger.
# Do an intial setup of the logger.
# Require verbose for pre-initilization to see critical errors
logger.initLogger(console=not plexpy.QUIET, log_dir=False, verbose=True)

View file

@ -1478,8 +1478,7 @@ a:hover .dashboard-stats-square {
text-align: center;
position: relative;
z-index: 0;
overflow: auto;
scrollbar-width: none;
overflow: hidden;
}
.dashboard-recent-media {
width: 100%;
@ -4325,10 +4324,6 @@ a:hover .overlay-refresh-image:hover {
.stream-info tr:nth-child(even) td {
background-color: rgba(255,255,255,0.010);
}
.stream-info td:nth-child(3),
.stream-info th:nth-child(3) {
width: 25px;
}
.number-input {
margin: 0 !important;
width: 55px !important;

View file

@ -74,7 +74,6 @@ DOCUMENTATION :: END
parent_href = page('info', data['parent_rating_key'])
grandparent_href = page('info', data['grandparent_rating_key'])
user_href = page('user', data['user_id']) if data['user_id'] else '#'
library_href = page('library', data['section_id']) if data['section_id'] else '#'
season = short_season(data['parent_title'])
%>
<div class="dashboard-activity-instance" id="activity-instance-${sk}" data-key="${sk}" data-id="${data['session_id']}"
@ -464,27 +463,21 @@ DOCUMENTATION :: END
<div class="dashboard-activity-metadata-subtitle-container">
% if data['live']:
<div id="media-type-${sk}" class="dashboard-activity-metadata-media_type-icon" title="Live TV">
<a href="${library_href}">
<i class="fa fa-fw fa-broadcast-tower"></i>
</a>&nbsp;
<i class="fa fa-fw fa-broadcast-tower"></i>&nbsp;
</div>
% elif data['channel_stream'] == 0:
<div id="media-type-${sk}" class="dashboard-activity-metadata-media_type-icon" title="${data['media_type'].capitalize()}">
<a href="${library_href}">
% if data['media_type'] == 'movie':
<i class="fa fa-fw fa-film"></i>
<i class="fa fa-fw fa-film"></i>&nbsp;
% elif data['media_type'] == 'episode':
<i class="fa fa-fw fa-television"></i>
<i class="fa fa-fw fa-television"></i>&nbsp;
% elif data['media_type'] == 'track':
<i class="fa fa-fw fa-music"></i>
<i class="fa fa-fw fa-music"></i>&nbsp;
% elif data['media_type'] == 'photo':
<i class="fa fa-fw fa-picture-o"></i>
<i class="fa fa-fw fa-picture-o"></i>&nbsp;
% elif data['media_type'] == 'clip':
<i class="fa fa-fw fa-video-camera"></i>
% else:
<i class="fa fa-fw fa-question-circle"></i>
<i class="fa fa-fw fa-video-camera"></i>&nbsp;
% endif
</a>&nbsp;
</div>
% else:
<div id="media-type-${sk}" class="dashboard-activity-metadata-media_type-icon" title="Channel">

View file

@ -20,7 +20,6 @@ DOCUMENTATION :: END
export = exporter.Export()
thumb_media_types = ', '.join([export.PLURAL_MEDIA_TYPES[k] for k, v in export.MEDIA_TYPES.items() if v[0]])
art_media_types = ', '.join([export.PLURAL_MEDIA_TYPES[k] for k, v in export.MEDIA_TYPES.items() if v[1]])
logo_media_types = ', '.join([export.PLURAL_MEDIA_TYPES[k] for k, v in export.MEDIA_TYPES.items() if v[2]])
%>
<div class="modal-dialog" role="document">
<div class="modal-content">
@ -145,22 +144,6 @@ DOCUMENTATION :: END
Select the level to export background artwork image files.<br>Note: Only applies to ${art_media_types}.
</p>
</div>
<div class="form-group">
<label for="export_logo_level">Logo Image Export Level</label>
<div class="row">
<div class="col-md-12">
<select class="form-control" id="export_logo_level" name="export_logo_level">
<option value="0" selected>Level 0 - None / Custom</option>
<option value="1">Level 1 - Uploaded and Selected Logos Only</option>
<option value="2">Level 2 - Selected and Locked Logos Only</option>
<option value="9">Level 9 - All Selected Logos</option>
</select>
</div>
</div>
<p class="help-block">
Select the level to export logo image files.<br>Note: Only applies to ${logo_media_types}.
</p>
</div>
<p class="help-block">
Warning: Exporting images may take a long time! Images will be saved to a folder alongside the data file.
</p>
@ -248,7 +231,6 @@ DOCUMENTATION :: END
$('#export_media_info_level').prop('disabled', true);
$("#export_thumb_level").prop('disabled', true);
$("#export_art_level").prop('disabled', true);
$("#export_logo_level").prop('disabled', true);
export_custom_metadata_fields.disable();
export_custom_media_info_fields.disable();
} else {
@ -256,7 +238,6 @@ DOCUMENTATION :: END
$('#export_media_info_level').prop('disabled', false);
$("#export_thumb_level").prop('disabled', false);
$("#export_art_level").prop('disabled', false);
$("#export_logo_level").prop('disabled', false);
export_custom_metadata_fields.enable();
export_custom_media_info_fields.enable();
}
@ -271,7 +252,6 @@ DOCUMENTATION :: END
var file_format = $('#export_file_format option:selected').val();
var thumb_level = $("#export_thumb_level option:selected").val();
var art_level = $("#export_art_level option:selected").val();
var logo_level = $("#export_logo_level option:selected").val();
var custom_fields = [
$('#export_custom_metadata_fields').val(),
$('#export_custom_media_info_fields').val()
@ -290,7 +270,6 @@ DOCUMENTATION :: END
file_format: file_format,
thumb_level: thumb_level,
art_level: art_level,
logo_level: logo_level,
custom_fields: custom_fields,
export_type: export_type,
individual_files: individual_files

View file

@ -301,10 +301,6 @@
return obj;
}, {});
if (!("Total" in chart_visibility)) {
chart_visibility["Total"] = false;
}
return data_series.map(function(s) {
var obj = Object.assign({}, s);
obj.visible = (chart_visibility[s.name] !== false);
@ -331,8 +327,7 @@
'Direct Play': '#E5A00D',
'Direct Stream': '#FFFFFF',
'Transcode': '#F06464',
'Max. Concurrent Streams': '#96C83C',
'Total': '#96C83C'
'Max. Concurrent Streams': '#96C83C'
};
var series_colors = [];
$.each(data_series, function(index, series) {

View file

@ -92,10 +92,10 @@
<h3 class="pull-left"><span id="recently-added-xml">Recently Added</span></h3>
<ul class="nav nav-header nav-dashboard pull-right" style="margin-top: -3px;">
<li>
<a href="#" id="recently-added-page-left" class="paginate-added btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-left"></i></a>
<a href="#" id="recently-added-page-left" class="paginate btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-left"></i></a>
</li>
<li>
<a href="#" id="recently-added-page-right" class="paginate-added btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-right"></i></a>
<a href="#" id="recently-added-page-right" class="paginate btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-right"></i></a>
</li>
</ul>
<div class="button-bar">
@ -298,8 +298,6 @@
$('#currentActivityHeader-bandwidth-tooltip').tooltip({ container: 'body', placement: 'right', delay: 50 });
var title = document.title;
function getCurrentActivity() {
activity_ready = false;
@ -370,8 +368,6 @@
$('#currentActivityHeader').show();
document.title = stream_count + ' stream' + (stream_count > 1 ? 's' : '') + ' | ' + title;
sessions.forEach(function (session) {
var s = (typeof Proxy === "function") ? new Proxy(session, defaultHandler) : session;
var key = s.session_key;
@ -604,8 +600,6 @@
} else {
$('#currentActivityHeader').hide();
$('#currentActivity').html('<div id="dashboard-no-activity" class="text-muted">Nothing is currently being played.</div>');
document.title = title;
}
activity_ready = true;
@ -942,14 +936,10 @@
count: recently_added_count,
media_type: recently_added_type
},
beforeSend: function () {
$(".dashboard-recent-media-row").animate({ scrollLeft: 0 }, 1000);
},
complete: function (xhr, status) {
$("#recentlyAdded").html(xhr.responseText);
$('#ajaxMsg').fadeOut();
highlightScrollerButton("#recently-added");
paginateScroller("#recently-added", ".paginate-added");
highlightAddedScrollerButton();
}
});
}
@ -965,11 +955,57 @@
recentlyAdded(recently_added_count, recently_added_type);
}
function highlightAddedScrollerButton() {
var scroller = $("#recently-added-row-scroller");
var numElems = scroller.find("li:visible").length;
scroller.width(numElems * 175);
if (scroller.width() > $("body").find(".container-fluid").width()) {
$("#recently-added-page-right").removeClass("disabled");
} else {
$("#recently-added-page-right").addClass("disabled");
}
}
$(window).resize(function () {
highlightAddedScrollerButton();
});
function resetScroller() {
leftTotal = 0;
$("#recently-added-row-scroller").animate({ left: leftTotal }, 1000);
$("#recently-added-page-left").addClass("disabled").blur();
}
var leftTotal = 0;
$(".paginate").click(function (e) {
e.preventDefault();
var scroller = $("#recently-added-row-scroller");
var containerWidth = $("body").find(".container-fluid").width();
var scrollAmount = $(this).data("id") * parseInt((containerWidth - 15) / 175) * 175;
var leftMax = Math.min(-parseInt(scroller.width()) + Math.abs(scrollAmount), 0);
leftTotal = Math.max(Math.min(leftTotal + scrollAmount, 0), leftMax);
scroller.animate({ left: leftTotal }, 250);
if (leftTotal === 0) {
$("#recently-added-page-left").addClass("disabled").blur();
} else {
$("#recently-added-page-left").removeClass("disabled");
}
if (leftTotal === leftMax) {
$("#recently-added-page-right").addClass("disabled").blur();
} else {
$("#recently-added-page-right").removeClass("disabled");
}
});
$('#recently-added-toggles').on('change', function () {
$('#recently-added-toggles > label').removeClass('active');
selected_filter = $('input[name=recently-added-toggle]:checked', '#recently-added-toggles');
$(selected_filter).closest('label').addClass('active');
recently_added_type = $(selected_filter).val();
resetScroller();
setLocalStorage('home_stats_recently_added_type', recently_added_type);
recentlyAdded(recently_added_count, recently_added_type);
});
@ -977,6 +1013,7 @@
$('#recently-added-count').change(function () {
forceMinMax($(this));
recently_added_count = $(this).val();
resetScroller();
setLocalStorage('home_stats_recently_added_count', recently_added_count);
recentlyAdded(recently_added_count, recently_added_type);
});

View file

@ -360,8 +360,7 @@ function humanDuration(ms, sig='dhm', units='ms', return_seconds=300000) {
sig = 'dhms'
}
r = factors[sig.slice(-1)];
ms = Math.round(ms * factors[units] / r) * r;
ms = ms * factors[units];
h = ms % factors['d'];
d = Math.trunc(ms / factors['d']);
@ -930,50 +929,3 @@ $('.modal').on('hide.bs.modal', function (e) {
$.fn.hasScrollBar = function() {
return this.get(0).scrollHeight > this.get(0).clientHeight;
}
function paginateScroller(scrollerId, buttonClass) {
$(buttonClass).click(function (e) {
e.preventDefault();
var scroller = $(scrollerId + "-row-scroller");
var scrollerParent = scroller.parent();
var containerWidth = scrollerParent.width();
var scrollCurrent = scrollerParent.scrollLeft();
var scrollAmount = $(this).data("id") * parseInt(containerWidth / 175) * 175;
var scrollMax = scroller.width() - Math.abs(scrollAmount);
var scrollTotal = Math.min(parseInt(scrollCurrent / 175) * 175 + scrollAmount, scrollMax);
scrollerParent.animate({ scrollLeft: scrollTotal }, 250);
});
}
function highlightScrollerButton(scrollerId) {
var scroller = $(scrollerId + "-row-scroller");
var scrollerParent = scroller.parent();
var buttonLeft = $(scrollerId + "-page-left");
var buttonRight = $(scrollerId + "-page-right");
var numElems = scroller.find("li").length;
scroller.width(numElems * 175);
$(buttonLeft).addClass("disabled").blur();
if (scroller.width() > scrollerParent.width()) {
$(buttonRight).removeClass("disabled");
} else {
$(buttonRight).addClass("disabled");
}
scrollerParent.scroll(function () {
var scrollCurrent = $(this).scrollLeft();
var scrollMax = scroller.width() - $(this).width();
if (scrollCurrent == 0) {
$(buttonLeft).addClass("disabled").blur();
} else {
$(buttonLeft).removeClass("disabled");
}
if (scrollCurrent >= scrollMax) {
$(buttonRight).addClass("disabled").blur();
} else {
$(buttonRight).removeClass("disabled");
}
});
}

View file

@ -100,7 +100,7 @@ export_table_options = {
"createdCell": function (td, cellData, rowData, row, col) {
if (cellData !== '') {
var images = '';
if (rowData['thumb_level'] || rowData['art_level'] || rowData['logo_level']) {
if (rowData['thumb_level'] || rowData['art_level']) {
images = ' + images';
}
$(td).html(cellData + images);
@ -161,14 +161,14 @@ export_table_options = {
if (cellData === 1 && rowData['exists']) {
var tooltip_title = '';
var icon = '';
if (rowData['thumb_level'] || rowData['art_level'] || rowData['logo_level'] || rowData['individual_files']) {
if (rowData['thumb_level'] || rowData['art_level'] || rowData['individual_files']) {
tooltip_title = 'Zip Archive';
icon = 'fa-file-archive';
} else {
tooltip_title = rowData['file_format'].toUpperCase() + ' File';
icon = 'fa-file-download';
}
var icon = (rowData['thumb_level'] || rowData['art_level'] || rowData['logo_level'] || rowData['individual_files']) ? 'fa-file-archive' : 'fa-file-download';
var icon = (rowData['thumb_level'] || rowData['art_level'] || rowData['individual_files']) ? 'fa-file-archive' : 'fa-file-download';
$(td).html('<button class="btn btn-xs btn-success pull-left" data-id="' + rowData['export_id'] + '"><span data-toggle="tooltip" data-placement="left" title="' + tooltip_title + '"><i class="fa ' + icon + ' fa-fw"></i> Download</span></button>');
} else if (cellData === 0) {
var percent = Math.min(getPercent(rowData['exported_items'], rowData['total_items']), 99)

View file

@ -149,10 +149,10 @@ DOCUMENTATION :: END
<div class="table-card-header">
<ul class="nav nav-header nav-dashboard pull-right">
<li>
<a href="#" id="recently-watched-page-left" class="paginate-watched btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-left"></i></a>
<a href="#" id="recently-watched-page-left" class="paginate-watched btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-left"></i></a>
</li>
<li>
<a href="#" id="recently-watched-page-right" class="paginate-watched btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-right"></i></a>
<a href="#" id="recently-watched-page-right" class="paginate-watched btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-right"></i></a>
</li>
</ul>
<div class="header-bar">
@ -175,10 +175,10 @@ DOCUMENTATION :: END
<div class="table-card-header">
<ul class="nav nav-header nav-dashboard pull-right">
<li>
<a href="#" id="recently-added-page-left" class="paginate-added btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-left"></i></a>
<a href="#" id="recently-added-page-left" class="paginate-added btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-left"></i></a>
</li>
<li>
<a href="#" id="recently-added-page-right" class="paginate-added btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-right"></i></a>
<a href="#" id="recently-added-page-right" class="paginate-added btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-right"></i></a>
</li>
</ul>
<div class="header-bar">
@ -690,8 +690,7 @@ DOCUMENTATION :: END
},
complete: function(xhr, status) {
$("#library-recently-watched").html(xhr.responseText);
highlightScrollerButton("#recently-watched");
paginateScroller("#recently-watched", ".paginate-watched");
highlightWatchedScrollerButton();
}
});
}
@ -707,8 +706,7 @@ DOCUMENTATION :: END
},
complete: function(xhr, status) {
$("#library-recently-added").html(xhr.responseText);
highlightScrollerButton("#recently-added");
paginateScroller("#recently-added", ".paginate-added");
highlightAddedScrollerButton();
}
});
}
@ -718,8 +716,83 @@ DOCUMENTATION :: END
recentlyAdded();
% endif
function highlightWatchedScrollerButton() {
var scroller = $("#recently-watched-row-scroller");
var numElems = scroller.find("li").length;
scroller.width(numElems * 175);
if (scroller.width() > $("#library-recently-watched").width()) {
$("#recently-watched-page-right").removeClass("disabled");
} else {
$("#recently-watched-page-right").addClass("disabled");
}
}
function highlightAddedScrollerButton() {
var scroller = $("#recently-added-row-scroller");
var numElems = scroller.find("li").length;
scroller.width(numElems * 175);
if (scroller.width() > $("#library-recently-added").width()) {
$("#recently-added-page-right").removeClass("disabled");
} else {
$("#recently-added-page-right").addClass("disabled");
}
}
$(window).resize(function() {
highlightWatchedScrollerButton();
highlightAddedScrollerButton();
});
$('div.art-face').animate({ opacity: 0.2 }, { duration: 1000 });
var leftTotalWatched = 0;
$(".paginate-watched").click(function (e) {
e.preventDefault();
var scroller = $("#recently-watched-row-scroller");
var containerWidth = $("#library-recently-watched").width();
var scrollAmount = $(this).data("id") * parseInt(containerWidth / 175) * 175;
var leftMax = Math.min(-parseInt(scroller.width()) + Math.abs(scrollAmount), 0);
leftTotalWatched = Math.max(Math.min(leftTotalWatched + scrollAmount, 0), leftMax);
scroller.animate({ left: leftTotalWatched }, 250);
if (leftTotalWatched == 0) {
$("#recently-watched-page-left").addClass("disabled").blur();
} else {
$("#recently-watched-page-left").removeClass("disabled");
}
if (leftTotalWatched == leftMax) {
$("#recently-watched-page-right").addClass("disabled").blur();
} else {
$("#recently-watched-page-right").removeClass("disabled");
}
});
var leftTotalAdded = 0;
$(".paginate-added").click(function (e) {
e.preventDefault();
var scroller = $("#recently-added-row-scroller");
var containerWidth = $("#library-recently-added").width();
var scrollAmount = $(this).data("id") * parseInt(containerWidth / 175) * 175;
var leftMax = Math.min(-parseInt(scroller.width()) + Math.abs(scrollAmount), 0);
leftTotalAdded = Math.max(Math.min(leftTotalAdded + scrollAmount, 0), leftMax);
scroller.animate({ left: leftTotalAdded }, 250);
if (leftTotalAdded == 0) {
$("#recently-added-page-left").addClass("disabled").blur();
} else {
$("#recently-added-page-left").removeClass("disabled");
}
if (leftTotalAdded == leftMax) {
$("#recently-added-page-right").addClass("disabled").blur();
} else {
$("#recently-added-page-right").removeClass("disabled");
}
});
$(document).ready(function () {
// Javascript to enable link to tab

View file

@ -36,7 +36,7 @@ DOCUMENTATION :: END
%>
<div class="dashboard-recent-media-row">
<div id="recently-added-row-scroller">
<div id="recently-added-row-scroller" style="left: 0;">
<ul class="dashboard-recent-media list-unstyled">
% for item in data:
<li>

View file

@ -3,7 +3,7 @@
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal" aria-hidden="true"><i class="fa fa-remove"></i></button>
<h4 class="modal-title" id="mobile-device-config-modal-header">${device['device_name']} Settings &nbsp;<small><span class="device_id">(Device ID: ${device['id']}${' - ' + device['friendly_name'] if device['friendly_name'] else ''})</span></small></h4>
<h4 class="modal-title" id="mobile-device-config-modal-header">${device['device_name']} Settings &nbsp;<small><span class="device_id">(Device ID: ${device['id']})</span></small></h4>
</div>
<div class="modal-body">
<div class="container-fluid">

View file

@ -13,7 +13,7 @@
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal" aria-hidden="true"><i class="fa fa-remove"></i></button>
<h4 class="modal-title" id="newsletter-config-modal-header">${newsletter['agent_label']} Newsletter Settings &nbsp;<small><span class="newsletter_id">(Newsletter ID: ${newsletter['id']}${' - ' + newsletter['friendly_name'] if newsletter['friendly_name'] else ''})</span></small></h4>
<h4 class="modal-title" id="newsletter-config-modal-header">${newsletter['agent_label']} Newsletter Settings &nbsp;<small><span class="newsletter_id">(Newsletter ID: ${newsletter['id']})</span></small></h4>
</div>
<div class="modal-body">
<div class="container-fluid">
@ -50,10 +50,7 @@
</div>
<p class="help-block">
<span id="simple_cron_message">Set the schedule for the newsletter.</span>
<span id="custom_cron_message">
Set the schedule for the newsletter using a <a href="${anon_url('https://crontab.guru')}" target="_blank" rel="noreferrer">custom crontab</a>.
<a href="${anon_url('https://apscheduler.readthedocs.io/en/3.x/modules/triggers/cron.html#expression-types')}" target="_blank" rel="noreferrer">Click here</a> for a list of supported expressions.
</span>
<span id="custom_cron_message">Set the schedule for the newsletter using a <a href="${anon_url('https://crontab.guru')}" target="_blank" rel="noreferrer">custom crontab</a>. Only standard cron values are valid.</span>
</p>
</div>
<div class="form-group">
@ -484,7 +481,7 @@
});
if (${newsletter['config']['custom_cron']}) {
$('#cron_value').val('${newsletter['cron'] | n}');
$('#cron_value').val('${newsletter['cron']}');
} else {
try {
cron_widget.cron('value', '${newsletter['cron']}');

View file

@ -12,7 +12,7 @@
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal" aria-hidden="true"><i class="fa fa-remove"></i></button>
<h4 class="modal-title" id="notifier-config-modal-header">${notifier['agent_label']} Settings &nbsp;<small><span class="notifier_id">(Notifier ID: ${notifier['id']}${' - ' + notifier['friendly_name'] if notifier['friendly_name'] else ''})</span></small></h4>
<h4 class="modal-title" id="notifier-config-modal-header">${notifier['agent_label']} Settings &nbsp;<small><span class="notifier_id">(Notifier ID: ${notifier['id']})</span></small></h4>
</div>
<div class="modal-body">
<div class="container-fluid">

View file

@ -36,7 +36,7 @@ DOCUMENTATION :: END
%>
% if data:
<div class="dashboard-recent-media-row">
<div id="recently-added-row-scroller">
<div id="recently-added-row-scroller" style="left: 0;">
<ul class="dashboard-recent-media list-unstyled">
% for item in data:
<div class="dashboard-recent-media-instance">

View file

@ -68,14 +68,14 @@ DOCUMENTATION :: END
<table class="stream-info" style="margin-top: 0;">
<thead>
<tr>
<th></th>
<th class="heading">
Source Details
<th>
</th>
<th><i class="fa fa-long-arrow-right"></i></th>
<th class="heading">
Stream Details
</th>
<th class="heading">
Source Details
</th>
</tr>
</thead>
</table>
@ -85,46 +85,38 @@ DOCUMENTATION :: END
<th>
Media
</th>
<th></th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
<tr>
<td>Bitrate</td>
<td>${data['bitrate']} ${'kbps' if data['bitrate'] else ''}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_bitrate']} ${'kbps' if data['stream_bitrate'] else ''}</td>
<td>${data['bitrate']} ${'kbps' if data['bitrate'] else ''}</td>
</tr>
% if data['media_type'] != 'track':
<tr>
<td>Resolution</td>
<td>${data['video_full_resolution']}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_full_resolution']}</td>
<td>${data['video_full_resolution']}</td>
</tr>
% endif
<tr>
<td>Quality</td>
<td>-</td>
<td></td>
<td>${data['quality_profile']}</td>
<td>-</td>
</tr>
% if data['optimized_version'] == 1:
<tr>
<td>Optimized Version</td>
<td>${data['optimized_version_profile']}<br>(${data['optimized_version_title']})</td>
<td></td>
<td>-</td>
<td>${data['optimized_version_profile']}<br>(${data['optimized_version_title']})</td>
</tr>
% endif
% if data['synced_version'] == 1:
<tr>
<td>Synced Version</td>
<td>${data['synced_version_profile']}</td>
<td></td>
<td>-</td>
<td>${data['synced_version_profile']}</td>
</tr>
% endif
</tbody>
@ -135,8 +127,6 @@ DOCUMENTATION :: END
<th>
Container
</th>
<th></th>
<th></th>
<th>
${data['stream_container_decision']}
</th>
@ -145,9 +135,8 @@ DOCUMENTATION :: END
<tbody>
<tr>
<td>Container</td>
<td>${data['container'].upper()}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_container'].upper()}</td>
<td>${data['container'].upper()}</td>
</tr>
</tbody>
</table>
@ -158,8 +147,6 @@ DOCUMENTATION :: END
<th>
Video
</th>
<th></th>
<th></th>
<th>
${data['stream_video_decision']}
</th>
@ -168,45 +155,38 @@ DOCUMENTATION :: END
<tbody>
<tr>
<td>Codec</td>
<td>${data['video_codec'].upper()} ${'(HW)' if data['transcode_hw_decoding'] else ''}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_codec'].upper()} ${'(HW)' if data['transcode_hw_encoding'] else ''}</td>
<td>${data['video_codec'].upper()} ${'(HW)' if data['transcode_hw_decoding'] else ''}</td>
</tr>
<tr>
<td>Bitrate</td>
<td>${data['video_bitrate']} ${'kbps' if data['video_bitrate'] else ''}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_bitrate']} ${'kbps' if data['stream_video_bitrate'] else ''}</td>
<td>${data['video_bitrate']} ${'kbps' if data['video_bitrate'] else ''}</td>
</tr>
<tr>
<td>Width</td>
<td>${data['video_width']}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_width']}</td>
<td>${data['video_width']}</td>
</tr>
<tr>
<td>Height</td>
<td>${data['video_height']}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_height']}</td>
<td>${data['video_height']}</td>
</tr>
<tr>
<td>Framerate</td>
<td>${data['video_framerate']}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_framerate']}</td>
<td>${data['video_framerate']}</td>
</tr>
<tr>
<td>Dynamic Range</td>
<td>${data['video_dynamic_range']}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_video_dynamic_range']}</td>
<td>${data['video_dynamic_range']}</td>
</tr>
<tr>
<td>Aspect Ratio</td>
<td>${data['aspect_ratio']}</td>
<td></td>
<td>-</td>
<td>${data['aspect_ratio']}</td>
</tr>
</tbody>
</table>
@ -217,8 +197,6 @@ DOCUMENTATION :: END
<th>
Audio
</th>
<th></th>
<th></th>
<th>
${data['stream_audio_decision']}
</th>
@ -227,27 +205,23 @@ DOCUMENTATION :: END
<tbody>
<tr>
<td>Codec</td>
<td>${AUDIO_CODEC_OVERRIDES.get(data['audio_codec'], data['audio_codec'].upper())}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${AUDIO_CODEC_OVERRIDES.get(data['stream_audio_codec'], data['stream_audio_codec'].upper())}</td>
<td>${AUDIO_CODEC_OVERRIDES.get(data['audio_codec'], data['audio_codec'].upper())}</td>
</tr>
<tr>
<td>Bitrate</td>
<td>${data['audio_bitrate']} ${'kbps' if data['audio_bitrate'] else ''}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_audio_bitrate']} ${'kbps' if data['stream_audio_bitrate'] else ''}</td>
<td>${data['audio_bitrate']} ${'kbps' if data['audio_bitrate'] else ''}</td>
</tr>
<tr>
<td>Channels</td>
<td>${data['audio_channels']}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_audio_channels']}</td>
<td>${data['audio_channels']}</td>
</tr>
<tr>
<td>Language</td>
<td>${data['audio_language'] or 'Unknown'}</td>
<td></td>
<td>-</td>
<td>${data['audio_language'] or 'Unknown'}</td>
</tr>
</tbody>
@ -259,8 +233,6 @@ DOCUMENTATION :: END
<th>
Subtitles
</th>
<th></th>
<th></th>
<th>
${'direct play' if data['stream_subtitle_decision'] not in ('transcode', 'copy', 'burn') else data['stream_subtitle_decision']}
</th>
@ -269,22 +241,19 @@ DOCUMENTATION :: END
<tbody>
<tr>
<td>Codec</td>
<td>${data['subtitle_codec'].upper()}</td>
<td><i class="fa fa-long-arrow-right"></i></td>
<td>${data['stream_subtitle_codec'].upper() or '-'}</td>
<td>${data['subtitle_codec'].upper()}</td>
</tr>
<tr>
<td>Language</td>
<td>${data['subtitle_language'] or 'Unknown'}</td>
<td></td>
<td>-</td>
<td>${data['subtitle_language'] or 'Unknown'}</td>
</tr>
% if data['subtitle_forced']:
<tr>
<td>Forced</td>
<td>${bool(data['subtitle_forced'])}</td>
<td></td>
<td>-</td>
<td>${bool(data['subtitle_forced'])}</td>
</tr>
% endif
</tbody>

View file

@ -125,10 +125,10 @@ DOCUMENTATION :: END
<div class="table-card-header">
<ul class="nav nav-header nav-dashboard pull-right">
<li>
<a href="#" id="recently-watched-page-left" class="paginate-watched btn-gray disabled" data-id="-1"><i class="fa fa-lg fa-chevron-left"></i></a>
<a href="#" id="recently-watched-page-left" class="paginate btn-gray disabled" data-id="+1"><i class="fa fa-lg fa-chevron-left"></i></a>
</li>
<li>
<a href="#" id="recently-watched-page-right" class="paginate-watched btn-gray" data-id="+1"><i class="fa fa-lg fa-chevron-right"></i></a>
<a href="#" id="recently-watched-page-right" class="paginate btn-gray" data-id="-1"><i class="fa fa-lg fa-chevron-right"></i></a>
</li>
</ul>
<div class="header-bar">
@ -666,14 +666,52 @@ DOCUMENTATION :: END
},
complete: function(xhr, status) {
$("#user-recently-watched").html(xhr.responseText);
highlightScrollerButton("#recently-watched");
paginateScroller("#recently-watched", ".paginate-watched");
highlightWatchedScrollerButton();
}
});
}
recentlyWatched();
function highlightWatchedScrollerButton() {
var scroller = $("#recently-watched-row-scroller");
var numElems = scroller.find("li").length;
scroller.width(numElems * 175);
if (scroller.width() > $("#user-recently-watched").width()) {
$("#recently-watched-page-right").removeClass("disabled");
} else {
$("#recently-watched-page-right").addClass("disabled");
}
}
$(window).resize(function() {
highlightWatchedScrollerButton();
});
var leftTotal = 0;
$(".paginate").click(function (e) {
e.preventDefault();
var scroller = $("#recently-watched-row-scroller");
var containerWidth = $("#user-recently-watched").width();
var scrollAmount = $(this).data("id") * parseInt(containerWidth / 175) * 175;
var leftMax = Math.min(-parseInt(scroller.width()) + Math.abs(scrollAmount), 0);
leftTotal = Math.max(Math.min(leftTotal + scrollAmount, 0), leftMax);
scroller.animate({ left: leftTotal }, 250);
if (leftTotal == 0) {
$("#recently-watched-page-left").addClass("disabled").blur();
} else {
$("#recently-watched-page-left").removeClass("disabled");
}
if (leftTotal == leftMax) {
$("#recently-watched-page-right").addClass("disabled").blur();
} else {
$("#recently-watched-page-right").removeClass("disabled");
}
});
$(document).ready(function () {
// Javascript to enable link to tab
var hash = document.location.hash;

View file

@ -31,7 +31,7 @@ DOCUMENTATION :: END
from plexpy.helpers import page, short_season
%>
<div class="dashboard-recent-media-row">
<div id="recently-watched-row-scroller">
<div id="recently-watched-row-scroller" style="left: 0;">
<ul class="dashboard-recent-media list-unstyled">
% for item in data:
<li>

View file

@ -0,0 +1,49 @@
__all__ = [
"ZoneInfo",
"reset_tzpath",
"available_timezones",
"TZPATH",
"ZoneInfoNotFoundError",
"InvalidTZPathWarning",
]
import sys
from . import _tzpath
from ._common import ZoneInfoNotFoundError
from ._version import __version__
try:
from ._czoneinfo import ZoneInfo
except ImportError: # pragma: nocover
from ._zoneinfo import ZoneInfo
reset_tzpath = _tzpath.reset_tzpath
available_timezones = _tzpath.available_timezones
InvalidTZPathWarning = _tzpath.InvalidTZPathWarning
if sys.version_info < (3, 7):
# Module-level __getattr__ was added in Python 3.7, so instead of lazily
# populating TZPATH on every access, we will register a callback with
# reset_tzpath to update the top-level tuple.
TZPATH = _tzpath.TZPATH
def _tzpath_callback(new_tzpath):
global TZPATH
TZPATH = new_tzpath
_tzpath.TZPATH_CALLBACKS.append(_tzpath_callback)
del _tzpath_callback
else:
def __getattr__(name):
if name == "TZPATH":
return _tzpath.TZPATH
else:
raise AttributeError(
f"module {__name__!r} has no attribute {name!r}"
)
def __dir__():
return sorted(list(globals()) + ["TZPATH"])

View file

@ -0,0 +1,45 @@
import os
import typing
from datetime import datetime, tzinfo
from typing import (
Any,
Iterable,
Optional,
Protocol,
Sequence,
Set,
Type,
Union,
)
_T = typing.TypeVar("_T", bound="ZoneInfo")
class _IOBytes(Protocol):
def read(self, __size: int) -> bytes: ...
def seek(self, __size: int, __whence: int = ...) -> Any: ...
class ZoneInfo(tzinfo):
@property
def key(self) -> str: ...
def __init__(self, key: str) -> None: ...
@classmethod
def no_cache(cls: Type[_T], key: str) -> _T: ...
@classmethod
def from_file(
cls: Type[_T], __fobj: _IOBytes, key: Optional[str] = ...
) -> _T: ...
@classmethod
def clear_cache(cls, *, only_keys: Iterable[str] = ...) -> None: ...
# Note: Both here and in clear_cache, the types allow the use of `str` where
# a sequence of strings is required. This should be remedied if a solution
# to this typing bug is found: https://github.com/python/typing/issues/256
def reset_tzpath(
to: Optional[Sequence[Union[os.PathLike, str]]] = ...
) -> None: ...
def available_timezones() -> Set[str]: ...
TZPATH: Sequence[str]
class ZoneInfoNotFoundError(KeyError): ...
class InvalidTZPathWarning(RuntimeWarning): ...

View file

@ -0,0 +1,171 @@
import struct
def load_tzdata(key):
try:
import importlib.resources as importlib_resources
except ImportError:
import importlib_resources
components = key.split("/")
package_name = ".".join(["tzdata.zoneinfo"] + components[:-1])
resource_name = components[-1]
try:
return importlib_resources.open_binary(package_name, resource_name)
except (ImportError, FileNotFoundError, UnicodeEncodeError):
# There are three types of exception that can be raised that all amount
# to "we cannot find this key":
#
# ImportError: If package_name doesn't exist (e.g. if tzdata is not
# installed, or if there's an error in the folder name like
# Amrica/New_York)
# FileNotFoundError: If resource_name doesn't exist in the package
# (e.g. Europe/Krasnoy)
# UnicodeEncodeError: If package_name or resource_name are not UTF-8,
# such as keys containing a surrogate character.
raise ZoneInfoNotFoundError(f"No time zone found with key {key}")
def load_data(fobj):
header = _TZifHeader.from_file(fobj)
if header.version == 1:
time_size = 4
time_type = "l"
else:
# Version 2+ has 64-bit integer transition times
time_size = 8
time_type = "q"
# Version 2+ also starts with a Version 1 header and data, which
# we need to skip now
skip_bytes = (
header.timecnt * 5 # Transition times and types
+ header.typecnt * 6 # Local time type records
+ header.charcnt # Time zone designations
+ header.leapcnt * 8 # Leap second records
+ header.isstdcnt # Standard/wall indicators
+ header.isutcnt # UT/local indicators
)
fobj.seek(skip_bytes, 1)
# Now we need to read the second header, which is not the same
# as the first
header = _TZifHeader.from_file(fobj)
typecnt = header.typecnt
timecnt = header.timecnt
charcnt = header.charcnt
# The data portion starts with timecnt transitions and indices
if timecnt:
trans_list_utc = struct.unpack(
f">{timecnt}{time_type}", fobj.read(timecnt * time_size)
)
trans_idx = struct.unpack(f">{timecnt}B", fobj.read(timecnt))
else:
trans_list_utc = ()
trans_idx = ()
# Read the ttinfo struct, (utoff, isdst, abbrind)
if typecnt:
utcoff, isdst, abbrind = zip(
*(struct.unpack(">lbb", fobj.read(6)) for i in range(typecnt))
)
else:
utcoff = ()
isdst = ()
abbrind = ()
# Now read the abbreviations. They are null-terminated strings, indexed
# not by position in the array but by position in the unsplit
# abbreviation string. I suppose this makes more sense in C, which uses
# null to terminate the strings, but it's inconvenient here...
abbr_vals = {}
abbr_chars = fobj.read(charcnt)
def get_abbr(idx):
# Gets a string starting at idx and running until the next \x00
#
# We cannot pre-populate abbr_vals by splitting on \x00 because there
# are some zones that use subsets of longer abbreviations, like so:
#
# LMT\x00AHST\x00HDT\x00
#
# Where the idx to abbr mapping should be:
#
# {0: "LMT", 4: "AHST", 5: "HST", 9: "HDT"}
if idx not in abbr_vals:
span_end = abbr_chars.find(b"\x00", idx)
abbr_vals[idx] = abbr_chars[idx:span_end].decode()
return abbr_vals[idx]
abbr = tuple(get_abbr(idx) for idx in abbrind)
# The remainder of the file consists of leap seconds (currently unused) and
# the standard/wall and ut/local indicators, which are metadata we don't need.
# In version 2 files, we need to skip the unnecessary data to get at the TZ string:
if header.version >= 2:
# Each leap second record has size (time_size + 4)
skip_bytes = header.isutcnt + header.isstdcnt + header.leapcnt * 12
fobj.seek(skip_bytes, 1)
c = fobj.read(1) # Should be \n
assert c == b"\n", c
tz_bytes = b""
while True:
c = fobj.read(1)
if c == b"\n":
break
tz_bytes += c
tz_str = tz_bytes
else:
tz_str = None
return trans_idx, trans_list_utc, utcoff, isdst, abbr, tz_str
class _TZifHeader:
__slots__ = [
"version",
"isutcnt",
"isstdcnt",
"leapcnt",
"timecnt",
"typecnt",
"charcnt",
]
def __init__(self, *args):
assert len(self.__slots__) == len(args)
for attr, val in zip(self.__slots__, args):
setattr(self, attr, val)
@classmethod
def from_file(cls, stream):
# The header starts with a 4-byte "magic" value
if stream.read(4) != b"TZif":
raise ValueError("Invalid TZif file: magic not found")
_version = stream.read(1)
if _version == b"\x00":
version = 1
else:
version = int(_version)
stream.read(15)
args = (version,)
# Slots are defined in the order that the bytes are arranged
args = args + struct.unpack(">6l", stream.read(24))
return cls(*args)
class ZoneInfoNotFoundError(KeyError):
"""Exception raised when a ZoneInfo key is not found."""

View file

@ -0,0 +1,207 @@
import os
import sys
PY36 = sys.version_info < (3, 7)
def reset_tzpath(to=None):
global TZPATH
tzpaths = to
if tzpaths is not None:
if isinstance(tzpaths, (str, bytes)):
raise TypeError(
f"tzpaths must be a list or tuple, "
+ f"not {type(tzpaths)}: {tzpaths!r}"
)
if not all(map(os.path.isabs, tzpaths)):
raise ValueError(_get_invalid_paths_message(tzpaths))
base_tzpath = tzpaths
else:
env_var = os.environ.get("PYTHONTZPATH", None)
if env_var is not None:
base_tzpath = _parse_python_tzpath(env_var)
elif sys.platform != "win32":
base_tzpath = [
"/usr/share/zoneinfo",
"/usr/lib/zoneinfo",
"/usr/share/lib/zoneinfo",
"/etc/zoneinfo",
]
base_tzpath.sort(key=lambda x: not os.path.exists(x))
else:
base_tzpath = ()
TZPATH = tuple(base_tzpath)
if TZPATH_CALLBACKS:
for callback in TZPATH_CALLBACKS:
callback(TZPATH)
def _parse_python_tzpath(env_var):
if not env_var:
return ()
raw_tzpath = env_var.split(os.pathsep)
new_tzpath = tuple(filter(os.path.isabs, raw_tzpath))
# If anything has been filtered out, we will warn about it
if len(new_tzpath) != len(raw_tzpath):
import warnings
msg = _get_invalid_paths_message(raw_tzpath)
warnings.warn(
"Invalid paths specified in PYTHONTZPATH environment variable."
+ msg,
InvalidTZPathWarning,
)
return new_tzpath
def _get_invalid_paths_message(tzpaths):
invalid_paths = (path for path in tzpaths if not os.path.isabs(path))
prefix = "\n "
indented_str = prefix + prefix.join(invalid_paths)
return (
"Paths should be absolute but found the following relative paths:"
+ indented_str
)
if sys.version_info < (3, 8):
def _isfile(path):
# bpo-33721: In Python 3.8 non-UTF8 paths return False rather than
# raising an error. See https://bugs.python.org/issue33721
try:
return os.path.isfile(path)
except ValueError:
return False
else:
_isfile = os.path.isfile
def find_tzfile(key):
"""Retrieve the path to a TZif file from a key."""
_validate_tzfile_path(key)
for search_path in TZPATH:
filepath = os.path.join(search_path, key)
if _isfile(filepath):
return filepath
return None
_TEST_PATH = os.path.normpath(os.path.join("_", "_"))[:-1]
def _validate_tzfile_path(path, _base=_TEST_PATH):
if os.path.isabs(path):
raise ValueError(
f"ZoneInfo keys may not be absolute paths, got: {path}"
)
# We only care about the kinds of path normalizations that would change the
# length of the key - e.g. a/../b -> a/b, or a/b/ -> a/b. On Windows,
# normpath will also change from a/b to a\b, but that would still preserve
# the length.
new_path = os.path.normpath(path)
if len(new_path) != len(path):
raise ValueError(
f"ZoneInfo keys must be normalized relative paths, got: {path}"
)
resolved = os.path.normpath(os.path.join(_base, new_path))
if not resolved.startswith(_base):
raise ValueError(
f"ZoneInfo keys must refer to subdirectories of TZPATH, got: {path}"
)
del _TEST_PATH
def available_timezones():
"""Returns a set containing all available time zones.
.. caution::
This may attempt to open a large number of files, since the best way to
determine if a given file on the time zone search path is to open it
and check for the "magic string" at the beginning.
"""
try:
from importlib import resources
except ImportError:
import importlib_resources as resources
valid_zones = set()
# Start with loading from the tzdata package if it exists: this has a
# pre-assembled list of zones that only requires opening one file.
try:
with resources.open_text("tzdata", "zones") as f:
for zone in f:
zone = zone.strip()
if zone:
valid_zones.add(zone)
except (ImportError, FileNotFoundError):
pass
def valid_key(fpath):
try:
with open(fpath, "rb") as f:
return f.read(4) == b"TZif"
except Exception: # pragma: nocover
return False
for tz_root in TZPATH:
if not os.path.exists(tz_root):
continue
for root, dirnames, files in os.walk(tz_root):
if root == tz_root:
# right/ and posix/ are special directories and shouldn't be
# included in the output of available zones
if "right" in dirnames:
dirnames.remove("right")
if "posix" in dirnames:
dirnames.remove("posix")
for file in files:
fpath = os.path.join(root, file)
key = os.path.relpath(fpath, start=tz_root)
if os.sep != "/": # pragma: nocover
key = key.replace(os.sep, "/")
if not key or key in valid_zones:
continue
if valid_key(fpath):
valid_zones.add(key)
if "posixrules" in valid_zones:
# posixrules is a special symlink-only time zone where it exists, it
# should not be included in the output
valid_zones.remove("posixrules")
return valid_zones
class InvalidTZPathWarning(RuntimeWarning):
"""Warning raised if an invalid path is specified in PYTHONTZPATH."""
TZPATH = ()
TZPATH_CALLBACKS = []
reset_tzpath()

View file

@ -0,0 +1 @@
__version__ = "0.2.1"

View file

@ -0,0 +1,754 @@
import bisect
import calendar
import collections
import functools
import re
import weakref
from datetime import datetime, timedelta, tzinfo
from . import _common, _tzpath
EPOCH = datetime(1970, 1, 1)
EPOCHORDINAL = datetime(1970, 1, 1).toordinal()
# It is relatively expensive to construct new timedelta objects, and in most
# cases we're looking at the same deltas, like integer numbers of hours, etc.
# To improve speed and memory use, we'll keep a dictionary with references
# to the ones we've already used so far.
#
# Loading every time zone in the 2020a version of the time zone database
# requires 447 timedeltas, which requires approximately the amount of space
# that ZoneInfo("America/New_York") with 236 transitions takes up, so we will
# set the cache size to 512 so that in the common case we always get cache
# hits, but specifically crafted ZoneInfo objects don't leak arbitrary amounts
# of memory.
@functools.lru_cache(maxsize=512)
def _load_timedelta(seconds):
return timedelta(seconds=seconds)
class ZoneInfo(tzinfo):
_strong_cache_size = 8
_strong_cache = collections.OrderedDict()
_weak_cache = weakref.WeakValueDictionary()
__module__ = "backports.zoneinfo"
def __init_subclass__(cls):
cls._strong_cache = collections.OrderedDict()
cls._weak_cache = weakref.WeakValueDictionary()
def __new__(cls, key):
instance = cls._weak_cache.get(key, None)
if instance is None:
instance = cls._weak_cache.setdefault(key, cls._new_instance(key))
instance._from_cache = True
# Update the "strong" cache
cls._strong_cache[key] = cls._strong_cache.pop(key, instance)
if len(cls._strong_cache) > cls._strong_cache_size:
cls._strong_cache.popitem(last=False)
return instance
@classmethod
def no_cache(cls, key):
obj = cls._new_instance(key)
obj._from_cache = False
return obj
@classmethod
def _new_instance(cls, key):
obj = super().__new__(cls)
obj._key = key
obj._file_path = obj._find_tzfile(key)
if obj._file_path is not None:
file_obj = open(obj._file_path, "rb")
else:
file_obj = _common.load_tzdata(key)
with file_obj as f:
obj._load_file(f)
return obj
@classmethod
def from_file(cls, fobj, key=None):
obj = super().__new__(cls)
obj._key = key
obj._file_path = None
obj._load_file(fobj)
obj._file_repr = repr(fobj)
# Disable pickling for objects created from files
obj.__reduce__ = obj._file_reduce
return obj
@classmethod
def clear_cache(cls, *, only_keys=None):
if only_keys is not None:
for key in only_keys:
cls._weak_cache.pop(key, None)
cls._strong_cache.pop(key, None)
else:
cls._weak_cache.clear()
cls._strong_cache.clear()
@property
def key(self):
return self._key
def utcoffset(self, dt):
return self._find_trans(dt).utcoff
def dst(self, dt):
return self._find_trans(dt).dstoff
def tzname(self, dt):
return self._find_trans(dt).tzname
def fromutc(self, dt):
"""Convert from datetime in UTC to datetime in local time"""
if not isinstance(dt, datetime):
raise TypeError("fromutc() requires a datetime argument")
if dt.tzinfo is not self:
raise ValueError("dt.tzinfo is not self")
timestamp = self._get_local_timestamp(dt)
num_trans = len(self._trans_utc)
if num_trans >= 1 and timestamp < self._trans_utc[0]:
tti = self._tti_before
fold = 0
elif (
num_trans == 0 or timestamp > self._trans_utc[-1]
) and not isinstance(self._tz_after, _ttinfo):
tti, fold = self._tz_after.get_trans_info_fromutc(
timestamp, dt.year
)
elif num_trans == 0:
tti = self._tz_after
fold = 0
else:
idx = bisect.bisect_right(self._trans_utc, timestamp)
if num_trans > 1 and timestamp >= self._trans_utc[1]:
tti_prev, tti = self._ttinfos[idx - 2 : idx]
elif timestamp > self._trans_utc[-1]:
tti_prev = self._ttinfos[-1]
tti = self._tz_after
else:
tti_prev = self._tti_before
tti = self._ttinfos[0]
# Detect fold
shift = tti_prev.utcoff - tti.utcoff
fold = shift.total_seconds() > timestamp - self._trans_utc[idx - 1]
dt += tti.utcoff
if fold:
return dt.replace(fold=1)
else:
return dt
def _find_trans(self, dt):
if dt is None:
if self._fixed_offset:
return self._tz_after
else:
return _NO_TTINFO
ts = self._get_local_timestamp(dt)
lt = self._trans_local[dt.fold]
num_trans = len(lt)
if num_trans and ts < lt[0]:
return self._tti_before
elif not num_trans or ts > lt[-1]:
if isinstance(self._tz_after, _TZStr):
return self._tz_after.get_trans_info(ts, dt.year, dt.fold)
else:
return self._tz_after
else:
# idx is the transition that occurs after this timestamp, so we
# subtract off 1 to get the current ttinfo
idx = bisect.bisect_right(lt, ts) - 1
assert idx >= 0
return self._ttinfos[idx]
def _get_local_timestamp(self, dt):
return (
(dt.toordinal() - EPOCHORDINAL) * 86400
+ dt.hour * 3600
+ dt.minute * 60
+ dt.second
)
def __str__(self):
if self._key is not None:
return f"{self._key}"
else:
return repr(self)
def __repr__(self):
if self._key is not None:
return f"{self.__class__.__name__}(key={self._key!r})"
else:
return f"{self.__class__.__name__}.from_file({self._file_repr})"
def __reduce__(self):
return (self.__class__._unpickle, (self._key, self._from_cache))
def _file_reduce(self):
import pickle
raise pickle.PicklingError(
"Cannot pickle a ZoneInfo file created from a file stream."
)
@classmethod
def _unpickle(cls, key, from_cache):
if from_cache:
return cls(key)
else:
return cls.no_cache(key)
def _find_tzfile(self, key):
return _tzpath.find_tzfile(key)
def _load_file(self, fobj):
# Retrieve all the data as it exists in the zoneinfo file
trans_idx, trans_utc, utcoff, isdst, abbr, tz_str = _common.load_data(
fobj
)
# Infer the DST offsets (needed for .dst()) from the data
dstoff = self._utcoff_to_dstoff(trans_idx, utcoff, isdst)
# Convert all the transition times (UTC) into "seconds since 1970-01-01 local time"
trans_local = self._ts_to_local(trans_idx, trans_utc, utcoff)
# Construct `_ttinfo` objects for each transition in the file
_ttinfo_list = [
_ttinfo(
_load_timedelta(utcoffset), _load_timedelta(dstoffset), tzname
)
for utcoffset, dstoffset, tzname in zip(utcoff, dstoff, abbr)
]
self._trans_utc = trans_utc
self._trans_local = trans_local
self._ttinfos = [_ttinfo_list[idx] for idx in trans_idx]
# Find the first non-DST transition
for i in range(len(isdst)):
if not isdst[i]:
self._tti_before = _ttinfo_list[i]
break
else:
if self._ttinfos:
self._tti_before = self._ttinfos[0]
else:
self._tti_before = None
# Set the "fallback" time zone
if tz_str is not None and tz_str != b"":
self._tz_after = _parse_tz_str(tz_str.decode())
else:
if not self._ttinfos and not _ttinfo_list:
raise ValueError("No time zone information found.")
if self._ttinfos:
self._tz_after = self._ttinfos[-1]
else:
self._tz_after = _ttinfo_list[-1]
# Determine if this is a "fixed offset" zone, meaning that the output
# of the utcoffset, dst and tzname functions does not depend on the
# specific datetime passed.
#
# We make three simplifying assumptions here:
#
# 1. If _tz_after is not a _ttinfo, it has transitions that might
# actually occur (it is possible to construct TZ strings that
# specify STD and DST but no transitions ever occur, such as
# AAA0BBB,0/0,J365/25).
# 2. If _ttinfo_list contains more than one _ttinfo object, the objects
# represent different offsets.
# 3. _ttinfo_list contains no unused _ttinfos (in which case an
# otherwise fixed-offset zone with extra _ttinfos defined may
# appear to *not* be a fixed offset zone).
#
# Violations to these assumptions would be fairly exotic, and exotic
# zones should almost certainly not be used with datetime.time (the
# only thing that would be affected by this).
if len(_ttinfo_list) > 1 or not isinstance(self._tz_after, _ttinfo):
self._fixed_offset = False
elif not _ttinfo_list:
self._fixed_offset = True
else:
self._fixed_offset = _ttinfo_list[0] == self._tz_after
@staticmethod
def _utcoff_to_dstoff(trans_idx, utcoffsets, isdsts):
# Now we must transform our ttis and abbrs into `_ttinfo` objects,
# but there is an issue: .dst() must return a timedelta with the
# difference between utcoffset() and the "standard" offset, but
# the "base offset" and "DST offset" are not encoded in the file;
# we can infer what they are from the isdst flag, but it is not
# sufficient to to just look at the last standard offset, because
# occasionally countries will shift both DST offset and base offset.
typecnt = len(isdsts)
dstoffs = [0] * typecnt # Provisionally assign all to 0.
dst_cnt = sum(isdsts)
dst_found = 0
for i in range(1, len(trans_idx)):
if dst_cnt == dst_found:
break
idx = trans_idx[i]
dst = isdsts[idx]
# We're only going to look at daylight saving time
if not dst:
continue
# Skip any offsets that have already been assigned
if dstoffs[idx] != 0:
continue
dstoff = 0
utcoff = utcoffsets[idx]
comp_idx = trans_idx[i - 1]
if not isdsts[comp_idx]:
dstoff = utcoff - utcoffsets[comp_idx]
if not dstoff and idx < (typecnt - 1):
comp_idx = trans_idx[i + 1]
# If the following transition is also DST and we couldn't
# find the DST offset by this point, we're going ot have to
# skip it and hope this transition gets assigned later
if isdsts[comp_idx]:
continue
dstoff = utcoff - utcoffsets[comp_idx]
if dstoff:
dst_found += 1
dstoffs[idx] = dstoff
else:
# If we didn't find a valid value for a given index, we'll end up
# with dstoff = 0 for something where `isdst=1`. This is obviously
# wrong - one hour will be a much better guess than 0
for idx in range(typecnt):
if not dstoffs[idx] and isdsts[idx]:
dstoffs[idx] = 3600
return dstoffs
@staticmethod
def _ts_to_local(trans_idx, trans_list_utc, utcoffsets):
"""Generate number of seconds since 1970 *in the local time*.
This is necessary to easily find the transition times in local time"""
if not trans_list_utc:
return [[], []]
# Start with the timestamps and modify in-place
trans_list_wall = [list(trans_list_utc), list(trans_list_utc)]
if len(utcoffsets) > 1:
offset_0 = utcoffsets[0]
offset_1 = utcoffsets[trans_idx[0]]
if offset_1 > offset_0:
offset_1, offset_0 = offset_0, offset_1
else:
offset_0 = offset_1 = utcoffsets[0]
trans_list_wall[0][0] += offset_0
trans_list_wall[1][0] += offset_1
for i in range(1, len(trans_idx)):
offset_0 = utcoffsets[trans_idx[i - 1]]
offset_1 = utcoffsets[trans_idx[i]]
if offset_1 > offset_0:
offset_1, offset_0 = offset_0, offset_1
trans_list_wall[0][i] += offset_0
trans_list_wall[1][i] += offset_1
return trans_list_wall
class _ttinfo:
__slots__ = ["utcoff", "dstoff", "tzname"]
def __init__(self, utcoff, dstoff, tzname):
self.utcoff = utcoff
self.dstoff = dstoff
self.tzname = tzname
def __eq__(self, other):
return (
self.utcoff == other.utcoff
and self.dstoff == other.dstoff
and self.tzname == other.tzname
)
def __repr__(self): # pragma: nocover
return (
f"{self.__class__.__name__}"
+ f"({self.utcoff}, {self.dstoff}, {self.tzname})"
)
_NO_TTINFO = _ttinfo(None, None, None)
class _TZStr:
__slots__ = (
"std",
"dst",
"start",
"end",
"get_trans_info",
"get_trans_info_fromutc",
"dst_diff",
)
def __init__(
self, std_abbr, std_offset, dst_abbr, dst_offset, start=None, end=None
):
self.dst_diff = dst_offset - std_offset
std_offset = _load_timedelta(std_offset)
self.std = _ttinfo(
utcoff=std_offset, dstoff=_load_timedelta(0), tzname=std_abbr
)
self.start = start
self.end = end
dst_offset = _load_timedelta(dst_offset)
delta = _load_timedelta(self.dst_diff)
self.dst = _ttinfo(utcoff=dst_offset, dstoff=delta, tzname=dst_abbr)
# These are assertions because the constructor should only be called
# by functions that would fail before passing start or end
assert start is not None, "No transition start specified"
assert end is not None, "No transition end specified"
self.get_trans_info = self._get_trans_info
self.get_trans_info_fromutc = self._get_trans_info_fromutc
def transitions(self, year):
start = self.start.year_to_epoch(year)
end = self.end.year_to_epoch(year)
return start, end
def _get_trans_info(self, ts, year, fold):
"""Get the information about the current transition - tti"""
start, end = self.transitions(year)
# With fold = 0, the period (denominated in local time) with the
# smaller offset starts at the end of the gap and ends at the end of
# the fold; with fold = 1, it runs from the start of the gap to the
# beginning of the fold.
#
# So in order to determine the DST boundaries we need to know both
# the fold and whether DST is positive or negative (rare), and it
# turns out that this boils down to fold XOR is_positive.
if fold == (self.dst_diff >= 0):
end -= self.dst_diff
else:
start += self.dst_diff
if start < end:
isdst = start <= ts < end
else:
isdst = not (end <= ts < start)
return self.dst if isdst else self.std
def _get_trans_info_fromutc(self, ts, year):
start, end = self.transitions(year)
start -= self.std.utcoff.total_seconds()
end -= self.dst.utcoff.total_seconds()
if start < end:
isdst = start <= ts < end
else:
isdst = not (end <= ts < start)
# For positive DST, the ambiguous period is one dst_diff after the end
# of DST; for negative DST, the ambiguous period is one dst_diff before
# the start of DST.
if self.dst_diff > 0:
ambig_start = end
ambig_end = end + self.dst_diff
else:
ambig_start = start
ambig_end = start - self.dst_diff
fold = ambig_start <= ts < ambig_end
return (self.dst if isdst else self.std, fold)
def _post_epoch_days_before_year(year):
"""Get the number of days between 1970-01-01 and YEAR-01-01"""
y = year - 1
return y * 365 + y // 4 - y // 100 + y // 400 - EPOCHORDINAL
class _DayOffset:
__slots__ = ["d", "julian", "hour", "minute", "second"]
def __init__(self, d, julian, hour=2, minute=0, second=0):
if not (0 + julian) <= d <= 365:
min_day = 0 + julian
raise ValueError(f"d must be in [{min_day}, 365], not: {d}")
self.d = d
self.julian = julian
self.hour = hour
self.minute = minute
self.second = second
def year_to_epoch(self, year):
days_before_year = _post_epoch_days_before_year(year)
d = self.d
if self.julian and d >= 59 and calendar.isleap(year):
d += 1
epoch = (days_before_year + d) * 86400
epoch += self.hour * 3600 + self.minute * 60 + self.second
return epoch
class _CalendarOffset:
__slots__ = ["m", "w", "d", "hour", "minute", "second"]
_DAYS_BEFORE_MONTH = (
-1,
0,
31,
59,
90,
120,
151,
181,
212,
243,
273,
304,
334,
)
def __init__(self, m, w, d, hour=2, minute=0, second=0):
if not 0 < m <= 12:
raise ValueError("m must be in (0, 12]")
if not 0 < w <= 5:
raise ValueError("w must be in (0, 5]")
if not 0 <= d <= 6:
raise ValueError("d must be in [0, 6]")
self.m = m
self.w = w
self.d = d
self.hour = hour
self.minute = minute
self.second = second
@classmethod
def _ymd2ord(cls, year, month, day):
return (
_post_epoch_days_before_year(year)
+ cls._DAYS_BEFORE_MONTH[month]
+ (month > 2 and calendar.isleap(year))
+ day
)
# TODO: These are not actually epoch dates as they are expressed in local time
def year_to_epoch(self, year):
"""Calculates the datetime of the occurrence from the year"""
# We know year and month, we need to convert w, d into day of month
#
# Week 1 is the first week in which day `d` (where 0 = Sunday) appears.
# Week 5 represents the last occurrence of day `d`, so we need to know
# the range of the month.
first_day, days_in_month = calendar.monthrange(year, self.m)
# This equation seems magical, so I'll break it down:
# 1. calendar says 0 = Monday, POSIX says 0 = Sunday
# so we need first_day + 1 to get 1 = Monday -> 7 = Sunday,
# which is still equivalent because this math is mod 7
# 2. Get first day - desired day mod 7: -1 % 7 = 6, so we don't need
# to do anything to adjust negative numbers.
# 3. Add 1 because month days are a 1-based index.
month_day = (self.d - (first_day + 1)) % 7 + 1
# Now use a 0-based index version of `w` to calculate the w-th
# occurrence of `d`
month_day += (self.w - 1) * 7
# month_day will only be > days_in_month if w was 5, and `w` means
# "last occurrence of `d`", so now we just check if we over-shot the
# end of the month and if so knock off 1 week.
if month_day > days_in_month:
month_day -= 7
ordinal = self._ymd2ord(year, self.m, month_day)
epoch = ordinal * 86400
epoch += self.hour * 3600 + self.minute * 60 + self.second
return epoch
def _parse_tz_str(tz_str):
# The tz string has the format:
#
# std[offset[dst[offset],start[/time],end[/time]]]
#
# std and dst must be 3 or more characters long and must not contain
# a leading colon, embedded digits, commas, nor a plus or minus signs;
# The spaces between "std" and "offset" are only for display and are
# not actually present in the string.
#
# The format of the offset is ``[+|-]hh[:mm[:ss]]``
offset_str, *start_end_str = tz_str.split(",", 1)
# fmt: off
parser_re = re.compile(
r"(?P<std>[^<0-9:.+-]+|<[a-zA-Z0-9+\-]+>)" +
r"((?P<stdoff>[+-]?\d{1,2}(:\d{2}(:\d{2})?)?)" +
r"((?P<dst>[^0-9:.+-]+|<[a-zA-Z0-9+\-]+>)" +
r"((?P<dstoff>[+-]?\d{1,2}(:\d{2}(:\d{2})?)?))?" +
r")?" + # dst
r")?$" # stdoff
)
# fmt: on
m = parser_re.match(offset_str)
if m is None:
raise ValueError(f"{tz_str} is not a valid TZ string")
std_abbr = m.group("std")
dst_abbr = m.group("dst")
dst_offset = None
std_abbr = std_abbr.strip("<>")
if dst_abbr:
dst_abbr = dst_abbr.strip("<>")
std_offset = m.group("stdoff")
if std_offset:
try:
std_offset = _parse_tz_delta(std_offset)
except ValueError as e:
raise ValueError(f"Invalid STD offset in {tz_str}") from e
else:
std_offset = 0
if dst_abbr is not None:
dst_offset = m.group("dstoff")
if dst_offset:
try:
dst_offset = _parse_tz_delta(dst_offset)
except ValueError as e:
raise ValueError(f"Invalid DST offset in {tz_str}") from e
else:
dst_offset = std_offset + 3600
if not start_end_str:
raise ValueError(f"Missing transition rules: {tz_str}")
start_end_strs = start_end_str[0].split(",", 1)
try:
start, end = (_parse_dst_start_end(x) for x in start_end_strs)
except ValueError as e:
raise ValueError(f"Invalid TZ string: {tz_str}") from e
return _TZStr(std_abbr, std_offset, dst_abbr, dst_offset, start, end)
elif start_end_str:
raise ValueError(f"Transition rule present without DST: {tz_str}")
else:
# This is a static ttinfo, don't return _TZStr
return _ttinfo(
_load_timedelta(std_offset), _load_timedelta(0), std_abbr
)
def _parse_dst_start_end(dststr):
date, *time = dststr.split("/")
if date[0] == "M":
n_is_julian = False
m = re.match(r"M(\d{1,2})\.(\d).(\d)$", date)
if m is None:
raise ValueError(f"Invalid dst start/end date: {dststr}")
date_offset = tuple(map(int, m.groups()))
offset = _CalendarOffset(*date_offset)
else:
if date[0] == "J":
n_is_julian = True
date = date[1:]
else:
n_is_julian = False
doy = int(date)
offset = _DayOffset(doy, n_is_julian)
if time:
time_components = list(map(int, time[0].split(":")))
n_components = len(time_components)
if n_components < 3:
time_components.extend([0] * (3 - n_components))
offset.hour, offset.minute, offset.second = time_components
return offset
def _parse_tz_delta(tz_delta):
match = re.match(
r"(?P<sign>[+-])?(?P<h>\d{1,2})(:(?P<m>\d{2})(:(?P<s>\d{2}))?)?",
tz_delta,
)
# Anything passed to this function should already have hit an equivalent
# regular expression to find the section to parse.
assert match is not None, tz_delta
h, m, s = (
int(v) if v is not None else 0
for v in map(match.group, ("h", "m", "s"))
)
total = h * 3600 + m * 60 + s
if not -86400 < total < 86400:
raise ValueError(
"Offset must be strictly between -24h and +24h:" + tz_delta
)
# Yes, +5 maps to an offset of -5h
if match.group("sign") != "-":
total *= -1
return total

View file

View file

@ -11,9 +11,9 @@ from bleach.sanitizer import (
# yyyymmdd
__releasedate__ = "20241029"
__releasedate__ = "20231006"
# x.y.z or x.y.z.dev0 -- semver
__version__ = "6.2.0"
__version__ = "6.1.0"
__all__ = ["clean", "linkify"]

View file

@ -1,7 +1,7 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import text_type
from bleach.six_shim import http_client, urllib
from six import text_type
from six.moves import http_client, urllib
import codecs
import re

View file

@ -1,6 +1,6 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import unichr as chr
from six import unichr as chr
from collections import deque, OrderedDict
from sys import version_info

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import text_type
from six import text_type
from bisect import bisect_left

View file

@ -7,7 +7,7 @@ try:
except ImportError:
from collections import Mapping
from bleach.six_shim import text_type, PY3
from six import text_type, PY3
if PY3:
import xml.etree.ElementTree as default_etree

View file

@ -1,6 +1,6 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import text_type
from six import text_type
from . import base
from ..constants import namespaces, voidElements

View file

@ -12,7 +12,7 @@ import re
import warnings
from xml.sax.saxutils import escape, unescape
from bleach.six_shim import urllib_parse as urlparse
from six.moves import urllib_parse as urlparse
from . import base
from ..constants import namespaces, prefixes

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import viewkeys
from six import with_metaclass, viewkeys
import types
@ -423,7 +423,7 @@ def getPhases(debug):
return type
# pylint:disable=unused-argument
class Phase(metaclass=getMetaclass(debug, log)):
class Phase(with_metaclass(getMetaclass(debug, log))):
"""Base class for helper object that implements each phase of processing
"""
__slots__ = ("parser", "tree", "__startTagCache", "__endTagCache")

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import text_type
from six import text_type
import re

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import text_type
from six import text_type
from ..constants import scopingElements, tableInsertModeElements, namespaces

View file

@ -1,7 +1,7 @@
from __future__ import absolute_import, division, unicode_literals
# pylint:disable=protected-access
from bleach.six_shim import text_type
from six import text_type
import re

View file

@ -28,7 +28,7 @@ from . import etree as etree_builders
from .. import _ihatexml
import lxml.etree as etree
from bleach.six_shim import PY3, binary_type
from six import PY3, binary_type
fullTree = True

View file

@ -3,7 +3,7 @@ from __future__ import absolute_import, division, unicode_literals
from collections import OrderedDict
import re
from bleach.six_shim import string_types
from six import string_types
from . import base
from .._utils import moduleFactoryFactory

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, unicode_literals
from bleach.six_shim import text_type
from six import text_type
from collections import OrderedDict

View file

@ -7,12 +7,8 @@ set -o pipefail
BLEACH_VENDOR_DIR=${BLEACH_VENDOR_DIR:-"."}
DEST=${DEST:-"."}
# Install with no dependencies
pip install --no-binary all --no-compile --no-deps -r "${BLEACH_VENDOR_DIR}/vendor.txt" --target "${DEST}"
# Apply patches
(cd "${DEST}" && patch -p2 < 01_html5lib_six.patch)
# install Python 3.6.14 urllib.urlparse for #536
curl --proto '=https' --tlsv1.2 -o "${DEST}/parse.py" https://raw.githubusercontent.com/python/cpython/v3.6.14/Lib/urllib/parse.py
(cd "${DEST}" && sha256sum parse.py > parse.py.SHA256SUM)

View file

@ -396,25 +396,16 @@ class BleachHTMLTokenizer(HTMLTokenizer):
# name that abruptly ends, but we should treat that like
# character data
yield {"type": TAG_TOKEN_TYPE_CHARACTERS, "data": self.stream.get_tag()}
elif last_error_token["data"] in (
"duplicate-attribute",
"eof-in-attribute-name",
"eof-in-attribute-value-no-quotes",
"expected-end-of-tag-but-got-eof",
):
# Handle the case where the text being parsed ends with <
# followed by characters and then space and then:
#
# * more characters
# * more characters repeated with a space between (e.g. "abc abc")
# * more characters and then a space and then an EOF (e.g. "abc def ")
#
# These cases are treated as a tag name followed by an
# followed by a series of characters and then space and then
# more characters. It's treated as a tag name followed by an
# attribute that abruptly ends, but we should treat that like
# character data instead.
# character data.
yield {"type": TAG_TOKEN_TYPE_CHARACTERS, "data": self.stream.get_tag()}
else:
yield last_error_token

View file

@ -1,19 +0,0 @@
"""
Replacement module for what html5lib uses six for.
"""
import http.client
import operator
import urllib
PY3 = True
binary_type = bytes
string_types = (str,)
text_type = str
unichr = chr
viewkeys = operator.methodcaller("keys")
http_client = http.client
urllib = urllib
urllib_parse = urllib.parse

View file

@ -1,4 +1,4 @@
from .core import contents, where
__all__ = ["contents", "where"]
__version__ = "2024.08.30"
__version__ = "2024.07.04"

View file

@ -4796,134 +4796,3 @@ PQQDAwNoADBlAjAdfKR7w4l1M+E7qUW/Runpod3JIha3RxEL2Jq68cgLcFBTApFw
hVmpHqTm6iMxoAACMQD94vizrxa5HnPEluPBMBnYfubDl94cT7iJLzPrSA8Z94dG
XSaQpYXFuXqUPoeovQA=
-----END CERTIFICATE-----
# Issuer: CN=TWCA CYBER Root CA O=TAIWAN-CA OU=Root CA
# Subject: CN=TWCA CYBER Root CA O=TAIWAN-CA OU=Root CA
# Label: "TWCA CYBER Root CA"
# Serial: 85076849864375384482682434040119489222
# MD5 Fingerprint: 0b:33:a0:97:52:95:d4:a9:fd:bb:db:6e:a3:55:5b:51
# SHA1 Fingerprint: f6:b1:1c:1a:83:38:e9:7b:db:b3:a8:c8:33:24:e0:2d:9c:7f:26:66
# SHA256 Fingerprint: 3f:63:bb:28:14:be:17:4e:c8:b6:43:9c:f0:8d:6d:56:f0:b7:c4:05:88:3a:56:48:a3:34:42:4d:6b:3e:c5:58
-----BEGIN CERTIFICATE-----
MIIFjTCCA3WgAwIBAgIQQAE0jMIAAAAAAAAAATzyxjANBgkqhkiG9w0BAQwFADBQ
MQswCQYDVQQGEwJUVzESMBAGA1UEChMJVEFJV0FOLUNBMRAwDgYDVQQLEwdSb290
IENBMRswGQYDVQQDExJUV0NBIENZQkVSIFJvb3QgQ0EwHhcNMjIxMTIyMDY1NDI5
WhcNNDcxMTIyMTU1OTU5WjBQMQswCQYDVQQGEwJUVzESMBAGA1UEChMJVEFJV0FO
LUNBMRAwDgYDVQQLEwdSb290IENBMRswGQYDVQQDExJUV0NBIENZQkVSIFJvb3Qg
Q0EwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDG+Moe2Qkgfh1sTs6P
40czRJzHyWmqOlt47nDSkvgEs1JSHWdyKKHfi12VCv7qze33Kc7wb3+szT3vsxxF
avcokPFhV8UMxKNQXd7UtcsZyoC5dc4pztKFIuwCY8xEMCDa6pFbVuYdHNWdZsc/
34bKS1PE2Y2yHer43CdTo0fhYcx9tbD47nORxc5zb87uEB8aBs/pJ2DFTxnk684i
JkXXYJndzk834H/nY62wuFm40AZoNWDTNq5xQwTxaWV4fPMf88oon1oglWa0zbfu
j3ikRRjpJi+NmykosaS3Om251Bw4ckVYsV7r8Cibt4LK/c/WMw+f+5eesRycnupf
Xtuq3VTpMCEobY5583WSjCb+3MX2w7DfRFlDo7YDKPYIMKoNM+HvnKkHIuNZW0CP
2oi3aQiotyMuRAlZN1vH4xfyIutuOVLF3lSnmMlLIJXcRolftBL5hSmO68gnFSDA
S9TMfAxsNAwmmyYxpjyn9tnQS6Jk/zuZQXLB4HCX8SS7K8R0IrGsayIyJNN4KsDA
oS/xUgXJP+92ZuJF2A09rZXIx4kmyA+upwMu+8Ff+iDhcK2wZSA3M2Cw1a/XDBzC
kHDXShi8fgGwsOsVHkQGzaRP6AzRwyAQ4VRlnrZR0Bp2a0JaWHY06rc3Ga4udfmW
5cFZ95RXKSWNOkyrTZpB0F8mAwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCAQYwDwYD
VR0TAQH/BAUwAwEB/zAfBgNVHSMEGDAWgBSdhWEUfMFib5do5E83QOGt4A1WNzAd
BgNVHQ4EFgQUnYVhFHzBYm+XaORPN0DhreANVjcwDQYJKoZIhvcNAQEMBQADggIB
AGSPesRiDrWIzLjHhg6hShbNcAu3p4ULs3a2D6f/CIsLJc+o1IN1KriWiLb73y0t
tGlTITVX1olNc79pj3CjYcya2x6a4CD4bLubIp1dhDGaLIrdaqHXKGnK/nZVekZn
68xDiBaiA9a5F/gZbG0jAn/xX9AKKSM70aoK7akXJlQKTcKlTfjF/biBzysseKNn
TKkHmvPfXvt89YnNdJdhEGoHK4Fa0o635yDRIG4kqIQnoVesqlVYL9zZyvpoBJ7t
RCT5dEA7IzOrg1oYJkK2bVS1FmAwbLGg+LhBoF1JSdJlBTrq/p1hvIbZv97Tujqx
f36SNI7JAG7cmL3c7IAFrQI932XtCwP39xaEBDG6k5TY8hL4iuO/Qq+n1M0RFxbI
Qh0UqEL20kCGoE8jypZFVmAGzbdVAaYBlGX+bgUJurSkquLvWL69J1bY73NxW0Qz
8ppy6rBePm6pUlvscG21h483XjyMnM7k8M4MZ0HMzvaAq07MTFb1wWFZk7Q+ptq4
NxKfKjLji7gh7MMrZQzvIt6IKTtM1/r+t+FHvpw+PoP7UV31aPcuIYXcv/Fa4nzX
xeSDwWrruoBa3lwtcHb4yOWHh8qgnaHlIhInD0Q9HWzq1MKLL295q39QpsQZp6F6
t5b5wR9iWqJDB0BeJsas7a5wFsWqynKKTbDPAYsDP27X
-----END CERTIFICATE-----
# Issuer: CN=SecureSign Root CA12 O=Cybertrust Japan Co., Ltd.
# Subject: CN=SecureSign Root CA12 O=Cybertrust Japan Co., Ltd.
# Label: "SecureSign Root CA12"
# Serial: 587887345431707215246142177076162061960426065942
# MD5 Fingerprint: c6:89:ca:64:42:9b:62:08:49:0b:1e:7f:e9:07:3d:e8
# SHA1 Fingerprint: 7a:22:1e:3d:de:1b:06:ac:9e:c8:47:70:16:8e:3c:e5:f7:6b:06:f4
# SHA256 Fingerprint: 3f:03:4b:b5:70:4d:44:b2:d0:85:45:a0:20:57:de:93:eb:f3:90:5f:ce:72:1a:cb:c7:30:c0:6d:da:ee:90:4e
-----BEGIN CERTIFICATE-----
MIIDcjCCAlqgAwIBAgIUZvnHwa/swlG07VOX5uaCwysckBYwDQYJKoZIhvcNAQEL
BQAwUTELMAkGA1UEBhMCSlAxIzAhBgNVBAoTGkN5YmVydHJ1c3QgSmFwYW4gQ28u
LCBMdGQuMR0wGwYDVQQDExRTZWN1cmVTaWduIFJvb3QgQ0ExMjAeFw0yMDA0MDgw
NTM2NDZaFw00MDA0MDgwNTM2NDZaMFExCzAJBgNVBAYTAkpQMSMwIQYDVQQKExpD
eWJlcnRydXN0IEphcGFuIENvLiwgTHRkLjEdMBsGA1UEAxMUU2VjdXJlU2lnbiBS
b290IENBMTIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC6OcE3emhF
KxS06+QT61d1I02PJC0W6K6OyX2kVzsqdiUzg2zqMoqUm048luT9Ub+ZyZN+v/mt
p7JIKwccJ/VMvHASd6SFVLX9kHrko+RRWAPNEHl57muTH2SOa2SroxPjcf59q5zd
J1M3s6oYwlkm7Fsf0uZlfO+TvdhYXAvA42VvPMfKWeP+bl+sg779XSVOKik71gur
FzJ4pOE+lEa+Ym6b3kaosRbnhW70CEBFEaCeVESE99g2zvVQR9wsMJvuwPWW0v4J
hscGWa5Pro4RmHvzC1KqYiaqId+OJTN5lxZJjfU+1UefNzFJM3IFTQy2VYzxV4+K
h9GtxRESOaCtAgMBAAGjQjBAMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQD
AgEGMB0GA1UdDgQWBBRXNPN0zwRL1SXm8UC2LEzZLemgrTANBgkqhkiG9w0BAQsF
AAOCAQEAPrvbFxbS8hQBICw4g0utvsqFepq2m2um4fylOqyttCg6r9cBg0krY6Ld
mmQOmFxv3Y67ilQiLUoT865AQ9tPkbeGGuwAtEGBpE/6aouIs3YIcipJQMPTw4WJ
mBClnW8Zt7vPemVV2zfrPIpyMpcemik+rY3moxtt9XUa5rBouVui7mlHJzWhhpmA
8zNL4WukJsPvdFlseqJkth5Ew1DgDzk9qTPxpfPSvWKErI4cqc1avTc7bgoitPQV
55FYxTpE05Uo2cBl6XLK0A+9H7MV2anjpEcJnuDLN/v9vZfVvhgaaaI5gdka9at/
yOPiZwud9AzqVN/Ssq+xIvEg37xEHA==
-----END CERTIFICATE-----
# Issuer: CN=SecureSign Root CA14 O=Cybertrust Japan Co., Ltd.
# Subject: CN=SecureSign Root CA14 O=Cybertrust Japan Co., Ltd.
# Label: "SecureSign Root CA14"
# Serial: 575790784512929437950770173562378038616896959179
# MD5 Fingerprint: 71:0d:72:fa:92:19:65:5e:89:04:ac:16:33:f0:bc:d5
# SHA1 Fingerprint: dd:50:c0:f7:79:b3:64:2e:74:a2:b8:9d:9f:d3:40:dd:bb:f0:f2:4f
# SHA256 Fingerprint: 4b:00:9c:10:34:49:4f:9a:b5:6b:ba:3b:a1:d6:27:31:fc:4d:20:d8:95:5a:dc:ec:10:a9:25:60:72:61:e3:38
-----BEGIN CERTIFICATE-----
MIIFcjCCA1qgAwIBAgIUZNtaDCBO6Ncpd8hQJ6JaJ90t8sswDQYJKoZIhvcNAQEM
BQAwUTELMAkGA1UEBhMCSlAxIzAhBgNVBAoTGkN5YmVydHJ1c3QgSmFwYW4gQ28u
LCBMdGQuMR0wGwYDVQQDExRTZWN1cmVTaWduIFJvb3QgQ0ExNDAeFw0yMDA0MDgw
NzA2MTlaFw00NTA0MDgwNzA2MTlaMFExCzAJBgNVBAYTAkpQMSMwIQYDVQQKExpD
eWJlcnRydXN0IEphcGFuIENvLiwgTHRkLjEdMBsGA1UEAxMUU2VjdXJlU2lnbiBS
b290IENBMTQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDF0nqh1oq/
FjHQmNE6lPxauG4iwWL3pwon71D2LrGeaBLwbCRjOfHw3xDG3rdSINVSW0KZnvOg
vlIfX8xnbacuUKLBl422+JX1sLrcneC+y9/3OPJH9aaakpUqYllQC6KxNedlsmGy
6pJxaeQp8E+BgQQ8sqVb1MWoWWd7VRxJq3qdwudzTe/NCcLEVxLbAQ4jeQkHO6Lo
/IrPj8BGJJw4J+CDnRugv3gVEOuGTgpa/d/aLIJ+7sr2KeH6caH3iGicnPCNvg9J
kdjqOvn90Ghx2+m1K06Ckm9mH+Dw3EzsytHqunQG+bOEkJTRX45zGRBdAuVwpcAQ
0BB8b8VYSbSwbprafZX1zNoCr7gsfXmPvkPx+SgojQlD+Ajda8iLLCSxjVIHvXib
y8posqTdDEx5YMaZ0ZPxMBoH064iwurO8YQJzOAUbn8/ftKChazcqRZOhaBgy/ac
18izju3Gm5h1DVXoX+WViwKkrkMpKBGk5hIwAUt1ax5mnXkvpXYvHUC0bcl9eQjs
0Wq2XSqypWa9a4X0dFbD9ed1Uigspf9mR6XU/v6eVL9lfgHWMI+lNpyiUBzuOIAB
SMbHdPTGrMNASRZhdCyvjG817XsYAFs2PJxQDcqSMxDxJklt33UkN4Ii1+iW/RVL
ApY+B3KVfqs9TC7XyvDf4Fg/LS8EmjijAQIDAQABo0IwQDAPBgNVHRMBAf8EBTAD
AQH/MA4GA1UdDwEB/wQEAwIBBjAdBgNVHQ4EFgQUBpOjCl4oaTeqYR3r6/wtbyPk
86AwDQYJKoZIhvcNAQEMBQADggIBAJaAcgkGfpzMkwQWu6A6jZJOtxEaCnFxEM0E
rX+lRVAQZk5KQaID2RFPeje5S+LGjzJmdSX7684/AykmjbgWHfYfM25I5uj4V7Ib
ed87hwriZLoAymzvftAj63iP/2SbNDefNWWipAA9EiOWWF3KY4fGoweITedpdopT
zfFP7ELyk+OZpDc8h7hi2/DsHzc/N19DzFGdtfCXwreFamgLRB7lUe6TzktuhsHS
DCRZNhqfLJGP4xjblJUK7ZGqDpncllPjYYPGFrojutzdfhrGe0K22VoF3Jpf1d+4
2kd92jjbrDnVHmtsKheMYc2xbXIBw8MgAGJoFjHVdqqGuw6qnsb58Nn4DSEC5MUo
FlkRudlpcyqSeLiSV5sI8jrlL5WwWLdrIBRtFO8KvH7YVdiI2i/6GaX7i+B/OfVy
K4XELKzvGUWSTLNhB9xNH27SgRNcmvMSZ4PPmz+Ln52kuaiWA3rF7iDeM9ovnhp6
dB7h7sxaOgTdsxoEqBRjrLdHEoOabPXm6RUVkRqEGQ6UROcSjiVbgGcZ3GOTEAtl
Lor6CZpO2oYofaphNdgOpygau1LgePhsumywbrmHXumZNTfxPWQrqaA0k89jL9WB
365jJ6UeTo3cKXhZ+PmhIIynJkBugnLNeLLIjzwec+fBH7/PzqUqm9tEZDKgu39c
JRNItX+S
-----END CERTIFICATE-----
# Issuer: CN=SecureSign Root CA15 O=Cybertrust Japan Co., Ltd.
# Subject: CN=SecureSign Root CA15 O=Cybertrust Japan Co., Ltd.
# Label: "SecureSign Root CA15"
# Serial: 126083514594751269499665114766174399806381178503
# MD5 Fingerprint: 13:30:fc:c4:62:a6:a9:de:b5:c1:68:af:b5:d2:31:47
# SHA1 Fingerprint: cb:ba:83:c8:c1:5a:5d:f1:f9:73:6f:ca:d7:ef:28:13:06:4a:07:7d
# SHA256 Fingerprint: e7:78:f0:f0:95:fe:84:37:29:cd:1a:00:82:17:9e:53:14:a9:c2:91:44:28:05:e1:fb:1d:8f:b6:b8:88:6c:3a
-----BEGIN CERTIFICATE-----
MIICIzCCAamgAwIBAgIUFhXHw9hJp75pDIqI7fBw+d23PocwCgYIKoZIzj0EAwMw
UTELMAkGA1UEBhMCSlAxIzAhBgNVBAoTGkN5YmVydHJ1c3QgSmFwYW4gQ28uLCBM
dGQuMR0wGwYDVQQDExRTZWN1cmVTaWduIFJvb3QgQ0ExNTAeFw0yMDA0MDgwODMy
NTZaFw00NTA0MDgwODMyNTZaMFExCzAJBgNVBAYTAkpQMSMwIQYDVQQKExpDeWJl
cnRydXN0IEphcGFuIENvLiwgTHRkLjEdMBsGA1UEAxMUU2VjdXJlU2lnbiBSb290
IENBMTUwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAQLUHSNZDKZmbPSYAi4Io5GdCx4
wCtELW1fHcmuS1Iggz24FG1Th2CeX2yF2wYUleDHKP+dX+Sq8bOLbe1PL0vJSpSR
ZHX+AezB2Ot6lHhWGENfa4HL9rzatAy2KZMIaY+jQjBAMA8GA1UdEwEB/wQFMAMB
Af8wDgYDVR0PAQH/BAQDAgEGMB0GA1UdDgQWBBTrQciu/NWeUUj1vYv0hyCTQSvT
9DAKBggqhkjOPQQDAwNoADBlAjEA2S6Jfl5OpBEHvVnCB96rMjhTKkZEBhd6zlHp
4P9mLQlO4E/0BdGF9jVg3PVys0Z9AjBEmEYagoUeYWmJSwdLZrWeqrqgHkHZAXQ6
bkU6iYAZezKYVWOr62Nuk22rGwlgMU4=
-----END CERTIFICATE-----

View file

@ -159,8 +159,6 @@ def from_bytes(
results: CharsetMatches = CharsetMatches()
early_stop_results: CharsetMatches = CharsetMatches()
sig_encoding, sig_payload = identify_sig_or_bom(sequences)
if sig_encoding is not None:
@ -223,20 +221,16 @@ def from_bytes(
try:
if is_too_large_sequence and is_multi_byte_decoder is False:
str(
(
sequences[: int(50e4)]
if strip_sig_or_bom is False
else sequences[len(sig_payload) : int(50e4)]
),
else sequences[len(sig_payload) : int(50e4)],
encoding=encoding_iana,
)
else:
decoded_payload = str(
(
sequences
if strip_sig_or_bom is False
else sequences[len(sig_payload) :]
),
else sequences[len(sig_payload) :],
encoding=encoding_iana,
)
except (UnicodeDecodeError, LookupError) as e:
@ -373,13 +367,7 @@ def from_bytes(
and not lazy_str_hard_failure
):
fallback_entry = CharsetMatch(
sequences,
encoding_iana,
threshold,
False,
[],
decoded_payload,
preemptive_declaration=specified_encoding,
sequences, encoding_iana, threshold, False, [], decoded_payload
)
if encoding_iana == specified_encoding:
fallback_specified = fallback_entry
@ -433,58 +421,28 @@ def from_bytes(
),
)
current_match = CharsetMatch(
results.append(
CharsetMatch(
sequences,
encoding_iana,
mean_mess_ratio,
bom_or_sig_available,
cd_ratios_merged,
(
decoded_payload
if (
is_too_large_sequence is False
or encoding_iana in [specified_encoding, "ascii", "utf_8"]
decoded_payload,
)
else None
),
preemptive_declaration=specified_encoding,
)
results.append(current_match)
if (
encoding_iana in [specified_encoding, "ascii", "utf_8"]
and mean_mess_ratio < 0.1
):
# If md says nothing to worry about, then... stop immediately!
if mean_mess_ratio == 0.0:
logger.debug(
"Encoding detection: %s is most likely the one.",
current_match.encoding,
"Encoding detection: %s is most likely the one.", encoding_iana
)
if explain:
logger.removeHandler(explain_handler)
logger.setLevel(previous_logger_level)
return CharsetMatches([current_match])
early_stop_results.append(current_match)
if (
len(early_stop_results)
and (specified_encoding is None or specified_encoding in tested)
and "ascii" in tested
and "utf_8" in tested
):
probable_result: CharsetMatch = early_stop_results.best() # type: ignore[assignment]
logger.debug(
"Encoding detection: %s is most likely the one.",
probable_result.encoding,
)
if explain:
logger.removeHandler(explain_handler)
logger.setLevel(previous_logger_level)
return CharsetMatches([probable_result])
return CharsetMatches([results[encoding_iana]])
if encoding_iana == sig_encoding:
logger.debug(

View file

@ -109,14 +109,6 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
dest="force",
help="Replace file without asking if you are sure, use this flag with caution.",
)
parser.add_argument(
"-i",
"--no-preemptive",
action="store_true",
default=False,
dest="no_preemptive",
help="Disable looking at a charset declaration to hint the detector.",
)
parser.add_argument(
"-t",
"--threshold",
@ -141,35 +133,21 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
args = parser.parse_args(argv)
if args.replace is True and args.normalize is False:
if args.files:
for my_file in args.files:
my_file.close()
print("Use --replace in addition of --normalize only.", file=sys.stderr)
return 1
if args.force is True and args.replace is False:
if args.files:
for my_file in args.files:
my_file.close()
print("Use --force in addition of --replace only.", file=sys.stderr)
return 1
if args.threshold < 0.0 or args.threshold > 1.0:
if args.files:
for my_file in args.files:
my_file.close()
print("--threshold VALUE should be between 0. AND 1.", file=sys.stderr)
return 1
x_ = []
for my_file in args.files:
matches = from_fp(
my_file,
threshold=args.threshold,
explain=args.verbose,
preemptive_behaviour=args.no_preemptive is False,
)
matches = from_fp(my_file, threshold=args.threshold, explain=args.verbose)
best_guess = matches.best()
@ -177,11 +155,9 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
print(
'Unable to identify originating encoding for "{}". {}'.format(
my_file.name,
(
"Maybe try increasing maximum amount of chaos."
if args.threshold < 1.0
else ""
),
else "",
),
file=sys.stderr,
)
@ -282,8 +258,8 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
try:
x_[0].unicode_path = join(dir_path, ".".join(o_))
with open(x_[0].unicode_path, "wb") as fp:
fp.write(best_guess.output())
with open(x_[0].unicode_path, "w", encoding="utf-8") as fp:
fp.write(str(best_guess))
except IOError as e:
print(str(e), file=sys.stderr)
if my_file.closed is False:

View file

@ -544,8 +544,6 @@ COMMON_SAFE_ASCII_CHARACTERS: Set[str] = {
"|",
'"',
"-",
"(",
")",
}

View file

@ -1,24 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Dict, Optional, Union
from warnings import warn
from .api import from_bytes
from .constant import CHARDET_CORRESPONDENCE
# TODO: remove this check when dropping Python 3.7 support
if TYPE_CHECKING:
from typing_extensions import TypedDict
class ResultDict(TypedDict):
encoding: Optional[str]
language: str
confidence: Optional[float]
def detect(
byte_str: bytes, should_rename_legacy: bool = False, **kwargs: Any
) -> ResultDict:
) -> Dict[str, Optional[Union[str, float]]]:
"""
chardet legacy method
Detect the encoding of the given byte string. It should be mostly backward-compatible.

View file

@ -236,7 +236,7 @@ class SuspiciousRange(MessDetectorPlugin):
@property
def ratio(self) -> float:
if self._character_count <= 13:
if self._character_count <= 24:
return 0.0
ratio_of_suspicious_range_usage: float = (
@ -260,7 +260,6 @@ class SuperWeirdWordPlugin(MessDetectorPlugin):
self._buffer: str = ""
self._buffer_accent_count: int = 0
self._buffer_glyph_count: int = 0
def eligible(self, character: str) -> bool:
return True
@ -280,14 +279,6 @@ class SuperWeirdWordPlugin(MessDetectorPlugin):
and is_thai(character) is False
):
self._foreign_long_watch = True
if (
is_cjk(character)
or is_hangul(character)
or is_katakana(character)
or is_hiragana(character)
or is_thai(character)
):
self._buffer_glyph_count += 1
return
if not self._buffer:
return
@ -300,20 +291,17 @@ class SuperWeirdWordPlugin(MessDetectorPlugin):
self._character_count += buffer_length
if buffer_length >= 4:
if self._buffer_accent_count / buffer_length >= 0.5:
if self._buffer_accent_count / buffer_length > 0.34:
self._is_current_word_bad = True
# Word/Buffer ending with an upper case accentuated letter are so rare,
# that we will consider them all as suspicious. Same weight as foreign_long suspicious.
elif (
if (
is_accentuated(self._buffer[-1])
and self._buffer[-1].isupper()
and all(_.isupper() for _ in self._buffer) is False
):
self._foreign_long_count += 1
self._is_current_word_bad = True
elif self._buffer_glyph_count == 1:
self._is_current_word_bad = True
self._foreign_long_count += 1
if buffer_length >= 24 and self._foreign_long_watch:
camel_case_dst = [
i
@ -337,7 +325,6 @@ class SuperWeirdWordPlugin(MessDetectorPlugin):
self._foreign_long_watch = False
self._buffer = ""
self._buffer_accent_count = 0
self._buffer_glyph_count = 0
elif (
character not in {"<", ">", "-", "=", "~", "|", "_"}
and character.isdigit() is False

View file

@ -1,10 +1,9 @@
from encodings.aliases import aliases
from hashlib import sha256
from json import dumps
from re import sub
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from .constant import RE_POSSIBLE_ENCODING_INDICATION, TOO_BIG_SEQUENCE
from .constant import TOO_BIG_SEQUENCE
from .utils import iana_name, is_multi_byte_encoding, unicode_range
@ -17,7 +16,6 @@ class CharsetMatch:
has_sig_or_bom: bool,
languages: "CoherenceMatches",
decoded_payload: Optional[str] = None,
preemptive_declaration: Optional[str] = None,
):
self._payload: bytes = payload
@ -35,13 +33,13 @@ class CharsetMatch:
self._string: Optional[str] = decoded_payload
self._preemptive_declaration: Optional[str] = preemptive_declaration
def __eq__(self, other: object) -> bool:
if not isinstance(other, CharsetMatch):
if isinstance(other, str):
return iana_name(other) == self.encoding
return False
raise TypeError(
"__eq__ cannot be invoked on {} and {}.".format(
str(other.__class__), str(self.__class__)
)
)
return self.encoding == other.encoding and self.fingerprint == other.fingerprint
def __lt__(self, other: object) -> bool:
@ -212,24 +210,7 @@ class CharsetMatch:
"""
if self._output_encoding is None or self._output_encoding != encoding:
self._output_encoding = encoding
decoded_string = str(self)
if (
self._preemptive_declaration is not None
and self._preemptive_declaration.lower()
not in ["utf-8", "utf8", "utf_8"]
):
patched_header = sub(
RE_POSSIBLE_ENCODING_INDICATION,
lambda m: m.string[m.span()[0] : m.span()[1]].replace(
m.groups()[0], iana_name(self._output_encoding) # type: ignore[arg-type]
),
decoded_string[:8192],
1,
)
decoded_string = patched_header + decoded_string[8192:]
self._output_payload = decoded_string.encode(encoding, "replace")
self._output_payload = str(self).encode(encoding, "replace")
return self._output_payload # type: ignore
@ -285,7 +266,7 @@ class CharsetMatches:
)
)
# We should disable the submatch factoring when the input file is too heavy (conserve RAM usage)
if len(item.raw) < TOO_BIG_SEQUENCE:
if len(item.raw) <= TOO_BIG_SEQUENCE:
for match in self._results:
if match.fingerprint == item.fingerprint and match.chaos == item.chaos:
match.add_submatch(item)

View file

@ -2,5 +2,5 @@
Expose version
"""
__version__ = "3.4.0"
__version__ = "3.3.2"
VERSION = __version__.split(".")

View file

@ -1,255 +0,0 @@
import re
import cherrypy
from cherrypy.lib import set_vary_header
import httpagentparser
CORS_ALLOW_METHODS = 'Access-Control-Allow-Methods'
CORS_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'
CORS_ALLOW_CREDENTIALS = 'Access-Control-Allow-Credentials'
CORS_EXPOSE_HEADERS = 'Access-Control-Expose-Headers'
CORS_REQUEST_METHOD = 'Access-Control-Request-Method'
CORS_REQUEST_HEADERS = 'Access-Control-Request-Headers'
CORS_MAX_AGE = 'Access-Control-Max-Age'
CORS_ALLOW_HEADERS = 'Access-Control-Allow-Headers'
PUBLIC_ORIGIN = '*'
def expose(allow_credentials=False, expose_headers=None, origins=None):
"""Adds CORS support to the resource.
If the resource is allowed to be exposed, the value of the
`Access-Control-Allow-Origin`_ header in the response will echo
the `Origin`_ request header, and `Origin` will be
appended to the `Vary`_ response header.
:param allow_credentials: Use credentials to make cookies work
(see `Access-Control-Allow-Credentials`_).
:type allow_credentials: bool
:param expose_headers: List of headers clients will be able to access
(see `Access-Control-Expose-Headers`_).
:type expose_headers: list or None
:param origins: List of allowed origins clients must reference.
:type origins: list or None
:returns: Whether the resource is being exposed.
:rtype: bool
- Configuration example:
.. code-block:: python
config = {
'/static': {
'tools.staticdir.on': True,
'cors.expose.on': True,
}
}
- Decorator example:
.. code-block:: python
@cherrypy_cors.tools.expose()
def DELETE(self):
self._delete()
"""
if _get_cors().expose(allow_credentials, expose_headers, origins):
_safe_caching_headers()
return True
return False
def expose_public(expose_headers=None):
"""Adds CORS support to the resource from any origin.
If the resource is allowed to be exposed, the value of the
`Access-Control-Allow-Origin`_ header in the response will be `*`.
:param expose_headers: List of headers clients will be able to access
(see `Access-Control-Expose-Headers`_).
:type expose_headers: list or None
:rtype: None
"""
_get_cors().expose_public(expose_headers)
def preflight(
allowed_methods,
allowed_headers=None,
allow_credentials=False,
max_age=None,
origins=None,
):
"""Adds CORS `preflight`_ support to a `HTTP OPTIONS` request.
:param allowed_methods: List of supported `HTTP` methods
(see `Access-Control-Allow-Methods`_).
:type allowed_methods: list or None
:param allowed_headers: List of supported `HTTP` headers
(see `Access-Control-Allow-Headers`_).
:type allowed_headers: list or None
:param allow_credentials: Use credentials to make cookies work
(see `Access-Control-Allow-Credentials`_).
:type allow_credentials: bool
:param max_age: Seconds to cache the preflight request
(see `Access-Control-Max-Age`_).
:type max_age: int
:param origins: List of allowed origins clients must reference.
:type origins: list or None
:returns: Whether the preflight is allowed.
:rtype: bool
- Used as a decorator with the `Method Dispatcher`_
.. code-block:: python
@cherrypy_cors.tools.preflight(
allowed_methods=["GET", "DELETE", "PUT"])
def OPTIONS(self):
pass
- Function call with the `Object Dispatcher`_
.. code-block:: python
@cherrypy.expose
@cherrypy.tools.allow(
methods=["GET", "DELETE", "PUT", "OPTIONS"])
def thing(self):
if cherrypy.request.method == "OPTIONS":
cherrypy_cors.preflight(
allowed_methods=["GET", "DELETE", "PUT"])
else:
self._do_other_things()
"""
if _get_cors().preflight(
allowed_methods, allowed_headers, allow_credentials, max_age, origins
):
_safe_caching_headers()
return True
return False
def install():
"""Install the toolbox such that it's available in all applications."""
cherrypy._cptree.Application.toolboxes.update(cors=tools)
class CORS:
"""A generic CORS handler."""
def __init__(self, req_headers, resp_headers):
self.req_headers = req_headers
self.resp_headers = resp_headers
def expose(self, allow_credentials, expose_headers, origins):
if self._is_valid_origin(origins):
self._add_origin_and_credentials_headers(allow_credentials)
self._add_expose_headers(expose_headers)
return True
return False
def expose_public(self, expose_headers):
self._add_public_origin()
self._add_expose_headers(expose_headers)
def preflight(
self, allowed_methods, allowed_headers, allow_credentials, max_age, origins
):
if self._is_valid_preflight_request(allowed_headers, allowed_methods, origins):
self._add_origin_and_credentials_headers(allow_credentials)
self._add_prefligt_headers(allowed_methods, max_age)
return True
return False
@property
def origin(self):
return self.req_headers.get('Origin')
def _is_valid_origin(self, origins):
if origins is None:
origins = [self.origin]
origins = map(self._make_regex, origins)
return self.origin is not None and any(
origin.match(self.origin) for origin in origins
)
@staticmethod
def _make_regex(pattern):
if isinstance(pattern, str):
pattern = re.compile(re.escape(pattern) + '$')
return pattern
def _add_origin_and_credentials_headers(self, allow_credentials):
self.resp_headers[CORS_ALLOW_ORIGIN] = self.origin
if allow_credentials:
self.resp_headers[CORS_ALLOW_CREDENTIALS] = 'true'
def _add_public_origin(self):
self.resp_headers[CORS_ALLOW_ORIGIN] = PUBLIC_ORIGIN
def _add_expose_headers(self, expose_headers):
if expose_headers:
self.resp_headers[CORS_EXPOSE_HEADERS] = expose_headers
@property
def requested_method(self):
return self.req_headers.get(CORS_REQUEST_METHOD)
@property
def requested_headers(self):
return self.req_headers.get(CORS_REQUEST_HEADERS)
def _has_valid_method(self, allowed_methods):
return self.requested_method and self.requested_method in allowed_methods
def _valid_headers(self, allowed_headers):
if self.requested_headers and allowed_headers:
for header in self.requested_headers.split(','):
if header.strip() not in allowed_headers:
return False
return True
def _is_valid_preflight_request(self, allowed_headers, allowed_methods, origins):
return (
self._is_valid_origin(origins)
and self._has_valid_method(allowed_methods)
and self._valid_headers(allowed_headers)
)
def _add_prefligt_headers(self, allowed_methods, max_age):
rh = self.resp_headers
rh[CORS_ALLOW_METHODS] = ', '.join(allowed_methods)
if max_age:
rh[CORS_MAX_AGE] = max_age
if self.requested_headers:
rh[CORS_ALLOW_HEADERS] = self.requested_headers
def _get_cors():
return CORS(cherrypy.serving.request.headers, cherrypy.serving.response.headers)
def _safe_caching_headers():
"""Adds `Origin`_ to the `Vary`_ header to ensure caching works properly.
Except in IE because it will disable caching completely. The caching
strategy in that case is out of the scope of this library.
https://blogs.msdn.microsoft.com/ieinternals/2009/06/17/vary-with-care/
"""
uah = cherrypy.serving.request.headers.get('User-Agent', '')
ua = httpagentparser.detect(uah)
IE = 'Microsoft Internet Explorer'
if ua.get('browser', {}).get('name') != IE:
set_vary_header(cherrypy.serving.response, "Origin")
tools = cherrypy._cptools.Toolbox("cors")
tools.expose = cherrypy.Tool('before_handler', expose)
tools.expose_public = cherrypy.Tool('before_handler', expose_public)
tools.preflight = cherrypy.Tool('before_handler', preflight)

View file

@ -26,10 +26,6 @@ class NullContext:
class Socket: # pragma: no cover
def __init__(self, family: int, type: int):
self.family = family
self.type = type
async def close(self):
pass
@ -50,6 +46,9 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout):
raise NotImplementedError

View file

@ -42,7 +42,7 @@ class _DatagramProtocol:
if exc is None:
# EOF we triggered. Is there a better way to do this?
try:
raise EOFError("EOF")
raise EOFError
except EOFError as e:
self.recvfrom.set_exception(e)
else:
@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
super().__init__(family, socket.SOCK_DGRAM)
super().__init__(family)
self.transport = transport
self.protocol = protocol
@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
super().__init__(af, socket.SOCK_STREAM)
self.family = af
self.reader = reader
self.writer = writer
@ -197,7 +197,7 @@ if dns._features.have("doh"):
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver

View file

@ -32,9 +32,6 @@ def _version_check(
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
# This shouldn't happen, but it apparently can.
if version is None:
return False
except Exception:
return False
t_version = _tuple_from_text(version)
@ -85,10 +82,10 @@ def force(feature: str, enabled: bool) -> None:
_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=43"],
"dnssec": ["cryptography>=41"],
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
"doq": ["aioquic>=1.0.0"],
"idna": ["idna>=3.7"],
"doq": ["aioquic>=0.9.25"],
"idna": ["idna>=3.6"],
"trio": ["trio>=0.23"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements

View file

@ -30,15 +30,12 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, sock):
super().__init__(sock.family, socket.SOCK_DGRAM)
self.socket = sock
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
if destination is None:
return await self.socket.send(what)
else:
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
@ -64,7 +61,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
super().__init__(family, socket.SOCK_STREAM)
self.family = family
self.stream = stream
self.tls = tls
@ -174,7 +171,7 @@ if dns._features.have("doh"):
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
@ -208,7 +205,7 @@ class Backend(dns._asyncbackend.Backend):
try:
if source:
await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM or destination is not None:
if socktype == socket.SOCK_STREAM:
connected = False
with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af))

View file

@ -19,12 +19,10 @@
import base64
import contextlib
import random
import socket
import struct
import time
import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union, cast
from typing import Any, Dict, Optional, Tuple, Union
import dns.asyncbackend
import dns.exception
@ -39,11 +37,9 @@ import dns.transaction
from dns._asyncbackend import NullContext
from dns.query import (
BadResponse,
HTTPVersion,
NoDOH,
NoDOQ,
UDPMode,
_check_status,
_compute_times,
_make_dot_ssl_context,
_matches_destination,
@ -342,7 +338,7 @@ async def _read_exactly(sock, count, expiration):
while count > 0:
n = await sock.recv(count, _timeout(expiration))
if n == b"":
raise EOFError("EOF")
raise EOFError
count = count - len(n)
s = s + n
return s
@ -504,20 +500,6 @@ async def tls(
return response
def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"],
) -> "dns.asyncresolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
return resolver
async def https(
q: dns.message.Message,
where: str,
@ -533,8 +515,7 @@ async def https(
verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
@ -548,65 +529,26 @@ async def https(
parameters, exceptions, and return type of this method.
"""
if not have_doh:
raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
wire = q.to_wire()
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
transport = None
headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = await resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
return await _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment
wire = q.to_wire()
headers = {"accept": "application/dns-message"}
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
backend = dns.asyncbackend.get_default_backend()
if source is None:
@ -615,14 +557,10 @@ async def https(
else:
local_address = source
local_port = source_port
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
transport = backend.get_transport_class()(
local_address=local_address,
http1=h1,
http2=h2,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
@ -630,7 +568,12 @@ async def https(
family=family,
)
cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=True, verify=verify, transport=transport
)
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
@ -643,33 +586,23 @@ async def https(
}
)
response = await backend.wait_for(
the_client.post(
url,
headers=headers,
content=wire,
extensions=extensions,
),
timeout,
the_client.post(url, headers=headers, content=wire), timeout
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for(
the_client.get(
url,
headers=headers,
params={"dns": twire},
extensions=extensions,
),
timeout,
the_client.get(url, headers=headers, params={"dns": twire}), timeout
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
f"{where} responded with status code {response.status_code}"
f"\nResponse body: {response.content!r}"
"{} responded with status code {}"
"\nResponse body: {!r}".format(
where, response.status_code, response.content
)
)
r = dns.message.from_wire(
response.content,
@ -684,181 +617,6 @@ async def https(
return r
async def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=hostname, h3=True
) as the_manager:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
# note that send_h3() does not need await
stream.send_h3(url, wire, post)
wire = await stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context,
verify_mode=verify,
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: dns.asyncbackend.Socket,
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
await udp_sock.sendto(wire, None, _timeout(expiration))
else:
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
tcpmsg = struct.pack("!H", len(wire)) + wire
await tcp_sock.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
(rwire, _) = await udp_sock.recvfrom(65535, timeout)
else:
ldata = await _read_exactly(tcp_sock, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(tcp_sock, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
@ -884,30 +642,139 @@ async def inbound_xfr(
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
(_, expiration) = _compute_times(lifetime)
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
if not backend:
backend = dns.asyncbackend.get_default_backend()
(_, expiration) = _compute_times(lifetime)
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
s = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
af, sock_type, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
try:
async for _ in _inbound_xfr(
txn_manager, s, query, serial, timeout, expiration
if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration))
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
pass
return
mexpiration = expiration
if is_udp:
destination = _lltuple((where, port), af)
while True:
timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535, timeout)
if _matches_destination(
af, from_address, destination, True
):
break
else:
ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
s = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=server_hostname
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
async with s:
async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
pass
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r

View file

@ -118,7 +118,6 @@ def key_id(key: Union[DNSKEY, CDNSKEY]) -> int:
"""
rdata = key.to_wire()
assert rdata is not None # for mypy
if key.algorithm == Algorithm.RSAMD5:
return (rdata[-3] << 8) + rdata[-2]
else:
@ -225,7 +224,7 @@ def make_ds(
if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()]
except Exception:
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
if validating:
check = policy.ok_to_validate_ds
else:
@ -241,15 +240,14 @@ def make_ds(
elif algorithm == DSDigest.SHA384:
dshash = hashlib.sha384()
else:
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
if isinstance(name, str):
name = dns.name.from_text(name, origin)
wire = name.canonicalize().to_wire()
kwire = key.to_wire(origin=origin)
assert wire is not None and kwire is not None # for mypy
assert wire is not None
dshash.update(wire)
dshash.update(kwire)
dshash.update(key.to_wire(origin=origin))
digest = dshash.digest()
dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest
@ -325,7 +323,6 @@ def _get_rrname_rdataset(
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
# pylint: disable=possibly-used-before-assignment
public_cls = get_algorithm_cls_from_dnskey(key).public_cls
try:
public_key = public_cls.from_dnskey(key)
@ -390,7 +387,6 @@ def _validate_rrsig(
data = _make_rrsig_signature_data(rrset, rrsig, origin)
# pylint: disable=possibly-used-before-assignment
for candidate_key in candidate_keys:
if not policy.ok_to_validate(candidate_key):
continue
@ -488,7 +484,6 @@ def _sign(
verify: bool = False,
policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
deterministic: bool = True,
) -> RRSIG:
"""Sign RRset using private key.
@ -528,10 +523,6 @@ def _sign(
names in the rrset (including its owner name) must be absolute; otherwise the
specified origin will be used to make names absolute when signing.
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
(reproducible) signatures when supported by the algorithm used for signing.
Currently, this only affects ECDSA.
Raises ``DeniedByPolicy`` if the signature is denied by policy.
"""
@ -589,7 +580,6 @@ def _sign(
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
# pylint: disable=possibly-used-before-assignment
if isinstance(private_key, GenericPrivateKey):
signing_key = private_key
else:
@ -599,7 +589,7 @@ def _sign(
except UnsupportedAlgorithm:
raise TypeError("Unsupported key algorithm")
signature = signing_key.sign(data, verify, deterministic)
signature = signing_key.sign(data, verify)
return cast(RRSIG, rrsig_template.replace(signature=signature))
@ -639,9 +629,7 @@ def _make_rrsig_signature_data(
rrname, rdataset = _get_rrname_rdataset(rrset)
data = b""
wire = rrsig.to_wire(origin=signer)
assert wire is not None # for mypy
data += wire[:18]
data += rrsig.to_wire(origin=signer)[:18]
data += rrsig.signer.to_digestable(signer)
# Derelativize the name before considering labels.
@ -698,7 +686,6 @@ def _make_dnskey(
algorithm = Algorithm.make(algorithm)
# pylint: disable=possibly-used-before-assignment
if isinstance(public_key, GenericPublicKey):
return public_key.to_dnskey(flags=flags, protocol=protocol)
else:
@ -845,7 +832,7 @@ def make_ds_rdataset(
if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()]
except Exception:
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
_algorithms.add(algorithm)
if rdataset.rdtype == dns.rdatatype.CDS:
@ -963,7 +950,6 @@ def default_rrset_signer(
lifetime: Optional[int] = None,
policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
deterministic: bool = True,
) -> None:
"""Default RRset signer"""
@ -989,7 +975,6 @@ def default_rrset_signer(
signer=signer,
policy=policy,
origin=origin,
deterministic=deterministic,
)
txn.add(rrset.name, rrset.ttl, rrsig)
@ -1006,7 +991,6 @@ def sign_zone(
nsec3: Optional[NSEC3PARAM] = None,
rrset_signer: Optional[RRsetSigner] = None,
policy: Optional[Policy] = None,
deterministic: bool = True,
) -> None:
"""Sign zone.
@ -1046,10 +1030,6 @@ def sign_zone(
function requires two arguments: transaction and RRset. If the not specified,
``dns.dnssec.default_rrset_signer`` will be used.
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
(reproducible) signatures when supported by the algorithm used for signing.
Currently, this only affects ECDSA.
Returns ``None``.
"""
@ -1076,9 +1056,6 @@ def sign_zone(
else:
cm = zone.writer()
if zone.origin is None:
raise ValueError("no zone origin")
with cm as _txn:
if add_dnskey:
if dnskey_ttl is None:
@ -1104,7 +1081,6 @@ def sign_zone(
lifetime=lifetime,
policy=policy,
origin=zone.origin,
deterministic=deterministic,
)
return _sign_zone_nsec(zone, _txn, _rrset_signer)

View file

@ -26,7 +26,6 @@ AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography:
# pylint: disable=possibly-used-before-assignment
algorithms.update(
{
(Algorithm.RSAMD5, None): PrivateRSAMD5,
@ -60,7 +59,7 @@ def get_algorithm_cls(
if cls:
return cls
raise UnsupportedAlgorithm(
f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
)

View file

@ -65,12 +65,7 @@ class GenericPrivateKey(ABC):
pass
@abstractmethod
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign DNSSEC data"""
@abstractmethod

View file

@ -68,12 +68,7 @@ class PrivateDSA(CryptographyPrivateKey):
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024:

View file

@ -47,17 +47,9 @@ class PrivateECDSA(CryptographyPrivateKey):
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
algorithm = ec.ECDSA(
self.public_cls.chosen_hash, deterministic_signing=deterministic
)
der_signature = self.key.sign(data, algorithm)
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big"

View file

@ -29,12 +29,7 @@ class PublicEDDSA(CryptographyPublicKey):
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA]
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data)
if verify:

View file

@ -56,12 +56,7 @@ class PrivateRSA(CryptographyPrivateKey):
public_cls = PublicRSA
default_public_exponent = 65537
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
if verify:

View file

@ -52,8 +52,6 @@ class OptionType(dns.enum.IntEnum):
CHAIN = 13
#: EDE (extended-dns-error)
EDE = 15
#: REPORTCHANNEL
REPORTCHANNEL = 18
@classmethod
def _maximum(cls):
@ -224,7 +222,7 @@ class ECSOption(Option): # lgtm[py/missing-equals]
self.addrdata = self.addrdata[:-1] + last
def to_text(self) -> str:
return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
@staticmethod
def from_text(text: str) -> Option:
@ -257,10 +255,10 @@ class ECSOption(Option): # lgtm[py/missing-equals]
ecs_text = tokens[0]
elif len(tokens) == 2:
if tokens[0] != optional_prefix:
raise ValueError(f'could not parse ECS from "{text}"')
raise ValueError('could not parse ECS from "{}"'.format(text))
ecs_text = tokens[1]
else:
raise ValueError(f'could not parse ECS from "{text}"')
raise ValueError('could not parse ECS from "{}"'.format(text))
n_slashes = ecs_text.count("/")
if n_slashes == 1:
address, tsrclen = ecs_text.split("/")
@ -268,16 +266,18 @@ class ECSOption(Option): # lgtm[py/missing-equals]
elif n_slashes == 2:
address, tsrclen, tscope = ecs_text.split("/")
else:
raise ValueError(f'could not parse ECS from "{text}"')
raise ValueError('could not parse ECS from "{}"'.format(text))
try:
scope = int(tscope)
except ValueError:
raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
raise ValueError(
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
try:
srclen = int(tsrclen)
except ValueError:
raise ValueError(
"invalid srclen " + f'"{tsrclen}": srclen must be an integer'
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
)
return ECSOption(address, srclen, scope)
@ -430,65 +430,10 @@ class NSIDOption(Option):
return cls(parser.get_remaining())
class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
super().__init__(dns.edns.OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
raise ValueError("client cookie must be 8 bytes")
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
def to_wire(self, file: Any = None) -> Optional[bytes]:
if file:
file.write(self.client)
if len(self.server) > 0:
file.write(self.server)
return None
else:
return self.client + self.server
def to_text(self) -> str:
client = binascii.hexlify(self.client).decode()
if len(self.server) > 0:
server = binascii.hexlify(self.server).decode()
else:
server = ""
return f"COOKIE {client}{server}"
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_bytes(8), parser.get_remaining())
class ReportChannelOption(Option):
# RFC 9567
def __init__(self, agent_domain: dns.name.Name):
super().__init__(OptionType.REPORTCHANNEL)
self.agent_domain = agent_domain
def to_wire(self, file: Any = None) -> Optional[bytes]:
return self.agent_domain.to_wire(file)
def to_text(self) -> str:
return "REPORTCHANNEL " + self.agent_domain.to_text()
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_name())
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption,
OptionType.COOKIE: CookieOption,
OptionType.REPORTCHANNEL: ReportChannelOption,
}
@ -567,6 +512,5 @@ KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN
EDE = OptionType.EDE
REPORTCHANNEL = OptionType.REPORTCHANNEL
### END generated OptionType constants

View file

@ -81,7 +81,7 @@ class DNSException(Exception):
if kwargs:
assert (
set(kwargs.keys()) == self.supp_kwargs
), f"following set of keyword args is required: {self.supp_kwargs}"
), "following set of keyword args is required: %s" % (self.supp_kwargs)
return kwargs
def _fmt_kwargs(self, **kwargs):

View file

@ -54,7 +54,7 @@ def from_text(text: str) -> Tuple[int, int, int]:
elif c.isdigit():
cur += c
else:
raise dns.exception.SyntaxError(f"Could not parse {c}")
raise dns.exception.SyntaxError("Could not parse %s" % (c))
if state == 0:
raise dns.exception.SyntaxError("no stop value specified")

View file

@ -143,7 +143,9 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
if m is not None:
b = dns.ipv4.inet_aton(m.group(2))
btext = (
f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
"{}:{:02x}{:02x}:{:02x}{:02x}".format(
m.group(1).decode(), b[0], b[1], b[2], b[3]
)
).encode()
#
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to

View file

@ -18,10 +18,9 @@
"""DNS Messages"""
import contextlib
import enum
import io
import time
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union
import dns.edns
import dns.entropy
@ -162,7 +161,6 @@ class Message:
self.index: IndexType = {}
self.errors: List[MessageError] = []
self.time = 0.0
self.wire: Optional[bytes] = None
@property
def question(self) -> List[dns.rrset.RRset]:
@ -222,16 +220,16 @@ class Message:
s = io.StringIO()
s.write("id %d\n" % self.id)
s.write(f"opcode {dns.opcode.to_text(self.opcode())}\n")
s.write(f"rcode {dns.rcode.to_text(self.rcode())}\n")
s.write(f"flags {dns.flags.to_text(self.flags)}\n")
s.write("opcode %s\n" % dns.opcode.to_text(self.opcode()))
s.write("rcode %s\n" % dns.rcode.to_text(self.rcode()))
s.write("flags %s\n" % dns.flags.to_text(self.flags))
if self.edns >= 0:
s.write(f"edns {self.edns}\n")
s.write("edns %s\n" % self.edns)
if self.ednsflags != 0:
s.write(f"eflags {dns.flags.edns_to_text(self.ednsflags)}\n")
s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags))
s.write("payload %d\n" % self.payload)
for opt in self.options:
s.write(f"option {opt.to_text()}\n")
s.write("option %s\n" % opt.to_text())
for name, which in self._section_enum.__members__.items():
s.write(f";{name}\n")
for rrset in self.section_from_number(which):
@ -647,7 +645,6 @@ class Message:
if multi:
self.tsig_ctx = ctx
wire = r.get_wire()
self.wire = wire
if prepend_length:
wire = len(wire).to_bytes(2, "big") + wire
return wire
@ -915,14 +912,6 @@ class Message:
self.flags &= 0x87FF
self.flags |= dns.opcode.to_flags(opcode)
def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]:
"""Return the list of options of the specified type."""
return [option for option in self.options if option.otype == otype]
def extended_errors(self) -> List[dns.edns.EDEOption]:
"""Return the list of Extended DNS Error (EDE) options in the message"""
return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE))
def _get_one_rr_per_rrset(self, value):
# What the caller picked is fine.
return value
@ -1203,9 +1192,9 @@ class _WireReader:
if rdtype == dns.rdatatype.OPT:
self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
elif rdtype == dns.rdatatype.TSIG:
if self.keyring is None or self.keyring is True:
if self.keyring is None:
raise UnknownTSIGKey("got signed message without keyring")
elif isinstance(self.keyring, dict):
if isinstance(self.keyring, dict):
key = self.keyring.get(absolute_name)
if isinstance(key, bytes):
key = dns.tsig.Key(absolute_name, key, rd.algorithm)
@ -1214,8 +1203,7 @@ class _WireReader:
else:
key = self.keyring
if key is None:
raise UnknownTSIGKey(f"key '{name}' unknown")
if key:
raise UnknownTSIGKey("key '%s' unknown" % name)
self.message.keyring = key
self.message.tsig_ctx = dns.tsig.validate(
self.parser.wire,
@ -1263,7 +1251,6 @@ class _WireReader:
factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
self.message = factory(id=id)
self.message.flags = dns.flags.Flag(flags)
self.message.wire = self.parser.wire
self.initialize_message(self.message)
self.one_rr_per_rrset = self.message._get_one_rr_per_rrset(
self.one_rr_per_rrset
@ -1303,10 +1290,8 @@ def from_wire(
) -> Message:
"""Convert a DNS wire format message into a message object.
*keyring*, a ``dns.tsig.Key``, ``dict``, ``bool``, or ``None``, the key or keyring
to use if the message is signed. If ``None`` or ``True``, then trying to decode
a message with a TSIG will fail as it cannot be validated. If ``False``, then
TSIG validation is disabled.
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message
is signed.
*request_mac*, a ``bytes`` or ``None``. If the message is a response to a
TSIG-signed request, *request_mac* should be set to the MAC of that request.
@ -1826,16 +1811,6 @@ def make_query(
return m
class CopyMode(enum.Enum):
"""
How should sections be copied when making an update response?
"""
NOTHING = 0
QUESTION = 1
EVERYTHING = 2
def make_response(
query: Message,
recursion_available: bool = False,
@ -1843,14 +1818,13 @@ def make_response(
fudge: int = 300,
tsig_error: int = 0,
pad: Optional[int] = None,
copy_mode: Optional[CopyMode] = None,
) -> Message:
"""Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all of the infrastructure
required of a response, but none of the content.
Response section(s) which are copied are shallow copies of the matching section(s)
in the query, so the query's RRsets should not be changed.
The response's question section is a shallow copy of the query's question section,
so the query's question RRsets should not be changed.
*query*, a ``dns.message.Message``, the query to respond to.
@ -1863,44 +1837,25 @@ def make_response(
*tsig_error*, an ``int``, the TSIG error.
*pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note
that if padding is non-zero, an EDNS PADDING option will always be added to the
if not ``None`` add padding bytes to make the message size a multiple of *pad*.
Note that if padding is non-zero, an EDNS PADDING option will always be added to the
message. If ``None``, add padding following RFC 8467, namely if the request is
padded, pad the response to 468 otherwise do not pad.
*copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are
copied. The default, ``None`` copies sections according to the default for the
message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all
opcodes. ``dns.message.CopyMode.QUESTION`` copies only the question section.
``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG
records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING``
copies no sections; note that this mode is for server testing purposes and is
otherwise not recommended for use. In particular, ``dns.message.is_response()``
will be ``False`` if you create a response this way and the rcode is not
``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``.
Returns a ``dns.message.Message`` object whose specific class is appropriate for the
query. For example, if query is a ``dns.update.UpdateMessage``, the response will
be one too.
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
too.
"""
if query.flags & dns.flags.QR:
raise dns.exception.FormError("specified query message is not a query")
opcode = query.opcode()
factory = _message_factory_from_opcode(opcode)
factory = _message_factory_from_opcode(query.opcode())
response = factory(id=query.id)
response.flags = dns.flags.QR | (query.flags & dns.flags.RD)
if recursion_available:
response.flags |= dns.flags.RA
response.set_opcode(opcode)
if copy_mode is None:
copy_mode = CopyMode.QUESTION
if copy_mode != CopyMode.NOTHING:
response.set_opcode(query.opcode())
response.question = list(query.question)
if copy_mode == CopyMode.EVERYTHING:
response.answer = list(query.answer)
response.authority = list(query.authority)
response.additional = list(query.additional)
if query.edns >= 0:
if pad is None:
# Set response padding per RFC 8467

View file

@ -59,11 +59,11 @@ class NameRelation(dns.enum.IntEnum):
@classmethod
def _maximum(cls):
return cls.COMMONANCESTOR # pragma: no cover
return cls.COMMONANCESTOR
@classmethod
def _short_name(cls):
return cls.__name__ # pragma: no cover
return cls.__name__
# Backwards compatibility
@ -277,7 +277,6 @@ class IDNA2008Codec(IDNACodec):
raise NoIDNA2008
try:
if self.uts_46:
# pylint: disable=possibly-used-before-assignment
label = idna.uts46_remap(label, False, self.transitional)
return idna.alabel(label)
except idna.IDNAError as e:

View file

@ -168,14 +168,12 @@ class DoHNameserver(Nameserver):
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
self.http_version = http_version
def kind(self):
return "DoH"
@ -216,7 +214,6 @@ class DoHNameserver(Nameserver):
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)
async def async_query(
@ -241,7 +238,6 @@ class DoHNameserver(Nameserver):
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)

View file

@ -23,13 +23,11 @@ import enum
import errno
import os
import os.path
import random
import selectors
import socket
import struct
import time
import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union, cast
from typing import Any, Dict, Optional, Tuple, Union
import dns._features
import dns.exception
@ -131,7 +129,7 @@ if _have_httpx:
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
@ -219,7 +217,7 @@ def _wait_for(fd, readable, writable, _, expiration):
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
return True
sel = selectors.DefaultSelector()
sel = _selector_class()
events = 0
if readable:
events |= selectors.EVENT_READ
@ -237,6 +235,26 @@ def _wait_for(fd, readable, writable, _, expiration):
raise dns.exception.Timeout
def _set_selector_class(selector_class):
# Internal API. Do not use.
global _selector_class
_selector_class = selector_class
if hasattr(selectors, "PollSelector"):
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
#
# We ignore typing here as we can't say _selector_class is Any
# on python < 3.8 due to a bug.
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
def _wait_for_readable(s, expiration):
_wait_for(s, True, False, True, expiration)
@ -337,36 +355,6 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
raise
def _maybe_get_resolver(
resolver: Optional["dns.resolver.Resolver"],
) -> "dns.resolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
resolver = dns.resolver.Resolver()
return resolver
class HTTPVersion(enum.IntEnum):
"""Which version of HTTP should be used?
DEFAULT will select the first version from the list [2, 1.1, 3] that
is available.
"""
DEFAULT = 0
HTTP_1 = 1
H1 = 1
HTTP_2 = 2
H2 = 2
HTTP_3 = 3
H3 = 3
def https(
q: dns.message.Message,
where: str,
@ -382,8 +370,7 @@ def https(
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
@ -433,66 +420,27 @@ def https(
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
*http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
Returns a ``dns.message.Message``.
"""
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
else:
url = where
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
return _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)
if not have_doh:
raise NoDOH # pragma: no cover
if session and not isinstance(session, httpx.Client):
raise ValueError("session parameter must be an httpx.Client")
wire = q.to_wire()
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
transport = None
headers = {"accept": "application/dns-message"}
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
# set source port and source address
@ -502,14 +450,10 @@ def https(
else:
local_address = the_source[0]
local_port = the_source[1]
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
transport = _HTTPTransport(
local_address=local_address,
http1=h1,
http2=h2,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
@ -517,7 +461,10 @@ def https(
family=family,
)
cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
@ -528,30 +475,20 @@ def https(
"content-length": str(len(wire)),
}
)
response = session.post(
url,
headers=headers,
content=wire,
timeout=timeout,
extensions=extensions,
)
response = session.post(url, headers=headers, content=wire, timeout=timeout)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = session.get(
url,
headers=headers,
timeout=timeout,
params={"dns": twire},
extensions=extensions,
url, headers=headers, timeout=timeout, params={"dns": twire}
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
f"{where} responded with status code {response.status_code}"
f"\nResponse body: {response.content}"
"{} responded with status code {}"
"\nResponse body: {}".format(where, response.status_code, response.content)
)
r = dns.message.from_wire(
response.content,
@ -566,81 +503,6 @@ def https(
return r
def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
if headers is None:
raise KeyError
for header, value in headers:
if header == name:
return value
raise KeyError
def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
value = _find_header(headers, b":status")
if value is None:
raise SyntaxError("no :status header in response")
status = int(value)
if status < 0:
raise SyntaxError("status is negative")
if status < 200 or status > 299:
error = ""
if len(wire) > 0:
try:
error = ": " + wire.decode()
except Exception:
pass
raise ValueError(f"{peer} responded with status code {status}{error}")
def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=hostname, h3=True
)
with manager:
connection = manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
with connection.make_stream(timeout) as stream:
stream.send_h3(url, wire, post)
wire = stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
def _udp_recv(sock, max_size, expiration):
"""Reads a datagram from the socket.
A Timeout exception will be raised if the operation is not completed
@ -993,7 +855,7 @@ def _net_read(sock, count, expiration):
try:
n = sock.recv(count)
if n == b"":
raise EOFError("EOF")
raise EOFError
count -= len(n)
s += n
except (BlockingIOError, ssl.SSLWantReadError):
@ -1161,7 +1023,6 @@ def tcp(
cm = _make_socket(af, socket.SOCK_STREAM, source)
with cm as s:
if not sock:
# pylint: disable=possibly-used-before-assignment
_connect(s, destination, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(
@ -1327,7 +1188,6 @@ def quic(
ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC.
@ -1352,21 +1212,17 @@ def quic(
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
to send the query.
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
connection to use to send the query.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
*hostname*, a ``str`` containing the server's hostname or ``None``. The default is
``None``, which means that no hostname is known, and if an SSL context is created,
hostname checking will be disabled. This value is ignored if *url* is not
``None``.
*server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
only, and has the same meaning as *hostname*.
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``.
"""
@ -1374,9 +1230,6 @@ def quic(
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.SyncQuicConnection
@ -1385,7 +1238,9 @@ def quic(
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection
else:
manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=server_hostname
)
the_manager = manager # for type checking happiness
with manager:
@ -1409,70 +1264,6 @@ def quic(
return r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: socket.socket,
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
def xfr(
where: str,
zone: Union[dns.name.Name, str],
@ -1542,52 +1333,134 @@ def xfr(
Returns a generator of ``dns.message.Message`` objects.
"""
class DummyTransactionManager(dns.transaction.TransactionManager):
def __init__(self, origin, relativize):
self.info = (origin, relativize, dns.name.empty if relativize else origin)
def origin_information(self):
return self.info
def get_class(self) -> dns.rdataclass.RdataClass:
raise NotImplementedError # pragma: no cover
def reader(self):
raise NotImplementedError # pragma: no cover
def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
class DummyTransaction:
def nop(self, *args, **kw):
pass
def __getattr__(self, _):
return self.nop
return cast(dns.transaction.Transaction, DummyTransaction())
if isinstance(zone, str):
zone = dns.name.from_text(zone)
rdtype = dns.rdatatype.RdataType.make(rdtype)
q = dns.message.make_query(zone, rdtype, rdclass)
if rdtype == dns.rdatatype.IXFR:
rrset = q.find_rrset(
q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
)
soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
rrset.add(soa, 0)
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
q.authority.append(rrset)
if keyring is not None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
wire = q.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
(_, expiration) = _compute_times(lifetime)
tm = DummyTransactionManager(zone, relativize)
if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError("cannot do a UDP AXFR")
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
with _make_socket(af, sock_type, source) as s:
(_, expiration) = _compute_times(lifetime)
_connect(s, destination, expiration)
yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
l = len(wire)
if use_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
done = False
delete_mode = True
expecting_SOA = False
soa_rrset = None
if relativize:
origin = zone
oname = dns.name.empty
else:
origin = None
oname = zone
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if use_udp:
(wire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
wire = _net_read(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=True,
one_rr_per_rrset=is_ixfr,
)
rcode = r.rcode()
if rcode != dns.rcode.NOERROR:
raise TransferError(rcode)
tsig_ctx = r.tsig_ctx
answer_index = 0
if soa_rrset is None:
if not r.answer or r.answer[0].name != oname:
raise dns.exception.FormError("No answer or RRset not for qname")
rrset = r.answer[0]
if rrset.rdtype != dns.rdatatype.SOA:
raise dns.exception.FormError("first RRset is not an SOA")
answer_index = 1
soa_rrset = rrset.copy()
if rdtype == dns.rdatatype.IXFR:
if dns.serial.Serial(soa_rrset[0].serial) <= serial:
#
# We're already up-to-date.
#
done = True
else:
expecting_SOA = True
#
# Process SOAs in the answer section (other than the initial
# SOA in the first message).
#
for rrset in r.answer[answer_index:]:
if done:
raise dns.exception.FormError("answers after final SOA")
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
if expecting_SOA:
if rrset[0].serial != serial:
raise dns.exception.FormError("IXFR base serial mismatch")
expecting_SOA = False
elif rdtype == dns.rdatatype.IXFR:
delete_mode = not delete_mode
#
# If this SOA RRset is equal to the first we saw then we're
# finished. If this is an IXFR we also check that we're
# seeing the record in the expected part of the response.
#
if rrset == soa_rrset and (
rdtype == dns.rdatatype.AXFR
or (rdtype == dns.rdatatype.IXFR and delete_mode)
):
done = True
elif expecting_SOA:
#
# We made an IXFR request and are expecting another
# SOA RR, but saw something else, so this must be an
# AXFR response.
#
rdtype = dns.rdatatype.AXFR
expecting_SOA = False
if done and q.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
yield r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def inbound_xfr(
@ -1641,25 +1514,65 @@ def inbound_xfr(
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
(_, expiration) = _compute_times(lifetime)
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
with _make_socket(af, socket.SOCK_DGRAM, source) as s:
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
with _make_socket(af, sock_type, source) as s:
_connect(s, destination, expiration)
try:
for _ in _inbound_xfr(
txn_manager, s, query, serial, timeout, expiration
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
pass
return
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
with _make_socket(af, socket.SOCK_STREAM, source) as s:
_connect(s, destination, expiration)
for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
pass
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")

View file

@ -1,7 +1,5 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import List, Tuple
import dns._features
import dns.asyncbackend
@ -75,6 +73,3 @@ else: # pragma: no cover
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError
Headers = List[Tuple[bytes, bytes]]

View file

@ -43,22 +43,8 @@ class AsyncioQuicStream(BaseQuicStream):
raise dns.exception.Timeout
self._expecting = 0
async def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.seen_end():
return
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
await self.wait_for_end(expiration)
return self._buffer.get_all()
else:
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
@ -97,7 +83,6 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._wake_pending = False
async def _receiver(self):
try:
@ -119,24 +104,19 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
await self._wakeup()
async with self._wake_timer:
self._wake_timer.notify_all()
except Exception:
pass
finally:
self._done = True
await self._wakeup()
self._handshake_complete.set()
async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
self._handshake_complete.set()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False
async def _sender(self):
await self._socket_created.wait()
@ -160,25 +140,6 @@ class AsyncioQuicConnection(AsyncQuicConnection):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
@ -200,7 +161,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
await self._wakeup()
async with self._wake_timer:
self._wake_timer.notify_all()
def run(self):
if self._closed:
@ -227,7 +189,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._wakeup()
async with self._wake_timer:
self._wake_timer.notify_all()
try:
await self._receiver_task
except asyncio.CancelledError:
@ -240,10 +203,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
class AsyncioQuicManager(AsyncQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True

View file

@ -1,16 +1,12 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import base64
import copy
import functools
import socket
import struct
import time
import urllib
from typing import Any, Optional
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
@ -55,12 +51,6 @@ class Buffer:
self._buffer = self._buffer[amount:]
return data
def get_all(self):
assert self.seen_end()
data = self._buffer
self._buffer = b""
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
@ -68,18 +58,10 @@ class BaseQuicStream:
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
self._headers = None
self._trailers = None
def id(self):
return self._stream_id
def headers(self):
return self._headers
def trailers(self):
return self._trailers
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
@ -95,51 +77,16 @@ class BaseQuicStream:
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises.
# Subclass must implement send() as sync / async and which takes a message and
# an EOF indicator.
def send_h3(self, url, datagram, post=True):
if not self._connection.is_h3():
raise SyntaxError("cannot send H3 to a non-H3 connection")
url_parts = urllib.parse.urlparse(url)
path = url_parts.path.encode()
if post:
method = b"POST"
else:
method = b"GET"
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
headers = [
(b":method", method),
(b":scheme", url_parts.scheme.encode()),
(b":authority", url_parts.netloc.encode()),
(b":path", path),
(b"accept", b"application/dns-message"),
]
if post:
headers.extend(
[
(b"content-type", b"application/dns-message"),
(b"content-length", str(len(datagram)).encode()),
]
)
self._connection.send_headers(self._stream_id, headers, not post)
if post:
self._connection.send_data(self._stream_id, datagram, True)
# or raises UnexpectedEOF.
def _encapsulate(self, datagram):
if self._connection.is_h3():
return datagram
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
try:
return (
self._expecting > 0 and self._buffer.have(self._expecting)
) or self._buffer.seen_end
return self._expecting > 0 and self._buffer.have(self._expecting)
except UnexpectedEOF:
return True
@ -150,13 +97,7 @@ class BaseQuicStream:
class BaseQuicConnection:
def __init__(
self,
connection,
address,
port,
source=None,
source_port=0,
manager=None,
self, connection, address, port, source=None, source_port=0, manager=None
):
self._done = False
self._connection = connection
@ -165,10 +106,6 @@ class BaseQuicConnection:
self._closed = False
self._manager = manager
self._streams = {}
if manager.is_h3():
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
else:
self._h3_conn = None
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
@ -183,18 +120,9 @@ class BaseQuicConnection:
else:
self._source = None
def is_h3(self):
return self._h3_conn is not None
def close_stream(self, stream_id):
del self._streams[stream_id]
def send_headers(self, stream_id, headers, is_end=False):
self._h3_conn.send_headers(stream_id, headers, is_end)
def send_data(self, stream_id, data, is_end=False):
self._h3_conn.send_data(stream_id, data, is_end)
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
@ -220,25 +148,17 @@ class AsyncQuicConnection(BaseQuicConnection):
class BaseQuicManager:
def __init__(
self, conf, verify_mode, connection_factory, server_name=None, h3=False
):
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
self._tokens = {}
self._h3 = h3
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
if h3:
alpn_protocols = ["h3"]
else:
alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=alpn_protocols,
alpn_protocols=["doq", "doq-i03"],
verify_mode=verify_mode,
server_name=server_name,
)
@ -247,13 +167,7 @@ class BaseQuicManager:
self._conf = conf
def _connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
connection = self._connections.get((address, port))
if connection is not None:
@ -275,24 +189,9 @@ class BaseQuicManager:
)
else:
session_ticket_handler = None
if want_token:
try:
token = self._tokens.pop((address, port))
# We found a token, so make a configuration that uses it.
conf = copy.copy(conf)
conf.token = token
except KeyError:
# No token
pass
# Whether or not we found a token, we want a handler to save # one.
token_handler = functools.partial(self.save_token, address, port)
else:
token_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
@ -308,9 +207,6 @@ class BaseQuicManager:
except KeyError:
pass
def is_h3(self):
return self._h3
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
@ -322,17 +218,6 @@ class BaseQuicManager:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
def save_token(self, address, port, token):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._tokens)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._tokens[key]
self._tokens[(address, port)] = token
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):

View file

@ -21,9 +21,11 @@ from dns.quic._common import (
UnexpectedEOF,
)
# Function used to create a socket. Can be overridden if needed in special
# situations.
socket_factory = socket.socket
# Avoid circularity with dns.query
if hasattr(selectors, "PollSelector"):
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
class SyncQuicStream(BaseQuicStream):
@ -44,23 +46,8 @@ class SyncQuicStream(BaseQuicStream):
raise dns.exception.Timeout
self._expecting = 0
def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.seen_end():
return
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
self.wait_for_end(expiration)
with self._lock:
return self._buffer.get_all()
else:
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
@ -94,7 +81,7 @@ class SyncQuicStream(BaseQuicStream):
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None:
try:
self._socket.bind(
@ -131,7 +118,7 @@ class SyncQuicConnection(BaseQuicConnection):
def _worker(self):
try:
sel = selectors.DefaultSelector()
sel = _selector_class()
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
while not self._done:
@ -153,7 +140,6 @@ class SyncQuicConnection(BaseQuicConnection):
finally:
with self._lock:
self._done = True
self._socket.close()
# Ensure anyone waiting for this gets woken up.
self._handshake_complete.set()
@ -164,25 +150,6 @@ class SyncQuicConnection(BaseQuicConnection):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(h3_event.data, h3_event.stream_ended)
else:
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
@ -203,18 +170,6 @@ class SyncQuicConnection(BaseQuicConnection):
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def send_headers(self, stream_id, headers, is_end=False):
with self._lock:
super().send_headers(stream_id, headers, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def send_data(self, stream_id, data, is_end=False):
with self._lock:
super().send_data(stream_id, data, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
@ -248,24 +203,16 @@ class SyncQuicConnection(BaseQuicConnection):
class SyncQuicManager(BaseQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock()
def connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
with self._lock:
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket, want_token
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
@ -279,10 +226,6 @@ class SyncQuicManager(BaseQuicManager):
with self._lock:
super().save_session_ticket(address, port, ticket)
def save_token(self, address, port, token):
with self._lock:
super().save_token(address, port, token)
def __enter__(self):
return self

View file

@ -36,23 +36,12 @@ class TrioQuicStream(BaseQuicStream):
await self._wake_up.wait()
self._expecting = 0
async def wait_for_end(self):
while True:
if self._buffer.seen_end():
return
async with self._wake_up:
await self._wake_up.wait()
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
if self._connection.is_h3():
await self.wait_for_end()
return self._buffer.get_all()
else:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
@ -126,7 +115,6 @@ class TrioQuicConnection(AsyncQuicConnection):
await self._socket.send(datagram)
finally:
self._done = True
self._socket.close()
self._handshake_complete.set()
async def _handle_events(self):
@ -136,25 +124,6 @@ class TrioQuicConnection(AsyncQuicConnection):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
@ -214,14 +183,9 @@ class TrioQuicConnection(AsyncQuicConnection):
class TrioQuicManager(AsyncQuicManager):
def __init__(
self,
nursery,
conf=None,
verify_mode=ssl.CERT_REQUIRED,
server_name=None,
h3=False,
self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery
def connect(

View file

@ -214,7 +214,7 @@ class Rdata:
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> None:
) -> bytes:
raise NotImplementedError # pragma: no cover
def to_wire(
@ -223,19 +223,14 @@ class Rdata:
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> Optional[bytes]:
) -> bytes:
"""Convert an rdata to wire format.
Returns a ``bytes`` if no output file was specified, or ``None`` otherwise.
Returns a ``bytes`` or ``None``.
"""
if file:
# We call _to_wire() and then return None explicitly instead of
# of just returning the None from _to_wire() as mypy's func-returns-value
# unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return
# a value (it only ever returns None)"
self._to_wire(file, compress, origin, canonicalize)
return None
return self._to_wire(file, compress, origin, canonicalize)
else:
f = io.BytesIO()
self._to_wire(f, compress, origin, canonicalize)
@ -258,9 +253,8 @@ class Rdata:
Returns a ``bytes``.
"""
wire = self.to_wire(origin=origin, canonicalize=True)
assert wire is not None # for mypy
return wire
return self.to_wire(origin=origin, canonicalize=True)
def __repr__(self):
covers = self.covers()
@ -440,11 +434,15 @@ class Rdata:
continue
if key not in parameters:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{key}'"
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, key
)
)
if key in ("rdclass", "rdtype"):
raise AttributeError(
f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'"
"Cannot overwrite '{}' attribute '{}'".format(
self.__class__.__name__, key
)
)
# Construct the parameter list. For each field, use the value in
@ -648,14 +646,13 @@ _rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType],
{}
)
_module_prefix = "dns.rdtypes"
_dynamic_load_allowed = True
def get_rdata_class(rdclass, rdtype, use_generic=True):
def get_rdata_class(rdclass, rdtype):
cls = _rdata_classes.get((rdclass, rdtype))
if not cls:
cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
if not cls and _dynamic_load_allowed:
if not cls:
rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace("-", "_")
@ -673,36 +670,12 @@ def get_rdata_class(rdclass, rdtype, use_generic=True):
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
pass
if not cls and use_generic:
if not cls:
cls = GenericRdata
_rdata_classes[(rdclass, rdtype)] = cls
return cls
def load_all_types(disable_dynamic_load=True):
"""Load all rdata types for which dnspython has a non-generic implementation.
Normally dnspython loads DNS rdatatype implementations on demand, but in some
specialized cases loading all types at an application-controlled time is preferred.
If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt
to use its dynamic loading mechanism if an unknown type is subsequently encountered,
and will simply use the ``GenericRdata`` class.
"""
# Load class IN and ANY types.
for rdtype in dns.rdatatype.RdataType:
get_rdata_class(dns.rdataclass.IN, rdtype, False)
# Load the one non-ANY implementation we have in CH. Everything
# else in CH is an ANY type, and we'll discover those on demand but won't
# have to import anything.
get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False)
if disable_dynamic_load:
# Now disable dynamic loading so any subsequent unknown type immediately becomes
# GenericRdata without a load attempt.
global _dynamic_load_allowed
_dynamic_load_allowed = False
def from_text(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],

View file

@ -160,7 +160,7 @@ class Rdataset(dns.set.Set):
return s[:100] + "..."
return s
return "[" + ", ".join(f"<{maybe_truncate(str(rr))}>" for rr in self) + "]"
return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
def __repr__(self):
if self.covers == 0:
@ -248,8 +248,12 @@ class Rdataset(dns.set.Set):
# (which is meaningless anyway).
#
s.write(
f"{ntext}{pad}{dns.rdataclass.to_text(rdclass)} "
f"{dns.rdatatype.to_text(self.rdtype)}\n"
"{}{}{} {}\n".format(
ntext,
pad,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype),
)
)
else:
for rd in self:

View file

@ -105,8 +105,6 @@ class RdataType(dns.enum.IntEnum):
CAA = 257
AVC = 258
AMTRELAY = 260
RESINFO = 261
WALLET = 262
TA = 32768
DLV = 32769
@ -127,7 +125,7 @@ class RdataType(dns.enum.IntEnum):
if text.find("-") >= 0:
try:
return cls[text.replace("-", "_")]
except KeyError: # pragma: no cover
except KeyError:
pass
return _registered_by_text.get(text)
@ -328,8 +326,6 @@ URI = RdataType.URI
CAA = RdataType.CAA
AVC = RdataType.AVC
AMTRELAY = RdataType.AMTRELAY
RESINFO = RdataType.RESINFO
WALLET = RdataType.WALLET
TA = RdataType.TA
DLV = RdataType.DLV

View file

@ -75,9 +75,8 @@ class GPOS(dns.rdata.Rdata):
raise dns.exception.FormError("bad longitude")
def to_text(self, origin=None, relativize=True, **kw):
return (
f"{self.latitude.decode()} {self.longitude.decode()} "
f"{self.altitude.decode()}"
return "{} {} {}".format(
self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
)
@classmethod

View file

@ -37,7 +37,9 @@ class HINFO(dns.rdata.Rdata):
self.os = self._as_bytes(os, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
return f'"{dns.rdata._escapify(self.cpu)}" "{dns.rdata._escapify(self.os)}"'
return '"{}" "{}"'.format(
dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
)
@classmethod
def from_text(

View file

@ -48,7 +48,7 @@ class HIP(dns.rdata.Rdata):
for server in self.servers:
servers.append(server.choose_relativity(origin, relativize))
if len(servers) > 0:
text += " " + " ".join(x.to_unicode() for x in servers)
text += " " + " ".join((x.to_unicode() for x in servers))
return "%u %s %s%s" % (self.algorithm, hit, key, text)
@classmethod

View file

@ -38,12 +38,11 @@ class ISDN(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress:
return (
f'"{dns.rdata._escapify(self.address)}" '
f'"{dns.rdata._escapify(self.subaddress)}"'
return '"{}" "{}"'.format(
dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
)
else:
return f'"{dns.rdata._escapify(self.address)}"'
return '"%s"' % dns.rdata._escapify(self.address)
@classmethod
def from_text(

View file

@ -44,7 +44,7 @@ def _exponent_of(what, desc):
exp = i - 1
break
if exp is None or exp < 0:
raise dns.exception.SyntaxError(f"{desc} value out of bounds")
raise dns.exception.SyntaxError("%s value out of bounds" % desc)
return exp
@ -83,10 +83,10 @@ def _encode_size(what, desc):
def _decode_size(what, desc):
exponent = what & 0x0F
if exponent > 9:
raise dns.exception.FormError(f"bad {desc} exponent")
raise dns.exception.FormError("bad %s exponent" % desc)
base = (what & 0xF0) >> 4
if base > 9:
raise dns.exception.FormError(f"bad {desc} base")
raise dns.exception.FormError("bad %s base" % desc)
return base * pow(10, exponent)
@ -184,9 +184,10 @@ class LOC(dns.rdata.Rdata):
or self.horizontal_precision != _default_hprec
or self.vertical_precision != _default_vprec
):
text += (
f" {self.size / 100.0:0.2f}m {self.horizontal_precision / 100.0:0.2f}m"
f" {self.vertical_precision / 100.0:0.2f}m"
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
self.size / 100.0,
self.horizontal_precision / 100.0,
self.vertical_precision / 100.0,
)
return text

View file

@ -44,7 +44,7 @@ class NSEC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize)
text = Bitmap(self.windows).to_text()
return f"{next}{text}"
return "{}{}".format(next, text)
@classmethod
def from_text(

View file

@ -1,24 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class RESINFO(dns.rdtypes.txtbase.TXTBase):
"""RESINFO record"""

View file

@ -37,7 +37,7 @@ class RP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
mbox = self.mbox.choose_relativity(origin, relativize)
txt = self.txt.choose_relativity(origin, relativize)
return f"{str(mbox)} {str(txt)}"
return "{} {}".format(str(mbox), str(txt))
@classmethod
def from_text(

View file

@ -69,7 +69,7 @@ class TKEY(dns.rdata.Rdata):
dns.rdata._base64ify(self.key, 0),
)
if len(self.other) > 0:
text += f" {dns.rdata._base64ify(self.other, 0)}"
text += " %s" % (dns.rdata._base64ify(self.other, 0))
return text

View file

@ -1,9 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class WALLET(dns.rdtypes.txtbase.TXTBase):
"""WALLET record"""

View file

@ -36,7 +36,7 @@ class X25(dns.rdata.Rdata):
self.address = self._as_bytes(address, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
return f'"{dns.rdata._escapify(self.address)}"'
return '"%s"' % dns.rdata._escapify(self.address)
@classmethod
def from_text(

View file

@ -51,7 +51,6 @@ __all__ = [
"OPENPGPKEY",
"OPT",
"PTR",
"RESINFO",
"RP",
"RRSIG",
"RT",
@ -64,7 +63,6 @@ __all__ = [
"TSIG",
"TXT",
"URI",
"WALLET",
"X25",
"ZONEMD",
]

View file

@ -37,7 +37,7 @@ class A(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
domain = self.domain.choose_relativity(origin, relativize)
return f"{domain} {self.address:o}"
return "%s %o" % (domain, self.address)
@classmethod
def from_text(

View file

@ -36,7 +36,7 @@ class NSAP(dns.rdata.Rdata):
self.address = self._as_bytes(address)
def to_text(self, origin=None, relativize=True, **kw):
return f"0x{binascii.hexlify(self.address).decode()}"
return "0x%s" % binascii.hexlify(self.address).decode()
@classmethod
def from_text(

View file

@ -36,7 +36,7 @@ class EUIBase(dns.rdata.Rdata):
self.eui = self._as_bytes(eui)
if len(self.eui) != self.byte_len:
raise dns.exception.FormError(
f"EUI{self.byte_len * 8} rdata has to have {self.byte_len} bytes"
"EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len)
)
def to_text(self, origin=None, relativize=True, **kw):
@ -49,16 +49,16 @@ class EUIBase(dns.rdata.Rdata):
text = tok.get_string()
if len(text) != cls.text_len:
raise dns.exception.SyntaxError(
f"Input text must have {cls.text_len} characters"
"Input text must have %s characters" % cls.text_len
)
for i in range(2, cls.byte_len * 3 - 1, 3):
if text[i] != "-":
raise dns.exception.SyntaxError(f"Dash expected at position {i}")
raise dns.exception.SyntaxError("Dash expected at position %s" % i)
text = text.replace("-", "")
try:
data = binascii.unhexlify(text.encode())
except (ValueError, TypeError) as ex:
raise dns.exception.SyntaxError(f"Hex decoding error: {str(ex)}")
raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex))
return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -35,7 +35,6 @@ class ParamKey(dns.enum.IntEnum):
ECH = 5
IPV6HINT = 6
DOHPATH = 7
OHTTP = 8
@classmethod
def _maximum(cls):
@ -397,36 +396,6 @@ class ECHParam(Param):
file.write(self.ech)
@dns.immutable.immutable
class OHTTPParam(Param):
# We don't ever expect to instantiate this class, but we need
# a from_value() and a from_wire_parser(), so we just return None
# from the class methods when things are OK.
@classmethod
def emptiness(cls):
return Emptiness.ALWAYS
@classmethod
def from_value(cls, value):
if value is None or value == "":
return None
else:
raise ValueError("ohttp with non-empty value")
def to_text(self):
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
if parser.remaining() != 0:
raise dns.exception.FormError
return None
def to_wire(self, file, origin=None): # pylint: disable=W0613
raise NotImplementedError # pragma: no cover
_class_for_key = {
ParamKey.MANDATORY: MandatoryParam,
ParamKey.ALPN: ALPNParam,
@ -435,7 +404,6 @@ _class_for_key = {
ParamKey.IPV4HINT: IPv4HintParam,
ParamKey.ECH: ECHParam,
ParamKey.IPV6HINT: IPv6HintParam,
ParamKey.OHTTP: OHTTPParam,
}

Some files were not shown because too many files have changed in this diff Show more