mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-08-22 14:13:40 -07:00
Merge branch 'nightly' into dependabot/pip/nightly/apscheduler-3.10.0
This commit is contained in:
commit
bcb9ced903
245 changed files with 11444 additions and 6640 deletions
3
.github/workflows/publish-docker.yml
vendored
3
.github/workflows/publish-docker.yml
vendored
|
@ -70,7 +70,7 @@ jobs:
|
|||
password: ${{ secrets.GHCR_TOKEN }}
|
||||
|
||||
- name: Docker Build and Push
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v4
|
||||
if: success()
|
||||
with:
|
||||
context: .
|
||||
|
@ -87,7 +87,6 @@ jobs:
|
|||
ghcr.io/${{ steps.prepare.outputs.docker_image }}:${{ steps.prepare.outputs.tag }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache
|
||||
provenance: false
|
||||
|
||||
discord:
|
||||
name: Discord Notification
|
||||
|
|
|
@ -79,7 +79,6 @@ select.form-control {
|
|||
color: #eee !important;
|
||||
border: 0px solid #444 !important;
|
||||
background: #555 !important;
|
||||
padding: 1px 2px;
|
||||
transition: background-color .3s;
|
||||
}
|
||||
.selectize-control.form-control .selectize-input {
|
||||
|
@ -87,7 +86,6 @@ select.form-control {
|
|||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
margin-bottom: 4px;
|
||||
padding-left: 5px;
|
||||
}
|
||||
.selectize-control.form-control.selectize-pms-ip .selectize-input {
|
||||
padding-left: 12px !important;
|
||||
|
|
|
@ -12,6 +12,7 @@ data :: Usable parameters (if not applicable for media type, blank value will be
|
|||
== Global keys ==
|
||||
rating_key Returns the unique identifier for the media item.
|
||||
media_type Returns the type of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'.
|
||||
sub_media_type Returns the subtype of media. Either 'movie', 'show', 'season', 'episode', 'artist', 'album', or 'track'.
|
||||
art Returns the location of the item's artwork
|
||||
title Returns the name of the movie, show, episode, artist, album, or track.
|
||||
edition_title Returns the edition title of a movie.
|
||||
|
@ -213,7 +214,7 @@ DOCUMENTATION :: END
|
|||
% if _session['user_group'] == 'admin':
|
||||
<span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span>
|
||||
% endif
|
||||
% elif data['media_type'] in ('artist', 'album', 'track', 'playlist', 'photo_album', 'photo', 'clip'):
|
||||
% elif data['media_type'] in ('artist', 'album', 'track', 'playlist', 'photo_album', 'photo', 'clip') or data['sub_media_type'] in ('artist', 'album', 'track'):
|
||||
<div class="summary-poster-face-track" style="background-image: url(${page('pms_image_proxy', data['thumb'], data['rating_key'], 300, 300, fallback='cover')});">
|
||||
<div class="summary-poster-face-overlay">
|
||||
<span></span>
|
||||
|
@ -267,7 +268,7 @@ DOCUMENTATION :: END
|
|||
<h1><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a></h1>
|
||||
<h2>${data['title']}</h2>
|
||||
% elif data['media_type'] == 'track':
|
||||
<h1><a href="${page('info', data['grandparent_rating_key'])}">${data['original_title'] or data['grandparent_title']}</a></h1>
|
||||
<h1><a href="${page('info', data['grandparent_rating_key'])}">${data['grandparent_title']}</a></h1>
|
||||
<h2><a href="${page('info', data['parent_rating_key'])}">${data['parent_title']}</a> - ${data['title']}</h2>
|
||||
<h3 class="hidden-xs">T${data['media_index']}</h3>
|
||||
% elif data['media_type'] in ('photo', 'clip'):
|
||||
|
@ -283,14 +284,14 @@ DOCUMENTATION :: END
|
|||
padding_height = ''
|
||||
if data['media_type'] == 'movie' or data['live']:
|
||||
padding_height = 'height: 305px;'
|
||||
elif data['media_type'] in ('show', 'season', 'collection'):
|
||||
padding_height = 'height: 270px;'
|
||||
elif data['media_type'] == 'episode':
|
||||
padding_height = 'height: 70px;'
|
||||
elif data['media_type'] in ('artist', 'album', 'playlist', 'photo_album', 'photo'):
|
||||
elif data['media_type'] in ('artist', 'album', 'playlist', 'photo_album', 'photo') or data['sub_media_type'] in ('artist', 'album', 'track'):
|
||||
padding_height = 'height: 150px;'
|
||||
elif data['media_type'] in ('track', 'clip'):
|
||||
padding_height = 'height: 180px;'
|
||||
elif data['media_type'] == 'episode':
|
||||
padding_height = 'height: 70px;'
|
||||
elif data['media_type'] in ('show', 'season', 'collection'):
|
||||
padding_height = 'height: 270px;'
|
||||
%>
|
||||
<div class="summary-content-padding hidden-xs hidden-sm" style="${padding_height}">
|
||||
% if data['media_type'] in ('movie', 'episode', 'track', 'clip'):
|
||||
|
@ -369,6 +370,11 @@ DOCUMENTATION :: END
|
|||
Studio <strong> ${data['studio']}</strong>
|
||||
% endif
|
||||
</div>
|
||||
<div class="summary-content-details-tag">
|
||||
% if data['media_type'] == 'track' and data['original_title']:
|
||||
Track Artists <strong> ${data['original_title']}</strong>
|
||||
% endif
|
||||
</div>
|
||||
<div class="summary-content-details-tag">
|
||||
% if data['media_type'] == 'movie':
|
||||
Year <strong> ${data['year']}</strong>
|
||||
|
@ -548,7 +554,7 @@ DOCUMENTATION :: END
|
|||
</div>
|
||||
</div>
|
||||
% endif
|
||||
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'):
|
||||
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
|
||||
<div class="col-md-12">
|
||||
<div class="table-card-header">
|
||||
<div class="header-bar">
|
||||
|
@ -812,7 +818,7 @@ DOCUMENTATION :: END
|
|||
% elif data['media_type'] == 'album':
|
||||
${data['parent_title']}<br />${data['title']}
|
||||
% elif data['media_type'] == 'track':
|
||||
${data['original_title'] or data['grandparent_title']}<br />${data['title']}<br />${data['parent_title']}
|
||||
${data['grandparent_title']}<br />${data['title']}<br />${data['parent_title']}
|
||||
% endif
|
||||
</strong>
|
||||
</p>
|
||||
|
@ -931,13 +937,16 @@ DOCUMENTATION :: END
|
|||
});
|
||||
</script>
|
||||
% endif
|
||||
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track'):
|
||||
% if data['media_type'] in ('movie', 'show', 'season', 'episode', 'artist', 'album', 'track', 'collection', 'playlist'):
|
||||
<script>
|
||||
// Populate watch time stats
|
||||
$.ajax({
|
||||
url: 'item_watch_time_stats',
|
||||
async: true,
|
||||
data: { rating_key: "${data['rating_key']}" },
|
||||
data: {
|
||||
rating_key: "${data['rating_key']}",
|
||||
media_type: "${data['media_type']}"
|
||||
},
|
||||
complete: function(xhr, status) {
|
||||
$("#watch-time-stats").html(xhr.responseText);
|
||||
}
|
||||
|
@ -946,7 +955,10 @@ DOCUMENTATION :: END
|
|||
$.ajax({
|
||||
url: 'item_user_stats',
|
||||
async: true,
|
||||
data: { rating_key: "${data['rating_key']}" },
|
||||
data: {
|
||||
rating_key: "${data['rating_key']}",
|
||||
media_type: "${data['media_type']}"
|
||||
},
|
||||
complete: function(xhr, status) {
|
||||
$("#user-stats").html(xhr.responseText);
|
||||
}
|
||||
|
|
|
@ -160,6 +160,16 @@ DOCUMENTATION :: END
|
|||
% endif
|
||||
</div>
|
||||
</a>
|
||||
<div class="item-children-instance-text-wrapper poster-item">
|
||||
<h3>
|
||||
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
|
||||
</h3>
|
||||
% if media_type == 'collection':
|
||||
<h3 class="text-muted">
|
||||
<a class="text-muted" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>
|
||||
</h3>
|
||||
% endif
|
||||
</div>
|
||||
% elif child['media_type'] == 'episode':
|
||||
<a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}">
|
||||
<div class="item-children-poster">
|
||||
|
@ -179,6 +189,29 @@ DOCUMENTATION :: END
|
|||
<h3>
|
||||
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
|
||||
</h3>
|
||||
% if media_type == 'collection':
|
||||
<h3 class="text-muted">
|
||||
<a href="${page('info', child['grandparent_rating_key'])}" title="${child['grandparent_title']}">${child['grandparent_title']}</a>
|
||||
</h3>
|
||||
<h3 class="text-muted">
|
||||
<a href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${short_season(child['parent_title'])}</a>
|
||||
· <a href="${page('info', child['rating_key'])}" title="Episode ${child['media_index']}">E${child['media_index']}</a>
|
||||
</h3>
|
||||
% endif
|
||||
</div>
|
||||
% elif child['media_type'] == 'artist':
|
||||
<a href="${page('info', child['rating_key'])}" title="${child['title']}">
|
||||
<div class="item-children-poster">
|
||||
<div class="item-children-poster-face cover-item" style="background-image: url(${page('pms_image_proxy', child['thumb'], child['rating_key'], 300, 300, fallback='cover')});"></div>
|
||||
% if _session['user_group'] == 'admin':
|
||||
<span class="overlay-refresh-image" title="Refresh image"><i class="fa fa-refresh refresh_pms_image"></i></span>
|
||||
% endif
|
||||
</div>
|
||||
</a>
|
||||
<div class="item-children-instance-text-wrapper cover-item">
|
||||
<h3>
|
||||
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
|
||||
</h3>
|
||||
</div>
|
||||
% elif child['media_type'] == 'album':
|
||||
<a href="${page('info', child['rating_key'])}" title="${child['title']}">
|
||||
|
@ -193,6 +226,11 @@ DOCUMENTATION :: END
|
|||
<h3>
|
||||
<a href="${page('info', child['rating_key'])}" title="${child['title']}">${child['title']}</a>
|
||||
</h3>
|
||||
% if media_type == 'collection':
|
||||
<h3 class="text-muted">
|
||||
<a class="text-muted" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>
|
||||
</h3>
|
||||
% endif
|
||||
</div>
|
||||
% elif child['media_type'] == 'track':
|
||||
<% e = 'even' if loop.index % 2 == 0 else 'odd' %>
|
||||
|
@ -205,7 +243,15 @@ DOCUMENTATION :: END
|
|||
${child['title']}
|
||||
</span>
|
||||
</a>
|
||||
% if child['original_title']:
|
||||
% if media_type == 'collection':
|
||||
-
|
||||
<a href="${page('info', child['grandparent_rating_key'])}" title="${child['grandparent_title']}">
|
||||
<span class="thumb-tooltip" data-toggle="popover" data-img="${page('pms_image_proxy', child['grandparent_thumb'], child['grandparent_rating_key'], 300, 300, fallback='cover')}" data-height="80" data-width="80">
|
||||
${child['grandparent_title']}
|
||||
</span>
|
||||
</a>
|
||||
<span class="text-muted"> (<a class="no-highlight" href="${page('info', child['parent_rating_key'])}" title="${child['parent_title']}">${child['parent_title']}</a>)</span>
|
||||
% elif child['original_title']:
|
||||
<span class="text-muted"> - ${child['original_title']}</span>
|
||||
% endif
|
||||
</span>
|
||||
|
|
|
@ -32,7 +32,12 @@ collections_table_options = {
|
|||
if (rowData['smart']) {
|
||||
smart = '<span class="media-type-tooltip" data-toggle="tooltip" title="Smart Collection"><i class="fa fa-cog fa-fw"></i></span> '
|
||||
}
|
||||
console.log(rowData['subtype'])
|
||||
if (rowData['subtype'] === 'artist' || rowData['subtype'] === 'album' || rowData['subtype'] === 'track') {
|
||||
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 300, null, null, null, 'cover') + '" data-height="80" data-width="80">' + rowData['title'] + '</span>';
|
||||
} else {
|
||||
var thumb_popover = '<span class="thumb-tooltip" data-toggle="popover" data-img="' + page('pms_image_proxy', rowData['thumb'], rowData['ratingKey'], 300, 450, null, null, null, 'poster') + '" data-height="120" data-width="80">' + rowData['title'] + '</span>';
|
||||
}
|
||||
$(td).html(smart + '<a href="' + page('info', rowData['ratingKey']) + '">' + thumb_popover + '</a>');
|
||||
}
|
||||
},
|
||||
|
|
|
@ -142,8 +142,10 @@
|
|||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<select class="form-control" id="${item['name']}" name="${item['name']}">
|
||||
% if item['select_all']:
|
||||
<option value="select-all">Select All</option>
|
||||
<option value="remove-all">Remove All</option>
|
||||
% endif
|
||||
% if isinstance(item['select_options'], dict):
|
||||
% for section, options in item['select_options'].items():
|
||||
<optgroup label="${section}">
|
||||
|
@ -153,7 +155,9 @@
|
|||
</optgroup>
|
||||
% endfor
|
||||
% else:
|
||||
% if item['select_all']:
|
||||
<option value="border-all"></option>
|
||||
% endif
|
||||
% for option in sorted(item['select_options'], key=lambda x: x['text'].lower()):
|
||||
<option value="${option['value']}">${option['text']}</option>
|
||||
% endfor
|
||||
|
|
|
@ -134,8 +134,10 @@
|
|||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<select class="form-control" id="${item['name']}" name="${item['name']}">
|
||||
% if item['select_all']:
|
||||
<option value="select-all">Select All</option>
|
||||
<option value="remove-all">Remove All</option>
|
||||
% endif
|
||||
% if isinstance(item['select_options'], dict):
|
||||
% for section, options in item['select_options'].items():
|
||||
<optgroup label="${section}">
|
||||
|
@ -145,7 +147,9 @@
|
|||
</optgroup>
|
||||
% endfor
|
||||
% else:
|
||||
% if item['select_all']:
|
||||
<option value="border-all"></option>
|
||||
% endif
|
||||
% for option in sorted(item['select_options'], key=lambda x: x['text'].lower()):
|
||||
<option value="${option['value']}">${option['text']}</option>
|
||||
% endfor
|
||||
|
@ -719,6 +723,12 @@
|
|||
pushoverPriority();
|
||||
});
|
||||
|
||||
var $pushover_sound = $('#pushover_sound').selectize({
|
||||
create: true
|
||||
});
|
||||
var pushover_sound = $pushover_sound[0].selectize;
|
||||
pushover_sound.setValue(${json.dumps(next((c['value'] for c in notifier['config_options'] if c['name'] == 'pushover_sound'), [])) | n});
|
||||
|
||||
% elif notifier['agent_name'] == 'plexmobileapp':
|
||||
var $plexmobileapp_user_ids = $('#plexmobileapp_user_ids').selectize({
|
||||
plugins: ['remove_button'],
|
||||
|
|
|
@ -213,6 +213,20 @@
|
|||
</div>
|
||||
<p class="help-block">Set the percentage for a music track to be considered as listened. Minimum 50, Maximum 95.</p>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="music_watched_percent">Video Watched Completion Behaviour</label>
|
||||
<div class="row">
|
||||
<div class="col-md-7">
|
||||
<select class="form-control" id="watched_marker" name="watched_marker">
|
||||
<option value="0" ${'selected' if config['watched_marker'] == 0 else ''}>At selected threshold percentage</option>
|
||||
<option value="1" ${'selected' if config['watched_marker'] == 1 else ''}>At final credits marker position</option>
|
||||
<option value="2" ${'selected' if config['watched_marker'] == 2 else ''}>At first credits marker position</option>
|
||||
<option value="3" ${'selected' if config['watched_marker'] == 3 else ''}>Earliest between threshold percent and first credits marker</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<p class="help-block">Decide whether to use end credits markers to determine the 'watched' state of video items. When markers are not available the selected threshold percentage will be used.</p>
|
||||
</div>
|
||||
<div class="form-group advanced-setting">
|
||||
<label>Flush Temporary Sessions</label>
|
||||
<p class="help-block">
|
||||
|
|
|
@ -1,121 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
###############################################################################
|
||||
# Formatting filter for urllib2's HTTPHandler(debuglevel=1) output
|
||||
# Copyright (c) 2013, Analytics Pros
|
||||
#
|
||||
# This project is free software, distributed under the BSD license.
|
||||
# Analytics Pros offers consulting and integration services if your firm needs
|
||||
# assistance in strategy, implementation, or auditing existing work.
|
||||
###############################################################################
|
||||
|
||||
|
||||
import sys, re, os
|
||||
from io import StringIO
|
||||
|
||||
|
||||
|
||||
class BufferTranslator(object):
|
||||
""" Provides a buffer-compatible interface for filtering buffer content.
|
||||
"""
|
||||
parsers = []
|
||||
|
||||
def __init__(self, output):
|
||||
self.output = output
|
||||
self.encoding = getattr(output, 'encoding', None)
|
||||
|
||||
def write(self, content):
|
||||
content = self.translate(content)
|
||||
self.output.write(content)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def stripslashes(content):
|
||||
return content.decode('string_escape')
|
||||
|
||||
@staticmethod
|
||||
def addslashes(content):
|
||||
return content.encode('string_escape')
|
||||
|
||||
def translate(self, line):
|
||||
for pattern, method in self.parsers:
|
||||
match = pattern.match(line)
|
||||
if match:
|
||||
return method(match)
|
||||
|
||||
return line
|
||||
|
||||
|
||||
|
||||
class LineBufferTranslator(BufferTranslator):
|
||||
""" Line buffer implementation supports translation of line-format input
|
||||
even when input is not already line-buffered. Caches input until newlines
|
||||
occur, and then dispatches translated input to output buffer.
|
||||
"""
|
||||
def __init__(self, *a, **kw):
|
||||
self._linepending = []
|
||||
super(LineBufferTranslator, self).__init__(*a, **kw)
|
||||
|
||||
def write(self, _input):
|
||||
lines = _input.splitlines(True)
|
||||
for i in range(0, len(lines)):
|
||||
last = i
|
||||
if lines[i].endswith('\n'):
|
||||
prefix = len(self._linepending) and ''.join(self._linepending) or ''
|
||||
self.output.write(self.translate(prefix + lines[i]))
|
||||
del self._linepending[0:]
|
||||
last = -1
|
||||
|
||||
if last >= 0:
|
||||
self._linepending.append(lines[ last ])
|
||||
|
||||
|
||||
def __del__(self):
|
||||
if len(self._linepending):
|
||||
self.output.write(self.translate(''.join(self._linepending)))
|
||||
|
||||
|
||||
class HTTPTranslator(LineBufferTranslator):
|
||||
""" Translates output from |urllib2| HTTPHandler(debuglevel = 1) into
|
||||
HTTP-compatible, readible text structures for human analysis.
|
||||
"""
|
||||
|
||||
RE_LINE_PARSER = re.compile(r'^(?:([a-z]+):)\s*(\'?)([^\r\n]*)\2(?:[\r\n]*)$')
|
||||
RE_LINE_BREAK = re.compile(r'(\r?\n|(?:\\r)?\\n)')
|
||||
RE_HTTP_METHOD = re.compile(r'^(POST|GET|HEAD|DELETE|PUT|TRACE|OPTIONS)')
|
||||
RE_PARAMETER_SPACER = re.compile(r'&([a-z0-9]+)=')
|
||||
|
||||
@classmethod
|
||||
def spacer(cls, line):
|
||||
return cls.RE_PARAMETER_SPACER.sub(r' &\1= ', line)
|
||||
|
||||
def translate(self, line):
|
||||
|
||||
parsed = self.RE_LINE_PARSER.match(line)
|
||||
|
||||
if parsed:
|
||||
value = parsed.group(3)
|
||||
stage = parsed.group(1)
|
||||
|
||||
if stage == 'send': # query string is rendered here
|
||||
return '\n# HTTP Request:\n' + self.stripslashes(value)
|
||||
elif stage == 'reply':
|
||||
return '\n\n# HTTP Response:\n' + self.stripslashes(value)
|
||||
elif stage == 'header':
|
||||
return value + '\n'
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def consume(outbuffer = None): # Capture standard output
|
||||
sys.stdout = HTTPTranslator(outbuffer or sys.stdout)
|
||||
return sys.stdout
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
consume(sys.stdout).write(sys.stdin.read())
|
||||
print('\n')
|
||||
|
||||
# vim: set nowrap tabstop=4 shiftwidth=4 softtabstop=0 expandtab textwidth=0 filetype=python foldmethod=indent foldcolumn=4
|
|
@ -1,424 +0,0 @@
|
|||
from future.moves.urllib.request import urlopen, build_opener, install_opener
|
||||
from future.moves.urllib.request import Request, HTTPSHandler
|
||||
from future.moves.urllib.error import URLError, HTTPError
|
||||
from future.moves.urllib.parse import urlencode
|
||||
|
||||
import random
|
||||
import datetime
|
||||
import time
|
||||
import uuid
|
||||
import hashlib
|
||||
import socket
|
||||
|
||||
|
||||
def generate_uuid(basedata=None):
|
||||
""" Provides a _random_ UUID with no input, or a UUID4-format MD5 checksum of any input data provided """
|
||||
if basedata is None:
|
||||
return str(uuid.uuid4())
|
||||
elif isinstance(basedata, str):
|
||||
checksum = hashlib.md5(str(basedata).encode('utf-8')).hexdigest()
|
||||
return '%8s-%4s-%4s-%4s-%12s' % (
|
||||
checksum[0:8], checksum[8:12], checksum[12:16], checksum[16:20], checksum[20:32])
|
||||
|
||||
|
||||
class Time(datetime.datetime):
|
||||
""" Wrappers and convenience methods for processing various time representations """
|
||||
|
||||
@classmethod
|
||||
def from_unix(cls, seconds, milliseconds=0):
|
||||
""" Produce a full |datetime.datetime| object from a Unix timestamp """
|
||||
base = list(time.gmtime(seconds))[0:6]
|
||||
base.append(milliseconds * 1000) # microseconds
|
||||
return cls(*base)
|
||||
|
||||
@classmethod
|
||||
def to_unix(cls, timestamp):
|
||||
""" Wrapper over time module to produce Unix epoch time as a float """
|
||||
if not isinstance(timestamp, datetime.datetime):
|
||||
raise TypeError('Time.milliseconds expects a datetime object')
|
||||
base = time.mktime(timestamp.timetuple())
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def milliseconds_offset(cls, timestamp, now=None):
|
||||
""" Offset time (in milliseconds) from a |datetime.datetime| object to now """
|
||||
if isinstance(timestamp, (int, float)):
|
||||
base = timestamp
|
||||
else:
|
||||
base = cls.to_unix(timestamp)
|
||||
base = base + (timestamp.microsecond / 1000000)
|
||||
if now is None:
|
||||
now = time.time()
|
||||
return (now - base) * 1000
|
||||
|
||||
|
||||
class HTTPRequest(object):
|
||||
""" URL Construction and request handling abstraction.
|
||||
This is not intended to be used outside this module.
|
||||
|
||||
Automates mapping of persistent state (i.e. query parameters)
|
||||
onto transcient datasets for each query.
|
||||
"""
|
||||
|
||||
endpoint = 'https://www.google-analytics.com/collect'
|
||||
|
||||
@staticmethod
|
||||
def debug():
|
||||
""" Activate debugging on urllib2 """
|
||||
handler = HTTPSHandler(debuglevel=1)
|
||||
opener = build_opener(handler)
|
||||
install_opener(opener)
|
||||
|
||||
# Store properties for all requests
|
||||
def __init__(self, user_agent=None, *args, **opts):
|
||||
self.user_agent = user_agent or 'Analytics Pros - Universal Analytics (Python)'
|
||||
|
||||
@classmethod
|
||||
def fixUTF8(cls, data): # Ensure proper encoding for UA's servers...
|
||||
""" Convert all strings to UTF-8 """
|
||||
for key in data:
|
||||
if isinstance(data[key], str):
|
||||
data[key] = data[key].encode('utf-8')
|
||||
return data
|
||||
|
||||
# Apply stored properties to the given dataset & POST to the configured endpoint
|
||||
def send(self, data):
|
||||
request = Request(
|
||||
self.endpoint + '?' + urlencode(self.fixUTF8(data)).encode('utf-8'),
|
||||
headers={
|
||||
'User-Agent': self.user_agent
|
||||
}
|
||||
)
|
||||
self.open(request)
|
||||
|
||||
def open(self, request):
|
||||
try:
|
||||
return urlopen(request)
|
||||
except HTTPError as e:
|
||||
return False
|
||||
except URLError as e:
|
||||
self.cache_request(request)
|
||||
return False
|
||||
|
||||
def cache_request(self, request):
|
||||
# TODO: implement a proper caching mechanism here for re-transmitting hits
|
||||
# record = (Time.now(), request.get_full_url(), request.get_data(), request.headers)
|
||||
pass
|
||||
|
||||
|
||||
class HTTPPost(HTTPRequest):
|
||||
|
||||
# Apply stored properties to the given dataset & POST to the configured endpoint
|
||||
def send(self, data):
|
||||
request = Request(
|
||||
self.endpoint,
|
||||
data=urlencode(self.fixUTF8(data)).encode('utf-8'),
|
||||
headers={
|
||||
'User-Agent': self.user_agent
|
||||
}
|
||||
)
|
||||
self.open(request)
|
||||
|
||||
|
||||
class Tracker(object):
|
||||
""" Primary tracking interface for Universal Analytics """
|
||||
params = None
|
||||
parameter_alias = {}
|
||||
valid_hittypes = ('pageview', 'event', 'social', 'screenview', 'transaction', 'item', 'exception', 'timing')
|
||||
|
||||
@classmethod
|
||||
def alias(cls, typemap, base, *names):
|
||||
""" Declare an alternate (humane) name for a measurement protocol parameter """
|
||||
cls.parameter_alias[base] = (typemap, base)
|
||||
for i in names:
|
||||
cls.parameter_alias[i] = (typemap, base)
|
||||
|
||||
@classmethod
|
||||
def coerceParameter(cls, name, value=None):
|
||||
if isinstance(name, str) and name[0] == '&':
|
||||
return name[1:], str(value)
|
||||
elif name in cls.parameter_alias:
|
||||
typecast, param_name = cls.parameter_alias.get(name)
|
||||
return param_name, typecast(value)
|
||||
else:
|
||||
raise KeyError('Parameter "{0}" is not recognized'.format(name))
|
||||
|
||||
def payload(self, data):
|
||||
for key, value in data.items():
|
||||
try:
|
||||
yield self.coerceParameter(key, value)
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
option_sequence = {
|
||||
'pageview': [(str, 'dp')],
|
||||
'event': [(str, 'ec'), (str, 'ea'), (str, 'el'), (int, 'ev')],
|
||||
'social': [(str, 'sn'), (str, 'sa'), (str, 'st')],
|
||||
'timing': [(str, 'utc'), (str, 'utv'), (str, 'utt'), (str, 'utl')]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def consume_options(cls, data, hittype, args):
|
||||
""" Interpret sequential arguments related to known hittypes based on declared structures """
|
||||
opt_position = 0
|
||||
data['t'] = hittype # integrate hit type parameter
|
||||
if hittype in cls.option_sequence:
|
||||
for expected_type, optname in cls.option_sequence[hittype]:
|
||||
if opt_position < len(args) and isinstance(args[opt_position], expected_type):
|
||||
data[optname] = args[opt_position]
|
||||
opt_position += 1
|
||||
|
||||
@classmethod
|
||||
def hittime(cls, timestamp=None, age=None, milliseconds=None):
|
||||
""" Returns an integer represeting the milliseconds offset for a given hit (relative to now) """
|
||||
if isinstance(timestamp, (int, float)):
|
||||
return int(Time.milliseconds_offset(Time.from_unix(timestamp, milliseconds=milliseconds)))
|
||||
if isinstance(timestamp, datetime.datetime):
|
||||
return int(Time.milliseconds_offset(timestamp))
|
||||
if isinstance(age, (int, float)):
|
||||
return int(age * 1000) + (milliseconds or 0)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
return self.params.get('tid', None)
|
||||
|
||||
def __init__(self, account, name=None, client_id=None, hash_client_id=False, user_id=None, user_agent=None,
|
||||
use_post=True):
|
||||
|
||||
if use_post is False:
|
||||
self.http = HTTPRequest(user_agent=user_agent)
|
||||
else:
|
||||
self.http = HTTPPost(user_agent=user_agent)
|
||||
|
||||
self.params = {'v': 1, 'tid': account}
|
||||
|
||||
if client_id is None:
|
||||
client_id = generate_uuid()
|
||||
|
||||
self.params['cid'] = client_id
|
||||
|
||||
self.hash_client_id = hash_client_id
|
||||
|
||||
if user_id is not None:
|
||||
self.params['uid'] = user_id
|
||||
|
||||
def set_timestamp(self, data):
|
||||
""" Interpret time-related options, apply queue-time parameter as needed """
|
||||
if 'hittime' in data: # an absolute timestamp
|
||||
data['qt'] = self.hittime(timestamp=data.pop('hittime', None))
|
||||
if 'hitage' in data: # a relative age (in seconds)
|
||||
data['qt'] = self.hittime(age=data.pop('hitage', None))
|
||||
|
||||
def send(self, hittype, *args, **data):
|
||||
""" Transmit HTTP requests to Google Analytics using the measurement protocol """
|
||||
|
||||
if hittype not in self.valid_hittypes:
|
||||
raise KeyError('Unsupported Universal Analytics Hit Type: {0}'.format(repr(hittype)))
|
||||
|
||||
self.set_timestamp(data)
|
||||
self.consume_options(data, hittype, args)
|
||||
|
||||
for item in args: # process dictionary-object arguments of transcient data
|
||||
if isinstance(item, dict):
|
||||
for key, val in self.payload(item):
|
||||
data[key] = val
|
||||
|
||||
for k, v in self.params.items(): # update only absent parameters
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
|
||||
data = dict(self.payload(data))
|
||||
|
||||
if self.hash_client_id:
|
||||
data['cid'] = generate_uuid(data['cid'])
|
||||
|
||||
# Transmit the hit to Google...
|
||||
self.http.send(data)
|
||||
|
||||
# Setting persistent attibutes of the session/hit/etc (inc. custom dimensions/metrics)
|
||||
def set(self, name, value=None):
|
||||
if isinstance(name, dict):
|
||||
for key, value in name.items():
|
||||
try:
|
||||
param, value = self.coerceParameter(key, value)
|
||||
self.params[param] = value
|
||||
except KeyError:
|
||||
pass
|
||||
elif isinstance(name, str):
|
||||
try:
|
||||
param, value = self.coerceParameter(name, value)
|
||||
self.params[param] = value
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __getitem__(self, name):
|
||||
param, value = self.coerceParameter(name, None)
|
||||
return self.params.get(param, None)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
param, value = self.coerceParameter(name, value)
|
||||
self.params[param] = value
|
||||
|
||||
def __delitem__(self, name):
|
||||
param, value = self.coerceParameter(name, None)
|
||||
if param in self.params:
|
||||
del self.params[param]
|
||||
|
||||
|
||||
def safe_unicode(obj):
|
||||
""" Safe convertion to the Unicode string version of the object """
|
||||
try:
|
||||
return str(obj)
|
||||
except UnicodeDecodeError:
|
||||
return obj.decode('utf-8')
|
||||
|
||||
|
||||
# Declaring name mappings for Measurement Protocol parameters
|
||||
MAX_CUSTOM_DEFINITIONS = 200
|
||||
MAX_EC_LISTS = 11 # 1-based index
|
||||
MAX_EC_PRODUCTS = 11 # 1-based index
|
||||
MAX_EC_PROMOTIONS = 11 # 1-based index
|
||||
|
||||
Tracker.alias(int, 'v', 'protocol-version')
|
||||
Tracker.alias(safe_unicode, 'cid', 'client-id', 'clientId', 'clientid')
|
||||
Tracker.alias(safe_unicode, 'tid', 'trackingId', 'account')
|
||||
Tracker.alias(safe_unicode, 'uid', 'user-id', 'userId', 'userid')
|
||||
Tracker.alias(safe_unicode, 'uip', 'user-ip', 'userIp', 'ipaddr')
|
||||
Tracker.alias(safe_unicode, 'ua', 'userAgent', 'userAgentOverride', 'user-agent')
|
||||
Tracker.alias(safe_unicode, 'dp', 'page', 'path')
|
||||
Tracker.alias(safe_unicode, 'dt', 'title', 'pagetitle', 'pageTitle' 'page-title')
|
||||
Tracker.alias(safe_unicode, 'dl', 'location')
|
||||
Tracker.alias(safe_unicode, 'dh', 'hostname')
|
||||
Tracker.alias(safe_unicode, 'sc', 'sessioncontrol', 'session-control', 'sessionControl')
|
||||
Tracker.alias(safe_unicode, 'dr', 'referrer', 'referer')
|
||||
Tracker.alias(int, 'qt', 'queueTime', 'queue-time')
|
||||
Tracker.alias(safe_unicode, 't', 'hitType', 'hittype')
|
||||
Tracker.alias(int, 'aip', 'anonymizeIp', 'anonIp', 'anonymize-ip')
|
||||
Tracker.alias(safe_unicode, 'ds', 'dataSource', 'data-source')
|
||||
|
||||
# Campaign attribution
|
||||
Tracker.alias(safe_unicode, 'cn', 'campaign', 'campaignName', 'campaign-name')
|
||||
Tracker.alias(safe_unicode, 'cs', 'source', 'campaignSource', 'campaign-source')
|
||||
Tracker.alias(safe_unicode, 'cm', 'medium', 'campaignMedium', 'campaign-medium')
|
||||
Tracker.alias(safe_unicode, 'ck', 'keyword', 'campaignKeyword', 'campaign-keyword')
|
||||
Tracker.alias(safe_unicode, 'cc', 'content', 'campaignContent', 'campaign-content')
|
||||
Tracker.alias(safe_unicode, 'ci', 'campaignId', 'campaignID', 'campaign-id')
|
||||
|
||||
# Technical specs
|
||||
Tracker.alias(safe_unicode, 'sr', 'screenResolution', 'screen-resolution', 'resolution')
|
||||
Tracker.alias(safe_unicode, 'vp', 'viewport', 'viewportSize', 'viewport-size')
|
||||
Tracker.alias(safe_unicode, 'de', 'encoding', 'documentEncoding', 'document-encoding')
|
||||
Tracker.alias(int, 'sd', 'colors', 'screenColors', 'screen-colors')
|
||||
Tracker.alias(safe_unicode, 'ul', 'language', 'user-language', 'userLanguage')
|
||||
|
||||
# Mobile app
|
||||
Tracker.alias(safe_unicode, 'an', 'appName', 'app-name', 'app')
|
||||
Tracker.alias(safe_unicode, 'cd', 'contentDescription', 'screenName', 'screen-name', 'content-description')
|
||||
Tracker.alias(safe_unicode, 'av', 'appVersion', 'app-version', 'version')
|
||||
Tracker.alias(safe_unicode, 'aid', 'appID', 'appId', 'application-id', 'app-id', 'applicationId')
|
||||
Tracker.alias(safe_unicode, 'aiid', 'appInstallerId', 'app-installer-id')
|
||||
|
||||
# Ecommerce
|
||||
Tracker.alias(safe_unicode, 'ta', 'affiliation', 'transactionAffiliation', 'transaction-affiliation')
|
||||
Tracker.alias(safe_unicode, 'ti', 'transaction', 'transactionId', 'transaction-id')
|
||||
Tracker.alias(float, 'tr', 'revenue', 'transactionRevenue', 'transaction-revenue')
|
||||
Tracker.alias(float, 'ts', 'shipping', 'transactionShipping', 'transaction-shipping')
|
||||
Tracker.alias(float, 'tt', 'tax', 'transactionTax', 'transaction-tax')
|
||||
Tracker.alias(safe_unicode, 'cu', 'currency', 'transactionCurrency',
|
||||
'transaction-currency') # Currency code, e.g. USD, EUR
|
||||
Tracker.alias(safe_unicode, 'in', 'item-name', 'itemName')
|
||||
Tracker.alias(float, 'ip', 'item-price', 'itemPrice')
|
||||
Tracker.alias(float, 'iq', 'item-quantity', 'itemQuantity')
|
||||
Tracker.alias(safe_unicode, 'ic', 'item-code', 'sku', 'itemCode')
|
||||
Tracker.alias(safe_unicode, 'iv', 'item-variation', 'item-category', 'itemCategory', 'itemVariation')
|
||||
|
||||
# Events
|
||||
Tracker.alias(safe_unicode, 'ec', 'event-category', 'eventCategory', 'category')
|
||||
Tracker.alias(safe_unicode, 'ea', 'event-action', 'eventAction', 'action')
|
||||
Tracker.alias(safe_unicode, 'el', 'event-label', 'eventLabel', 'label')
|
||||
Tracker.alias(int, 'ev', 'event-value', 'eventValue', 'value')
|
||||
Tracker.alias(int, 'ni', 'noninteractive', 'nonInteractive', 'noninteraction', 'nonInteraction')
|
||||
|
||||
# Social
|
||||
Tracker.alias(safe_unicode, 'sa', 'social-action', 'socialAction')
|
||||
Tracker.alias(safe_unicode, 'sn', 'social-network', 'socialNetwork')
|
||||
Tracker.alias(safe_unicode, 'st', 'social-target', 'socialTarget')
|
||||
|
||||
# Exceptions
|
||||
Tracker.alias(safe_unicode, 'exd', 'exception-description', 'exceptionDescription', 'exDescription')
|
||||
Tracker.alias(int, 'exf', 'exception-fatal', 'exceptionFatal', 'exFatal')
|
||||
|
||||
# User Timing
|
||||
Tracker.alias(safe_unicode, 'utc', 'timingCategory', 'timing-category')
|
||||
Tracker.alias(safe_unicode, 'utv', 'timingVariable', 'timing-variable')
|
||||
Tracker.alias(float, 'utt', 'time', 'timingTime', 'timing-time')
|
||||
Tracker.alias(safe_unicode, 'utl', 'timingLabel', 'timing-label')
|
||||
Tracker.alias(float, 'dns', 'timingDNS', 'timing-dns')
|
||||
Tracker.alias(float, 'pdt', 'timingPageLoad', 'timing-page-load')
|
||||
Tracker.alias(float, 'rrt', 'timingRedirect', 'timing-redirect')
|
||||
Tracker.alias(safe_unicode, 'tcp', 'timingTCPConnect', 'timing-tcp-connect')
|
||||
Tracker.alias(safe_unicode, 'srt', 'timingServerResponse', 'timing-server-response')
|
||||
|
||||
# Custom dimensions and metrics
|
||||
for i in range(0, 200):
|
||||
Tracker.alias(safe_unicode, 'cd{0}'.format(i), 'dimension{0}'.format(i))
|
||||
Tracker.alias(int, 'cm{0}'.format(i), 'metric{0}'.format(i))
|
||||
|
||||
# Content groups
|
||||
for i in range(0, 5):
|
||||
Tracker.alias(safe_unicode, 'cg{0}'.format(i), 'contentGroup{0}'.format(i))
|
||||
|
||||
# Enhanced Ecommerce
|
||||
Tracker.alias(str, 'pa') # Product action
|
||||
Tracker.alias(str, 'tcc') # Coupon code
|
||||
Tracker.alias(str, 'pal') # Product action list
|
||||
Tracker.alias(int, 'cos') # Checkout step
|
||||
Tracker.alias(str, 'col') # Checkout step option
|
||||
|
||||
Tracker.alias(str, 'promoa') # Promotion action
|
||||
|
||||
for product_index in range(1, MAX_EC_PRODUCTS):
|
||||
Tracker.alias(str, 'pr{0}id'.format(product_index)) # Product SKU
|
||||
Tracker.alias(str, 'pr{0}nm'.format(product_index)) # Product name
|
||||
Tracker.alias(str, 'pr{0}br'.format(product_index)) # Product brand
|
||||
Tracker.alias(str, 'pr{0}ca'.format(product_index)) # Product category
|
||||
Tracker.alias(str, 'pr{0}va'.format(product_index)) # Product variant
|
||||
Tracker.alias(str, 'pr{0}pr'.format(product_index)) # Product price
|
||||
Tracker.alias(int, 'pr{0}qt'.format(product_index)) # Product quantity
|
||||
Tracker.alias(str, 'pr{0}cc'.format(product_index)) # Product coupon code
|
||||
Tracker.alias(int, 'pr{0}ps'.format(product_index)) # Product position
|
||||
|
||||
for custom_index in range(MAX_CUSTOM_DEFINITIONS):
|
||||
Tracker.alias(str, 'pr{0}cd{1}'.format(product_index, custom_index)) # Product custom dimension
|
||||
Tracker.alias(int, 'pr{0}cm{1}'.format(product_index, custom_index)) # Product custom metric
|
||||
|
||||
for list_index in range(1, MAX_EC_LISTS):
|
||||
Tracker.alias(str, 'il{0}pi{1}id'.format(list_index, product_index)) # Product impression SKU
|
||||
Tracker.alias(str, 'il{0}pi{1}nm'.format(list_index, product_index)) # Product impression name
|
||||
Tracker.alias(str, 'il{0}pi{1}br'.format(list_index, product_index)) # Product impression brand
|
||||
Tracker.alias(str, 'il{0}pi{1}ca'.format(list_index, product_index)) # Product impression category
|
||||
Tracker.alias(str, 'il{0}pi{1}va'.format(list_index, product_index)) # Product impression variant
|
||||
Tracker.alias(int, 'il{0}pi{1}ps'.format(list_index, product_index)) # Product impression position
|
||||
Tracker.alias(int, 'il{0}pi{1}pr'.format(list_index, product_index)) # Product impression price
|
||||
|
||||
for custom_index in range(MAX_CUSTOM_DEFINITIONS):
|
||||
Tracker.alias(str, 'il{0}pi{1}cd{2}'.format(list_index, product_index,
|
||||
custom_index)) # Product impression custom dimension
|
||||
Tracker.alias(int, 'il{0}pi{1}cm{2}'.format(list_index, product_index,
|
||||
custom_index)) # Product impression custom metric
|
||||
|
||||
for list_index in range(1, MAX_EC_LISTS):
|
||||
Tracker.alias(str, 'il{0}nm'.format(list_index)) # Product impression list name
|
||||
|
||||
for promotion_index in range(1, MAX_EC_PROMOTIONS):
|
||||
Tracker.alias(str, 'promo{0}id'.format(promotion_index)) # Promotion ID
|
||||
Tracker.alias(str, 'promo{0}nm'.format(promotion_index)) # Promotion name
|
||||
Tracker.alias(str, 'promo{0}cr'.format(promotion_index)) # Promotion creative
|
||||
Tracker.alias(str, 'promo{0}ps'.format(promotion_index)) # Promotion position
|
||||
|
||||
|
||||
# Shortcut for creating trackers
|
||||
def create(account, *args, **kwargs):
|
||||
return Tracker(account, *args, **kwargs)
|
||||
|
||||
# vim: set nowrap tabstop=4 shiftwidth=4 softtabstop=0 expandtab textwidth=0 filetype=python foldmethod=indent foldcolumn=4
|
|
@ -1 +0,0 @@
|
|||
from . import Tracker
|
|
@ -11,9 +11,9 @@ from bleach.sanitizer import (
|
|||
|
||||
|
||||
# yyyymmdd
|
||||
__releasedate__ = "20220627"
|
||||
__releasedate__ = "20230123"
|
||||
# x.y.z or x.y.z.dev0 -- semver
|
||||
__version__ = "5.0.1"
|
||||
__version__ = "6.0.0"
|
||||
|
||||
|
||||
__all__ = ["clean", "linkify"]
|
||||
|
@ -52,7 +52,7 @@ def clean(
|
|||
|
||||
:arg str text: the text to clean
|
||||
|
||||
:arg list tags: allowed list of tags; defaults to
|
||||
:arg set tags: set of allowed tags; defaults to
|
||||
``bleach.sanitizer.ALLOWED_TAGS``
|
||||
|
||||
:arg dict attributes: allowed attributes; can be a callable, list or dict;
|
||||
|
|
|
@ -38,6 +38,9 @@ from bleach._vendor.html5lib.filters.sanitizer import (
|
|||
allowed_protocols,
|
||||
allowed_css_properties,
|
||||
allowed_svg_properties,
|
||||
attr_val_is_uri,
|
||||
svg_attr_val_allows_ref,
|
||||
svg_allow_local_href,
|
||||
) # noqa: E402 module level import not at top of file
|
||||
from bleach._vendor.html5lib.filters.sanitizer import (
|
||||
Filter as SanitizerFilter,
|
||||
|
@ -78,7 +81,8 @@ TAG_TOKEN_TYPE_PARSEERROR = constants.tokenTypes["ParseError"]
|
|||
|
||||
#: List of valid HTML tags, from WHATWG HTML Living Standard as of 2018-10-17
|
||||
#: https://html.spec.whatwg.org/multipage/indices.html#elements-3
|
||||
HTML_TAGS = [
|
||||
HTML_TAGS = frozenset(
|
||||
(
|
||||
"a",
|
||||
"abbr",
|
||||
"address",
|
||||
|
@ -191,14 +195,15 @@ HTML_TAGS = [
|
|||
"var",
|
||||
"video",
|
||||
"wbr",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
#: List of block level HTML tags, as per https://github.com/mozilla/bleach/issues/369
|
||||
#: from mozilla on 2019.07.11
|
||||
#: https://developer.mozilla.org/en-US/docs/Web/HTML/Block-level_elements#Elements
|
||||
HTML_TAGS_BLOCK_LEVEL = frozenset(
|
||||
[
|
||||
(
|
||||
"address",
|
||||
"article",
|
||||
"aside",
|
||||
|
@ -232,7 +237,7 @@ HTML_TAGS_BLOCK_LEVEL = frozenset(
|
|||
"section",
|
||||
"table",
|
||||
"ul",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -473,7 +478,7 @@ class BleachHTMLParser(HTMLParser):
|
|||
|
||||
def __init__(self, tags, strip, consume_entities, **kwargs):
|
||||
"""
|
||||
:arg tags: list of allowed tags--everything else is either stripped or
|
||||
:arg tags: set of allowed tags--everything else is either stripped or
|
||||
escaped; if None, then this doesn't look at tags at all
|
||||
:arg strip: whether to strip disallowed tags (True) or escape them (False);
|
||||
if tags=None, then this doesn't have any effect
|
||||
|
@ -481,7 +486,9 @@ class BleachHTMLParser(HTMLParser):
|
|||
leave them as is when tokenizing (BleachHTMLTokenizer-added behavior)
|
||||
|
||||
"""
|
||||
self.tags = [tag.lower() for tag in tags] if tags is not None else None
|
||||
self.tags = (
|
||||
frozenset((tag.lower() for tag in tags)) if tags is not None else None
|
||||
)
|
||||
self.strip = strip
|
||||
self.consume_entities = consume_entities
|
||||
super().__init__(**kwargs)
|
||||
|
@ -691,7 +698,7 @@ class BleachHTMLSerializer(HTMLSerializer):
|
|||
# Only leave entities in that are not ambiguous. If they're
|
||||
# ambiguous, then we escape the ampersand.
|
||||
if entity is not None and convert_entity(entity) is not None:
|
||||
yield "&" + entity + ";"
|
||||
yield f"&{entity};"
|
||||
|
||||
# Length of the entity plus 2--one for & at the beginning
|
||||
# and one for ; at the end
|
||||
|
|
|
@ -120,9 +120,10 @@ class Linker:
|
|||
:arg list callbacks: list of callbacks to run when adjusting tag attributes;
|
||||
defaults to ``bleach.linkifier.DEFAULT_CALLBACKS``
|
||||
|
||||
:arg list skip_tags: list of tags that you don't want to linkify the
|
||||
contents of; for example, you could set this to ``['pre']`` to skip
|
||||
linkifying contents of ``pre`` tags
|
||||
:arg set skip_tags: set of tags that you don't want to linkify the
|
||||
contents of; for example, you could set this to ``{'pre'}`` to skip
|
||||
linkifying contents of ``pre`` tags; ``None`` means you don't
|
||||
want linkify to skip any tags
|
||||
|
||||
:arg bool parse_email: whether or not to linkify email addresses
|
||||
|
||||
|
@ -130,7 +131,7 @@ class Linker:
|
|||
|
||||
:arg email_re: email matching regex
|
||||
|
||||
:arg list recognized_tags: the list of tags that linkify knows about;
|
||||
:arg set recognized_tags: the set of tags that linkify knows about;
|
||||
everything else gets escaped
|
||||
|
||||
:returns: linkified text as unicode
|
||||
|
@ -145,15 +146,18 @@ class Linker:
|
|||
# Create a parser/tokenizer that allows all HTML tags and escapes
|
||||
# anything not in that list.
|
||||
self.parser = html5lib_shim.BleachHTMLParser(
|
||||
tags=recognized_tags,
|
||||
tags=frozenset(recognized_tags),
|
||||
strip=False,
|
||||
consume_entities=True,
|
||||
consume_entities=False,
|
||||
namespaceHTMLElements=False,
|
||||
)
|
||||
self.walker = html5lib_shim.getTreeWalker("etree")
|
||||
self.serializer = html5lib_shim.BleachHTMLSerializer(
|
||||
quote_attr_values="always",
|
||||
omit_optional_tags=False,
|
||||
# We want to leave entities as they are without escaping or
|
||||
# resolving or expanding
|
||||
resolve_entities=False,
|
||||
# linkify does not sanitize
|
||||
sanitize=False,
|
||||
# linkify preserves attr order
|
||||
|
@ -218,8 +222,8 @@ class LinkifyFilter(html5lib_shim.Filter):
|
|||
:arg list callbacks: list of callbacks to run when adjusting tag attributes;
|
||||
defaults to ``bleach.linkifier.DEFAULT_CALLBACKS``
|
||||
|
||||
:arg list skip_tags: list of tags that you don't want to linkify the
|
||||
contents of; for example, you could set this to ``['pre']`` to skip
|
||||
:arg set skip_tags: set of tags that you don't want to linkify the
|
||||
contents of; for example, you could set this to ``{'pre'}`` to skip
|
||||
linkifying contents of ``pre`` tags
|
||||
|
||||
:arg bool parse_email: whether or not to linkify email addresses
|
||||
|
@ -232,7 +236,7 @@ class LinkifyFilter(html5lib_shim.Filter):
|
|||
super().__init__(source)
|
||||
|
||||
self.callbacks = callbacks or []
|
||||
self.skip_tags = skip_tags or []
|
||||
self.skip_tags = skip_tags or {}
|
||||
self.parse_email = parse_email
|
||||
|
||||
self.url_re = url_re
|
||||
|
@ -510,6 +514,62 @@ class LinkifyFilter(html5lib_shim.Filter):
|
|||
yield {"type": "Characters", "data": str(new_text)}
|
||||
yield token_buffer[-1]
|
||||
|
||||
def extract_entities(self, token):
|
||||
"""Handles Characters tokens with entities
|
||||
|
||||
Our overridden tokenizer doesn't do anything with entities. However,
|
||||
that means that the serializer will convert all ``&`` in Characters
|
||||
tokens to ``&``.
|
||||
|
||||
Since we don't want that, we extract entities here and convert them to
|
||||
Entity tokens so the serializer will let them be.
|
||||
|
||||
:arg token: the Characters token to work on
|
||||
|
||||
:returns: generator of tokens
|
||||
|
||||
"""
|
||||
data = token.get("data", "")
|
||||
|
||||
# If there isn't a & in the data, we can return now
|
||||
if "&" not in data:
|
||||
yield token
|
||||
return
|
||||
|
||||
new_tokens = []
|
||||
|
||||
# For each possible entity that starts with a "&", we try to extract an
|
||||
# actual entity and re-tokenize accordingly
|
||||
for part in html5lib_shim.next_possible_entity(data):
|
||||
if not part:
|
||||
continue
|
||||
|
||||
if part.startswith("&"):
|
||||
entity = html5lib_shim.match_entity(part)
|
||||
if entity is not None:
|
||||
if entity == "amp":
|
||||
# LinkifyFilter can't match urls across token boundaries
|
||||
# which is problematic with & since that shows up in
|
||||
# querystrings all the time. This special-cases &
|
||||
# and converts it to a & and sticks it in as a
|
||||
# Characters token. It'll get merged with surrounding
|
||||
# tokens in the BleachSanitizerfilter.__iter__ and
|
||||
# escaped in the serializer.
|
||||
new_tokens.append({"type": "Characters", "data": "&"})
|
||||
else:
|
||||
new_tokens.append({"type": "Entity", "name": entity})
|
||||
|
||||
# Length of the entity plus 2--one for & at the beginning
|
||||
# and one for ; at the end
|
||||
remainder = part[len(entity) + 2 :]
|
||||
if remainder:
|
||||
new_tokens.append({"type": "Characters", "data": remainder})
|
||||
continue
|
||||
|
||||
new_tokens.append({"type": "Characters", "data": part})
|
||||
|
||||
yield from new_tokens
|
||||
|
||||
def __iter__(self):
|
||||
in_a = False
|
||||
in_skip_tag = None
|
||||
|
@ -564,8 +624,8 @@ class LinkifyFilter(html5lib_shim.Filter):
|
|||
|
||||
new_stream = self.handle_links(new_stream)
|
||||
|
||||
for token in new_stream:
|
||||
yield token
|
||||
for new_token in new_stream:
|
||||
yield from self.extract_entities(new_token)
|
||||
|
||||
# We've already yielded this token, so continue
|
||||
continue
|
||||
|
|
|
@ -8,8 +8,9 @@ from bleach import html5lib_shim
|
|||
from bleach import parse_shim
|
||||
|
||||
|
||||
#: List of allowed tags
|
||||
ALLOWED_TAGS = [
|
||||
#: Set of allowed tags
|
||||
ALLOWED_TAGS = frozenset(
|
||||
(
|
||||
"a",
|
||||
"abbr",
|
||||
"acronym",
|
||||
|
@ -22,7 +23,8 @@ ALLOWED_TAGS = [
|
|||
"ol",
|
||||
"strong",
|
||||
"ul",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
#: Map of allowed attributes by tag
|
||||
|
@ -33,7 +35,7 @@ ALLOWED_ATTRIBUTES = {
|
|||
}
|
||||
|
||||
#: List of allowed protocols
|
||||
ALLOWED_PROTOCOLS = ["http", "https", "mailto"]
|
||||
ALLOWED_PROTOCOLS = frozenset(("http", "https", "mailto"))
|
||||
|
||||
#: Invisible characters--0 to and including 31 except 9 (tab), 10 (lf), and 13 (cr)
|
||||
INVISIBLE_CHARACTERS = "".join(
|
||||
|
@ -48,6 +50,10 @@ INVISIBLE_CHARACTERS_RE = re.compile("[" + INVISIBLE_CHARACTERS + "]", re.UNICOD
|
|||
INVISIBLE_REPLACEMENT_CHAR = "?"
|
||||
|
||||
|
||||
class NoCssSanitizerWarning(UserWarning):
|
||||
pass
|
||||
|
||||
|
||||
class Cleaner:
|
||||
"""Cleaner for cleaning HTML fragments of malicious content
|
||||
|
||||
|
@ -89,7 +95,7 @@ class Cleaner:
|
|||
):
|
||||
"""Initializes a Cleaner
|
||||
|
||||
:arg list tags: allowed list of tags; defaults to
|
||||
:arg set tags: set of allowed tags; defaults to
|
||||
``bleach.sanitizer.ALLOWED_TAGS``
|
||||
|
||||
:arg dict attributes: allowed attributes; can be a callable, list or dict;
|
||||
|
@ -143,6 +149,25 @@ class Cleaner:
|
|||
alphabetical_attributes=False,
|
||||
)
|
||||
|
||||
if css_sanitizer is None:
|
||||
# FIXME(willkg): this doesn't handle when attributes or an
|
||||
# attributes value is a callable
|
||||
attributes_values = []
|
||||
if isinstance(attributes, list):
|
||||
attributes_values = attributes
|
||||
|
||||
elif isinstance(attributes, dict):
|
||||
attributes_values = []
|
||||
for values in attributes.values():
|
||||
if isinstance(values, (list, tuple)):
|
||||
attributes_values.extend(values)
|
||||
|
||||
if "style" in attributes_values:
|
||||
warnings.warn(
|
||||
"'style' attribute specified, but css_sanitizer not set.",
|
||||
category=NoCssSanitizerWarning,
|
||||
)
|
||||
|
||||
def clean(self, text):
|
||||
"""Cleans text and returns sanitized result as unicode
|
||||
|
||||
|
@ -155,9 +180,8 @@ class Cleaner:
|
|||
"""
|
||||
if not isinstance(text, str):
|
||||
message = (
|
||||
"argument cannot be of '{name}' type, must be of text type".format(
|
||||
name=text.__class__.__name__
|
||||
)
|
||||
f"argument cannot be of {text.__class__.__name__!r} type, "
|
||||
+ "must be of text type"
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
|
@ -167,13 +191,11 @@ class Cleaner:
|
|||
dom = self.parser.parseFragment(text)
|
||||
filtered = BleachSanitizerFilter(
|
||||
source=self.walker(dom),
|
||||
# Bleach-sanitizer-specific things
|
||||
allowed_tags=self.tags,
|
||||
attributes=self.attributes,
|
||||
strip_disallowed_elements=self.strip,
|
||||
strip_disallowed_tags=self.strip,
|
||||
strip_html_comments=self.strip_comments,
|
||||
css_sanitizer=self.css_sanitizer,
|
||||
# html5lib-sanitizer things
|
||||
allowed_elements=self.tags,
|
||||
allowed_protocols=self.protocols,
|
||||
)
|
||||
|
||||
|
@ -237,19 +259,21 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
|
|||
def __init__(
|
||||
self,
|
||||
source,
|
||||
allowed_elements=ALLOWED_TAGS,
|
||||
allowed_tags=ALLOWED_TAGS,
|
||||
attributes=ALLOWED_ATTRIBUTES,
|
||||
allowed_protocols=ALLOWED_PROTOCOLS,
|
||||
strip_disallowed_elements=False,
|
||||
attr_val_is_uri=html5lib_shim.attr_val_is_uri,
|
||||
svg_attr_val_allows_ref=html5lib_shim.svg_attr_val_allows_ref,
|
||||
svg_allow_local_href=html5lib_shim.svg_allow_local_href,
|
||||
strip_disallowed_tags=False,
|
||||
strip_html_comments=True,
|
||||
css_sanitizer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Creates a BleachSanitizerFilter instance
|
||||
|
||||
:arg source: html5lib TreeWalker stream as an html5lib TreeWalker
|
||||
|
||||
:arg list allowed_elements: allowed list of tags; defaults to
|
||||
:arg set allowed_tags: set of allowed tags; defaults to
|
||||
``bleach.sanitizer.ALLOWED_TAGS``
|
||||
|
||||
:arg dict attributes: allowed attributes; can be a callable, list or dict;
|
||||
|
@ -258,8 +282,16 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
|
|||
:arg list allowed_protocols: allowed list of protocols for links; defaults
|
||||
to ``bleach.sanitizer.ALLOWED_PROTOCOLS``
|
||||
|
||||
:arg bool strip_disallowed_elements: whether or not to strip disallowed
|
||||
elements
|
||||
:arg attr_val_is_uri: set of attributes that have URI values
|
||||
|
||||
:arg svg_attr_val_allows_ref: set of SVG attributes that can have
|
||||
references
|
||||
|
||||
:arg svg_allow_local_href: set of SVG elements that can have local
|
||||
hrefs
|
||||
|
||||
:arg bool strip_disallowed_tags: whether or not to strip disallowed
|
||||
tags
|
||||
|
||||
:arg bool strip_html_comments: whether or not to strip HTML comments
|
||||
|
||||
|
@ -267,24 +299,24 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
|
|||
sanitizing style attribute values and style text; defaults to None
|
||||
|
||||
"""
|
||||
self.attr_filter = attribute_filter_factory(attributes)
|
||||
self.strip_disallowed_elements = strip_disallowed_elements
|
||||
self.strip_html_comments = strip_html_comments
|
||||
self.css_sanitizer = css_sanitizer
|
||||
# NOTE(willkg): This is the superclass of
|
||||
# html5lib.filters.sanitizer.Filter. We call this directly skipping the
|
||||
# __init__ for html5lib.filters.sanitizer.Filter because that does
|
||||
# things we don't need to do and kicks up the deprecation warning for
|
||||
# using Sanitizer.
|
||||
html5lib_shim.Filter.__init__(self, source)
|
||||
|
||||
# filter out html5lib deprecation warnings to use bleach from BleachSanitizerFilter init
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="html5lib's sanitizer is deprecated",
|
||||
category=DeprecationWarning,
|
||||
module="bleach._vendor.html5lib",
|
||||
)
|
||||
return super().__init__(
|
||||
source,
|
||||
allowed_elements=allowed_elements,
|
||||
allowed_protocols=allowed_protocols,
|
||||
**kwargs,
|
||||
)
|
||||
self.allowed_tags = frozenset(allowed_tags)
|
||||
self.allowed_protocols = frozenset(allowed_protocols)
|
||||
|
||||
self.attr_filter = attribute_filter_factory(attributes)
|
||||
self.strip_disallowed_tags = strip_disallowed_tags
|
||||
self.strip_html_comments = strip_html_comments
|
||||
|
||||
self.attr_val_is_uri = attr_val_is_uri
|
||||
self.svg_attr_val_allows_ref = svg_attr_val_allows_ref
|
||||
self.css_sanitizer = css_sanitizer
|
||||
self.svg_allow_local_href = svg_allow_local_href
|
||||
|
||||
def sanitize_stream(self, token_iterator):
|
||||
for token in token_iterator:
|
||||
|
@ -354,10 +386,10 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
|
|||
"""
|
||||
token_type = token["type"]
|
||||
if token_type in ["StartTag", "EndTag", "EmptyTag"]:
|
||||
if token["name"] in self.allowed_elements:
|
||||
if token["name"] in self.allowed_tags:
|
||||
return self.allow_token(token)
|
||||
|
||||
elif self.strip_disallowed_elements:
|
||||
elif self.strip_disallowed_tags:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
@ -570,7 +602,7 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
|
|||
def disallowed_token(self, token):
|
||||
token_type = token["type"]
|
||||
if token_type == "EndTag":
|
||||
token["data"] = "</%s>" % token["name"]
|
||||
token["data"] = f"</{token['name']}>"
|
||||
|
||||
elif token["data"]:
|
||||
assert token_type in ("StartTag", "EmptyTag")
|
||||
|
@ -586,25 +618,19 @@ class BleachSanitizerFilter(html5lib_shim.SanitizerFilter):
|
|||
if ns is None or ns not in html5lib_shim.prefixes:
|
||||
namespaced_name = name
|
||||
else:
|
||||
namespaced_name = "{}:{}".format(html5lib_shim.prefixes[ns], name)
|
||||
namespaced_name = f"{html5lib_shim.prefixes[ns]}:{name}"
|
||||
|
||||
attrs.append(
|
||||
' %s="%s"'
|
||||
% (
|
||||
namespaced_name,
|
||||
# NOTE(willkg): HTMLSerializer escapes attribute values
|
||||
# already, so if we do it here (like HTMLSerializer does),
|
||||
# then we end up double-escaping.
|
||||
v,
|
||||
)
|
||||
)
|
||||
token["data"] = "<{}{}>".format(token["name"], "".join(attrs))
|
||||
attrs.append(f' {namespaced_name}="{v}"')
|
||||
token["data"] = f"<{token['name']}{''.join(attrs)}>"
|
||||
|
||||
else:
|
||||
token["data"] = "<%s>" % token["name"]
|
||||
token["data"] = f"<{token['name']}>"
|
||||
|
||||
if token.get("selfClosing"):
|
||||
token["data"] = token["data"][:-1] + "/>"
|
||||
token["data"] = f"{token['data'][:-1]}/>"
|
||||
|
||||
token["type"] = "Characters"
|
||||
|
||||
|
|
|
@ -21,14 +21,8 @@ at <https://github.com/Ousret/charset_normalizer>.
|
|||
"""
|
||||
import logging
|
||||
|
||||
from .api import from_bytes, from_fp, from_path, normalize
|
||||
from .legacy import (
|
||||
CharsetDetector,
|
||||
CharsetDoctor,
|
||||
CharsetNormalizerMatch,
|
||||
CharsetNormalizerMatches,
|
||||
detect,
|
||||
)
|
||||
from .api import from_bytes, from_fp, from_path
|
||||
from .legacy import detect
|
||||
from .models import CharsetMatch, CharsetMatches
|
||||
from .utils import set_logging_handler
|
||||
from .version import VERSION, __version__
|
||||
|
@ -37,14 +31,9 @@ __all__ = (
|
|||
"from_fp",
|
||||
"from_path",
|
||||
"from_bytes",
|
||||
"normalize",
|
||||
"detect",
|
||||
"CharsetMatch",
|
||||
"CharsetMatches",
|
||||
"CharsetNormalizerMatch",
|
||||
"CharsetNormalizerMatches",
|
||||
"CharsetDetector",
|
||||
"CharsetDoctor",
|
||||
"__version__",
|
||||
"VERSION",
|
||||
"set_logging_handler",
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
import warnings
|
||||
from os import PathLike
|
||||
from os.path import basename, splitext
|
||||
from typing import Any, BinaryIO, List, Optional, Set
|
||||
|
||||
from .cd import (
|
||||
|
@ -41,11 +39,12 @@ def from_bytes(
|
|||
cp_exclusion: Optional[List[str]] = None,
|
||||
preemptive_behaviour: bool = True,
|
||||
explain: bool = False,
|
||||
language_threshold: float = 0.1,
|
||||
) -> CharsetMatches:
|
||||
"""
|
||||
Given a raw bytes sequence, return the best possibles charset usable to render str objects.
|
||||
If there is no results, it is a strong indicator that the source is binary/not text.
|
||||
By default, the process will extract 5 blocs of 512o each to assess the mess and coherence of a given sequence.
|
||||
By default, the process will extract 5 blocks of 512o each to assess the mess and coherence of a given sequence.
|
||||
And will give up a particular code page after 20% of measured mess. Those criteria are customizable at will.
|
||||
|
||||
The preemptive behavior DOES NOT replace the traditional detection workflow, it prioritize a particular code page
|
||||
|
@ -197,7 +196,14 @@ def from_bytes(
|
|||
if encoding_iana in {"utf_16", "utf_32"} and not bom_or_sig_available:
|
||||
logger.log(
|
||||
TRACE,
|
||||
"Encoding %s wont be tested as-is because it require a BOM. Will try some sub-encoder LE/BE.",
|
||||
"Encoding %s won't be tested as-is because it require a BOM. Will try some sub-encoder LE/BE.",
|
||||
encoding_iana,
|
||||
)
|
||||
continue
|
||||
if encoding_iana in {"utf_7"} and not bom_or_sig_available:
|
||||
logger.log(
|
||||
TRACE,
|
||||
"Encoding %s won't be tested as-is because detection is unreliable without BOM/SIG.",
|
||||
encoding_iana,
|
||||
)
|
||||
continue
|
||||
|
@ -297,7 +303,13 @@ def from_bytes(
|
|||
):
|
||||
md_chunks.append(chunk)
|
||||
|
||||
md_ratios.append(mess_ratio(chunk, threshold))
|
||||
md_ratios.append(
|
||||
mess_ratio(
|
||||
chunk,
|
||||
threshold,
|
||||
explain is True and 1 <= len(cp_isolation) <= 2,
|
||||
)
|
||||
)
|
||||
|
||||
if md_ratios[-1] >= threshold:
|
||||
early_stop_count += 1
|
||||
|
@ -389,7 +401,9 @@ def from_bytes(
|
|||
if encoding_iana != "ascii":
|
||||
for chunk in md_chunks:
|
||||
chunk_languages = coherence_ratio(
|
||||
chunk, 0.1, ",".join(target_languages) if target_languages else None
|
||||
chunk,
|
||||
language_threshold,
|
||||
",".join(target_languages) if target_languages else None,
|
||||
)
|
||||
|
||||
cd_ratios.append(chunk_languages)
|
||||
|
@ -491,6 +505,7 @@ def from_fp(
|
|||
cp_exclusion: Optional[List[str]] = None,
|
||||
preemptive_behaviour: bool = True,
|
||||
explain: bool = False,
|
||||
language_threshold: float = 0.1,
|
||||
) -> CharsetMatches:
|
||||
"""
|
||||
Same thing than the function from_bytes but using a file pointer that is already ready.
|
||||
|
@ -505,6 +520,7 @@ def from_fp(
|
|||
cp_exclusion,
|
||||
preemptive_behaviour,
|
||||
explain,
|
||||
language_threshold,
|
||||
)
|
||||
|
||||
|
||||
|
@ -517,6 +533,7 @@ def from_path(
|
|||
cp_exclusion: Optional[List[str]] = None,
|
||||
preemptive_behaviour: bool = True,
|
||||
explain: bool = False,
|
||||
language_threshold: float = 0.1,
|
||||
) -> CharsetMatches:
|
||||
"""
|
||||
Same thing than the function from_bytes but with one extra step. Opening and reading given file path in binary mode.
|
||||
|
@ -532,53 +549,5 @@ def from_path(
|
|||
cp_exclusion,
|
||||
preemptive_behaviour,
|
||||
explain,
|
||||
language_threshold,
|
||||
)
|
||||
|
||||
|
||||
def normalize(
|
||||
path: "PathLike[Any]",
|
||||
steps: int = 5,
|
||||
chunk_size: int = 512,
|
||||
threshold: float = 0.20,
|
||||
cp_isolation: Optional[List[str]] = None,
|
||||
cp_exclusion: Optional[List[str]] = None,
|
||||
preemptive_behaviour: bool = True,
|
||||
) -> CharsetMatch:
|
||||
"""
|
||||
Take a (text-based) file path and try to create another file next to it, this time using UTF-8.
|
||||
"""
|
||||
warnings.warn(
|
||||
"normalize is deprecated and will be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
results = from_path(
|
||||
path,
|
||||
steps,
|
||||
chunk_size,
|
||||
threshold,
|
||||
cp_isolation,
|
||||
cp_exclusion,
|
||||
preemptive_behaviour,
|
||||
)
|
||||
|
||||
filename = basename(path)
|
||||
target_extensions = list(splitext(filename))
|
||||
|
||||
if len(results) == 0:
|
||||
raise IOError(
|
||||
'Unable to normalize "{}", no encoding charset seems to fit.'.format(
|
||||
filename
|
||||
)
|
||||
)
|
||||
|
||||
result = results.best()
|
||||
|
||||
target_extensions[0] += "-" + result.encoding # type: ignore
|
||||
|
||||
with open(
|
||||
"{}".format(str(path).replace(filename, "".join(target_extensions))), "wb"
|
||||
) as fp:
|
||||
fp.write(result.output()) # type: ignore
|
||||
|
||||
return result # type: ignore
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from typing import Dict, List
|
||||
|
||||
# Language label that contain the em dash "—"
|
||||
# character are to be considered alternative seq to origin
|
||||
FREQUENCIES: Dict[str, List[str]] = {
|
||||
"English": [
|
||||
"e",
|
||||
|
@ -30,6 +32,34 @@ FREQUENCIES: Dict[str, List[str]] = {
|
|||
"z",
|
||||
"q",
|
||||
],
|
||||
"English—": [
|
||||
"e",
|
||||
"a",
|
||||
"t",
|
||||
"i",
|
||||
"o",
|
||||
"n",
|
||||
"s",
|
||||
"r",
|
||||
"h",
|
||||
"l",
|
||||
"d",
|
||||
"c",
|
||||
"m",
|
||||
"u",
|
||||
"f",
|
||||
"p",
|
||||
"g",
|
||||
"w",
|
||||
"b",
|
||||
"y",
|
||||
"v",
|
||||
"k",
|
||||
"j",
|
||||
"x",
|
||||
"z",
|
||||
"q",
|
||||
],
|
||||
"German": [
|
||||
"e",
|
||||
"n",
|
||||
|
@ -226,33 +256,303 @@ FREQUENCIES: Dict[str, List[str]] = {
|
|||
"ж",
|
||||
"ц",
|
||||
],
|
||||
# Jap-Kanji
|
||||
"Japanese": [
|
||||
"人",
|
||||
"一",
|
||||
"大",
|
||||
"亅",
|
||||
"丁",
|
||||
"丨",
|
||||
"竹",
|
||||
"笑",
|
||||
"口",
|
||||
"日",
|
||||
"今",
|
||||
"二",
|
||||
"彳",
|
||||
"行",
|
||||
"十",
|
||||
"土",
|
||||
"丶",
|
||||
"寸",
|
||||
"寺",
|
||||
"時",
|
||||
"乙",
|
||||
"丿",
|
||||
"乂",
|
||||
"气",
|
||||
"気",
|
||||
"冂",
|
||||
"巾",
|
||||
"亠",
|
||||
"市",
|
||||
"目",
|
||||
"儿",
|
||||
"見",
|
||||
"八",
|
||||
"小",
|
||||
"凵",
|
||||
"県",
|
||||
"月",
|
||||
"彐",
|
||||
"門",
|
||||
"間",
|
||||
"木",
|
||||
"東",
|
||||
"山",
|
||||
"出",
|
||||
"本",
|
||||
"中",
|
||||
"刀",
|
||||
"分",
|
||||
"耳",
|
||||
"又",
|
||||
"取",
|
||||
"最",
|
||||
"言",
|
||||
"田",
|
||||
"心",
|
||||
"思",
|
||||
"刂",
|
||||
"前",
|
||||
"京",
|
||||
"尹",
|
||||
"事",
|
||||
"生",
|
||||
"厶",
|
||||
"云",
|
||||
"会",
|
||||
"未",
|
||||
"来",
|
||||
"白",
|
||||
"冫",
|
||||
"楽",
|
||||
"灬",
|
||||
"馬",
|
||||
"尸",
|
||||
"尺",
|
||||
"駅",
|
||||
"明",
|
||||
"耂",
|
||||
"者",
|
||||
"了",
|
||||
"阝",
|
||||
"都",
|
||||
"高",
|
||||
"卜",
|
||||
"占",
|
||||
"厂",
|
||||
"广",
|
||||
"店",
|
||||
"子",
|
||||
"申",
|
||||
"奄",
|
||||
"亻",
|
||||
"俺",
|
||||
"上",
|
||||
"方",
|
||||
"冖",
|
||||
"学",
|
||||
"衣",
|
||||
"艮",
|
||||
"食",
|
||||
"自",
|
||||
],
|
||||
# Jap-Katakana
|
||||
"Japanese—": [
|
||||
"ー",
|
||||
"ン",
|
||||
"ス",
|
||||
"・",
|
||||
"ル",
|
||||
"ト",
|
||||
"リ",
|
||||
"イ",
|
||||
"ア",
|
||||
"ラ",
|
||||
"ッ",
|
||||
"ク",
|
||||
"ド",
|
||||
"シ",
|
||||
"レ",
|
||||
"ジ",
|
||||
"タ",
|
||||
"フ",
|
||||
"ロ",
|
||||
"カ",
|
||||
"テ",
|
||||
"マ",
|
||||
"ィ",
|
||||
"グ",
|
||||
"バ",
|
||||
"ム",
|
||||
"プ",
|
||||
"オ",
|
||||
"コ",
|
||||
"デ",
|
||||
"ニ",
|
||||
"ウ",
|
||||
"メ",
|
||||
"サ",
|
||||
"ビ",
|
||||
"ナ",
|
||||
"ブ",
|
||||
"ャ",
|
||||
"エ",
|
||||
"ュ",
|
||||
"チ",
|
||||
"キ",
|
||||
"ズ",
|
||||
"ダ",
|
||||
"パ",
|
||||
"ミ",
|
||||
"ェ",
|
||||
"ョ",
|
||||
"ハ",
|
||||
"セ",
|
||||
"ベ",
|
||||
"ガ",
|
||||
"モ",
|
||||
"ツ",
|
||||
"ネ",
|
||||
"ボ",
|
||||
"ソ",
|
||||
"ノ",
|
||||
"ァ",
|
||||
"ヴ",
|
||||
"ワ",
|
||||
"ポ",
|
||||
"ペ",
|
||||
"ピ",
|
||||
"ケ",
|
||||
"ゴ",
|
||||
"ギ",
|
||||
"ザ",
|
||||
"ホ",
|
||||
"ゲ",
|
||||
"ォ",
|
||||
"ヤ",
|
||||
"ヒ",
|
||||
"ユ",
|
||||
"ヨ",
|
||||
"ヘ",
|
||||
"ゼ",
|
||||
"ヌ",
|
||||
"ゥ",
|
||||
"ゾ",
|
||||
"ヶ",
|
||||
"ヂ",
|
||||
"ヲ",
|
||||
"ヅ",
|
||||
"ヵ",
|
||||
"ヱ",
|
||||
"ヰ",
|
||||
"ヮ",
|
||||
"ヽ",
|
||||
"゠",
|
||||
"ヾ",
|
||||
"ヷ",
|
||||
"ヿ",
|
||||
"ヸ",
|
||||
"ヹ",
|
||||
"ヺ",
|
||||
],
|
||||
# Jap-Hiragana
|
||||
"Japanese——": [
|
||||
"の",
|
||||
"に",
|
||||
"る",
|
||||
"た",
|
||||
"は",
|
||||
"ー",
|
||||
"と",
|
||||
"は",
|
||||
"し",
|
||||
"い",
|
||||
"を",
|
||||
"で",
|
||||
"て",
|
||||
"が",
|
||||
"い",
|
||||
"ン",
|
||||
"れ",
|
||||
"な",
|
||||
"年",
|
||||
"ス",
|
||||
"っ",
|
||||
"ル",
|
||||
"れ",
|
||||
"か",
|
||||
"ら",
|
||||
"あ",
|
||||
"さ",
|
||||
"も",
|
||||
"っ",
|
||||
"り",
|
||||
"す",
|
||||
"あ",
|
||||
"も",
|
||||
"こ",
|
||||
"ま",
|
||||
"う",
|
||||
"く",
|
||||
"よ",
|
||||
"き",
|
||||
"ん",
|
||||
"め",
|
||||
"お",
|
||||
"け",
|
||||
"そ",
|
||||
"つ",
|
||||
"だ",
|
||||
"や",
|
||||
"え",
|
||||
"ど",
|
||||
"わ",
|
||||
"ち",
|
||||
"み",
|
||||
"せ",
|
||||
"じ",
|
||||
"ば",
|
||||
"へ",
|
||||
"び",
|
||||
"ず",
|
||||
"ろ",
|
||||
"ほ",
|
||||
"げ",
|
||||
"む",
|
||||
"べ",
|
||||
"ひ",
|
||||
"ょ",
|
||||
"ゆ",
|
||||
"ぶ",
|
||||
"ご",
|
||||
"ゃ",
|
||||
"ね",
|
||||
"ふ",
|
||||
"ぐ",
|
||||
"ぎ",
|
||||
"ぼ",
|
||||
"ゅ",
|
||||
"づ",
|
||||
"ざ",
|
||||
"ぞ",
|
||||
"ぬ",
|
||||
"ぜ",
|
||||
"ぱ",
|
||||
"ぽ",
|
||||
"ぷ",
|
||||
"ぴ",
|
||||
"ぃ",
|
||||
"ぁ",
|
||||
"ぇ",
|
||||
"ぺ",
|
||||
"ゞ",
|
||||
"ぢ",
|
||||
"ぉ",
|
||||
"ぅ",
|
||||
"ゐ",
|
||||
"ゝ",
|
||||
"ゑ",
|
||||
"゛",
|
||||
"゜",
|
||||
"ゎ",
|
||||
"ゔ",
|
||||
"゚",
|
||||
"ゟ",
|
||||
"゙",
|
||||
"ゕ",
|
||||
"ゖ",
|
||||
],
|
||||
"Portuguese": [
|
||||
"a",
|
||||
|
@ -340,6 +640,77 @@ FREQUENCIES: Dict[str, List[str]] = {
|
|||
"就",
|
||||
"出",
|
||||
"会",
|
||||
"可",
|
||||
"也",
|
||||
"你",
|
||||
"对",
|
||||
"生",
|
||||
"能",
|
||||
"而",
|
||||
"子",
|
||||
"那",
|
||||
"得",
|
||||
"于",
|
||||
"着",
|
||||
"下",
|
||||
"自",
|
||||
"之",
|
||||
"年",
|
||||
"过",
|
||||
"发",
|
||||
"后",
|
||||
"作",
|
||||
"里",
|
||||
"用",
|
||||
"道",
|
||||
"行",
|
||||
"所",
|
||||
"然",
|
||||
"家",
|
||||
"种",
|
||||
"事",
|
||||
"成",
|
||||
"方",
|
||||
"多",
|
||||
"经",
|
||||
"么",
|
||||
"去",
|
||||
"法",
|
||||
"学",
|
||||
"如",
|
||||
"都",
|
||||
"同",
|
||||
"现",
|
||||
"当",
|
||||
"没",
|
||||
"动",
|
||||
"面",
|
||||
"起",
|
||||
"看",
|
||||
"定",
|
||||
"天",
|
||||
"分",
|
||||
"还",
|
||||
"进",
|
||||
"好",
|
||||
"小",
|
||||
"部",
|
||||
"其",
|
||||
"些",
|
||||
"主",
|
||||
"样",
|
||||
"理",
|
||||
"心",
|
||||
"她",
|
||||
"本",
|
||||
"前",
|
||||
"开",
|
||||
"但",
|
||||
"因",
|
||||
"只",
|
||||
"从",
|
||||
"想",
|
||||
"实",
|
||||
],
|
||||
"Ukrainian": [
|
||||
"о",
|
||||
|
@ -956,34 +1327,6 @@ FREQUENCIES: Dict[str, List[str]] = {
|
|||
"ö",
|
||||
"y",
|
||||
],
|
||||
"Simple English": [
|
||||
"e",
|
||||
"a",
|
||||
"t",
|
||||
"i",
|
||||
"o",
|
||||
"n",
|
||||
"s",
|
||||
"r",
|
||||
"h",
|
||||
"l",
|
||||
"d",
|
||||
"c",
|
||||
"m",
|
||||
"u",
|
||||
"f",
|
||||
"p",
|
||||
"g",
|
||||
"w",
|
||||
"b",
|
||||
"y",
|
||||
"v",
|
||||
"k",
|
||||
"j",
|
||||
"x",
|
||||
"z",
|
||||
"q",
|
||||
],
|
||||
"Thai": [
|
||||
"า",
|
||||
"น",
|
||||
|
@ -1066,31 +1409,6 @@ FREQUENCIES: Dict[str, List[str]] = {
|
|||
"ஒ",
|
||||
"ஸ",
|
||||
],
|
||||
"Classical Chinese": [
|
||||
"之",
|
||||
"年",
|
||||
"為",
|
||||
"也",
|
||||
"以",
|
||||
"一",
|
||||
"人",
|
||||
"其",
|
||||
"者",
|
||||
"國",
|
||||
"有",
|
||||
"二",
|
||||
"十",
|
||||
"於",
|
||||
"曰",
|
||||
"三",
|
||||
"不",
|
||||
"大",
|
||||
"而",
|
||||
"子",
|
||||
"中",
|
||||
"五",
|
||||
"四",
|
||||
],
|
||||
"Kazakh": [
|
||||
"а",
|
||||
"ы",
|
||||
|
|
|
@ -105,7 +105,7 @@ def mb_encoding_languages(iana_name: str) -> List[str]:
|
|||
):
|
||||
return ["Japanese"]
|
||||
if iana_name.startswith("gb") or iana_name in ZH_NAMES:
|
||||
return ["Chinese", "Classical Chinese"]
|
||||
return ["Chinese"]
|
||||
if iana_name.startswith("iso2022_kr") or iana_name in KO_NAMES:
|
||||
return ["Korean"]
|
||||
|
||||
|
@ -179,22 +179,45 @@ def characters_popularity_compare(
|
|||
character_approved_count: int = 0
|
||||
FREQUENCIES_language_set = set(FREQUENCIES[language])
|
||||
|
||||
for character in ordered_characters:
|
||||
ordered_characters_count: int = len(ordered_characters)
|
||||
target_language_characters_count: int = len(FREQUENCIES[language])
|
||||
|
||||
large_alphabet: bool = target_language_characters_count > 26
|
||||
|
||||
for character, character_rank in zip(
|
||||
ordered_characters, range(0, ordered_characters_count)
|
||||
):
|
||||
if character not in FREQUENCIES_language_set:
|
||||
continue
|
||||
|
||||
character_rank_in_language: int = FREQUENCIES[language].index(character)
|
||||
expected_projection_ratio: float = (
|
||||
target_language_characters_count / ordered_characters_count
|
||||
)
|
||||
character_rank_projection: int = int(character_rank * expected_projection_ratio)
|
||||
|
||||
if (
|
||||
large_alphabet is False
|
||||
and abs(character_rank_projection - character_rank_in_language) > 4
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
large_alphabet is True
|
||||
and abs(character_rank_projection - character_rank_in_language)
|
||||
< target_language_characters_count / 3
|
||||
):
|
||||
character_approved_count += 1
|
||||
continue
|
||||
|
||||
characters_before_source: List[str] = FREQUENCIES[language][
|
||||
0 : FREQUENCIES[language].index(character)
|
||||
0:character_rank_in_language
|
||||
]
|
||||
characters_after_source: List[str] = FREQUENCIES[language][
|
||||
FREQUENCIES[language].index(character) :
|
||||
]
|
||||
characters_before: List[str] = ordered_characters[
|
||||
0 : ordered_characters.index(character)
|
||||
]
|
||||
characters_after: List[str] = ordered_characters[
|
||||
ordered_characters.index(character) :
|
||||
character_rank_in_language:
|
||||
]
|
||||
characters_before: List[str] = ordered_characters[0:character_rank]
|
||||
characters_after: List[str] = ordered_characters[character_rank:]
|
||||
|
||||
before_match_count: int = len(
|
||||
set(characters_before) & set(characters_before_source)
|
||||
|
@ -289,6 +312,33 @@ def merge_coherence_ratios(results: List[CoherenceMatches]) -> CoherenceMatches:
|
|||
return sorted(merge, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
def filter_alt_coherence_matches(results: CoherenceMatches) -> CoherenceMatches:
|
||||
"""
|
||||
We shall NOT return "English—" in CoherenceMatches because it is an alternative
|
||||
of "English". This function only keeps the best match and remove the em-dash in it.
|
||||
"""
|
||||
index_results: Dict[str, List[float]] = dict()
|
||||
|
||||
for result in results:
|
||||
language, ratio = result
|
||||
no_em_name: str = language.replace("—", "")
|
||||
|
||||
if no_em_name not in index_results:
|
||||
index_results[no_em_name] = []
|
||||
|
||||
index_results[no_em_name].append(ratio)
|
||||
|
||||
if any(len(index_results[e]) > 1 for e in index_results):
|
||||
filtered_results: CoherenceMatches = []
|
||||
|
||||
for language in index_results:
|
||||
filtered_results.append((language, max(index_results[language])))
|
||||
|
||||
return filtered_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def coherence_ratio(
|
||||
decoded_sequence: str, threshold: float = 0.1, lg_inclusion: Optional[str] = None
|
||||
|
@ -336,4 +386,6 @@ def coherence_ratio(
|
|||
if sufficient_match_count >= 3:
|
||||
break
|
||||
|
||||
return sorted(results, key=lambda x: x[1], reverse=True)
|
||||
return sorted(
|
||||
filter_alt_coherence_matches(results), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
import argparse
|
||||
import sys
|
||||
from json import dumps
|
||||
from os.path import abspath
|
||||
from os.path import abspath, basename, dirname, join, realpath
|
||||
from platform import python_version
|
||||
from typing import List, Optional
|
||||
from unicodedata import unidata_version
|
||||
|
||||
try:
|
||||
from unicodedata2 import unidata_version
|
||||
except ImportError:
|
||||
from unicodedata import unidata_version
|
||||
|
||||
import charset_normalizer.md as md_module
|
||||
from charset_normalizer import from_fp
|
||||
from charset_normalizer.models import CliDetectionResult
|
||||
from charset_normalizer.version import __version__
|
||||
|
@ -124,8 +121,11 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
|
|||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version="Charset-Normalizer {} - Python {} - Unicode {}".format(
|
||||
__version__, python_version(), unidata_version
|
||||
version="Charset-Normalizer {} - Python {} - Unicode {} - SpeedUp {}".format(
|
||||
__version__,
|
||||
python_version(),
|
||||
unidata_version,
|
||||
"OFF" if md_module.__file__.lower().endswith(".py") else "ON",
|
||||
),
|
||||
help="Show version information and exit.",
|
||||
)
|
||||
|
@ -234,7 +234,10 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
|
|||
my_file.close()
|
||||
continue
|
||||
|
||||
o_: List[str] = my_file.name.split(".")
|
||||
dir_path = dirname(realpath(my_file.name))
|
||||
file_name = basename(realpath(my_file.name))
|
||||
|
||||
o_: List[str] = file_name.split(".")
|
||||
|
||||
if args.replace is False:
|
||||
o_.insert(-1, best_guess.encoding)
|
||||
|
@ -255,7 +258,7 @@ def cli_detect(argv: Optional[List[str]] = None) -> int:
|
|||
continue
|
||||
|
||||
try:
|
||||
x_[0].unicode_path = abspath("./{}".format(".".join(o_)))
|
||||
x_[0].unicode_path = join(dir_path, ".".join(o_))
|
||||
|
||||
with open(x_[0].unicode_path, "w", encoding="utf-8") as fp:
|
||||
fp.write(str(best_guess))
|
||||
|
|
|
@ -489,9 +489,7 @@ COMMON_SAFE_ASCII_CHARACTERS: Set[str] = {
|
|||
KO_NAMES: Set[str] = {"johab", "cp949", "euc_kr"}
|
||||
ZH_NAMES: Set[str] = {"big5", "cp950", "big5hkscs", "hz"}
|
||||
|
||||
NOT_PRINTABLE_PATTERN = re_compile(r"[0-9\W\n\r\t]+")
|
||||
|
||||
LANGUAGE_SUPPORTED_COUNT: int = len(FREQUENCIES)
|
||||
|
||||
# Logging LEVEL bellow DEBUG
|
||||
# Logging LEVEL below DEBUG
|
||||
TRACE: int = 5
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from .api import from_bytes, from_fp, from_path, normalize
|
||||
from .api import from_bytes
|
||||
from .constant import CHARDET_CORRESPONDENCE
|
||||
from .models import CharsetMatch, CharsetMatches
|
||||
|
||||
|
||||
def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
|
||||
|
@ -43,53 +41,3 @@ def detect(byte_str: bytes) -> Dict[str, Optional[Union[str, float]]]:
|
|||
"language": language,
|
||||
"confidence": confidence,
|
||||
}
|
||||
|
||||
|
||||
class CharsetNormalizerMatch(CharsetMatch):
|
||||
pass
|
||||
|
||||
|
||||
class CharsetNormalizerMatches(CharsetMatches):
|
||||
@staticmethod
|
||||
def from_fp(*args, **kwargs): # type: ignore
|
||||
warnings.warn( # pragma: nocover
|
||||
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
|
||||
"and scheduled to be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return from_fp(*args, **kwargs) # pragma: nocover
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(*args, **kwargs): # type: ignore
|
||||
warnings.warn( # pragma: nocover
|
||||
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
|
||||
"and scheduled to be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return from_bytes(*args, **kwargs) # pragma: nocover
|
||||
|
||||
@staticmethod
|
||||
def from_path(*args, **kwargs): # type: ignore
|
||||
warnings.warn( # pragma: nocover
|
||||
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
|
||||
"and scheduled to be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return from_path(*args, **kwargs) # pragma: nocover
|
||||
|
||||
@staticmethod
|
||||
def normalize(*args, **kwargs): # type: ignore
|
||||
warnings.warn( # pragma: nocover
|
||||
"staticmethod from_fp, from_bytes, from_path and normalize are deprecated "
|
||||
"and scheduled to be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return normalize(*args, **kwargs) # pragma: nocover
|
||||
|
||||
|
||||
class CharsetDetector(CharsetNormalizerMatches):
|
||||
pass
|
||||
|
||||
|
||||
class CharsetDoctor(CharsetNormalizerMatches):
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
|
||||
from .constant import COMMON_SAFE_ASCII_CHARACTERS, UNICODE_SECONDARY_RANGE_KEYWORD
|
||||
from .constant import (
|
||||
COMMON_SAFE_ASCII_CHARACTERS,
|
||||
TRACE,
|
||||
UNICODE_SECONDARY_RANGE_KEYWORD,
|
||||
)
|
||||
from .utils import (
|
||||
is_accentuated,
|
||||
is_ascii,
|
||||
|
@ -123,7 +128,7 @@ class TooManyAccentuatedPlugin(MessDetectorPlugin):
|
|||
|
||||
@property
|
||||
def ratio(self) -> float:
|
||||
if self._character_count == 0:
|
||||
if self._character_count == 0 or self._character_count < 8:
|
||||
return 0.0
|
||||
ratio_of_accentuation: float = self._accentuated_count / self._character_count
|
||||
return ratio_of_accentuation if ratio_of_accentuation >= 0.35 else 0.0
|
||||
|
@ -547,7 +552,20 @@ def mess_ratio(
|
|||
break
|
||||
|
||||
if debug:
|
||||
logger = getLogger("charset_normalizer")
|
||||
|
||||
logger.log(
|
||||
TRACE,
|
||||
"Mess-detector extended-analysis start. "
|
||||
f"intermediary_mean_mess_ratio_calc={intermediary_mean_mess_ratio_calc} mean_mess_ratio={mean_mess_ratio} "
|
||||
f"maximum_threshold={maximum_threshold}",
|
||||
)
|
||||
|
||||
if len(decoded_sequence) > 16:
|
||||
logger.log(TRACE, f"Starting with: {decoded_sequence[:16]}")
|
||||
logger.log(TRACE, f"Ending with: {decoded_sequence[-16::]}")
|
||||
|
||||
for dt in detectors: # pragma: nocover
|
||||
print(dt.__class__, dt.ratio)
|
||||
logger.log(TRACE, f"{dt.__class__}: {dt.ratio}")
|
||||
|
||||
return round(mean_mess_ratio, 3)
|
||||
|
|
|
@ -1,22 +1,9 @@
|
|||
import warnings
|
||||
from collections import Counter
|
||||
from encodings.aliases import aliases
|
||||
from hashlib import sha256
|
||||
from json import dumps
|
||||
from re import sub
|
||||
from typing import (
|
||||
Any,
|
||||
Counter as TypeCounter,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from .constant import NOT_PRINTABLE_PATTERN, TOO_BIG_SEQUENCE
|
||||
from .md import mess_ratio
|
||||
from .constant import TOO_BIG_SEQUENCE
|
||||
from .utils import iana_name, is_multi_byte_encoding, unicode_range
|
||||
|
||||
|
||||
|
@ -65,7 +52,7 @@ class CharsetMatch:
|
|||
chaos_difference: float = abs(self.chaos - other.chaos)
|
||||
coherence_difference: float = abs(self.coherence - other.coherence)
|
||||
|
||||
# Bellow 1% difference --> Use Coherence
|
||||
# Below 1% difference --> Use Coherence
|
||||
if chaos_difference < 0.01 and coherence_difference > 0.02:
|
||||
# When having a tough decision, use the result that decoded as many multi-byte as possible.
|
||||
if chaos_difference == 0.0 and self.coherence == other.coherence:
|
||||
|
@ -78,45 +65,6 @@ class CharsetMatch:
|
|||
def multi_byte_usage(self) -> float:
|
||||
return 1.0 - len(str(self)) / len(self.raw)
|
||||
|
||||
@property
|
||||
def chaos_secondary_pass(self) -> float:
|
||||
"""
|
||||
Check once again chaos in decoded text, except this time, with full content.
|
||||
Use with caution, this can be very slow.
|
||||
Notice: Will be removed in 3.0
|
||||
"""
|
||||
warnings.warn(
|
||||
"chaos_secondary_pass is deprecated and will be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return mess_ratio(str(self), 1.0)
|
||||
|
||||
@property
|
||||
def coherence_non_latin(self) -> float:
|
||||
"""
|
||||
Coherence ratio on the first non-latin language detected if ANY.
|
||||
Notice: Will be removed in 3.0
|
||||
"""
|
||||
warnings.warn(
|
||||
"coherence_non_latin is deprecated and will be removed in 3.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def w_counter(self) -> TypeCounter[str]:
|
||||
"""
|
||||
Word counter instance on decoded text.
|
||||
Notice: Will be removed in 3.0
|
||||
"""
|
||||
warnings.warn(
|
||||
"w_counter is deprecated and will be removed in 3.0", DeprecationWarning
|
||||
)
|
||||
|
||||
string_printable_only = sub(NOT_PRINTABLE_PATTERN, " ", str(self).lower())
|
||||
|
||||
return Counter(string_printable_only.split())
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Lazy Str Loading
|
||||
if self._string is None:
|
||||
|
@ -252,18 +200,6 @@ class CharsetMatch:
|
|||
"""
|
||||
return [self._encoding] + [m.encoding for m in self._leaves]
|
||||
|
||||
def first(self) -> "CharsetMatch":
|
||||
"""
|
||||
Kept for BC reasons. Will be removed in 3.0.
|
||||
"""
|
||||
return self
|
||||
|
||||
def best(self) -> "CharsetMatch":
|
||||
"""
|
||||
Kept for BC reasons. Will be removed in 3.0.
|
||||
"""
|
||||
return self
|
||||
|
||||
def output(self, encoding: str = "utf_8") -> bytes:
|
||||
"""
|
||||
Method to get re-encoded bytes payload using given target encoding. Default to UTF-8.
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
try:
|
||||
# WARNING: unicodedata2 support is going to be removed in 3.0
|
||||
# Python is quickly catching up.
|
||||
import unicodedata2 as unicodedata
|
||||
except ImportError:
|
||||
import unicodedata # type: ignore[no-redef]
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import unicodedata
|
||||
from codecs import IncrementalDecoder
|
||||
from encodings.aliases import aliases
|
||||
from functools import lru_cache
|
||||
|
@ -402,7 +396,7 @@ def cut_sequence_chunks(
|
|||
|
||||
# multi-byte bad cutting detector and adjustment
|
||||
# not the cleanest way to perform that fix but clever enough for now.
|
||||
if is_multi_byte_decoder and i > 0 and sequences[i] >= 0x80:
|
||||
if is_multi_byte_decoder and i > 0:
|
||||
|
||||
chunk_partial_size_chk: int = min(chunk_size, 16)
|
||||
|
||||
|
|
|
@ -2,5 +2,5 @@
|
|||
Expose version
|
||||
"""
|
||||
|
||||
__version__ = "2.1.1"
|
||||
__version__ = "3.0.1"
|
||||
VERSION = __version__.split(".")
|
||||
|
|
|
@ -18,49 +18,52 @@
|
|||
"""dnspython DNS toolkit"""
|
||||
|
||||
__all__ = [
|
||||
'asyncbackend',
|
||||
'asyncquery',
|
||||
'asyncresolver',
|
||||
'dnssec',
|
||||
'e164',
|
||||
'edns',
|
||||
'entropy',
|
||||
'exception',
|
||||
'flags',
|
||||
'immutable',
|
||||
'inet',
|
||||
'ipv4',
|
||||
'ipv6',
|
||||
'message',
|
||||
'name',
|
||||
'namedict',
|
||||
'node',
|
||||
'opcode',
|
||||
'query',
|
||||
'rcode',
|
||||
'rdata',
|
||||
'rdataclass',
|
||||
'rdataset',
|
||||
'rdatatype',
|
||||
'renderer',
|
||||
'resolver',
|
||||
'reversename',
|
||||
'rrset',
|
||||
'serial',
|
||||
'set',
|
||||
'tokenizer',
|
||||
'transaction',
|
||||
'tsig',
|
||||
'tsigkeyring',
|
||||
'ttl',
|
||||
'rdtypes',
|
||||
'update',
|
||||
'version',
|
||||
'versioned',
|
||||
'wire',
|
||||
'xfr',
|
||||
'zone',
|
||||
'zonefile',
|
||||
"asyncbackend",
|
||||
"asyncquery",
|
||||
"asyncresolver",
|
||||
"dnssec",
|
||||
"dnssectypes",
|
||||
"e164",
|
||||
"edns",
|
||||
"entropy",
|
||||
"exception",
|
||||
"flags",
|
||||
"immutable",
|
||||
"inet",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"message",
|
||||
"name",
|
||||
"namedict",
|
||||
"node",
|
||||
"opcode",
|
||||
"query",
|
||||
"quic",
|
||||
"rcode",
|
||||
"rdata",
|
||||
"rdataclass",
|
||||
"rdataset",
|
||||
"rdatatype",
|
||||
"renderer",
|
||||
"resolver",
|
||||
"reversename",
|
||||
"rrset",
|
||||
"serial",
|
||||
"set",
|
||||
"tokenizer",
|
||||
"transaction",
|
||||
"tsig",
|
||||
"tsigkeyring",
|
||||
"ttl",
|
||||
"rdtypes",
|
||||
"update",
|
||||
"version",
|
||||
"versioned",
|
||||
"wire",
|
||||
"xfr",
|
||||
"zone",
|
||||
"zonetypes",
|
||||
"zonefile",
|
||||
]
|
||||
|
||||
from dns.version import version as __version__ # noqa
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
|
||||
# but it is only for sync use.
|
||||
|
||||
|
||||
class NullContext:
|
||||
def __init__(self, enter_result=None):
|
||||
self.enter_result = enter_result
|
||||
|
@ -23,6 +24,7 @@ class NullContext:
|
|||
# These are declared here so backends can import them without creating
|
||||
# circular dependencies with dns.asyncbackend.
|
||||
|
||||
|
||||
class Socket: # pragma: no cover
|
||||
async def close(self):
|
||||
pass
|
||||
|
@ -41,6 +43,9 @@ class Socket: # pragma: no cover
|
|||
|
||||
|
||||
class DatagramSocket(Socket): # pragma: no cover
|
||||
def __init__(self, family: int):
|
||||
self.family = family
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -58,12 +63,23 @@ class StreamSocket(Socket): # pragma: no cover
|
|||
|
||||
class Backend: # pragma: no cover
|
||||
def name(self):
|
||||
return 'unknown'
|
||||
return "unknown"
|
||||
|
||||
async def make_socket(self, af, socktype, proto=0,
|
||||
source=None, destination=None, timeout=None,
|
||||
ssl_context=None, server_hostname=None):
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def datagram_connection_required(self):
|
||||
return False
|
||||
|
||||
async def sleep(self, interval):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -10,7 +10,8 @@ import dns._asyncbackend
|
|||
import dns.exception
|
||||
|
||||
|
||||
_is_win32 = sys.platform == 'win32'
|
||||
_is_win32 = sys.platform == "win32"
|
||||
|
||||
|
||||
def _get_running_loop():
|
||||
try:
|
||||
|
@ -30,7 +31,6 @@ class _DatagramProtocol:
|
|||
def datagram_received(self, data, addr):
|
||||
if self.recvfrom and not self.recvfrom.done():
|
||||
self.recvfrom.set_result((data, addr))
|
||||
self.recvfrom = None
|
||||
|
||||
def error_received(self, exc): # pragma: no cover
|
||||
if self.recvfrom and not self.recvfrom.done():
|
||||
|
@ -56,30 +56,34 @@ async def _maybe_wait_for(awaitable, timeout):
|
|||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, family, transport, protocol):
|
||||
self.family = family
|
||||
super().__init__(family)
|
||||
self.transport = transport
|
||||
self.protocol = protocol
|
||||
|
||||
async def sendto(self, what, destination, timeout): # pragma: no cover
|
||||
# no timeout for asyncio sendto
|
||||
self.transport.sendto(what, destination)
|
||||
return len(what)
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
# ignore size as there's no way I know to tell protocol about it
|
||||
done = _get_running_loop().create_future()
|
||||
try:
|
||||
assert self.protocol.recvfrom is None
|
||||
self.protocol.recvfrom = done
|
||||
await _maybe_wait_for(done, timeout)
|
||||
return done.result()
|
||||
finally:
|
||||
self.protocol.recvfrom = None
|
||||
|
||||
async def close(self):
|
||||
self.protocol.close()
|
||||
|
||||
async def getpeername(self):
|
||||
return self.transport.get_extra_info('peername')
|
||||
return self.transport.get_extra_info("peername")
|
||||
|
||||
async def getsockname(self):
|
||||
return self.transport.get_extra_info('sockname')
|
||||
return self.transport.get_extra_info("sockname")
|
||||
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
|
@ -93,8 +97,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
return await _maybe_wait_for(self.writer.drain(), timeout)
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
return await _maybe_wait_for(self.reader.read(size),
|
||||
timeout)
|
||||
return await _maybe_wait_for(self.reader.read(size), timeout)
|
||||
|
||||
async def close(self):
|
||||
self.writer.close()
|
||||
|
@ -104,43 +107,64 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
pass
|
||||
|
||||
async def getpeername(self):
|
||||
return self.writer.get_extra_info('peername')
|
||||
return self.writer.get_extra_info("peername")
|
||||
|
||||
async def getsockname(self):
|
||||
return self.writer.get_extra_info('sockname')
|
||||
return self.writer.get_extra_info("sockname")
|
||||
|
||||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
return 'asyncio'
|
||||
return "asyncio"
|
||||
|
||||
async def make_socket(self, af, socktype, proto=0,
|
||||
source=None, destination=None, timeout=None,
|
||||
ssl_context=None, server_hostname=None):
|
||||
if destination is None and socktype == socket.SOCK_DGRAM and \
|
||||
_is_win32:
|
||||
raise NotImplementedError('destinationless datagram sockets '
|
||||
'are not supported by asyncio '
|
||||
'on Windows')
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
|
||||
raise NotImplementedError(
|
||||
"destinationless datagram sockets "
|
||||
"are not supported by asyncio "
|
||||
"on Windows"
|
||||
)
|
||||
loop = _get_running_loop()
|
||||
if socktype == socket.SOCK_DGRAM:
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
_DatagramProtocol, source, family=af,
|
||||
proto=proto, remote_addr=destination)
|
||||
_DatagramProtocol,
|
||||
source,
|
||||
family=af,
|
||||
proto=proto,
|
||||
remote_addr=destination,
|
||||
)
|
||||
return DatagramSocket(af, transport, protocol)
|
||||
elif socktype == socket.SOCK_STREAM:
|
||||
if destination is None:
|
||||
# This shouldn't happen, but we check to make code analysis software
|
||||
# happier.
|
||||
raise ValueError("destination required for stream sockets")
|
||||
(r, w) = await _maybe_wait_for(
|
||||
asyncio.open_connection(destination[0],
|
||||
asyncio.open_connection(
|
||||
destination[0],
|
||||
destination[1],
|
||||
ssl=ssl_context,
|
||||
family=af,
|
||||
proto=proto,
|
||||
local_addr=source,
|
||||
server_hostname=server_hostname),
|
||||
timeout)
|
||||
server_hostname=server_hostname,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
return StreamSocket(af, r, w)
|
||||
raise NotImplementedError('unsupported socket ' +
|
||||
f'type {socktype}') # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"unsupported socket " + f"type {socktype}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def sleep(self, interval):
|
||||
await asyncio.sleep(interval)
|
||||
|
|
|
@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
|
|||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, socket):
|
||||
super().__init__(socket.family)
|
||||
self.socket = socket
|
||||
self.family = socket.family
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.sendto(what, destination)
|
||||
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.recvfrom(size)
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
await self.socket.close()
|
||||
|
@ -57,12 +59,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
async def sendall(self, what, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.sendall(what)
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.recv(size)
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
await self.socket.close()
|
||||
|
@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
return 'curio'
|
||||
return "curio"
|
||||
|
||||
async def make_socket(self, af, socktype, proto=0,
|
||||
source=None, destination=None, timeout=None,
|
||||
ssl_context=None, server_hostname=None):
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
if socktype == socket.SOCK_DGRAM:
|
||||
s = curio.socket.socket(af, socktype, proto)
|
||||
try:
|
||||
|
@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend):
|
|||
else:
|
||||
source_addr = None
|
||||
async with _maybe_timeout(timeout):
|
||||
s = await curio.open_connection(destination[0], destination[1],
|
||||
s = await curio.open_connection(
|
||||
destination[0],
|
||||
destination[1],
|
||||
ssl=ssl_context,
|
||||
source_addr=source_addr,
|
||||
server_hostname=server_hostname)
|
||||
server_hostname=server_hostname,
|
||||
)
|
||||
return StreamSocket(s)
|
||||
raise NotImplementedError('unsupported socket ' +
|
||||
f'type {socktype}') # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"unsupported socket " + f"type {socktype}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def sleep(self, interval):
|
||||
await curio.sleep(interval)
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# This implementation of the immutable decorator is for python 3.6,
|
||||
# which doesn't have Context Variables. This implementation is somewhat
|
||||
# costly for classes with slots, as it adds a __dict__ to them.
|
||||
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
class _Immutable:
|
||||
"""Immutable mixin class"""
|
||||
|
||||
# Note we MUST NOT have __slots__ as that causes
|
||||
#
|
||||
# TypeError: multiple bases have instance lay-out conflict
|
||||
#
|
||||
# when we get mixed in with another class with slots. When we
|
||||
# get mixed into something with slots, it effectively adds __dict__ to
|
||||
# the slots of the other class, which allows attribute setting to work,
|
||||
# albeit at the cost of the dictionary.
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if not hasattr(self, '_immutable_init') or \
|
||||
self._immutable_init is not self:
|
||||
raise TypeError("object doesn't support attribute assignment")
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
if not hasattr(self, '_immutable_init') or \
|
||||
self._immutable_init is not self:
|
||||
raise TypeError("object doesn't support attribute assignment")
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
|
||||
def _immutable_init(f):
|
||||
def nf(*args, **kwargs):
|
||||
try:
|
||||
# Are we already initializing an immutable class?
|
||||
previous = args[0]._immutable_init
|
||||
except AttributeError:
|
||||
# We are the first!
|
||||
previous = None
|
||||
object.__setattr__(args[0], '_immutable_init', args[0])
|
||||
try:
|
||||
# call the actual __init__
|
||||
f(*args, **kwargs)
|
||||
finally:
|
||||
if not previous:
|
||||
# If we started the initialization, establish immutability
|
||||
# by removing the attribute that allows mutation
|
||||
object.__delattr__(args[0], '_immutable_init')
|
||||
nf.__signature__ = inspect.signature(f)
|
||||
return nf
|
||||
|
||||
|
||||
def immutable(cls):
|
||||
if _Immutable in cls.__mro__:
|
||||
# Some ancestor already has the mixin, so just make sure we keep
|
||||
# following the __init__ protocol.
|
||||
cls.__init__ = _immutable_init(cls.__init__)
|
||||
if hasattr(cls, '__setstate__'):
|
||||
cls.__setstate__ = _immutable_init(cls.__setstate__)
|
||||
ncls = cls
|
||||
else:
|
||||
# Mixin the Immutable class and follow the __init__ protocol.
|
||||
class ncls(_Immutable, cls):
|
||||
|
||||
@_immutable_init
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if hasattr(cls, '__setstate__'):
|
||||
@_immutable_init
|
||||
def __setstate__(self, *args, **kwargs):
|
||||
super().__setstate__(*args, **kwargs)
|
||||
|
||||
# make ncls have the same name and module as cls
|
||||
ncls.__name__ = cls.__name__
|
||||
ncls.__qualname__ = cls.__qualname__
|
||||
ncls.__module__ = cls.__module__
|
||||
return ncls
|
|
@ -8,7 +8,7 @@ import contextvars
|
|||
import inspect
|
||||
|
||||
|
||||
_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
|
||||
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
|
||||
|
||||
|
||||
class _Immutable:
|
||||
|
@ -41,6 +41,7 @@ def _immutable_init(f):
|
|||
f(*args, **kwargs)
|
||||
finally:
|
||||
_in__init__.reset(previous)
|
||||
|
||||
nf.__signature__ = inspect.signature(f)
|
||||
return nf
|
||||
|
||||
|
@ -50,7 +51,7 @@ def immutable(cls):
|
|||
# Some ancestor already has the mixin, so just make sure we keep
|
||||
# following the __init__ protocol.
|
||||
cls.__init__ = _immutable_init(cls.__init__)
|
||||
if hasattr(cls, '__setstate__'):
|
||||
if hasattr(cls, "__setstate__"):
|
||||
cls.__setstate__ = _immutable_init(cls.__setstate__)
|
||||
ncls = cls
|
||||
else:
|
||||
|
@ -63,7 +64,8 @@ def immutable(cls):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if hasattr(cls, '__setstate__'):
|
||||
if hasattr(cls, "__setstate__"):
|
||||
|
||||
@_immutable_init
|
||||
def __setstate__(self, *args, **kwargs):
|
||||
super().__setstate__(*args, **kwargs)
|
||||
|
|
|
@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
|
|||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, socket):
|
||||
super().__init__(socket.family)
|
||||
self.socket = socket
|
||||
self.family = socket.family
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.socket.sendto(what, destination)
|
||||
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.socket.recvfrom(size)
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
self.socket.close()
|
||||
|
@ -58,12 +60,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
async def sendall(self, what, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.stream.send_all(what)
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.stream.receive_some(size)
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
await self.stream.aclose()
|
||||
|
@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
return 'trio'
|
||||
return "trio"
|
||||
|
||||
async def make_socket(self, af, socktype, proto=0, source=None,
|
||||
destination=None, timeout=None,
|
||||
ssl_context=None, server_hostname=None):
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
s = trio.socket.socket(af, socktype, proto)
|
||||
stream = None
|
||||
try:
|
||||
|
@ -103,19 +113,20 @@ class Backend(dns._asyncbackend.Backend):
|
|||
return DatagramSocket(s)
|
||||
elif socktype == socket.SOCK_STREAM:
|
||||
stream = trio.SocketStream(s)
|
||||
s = None
|
||||
tls = False
|
||||
if ssl_context:
|
||||
tls = True
|
||||
try:
|
||||
stream = trio.SSLStream(stream, ssl_context,
|
||||
server_hostname=server_hostname)
|
||||
stream = trio.SSLStream(
|
||||
stream, ssl_context, server_hostname=server_hostname
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
await stream.aclose()
|
||||
raise
|
||||
return StreamSocket(af, stream, tls)
|
||||
raise NotImplementedError('unsupported socket ' +
|
||||
f'type {socktype}') # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"unsupported socket " + f"type {socktype}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def sleep(self, interval):
|
||||
await trio.sleep(interval)
|
||||
|
|
|
@ -1,26 +1,33 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import dns.exception
|
||||
|
||||
# pylint: disable=unused-import
|
||||
|
||||
from dns._asyncbackend import Socket, DatagramSocket, \
|
||||
StreamSocket, Backend # noqa:
|
||||
from dns._asyncbackend import (
|
||||
Socket,
|
||||
DatagramSocket,
|
||||
StreamSocket,
|
||||
Backend,
|
||||
) # noqa: F401 lgtm[py/unused-import]
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
_default_backend = None
|
||||
|
||||
_backends = {}
|
||||
_backends: Dict[str, Backend] = {}
|
||||
|
||||
# Allow sniffio import to be disabled for testing purposes
|
||||
_no_sniffio = False
|
||||
|
||||
|
||||
class AsyncLibraryNotFoundError(dns.exception.DNSException):
|
||||
pass
|
||||
|
||||
|
||||
def get_backend(name):
|
||||
def get_backend(name: str) -> Backend:
|
||||
"""Get the specified asynchronous backend.
|
||||
|
||||
*name*, a ``str``, the name of the backend. Currently the "trio",
|
||||
|
@ -32,22 +39,25 @@ def get_backend(name):
|
|||
backend = _backends.get(name)
|
||||
if backend:
|
||||
return backend
|
||||
if name == 'trio':
|
||||
if name == "trio":
|
||||
import dns._trio_backend
|
||||
|
||||
backend = dns._trio_backend.Backend()
|
||||
elif name == 'curio':
|
||||
elif name == "curio":
|
||||
import dns._curio_backend
|
||||
|
||||
backend = dns._curio_backend.Backend()
|
||||
elif name == 'asyncio':
|
||||
elif name == "asyncio":
|
||||
import dns._asyncio_backend
|
||||
|
||||
backend = dns._asyncio_backend.Backend()
|
||||
else:
|
||||
raise NotImplementedError(f'unimplemented async backend {name}')
|
||||
raise NotImplementedError(f"unimplemented async backend {name}")
|
||||
_backends[name] = backend
|
||||
return backend
|
||||
|
||||
|
||||
def sniff():
|
||||
def sniff() -> str:
|
||||
"""Attempt to determine the in-use asynchronous I/O library by using
|
||||
the ``sniffio`` module if it is available.
|
||||
|
||||
|
@ -59,35 +69,32 @@ def sniff():
|
|||
if _no_sniffio:
|
||||
raise ImportError
|
||||
import sniffio
|
||||
|
||||
try:
|
||||
return sniffio.current_async_library()
|
||||
except sniffio.AsyncLibraryNotFoundError:
|
||||
raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
|
||||
'async library')
|
||||
raise AsyncLibraryNotFoundError(
|
||||
"sniffio cannot determine " + "async library"
|
||||
)
|
||||
except ImportError:
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return 'asyncio'
|
||||
return "asyncio"
|
||||
except RuntimeError:
|
||||
raise AsyncLibraryNotFoundError('no async library detected')
|
||||
except AttributeError: # pragma: no cover
|
||||
# we have to check current_task on 3.6
|
||||
if not asyncio.Task.current_task():
|
||||
raise AsyncLibraryNotFoundError('no async library detected')
|
||||
return 'asyncio'
|
||||
raise AsyncLibraryNotFoundError("no async library detected")
|
||||
|
||||
|
||||
def get_default_backend():
|
||||
"""Get the default backend, initializing it if necessary.
|
||||
"""
|
||||
def get_default_backend() -> Backend:
|
||||
"""Get the default backend, initializing it if necessary."""
|
||||
if _default_backend:
|
||||
return _default_backend
|
||||
|
||||
return set_default_backend(sniff())
|
||||
|
||||
|
||||
def set_default_backend(name):
|
||||
def set_default_backend(name: str) -> Backend:
|
||||
"""Set the default backend.
|
||||
|
||||
It's not normally necessary to call this method, as
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
class Backend:
|
||||
...
|
||||
|
||||
def get_backend(name: str) -> Backend:
|
||||
...
|
||||
def sniff() -> str:
|
||||
...
|
||||
def get_default_backend() -> Backend:
|
||||
...
|
||||
def set_default_backend(name: str) -> Backend:
|
||||
...
|
|
@ -17,7 +17,10 @@
|
|||
|
||||
"""Talk to a DNS server."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
|
@ -27,12 +30,24 @@ import dns.exception
|
|||
import dns.inet
|
||||
import dns.name
|
||||
import dns.message
|
||||
import dns.quic
|
||||
import dns.rcode
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.transaction
|
||||
|
||||
from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
|
||||
UDPMode, _have_httpx, _have_http2, NoDOH
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.query import (
|
||||
_compute_times,
|
||||
_matches_destination,
|
||||
BadResponse,
|
||||
ssl,
|
||||
UDPMode,
|
||||
_have_httpx,
|
||||
_have_http2,
|
||||
NoDOH,
|
||||
NoDOQ,
|
||||
)
|
||||
|
||||
if _have_httpx:
|
||||
import httpx
|
||||
|
@ -47,11 +62,11 @@ def _source_tuple(af, address, port):
|
|||
if address or port:
|
||||
if address is None:
|
||||
if af == socket.AF_INET:
|
||||
address = '0.0.0.0'
|
||||
address = "0.0.0.0"
|
||||
elif af == socket.AF_INET6:
|
||||
address = '::'
|
||||
address = "::"
|
||||
else:
|
||||
raise NotImplementedError(f'unknown address family {af}')
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
return (address, port)
|
||||
else:
|
||||
return None
|
||||
|
@ -66,7 +81,12 @@ def _timeout(expiration, now=None):
|
|||
return None
|
||||
|
||||
|
||||
async def send_udp(sock, what, destination, expiration=None):
|
||||
async def send_udp(
|
||||
sock: dns.asyncbackend.DatagramSocket,
|
||||
what: Union[dns.message.Message, bytes],
|
||||
destination: Any,
|
||||
expiration: Optional[float] = None,
|
||||
) -> Tuple[int, float]:
|
||||
"""Send a DNS message to the specified UDP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.DatagramSocket``.
|
||||
|
@ -78,7 +98,8 @@ async def send_udp(sock, what, destination, expiration=None):
|
|||
|
||||
*expiration*, a ``float`` or ``None``, the absolute time at which
|
||||
a timeout exception should be raised. If ``None``, no timeout will
|
||||
occur.
|
||||
occur. The expiration value is meaningless for the asyncio backend, as
|
||||
asyncio's transport sendto() never blocks.
|
||||
|
||||
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
|
||||
"""
|
||||
|
@ -90,35 +111,61 @@ async def send_udp(sock, what, destination, expiration=None):
|
|||
return (n, sent_time)
|
||||
|
||||
|
||||
async def receive_udp(sock, destination=None, expiration=None,
|
||||
ignore_unexpected=False, one_rr_per_rrset=False,
|
||||
keyring=None, request_mac=b'', ignore_trailing=False,
|
||||
raise_on_truncation=False):
|
||||
async def receive_udp(
|
||||
sock: dns.asyncbackend.DatagramSocket,
|
||||
destination: Optional[Any] = None,
|
||||
expiration: Optional[float] = None,
|
||||
ignore_unexpected: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
|
||||
request_mac: Optional[bytes] = b"",
|
||||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
) -> Any:
|
||||
"""Read a DNS message from a UDP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.DatagramSocket``.
|
||||
|
||||
See :py:func:`dns.query.receive_udp()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
parameters, and exceptions.
|
||||
|
||||
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
|
||||
received time, and the address where the message arrived from.
|
||||
"""
|
||||
|
||||
wire = b''
|
||||
wire = b""
|
||||
while 1:
|
||||
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
|
||||
if _matches_destination(sock.family, from_address, destination,
|
||||
ignore_unexpected):
|
||||
if _matches_destination(
|
||||
sock.family, from_address, destination, ignore_unexpected
|
||||
):
|
||||
break
|
||||
received_time = time.time()
|
||||
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
raise_on_truncation=raise_on_truncation)
|
||||
raise_on_truncation=raise_on_truncation,
|
||||
)
|
||||
return (r, received_time, from_address)
|
||||
|
||||
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
|
||||
ignore_unexpected=False, one_rr_per_rrset=False,
|
||||
ignore_trailing=False, raise_on_truncation=False, sock=None,
|
||||
backend=None):
|
||||
|
||||
async def udp(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 53,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
ignore_unexpected: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via UDP.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
|
||||
|
@ -134,13 +181,10 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
|
|||
"""
|
||||
wire = q.to_wire()
|
||||
(begin_time, expiration) = _compute_times(timeout)
|
||||
s = None
|
||||
# After 3.6 is no longer supported, this can use an AsyncExitStack.
|
||||
try:
|
||||
af = dns.inet.af_for_address(where)
|
||||
destination = _lltuple((where, port), af)
|
||||
if sock:
|
||||
s = sock
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
|
@ -149,27 +193,40 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
|
|||
dtuple = (where, port)
|
||||
else:
|
||||
dtuple = None
|
||||
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
|
||||
dtuple)
|
||||
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
|
||||
async with cm as s:
|
||||
await send_udp(s, wire, destination, expiration)
|
||||
(r, received_time, _) = await receive_udp(s, destination, expiration,
|
||||
(r, received_time, _) = await receive_udp(
|
||||
s,
|
||||
destination,
|
||||
expiration,
|
||||
ignore_unexpected,
|
||||
one_rr_per_rrset,
|
||||
q.keyring, q.mac,
|
||||
q.keyring,
|
||||
q.mac,
|
||||
ignore_trailing,
|
||||
raise_on_truncation)
|
||||
raise_on_truncation,
|
||||
)
|
||||
r.time = received_time - begin_time
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
finally:
|
||||
if not sock and s:
|
||||
await s.close()
|
||||
|
||||
async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
|
||||
source_port=0, ignore_unexpected=False,
|
||||
one_rr_per_rrset=False, ignore_trailing=False,
|
||||
udp_sock=None, tcp_sock=None, backend=None):
|
||||
|
||||
async def udp_with_fallback(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 53,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
ignore_unexpected: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
|
||||
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> Tuple[dns.message.Message, bool]:
|
||||
"""Return the response to the query, trying UDP first and falling back
|
||||
to TCP if UDP results in a truncated response.
|
||||
|
||||
|
@ -191,18 +248,42 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
|
|||
method.
|
||||
"""
|
||||
try:
|
||||
response = await udp(q, where, timeout, port, source, source_port,
|
||||
ignore_unexpected, one_rr_per_rrset,
|
||||
ignore_trailing, True, udp_sock, backend)
|
||||
response = await udp(
|
||||
q,
|
||||
where,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
ignore_unexpected,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
True,
|
||||
udp_sock,
|
||||
backend,
|
||||
)
|
||||
return (response, False)
|
||||
except dns.message.Truncated:
|
||||
response = await tcp(q, where, timeout, port, source, source_port,
|
||||
one_rr_per_rrset, ignore_trailing, tcp_sock,
|
||||
backend)
|
||||
response = await tcp(
|
||||
q,
|
||||
where,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
tcp_sock,
|
||||
backend,
|
||||
)
|
||||
return (response, True)
|
||||
|
||||
|
||||
async def send_tcp(sock, what, expiration=None):
|
||||
async def send_tcp(
|
||||
sock: dns.asyncbackend.StreamSocket,
|
||||
what: Union[dns.message.Message, bytes],
|
||||
expiration: Optional[float] = None,
|
||||
) -> Tuple[int, float]:
|
||||
"""Send a DNS message to the specified TCP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.StreamSocket``.
|
||||
|
@ -212,12 +293,14 @@ async def send_tcp(sock, what, expiration=None):
|
|||
"""
|
||||
|
||||
if isinstance(what, dns.message.Message):
|
||||
what = what.to_wire()
|
||||
l = len(what)
|
||||
wire = what.to_wire()
|
||||
else:
|
||||
wire = what
|
||||
l = len(wire)
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = struct.pack("!H", l) + what
|
||||
tcpmsg = struct.pack("!H", l) + wire
|
||||
sent_time = time.time()
|
||||
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
|
||||
return (len(tcpmsg), sent_time)
|
||||
|
@ -227,18 +310,24 @@ async def _read_exactly(sock, count, expiration):
|
|||
"""Read the specified number of bytes from stream. Keep trying until we
|
||||
either get the desired amount, or we hit EOF.
|
||||
"""
|
||||
s = b''
|
||||
s = b""
|
||||
while count > 0:
|
||||
n = await sock.recv(count, _timeout(expiration))
|
||||
if n == b'':
|
||||
if n == b"":
|
||||
raise EOFError
|
||||
count = count - len(n)
|
||||
s = s + n
|
||||
return s
|
||||
|
||||
|
||||
async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
|
||||
keyring=None, request_mac=b'', ignore_trailing=False):
|
||||
async def receive_tcp(
|
||||
sock: dns.asyncbackend.StreamSocket,
|
||||
expiration: Optional[float] = None,
|
||||
one_rr_per_rrset: bool = False,
|
||||
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
|
||||
request_mac: Optional[bytes] = b"",
|
||||
ignore_trailing: bool = False,
|
||||
) -> Tuple[dns.message.Message, float]:
|
||||
"""Read a DNS message from a TCP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.StreamSocket``.
|
||||
|
@ -251,15 +340,28 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
|
|||
(l,) = struct.unpack("!H", ldata)
|
||||
wire = await _read_exactly(sock, l, expiration)
|
||||
received_time = time.time()
|
||||
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing)
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
return (r, received_time)
|
||||
|
||||
|
||||
async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
|
||||
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
|
||||
backend=None):
|
||||
async def tcp(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 53,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
sock: Optional[dns.asyncbackend.StreamSocket] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via TCP.
|
||||
|
||||
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
|
||||
|
@ -276,41 +378,48 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
|
|||
|
||||
wire = q.to_wire()
|
||||
(begin_time, expiration) = _compute_times(timeout)
|
||||
s = None
|
||||
# After 3.6 is no longer supported, this can use an AsyncExitStack.
|
||||
try:
|
||||
if sock:
|
||||
# Verify that the socket is connected, as if it's not connected,
|
||||
# it's not writable, and the polling in send_tcp() will time out or
|
||||
# hang forever.
|
||||
await sock.getpeername()
|
||||
s = sock
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
# These are simple (address, port) pairs, not
|
||||
# family-dependent tuples you pass to lowlevel socket
|
||||
# code.
|
||||
# These are simple (address, port) pairs, not family-dependent tuples
|
||||
# you pass to low-level socket code.
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
|
||||
dtuple, timeout)
|
||||
cm = await backend.make_socket(
|
||||
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
|
||||
)
|
||||
async with cm as s:
|
||||
await send_tcp(s, wire, expiration)
|
||||
(r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
|
||||
q.keyring, q.mac,
|
||||
ignore_trailing)
|
||||
(r, received_time) = await receive_tcp(
|
||||
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
|
||||
)
|
||||
r.time = received_time - begin_time
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
finally:
|
||||
if not sock and s:
|
||||
await s.close()
|
||||
|
||||
async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
|
||||
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
|
||||
backend=None, ssl_context=None, server_hostname=None):
|
||||
|
||||
async def tls(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
sock: Optional[dns.asyncbackend.StreamSocket] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via TLS.
|
||||
|
||||
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
|
||||
|
@ -326,11 +435,14 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
|
|||
See :py:func:`dns.query.tls()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
# After 3.6 is no longer supported, this can use an AsyncExitStack.
|
||||
(begin_time, expiration) = _compute_times(timeout)
|
||||
if not sock:
|
||||
if sock:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
if ssl_context is None:
|
||||
ssl_context = ssl.create_default_context()
|
||||
# See the comment about ssl.create_default_context() in query.py
|
||||
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
|
||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
if server_hostname is None:
|
||||
ssl_context.check_hostname = False
|
||||
else:
|
||||
|
@ -341,25 +453,49 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
|
|||
dtuple = (where, port)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
|
||||
dtuple, timeout, ssl_context,
|
||||
server_hostname)
|
||||
else:
|
||||
s = sock
|
||||
try:
|
||||
cm = await backend.make_socket(
|
||||
af,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
stuple,
|
||||
dtuple,
|
||||
timeout,
|
||||
ssl_context,
|
||||
server_hostname,
|
||||
)
|
||||
async with cm as s:
|
||||
timeout = _timeout(expiration)
|
||||
response = await tcp(q, where, timeout, port, source, source_port,
|
||||
one_rr_per_rrset, ignore_trailing, s, backend)
|
||||
response = await tcp(
|
||||
q,
|
||||
where,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
s,
|
||||
backend,
|
||||
)
|
||||
end_time = time.time()
|
||||
response.time = end_time - begin_time
|
||||
return response
|
||||
finally:
|
||||
if not sock and s:
|
||||
await s.close()
|
||||
|
||||
async def https(q, where, timeout=None, port=443, source=None, source_port=0,
|
||||
one_rr_per_rrset=False, ignore_trailing=False, client=None,
|
||||
path='/dns-query', post=True, verify=True):
|
||||
|
||||
async def https(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 443,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0, # pylint: disable=W0613
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
client: Optional["httpx.AsyncClient"] = None,
|
||||
path: str = "/dns-query",
|
||||
post: bool = True,
|
||||
verify: Union[bool, str] = True,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
|
||||
|
@ -373,7 +509,7 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
|
|||
"""
|
||||
|
||||
if not _have_httpx:
|
||||
raise NoDOH('httpx is not available.') # pragma: no cover
|
||||
raise NoDOH("httpx is not available.") # pragma: no cover
|
||||
|
||||
wire = q.to_wire()
|
||||
try:
|
||||
|
@ -381,65 +517,78 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
|
|||
except ValueError:
|
||||
af = None
|
||||
transport = None
|
||||
headers = {
|
||||
"accept": "application/dns-message"
|
||||
}
|
||||
headers = {"accept": "application/dns-message"}
|
||||
if af is not None:
|
||||
if af == socket.AF_INET:
|
||||
url = 'https://{}:{}{}'.format(where, port, path)
|
||||
url = "https://{}:{}{}".format(where, port, path)
|
||||
elif af == socket.AF_INET6:
|
||||
url = 'https://[{}]:{}{}'.format(where, port, path)
|
||||
url = "https://[{}]:{}{}".format(where, port, path)
|
||||
else:
|
||||
url = where
|
||||
if source is not None:
|
||||
transport = httpx.AsyncHTTPTransport(local_address=source[0])
|
||||
|
||||
# After 3.6 is no longer supported, this can use an AsyncExitStack
|
||||
client_to_close = None
|
||||
try:
|
||||
if not client:
|
||||
client = httpx.AsyncClient(http1=True, http2=_have_http2,
|
||||
verify=verify, transport=transport)
|
||||
client_to_close = client
|
||||
if client:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
||||
else:
|
||||
cm = httpx.AsyncClient(
|
||||
http1=True, http2=_have_http2, verify=verify, transport=transport
|
||||
)
|
||||
|
||||
async with cm as the_client:
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
if post:
|
||||
headers.update({
|
||||
headers.update(
|
||||
{
|
||||
"content-type": "application/dns-message",
|
||||
"content-length": str(len(wire))
|
||||
})
|
||||
response = await client.post(url, headers=headers, content=wire,
|
||||
timeout=timeout)
|
||||
"content-length": str(len(wire)),
|
||||
}
|
||||
)
|
||||
response = await the_client.post(
|
||||
url, headers=headers, content=wire, timeout=timeout
|
||||
)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
wire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = await client.get(url, headers=headers, timeout=timeout,
|
||||
params={"dns": wire})
|
||||
finally:
|
||||
if client_to_close:
|
||||
await client.aclose()
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = await the_client.get(
|
||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
# status codes
|
||||
if response.status_code < 200 or response.status_code > 299:
|
||||
raise ValueError('{} responded with status code {}'
|
||||
'\nResponse body: {}'.format(where,
|
||||
response.status_code,
|
||||
response.content))
|
||||
r = dns.message.from_wire(response.content,
|
||||
raise ValueError(
|
||||
"{} responded with status code {}"
|
||||
"\nResponse body: {!r}".format(
|
||||
where, response.status_code, response.content
|
||||
)
|
||||
)
|
||||
r = dns.message.from_wire(
|
||||
response.content,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing)
|
||||
r.time = response.elapsed
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = response.elapsed.total_seconds()
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
async def inbound_xfr(where, txn_manager, query=None,
|
||||
port=53, timeout=None, lifetime=None, source=None,
|
||||
source_port=0, udp_mode=UDPMode.NEVER, backend=None):
|
||||
|
||||
async def inbound_xfr(
|
||||
where: str,
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
query: Optional[dns.message.Message] = None,
|
||||
port: int = 53,
|
||||
timeout: Optional[float] = None,
|
||||
lifetime: Optional[float] = None,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
udp_mode: UDPMode = UDPMode.NEVER,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> None:
|
||||
"""Conduct an inbound transfer and apply it via a transaction from the
|
||||
txn_manager.
|
||||
|
||||
|
@ -472,42 +621,48 @@ async def inbound_xfr(where, txn_manager, query=None,
|
|||
is_udp = False
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
s = await backend.make_socket(af, sock_type, 0, stuple, dtuple,
|
||||
_timeout(expiration))
|
||||
s = await backend.make_socket(
|
||||
af, sock_type, 0, stuple, dtuple, _timeout(expiration)
|
||||
)
|
||||
async with s:
|
||||
if is_udp:
|
||||
await s.sendto(wire, dtuple, _timeout(expiration))
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
await s.sendall(tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial,
|
||||
is_udp) as inbound:
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or \
|
||||
(expiration is not None and mexpiration > expiration):
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
destination = _lltuple((where, port), af)
|
||||
while True:
|
||||
timeout = _timeout(mexpiration)
|
||||
(rwire, from_address) = await s.recvfrom(65535,
|
||||
timeout)
|
||||
if _matches_destination(af, from_address,
|
||||
destination, True):
|
||||
(rwire, from_address) = await s.recvfrom(65535, timeout)
|
||||
if _matches_destination(
|
||||
af, from_address, destination, True
|
||||
):
|
||||
break
|
||||
else:
|
||||
ldata = await _read_exactly(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = await _read_exactly(s, l, mexpiration)
|
||||
is_ixfr = (rdtype == dns.rdatatype.IXFR)
|
||||
r = dns.message.from_wire(rwire, keyring=query.keyring,
|
||||
request_mac=query.mac, xfr=True,
|
||||
origin=origin, tsig_ctx=tsig_ctx,
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr)
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
try:
|
||||
done = inbound.process_message(r)
|
||||
except dns.xfr.UseTCP:
|
||||
|
@ -521,3 +676,62 @@ async def inbound_xfr(where, txn_manager, query=None,
|
|||
tsig_ctx = r.tsig_ctx
|
||||
if not retry and query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
|
||||
|
||||
async def quic(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
connection: Optional[dns.quic.AsyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending an asynchronous query via
|
||||
DNS-over-QUIC.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.quic()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.AsyncQuicConnection
|
||||
if connection:
|
||||
cfactory = dns.quic.null_factory
|
||||
mfactory = dns.quic.null_factory
|
||||
the_connection = connection
|
||||
else:
|
||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(context, verify_mode=verify) as the_manager:
|
||||
if not connection:
|
||||
the_connection = the_manager.connect(where, port, source, source_port)
|
||||
start = time.time()
|
||||
stream = await the_connection.make_stream()
|
||||
async with stream:
|
||||
await stream.send(wire, True)
|
||||
wire = await stream.receive(timeout)
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
from typing import Optional, Union, Dict, Generator, Any
|
||||
from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
|
||||
|
||||
# If the ssl import works, then
|
||||
#
|
||||
# error: Name 'ssl' already defined (by an import)
|
||||
#
|
||||
# is expected and can be ignored.
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
class ssl: # type: ignore
|
||||
SSLContext : Dict = {}
|
||||
|
||||
async def udp(q : message.Message, where : str,
|
||||
timeout : Optional[float] = None, port=53,
|
||||
source : Optional[str] = None, source_port : Optional[int] = 0,
|
||||
ignore_unexpected : Optional[bool] = False,
|
||||
one_rr_per_rrset : Optional[bool] = False,
|
||||
ignore_trailing : Optional[bool] = False,
|
||||
sock : Optional[asyncbackend.DatagramSocket] = None,
|
||||
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
|
||||
pass
|
||||
|
||||
async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
|
||||
af : Optional[int] = None, source : Optional[str] = None,
|
||||
source_port : Optional[int] = 0,
|
||||
one_rr_per_rrset : Optional[bool] = False,
|
||||
ignore_trailing : Optional[bool] = False,
|
||||
sock : Optional[asyncbackend.StreamSocket] = None,
|
||||
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
|
||||
pass
|
||||
|
||||
async def tls(q : message.Message, where : str,
|
||||
timeout : Optional[float] = None, port=53,
|
||||
source : Optional[str] = None, source_port : Optional[int] = 0,
|
||||
one_rr_per_rrset : Optional[bool] = False,
|
||||
ignore_trailing : Optional[bool] = False,
|
||||
sock : Optional[asyncbackend.StreamSocket] = None,
|
||||
backend : Optional[asyncbackend.Backend] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
server_hostname: Optional[str] = None) -> message.Message:
|
||||
pass
|
|
@ -17,13 +17,18 @@
|
|||
|
||||
"""Asynchronous DNS stub resolver."""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import time
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.asyncquery
|
||||
import dns.exception
|
||||
import dns.name
|
||||
import dns.query
|
||||
import dns.resolver
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.resolver # lgtm[py/import-and-import-from]
|
||||
|
||||
# import some resolver symbols for brevity
|
||||
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
|
||||
|
@ -37,11 +42,19 @@ _tcp = dns.asyncquery.tcp
|
|||
class Resolver(dns.resolver.BaseResolver):
|
||||
"""Asynchronous DNS stub resolver."""
|
||||
|
||||
async def resolve(self, qname, rdtype=dns.rdatatype.A,
|
||||
rdclass=dns.rdataclass.IN,
|
||||
tcp=False, source=None, raise_on_no_answer=True,
|
||||
source_port=0, lifetime=None, search=None,
|
||||
backend=None):
|
||||
async def resolve(
|
||||
self,
|
||||
qname: Union[dns.name.Name, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: Optional[str] = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: Optional[float] = None,
|
||||
search: Optional[bool] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Query nameservers asynchronously to find the answer to the question.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
|
@ -52,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
type of this method.
|
||||
"""
|
||||
|
||||
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
|
||||
raise_on_no_answer, search)
|
||||
resolution = dns.resolver._Resolution(
|
||||
self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
|
||||
)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
start = time.time()
|
||||
|
@ -66,30 +80,40 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
if answer is not None:
|
||||
# cache hit!
|
||||
return answer
|
||||
assert request is not None # needed for type checking
|
||||
done = False
|
||||
while not done:
|
||||
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
|
||||
if backoff:
|
||||
await backend.sleep(backoff)
|
||||
timeout = self._compute_timeout(start, lifetime,
|
||||
resolution.errors)
|
||||
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
||||
try:
|
||||
if dns.inet.is_address(nameserver):
|
||||
if tcp:
|
||||
response = await _tcp(request, nameserver,
|
||||
timeout, port,
|
||||
source, source_port,
|
||||
backend=backend)
|
||||
else:
|
||||
response = await _udp(request, nameserver,
|
||||
timeout, port,
|
||||
source, source_port,
|
||||
raise_on_truncation=True,
|
||||
backend=backend)
|
||||
else:
|
||||
response = await dns.asyncquery.https(request,
|
||||
response = await _tcp(
|
||||
request,
|
||||
nameserver,
|
||||
timeout=timeout)
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
backend=backend,
|
||||
)
|
||||
else:
|
||||
response = await _udp(
|
||||
request,
|
||||
nameserver,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
raise_on_truncation=True,
|
||||
backend=backend,
|
||||
)
|
||||
else:
|
||||
response = await dns.asyncquery.https(
|
||||
request, nameserver, timeout=timeout
|
||||
)
|
||||
except Exception as ex:
|
||||
(_, done) = resolution.query_result(None, ex)
|
||||
continue
|
||||
|
@ -101,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
if answer is not None:
|
||||
return answer
|
||||
|
||||
async def resolve_address(self, ipaddr, *args, **kwargs):
|
||||
async def resolve_address(
|
||||
self, ipaddr: str, *args: Any, **kwargs: Any
|
||||
) -> dns.resolver.Answer:
|
||||
"""Use an asynchronous resolver to run a reverse query for PTR
|
||||
records.
|
||||
|
||||
|
@ -116,15 +142,20 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
function.
|
||||
|
||||
"""
|
||||
|
||||
return await self.resolve(dns.reversename.from_address(ipaddr),
|
||||
rdtype=dns.rdatatype.PTR,
|
||||
rdclass=dns.rdataclass.IN,
|
||||
*args, **kwargs)
|
||||
# We make a modified kwargs for type checking happiness, as otherwise
|
||||
# we get a legit warning about possibly having rdtype and rdclass
|
||||
# in the kwargs more than once.
|
||||
modified_kwargs: Dict[str, Any] = {}
|
||||
modified_kwargs.update(kwargs)
|
||||
modified_kwargs["rdtype"] = dns.rdatatype.PTR
|
||||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||
return await self.resolve(
|
||||
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
||||
)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
async def canonical_name(self, name):
|
||||
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
"""Determine the canonical name of *name*.
|
||||
|
||||
The canonical name is the name the resolver uses for queries
|
||||
|
@ -149,14 +180,15 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
default_resolver = None
|
||||
|
||||
|
||||
def get_default_resolver():
|
||||
def get_default_resolver() -> Resolver:
|
||||
"""Get the default asynchronous resolver, initializing it if necessary."""
|
||||
if default_resolver is None:
|
||||
reset_default_resolver()
|
||||
assert default_resolver is not None
|
||||
return default_resolver
|
||||
|
||||
|
||||
def reset_default_resolver():
|
||||
def reset_default_resolver() -> None:
|
||||
"""Re-initialize default asynchronous resolver.
|
||||
|
||||
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
|
||||
|
@ -167,9 +199,18 @@ def reset_default_resolver():
|
|||
default_resolver = Resolver()
|
||||
|
||||
|
||||
async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
|
||||
tcp=False, source=None, raise_on_no_answer=True,
|
||||
source_port=0, lifetime=None, search=None, backend=None):
|
||||
async def resolve(
|
||||
qname: Union[dns.name.Name, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: Optional[str] = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: Optional[float] = None,
|
||||
search: Optional[bool] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Query nameservers asynchronously to find the answer to the question.
|
||||
|
||||
This is a convenience function that uses the default resolver
|
||||
|
@ -179,13 +220,23 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
|
|||
information on the parameters.
|
||||
"""
|
||||
|
||||
return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp,
|
||||
source, raise_on_no_answer,
|
||||
source_port, lifetime, search,
|
||||
backend)
|
||||
return await get_default_resolver().resolve(
|
||||
qname,
|
||||
rdtype,
|
||||
rdclass,
|
||||
tcp,
|
||||
source,
|
||||
raise_on_no_answer,
|
||||
source_port,
|
||||
lifetime,
|
||||
search,
|
||||
backend,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_address(ipaddr, *args, **kwargs):
|
||||
async def resolve_address(
|
||||
ipaddr: str, *args: Any, **kwargs: Any
|
||||
) -> dns.resolver.Answer:
|
||||
"""Use a resolver to run a reverse query for PTR records.
|
||||
|
||||
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
|
||||
|
@ -194,7 +245,8 @@ async def resolve_address(ipaddr, *args, **kwargs):
|
|||
|
||||
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
||||
|
||||
async def canonical_name(name):
|
||||
|
||||
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
"""Determine the canonical name of *name*.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.canonical_name` for more
|
||||
|
@ -203,8 +255,14 @@ async def canonical_name(name):
|
|||
|
||||
return await get_default_resolver().canonical_name(name)
|
||||
|
||||
async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
|
||||
resolver=None, backend=None):
|
||||
|
||||
async def zone_for_name(
|
||||
name: Union[dns.name.Name, str],
|
||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
resolver: Optional[Resolver] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
) -> dns.name.Name:
|
||||
"""Find the name of the zone which contains the specified name.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
|
||||
|
@ -219,8 +277,10 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
|
|||
raise NotAbsolute(name)
|
||||
while True:
|
||||
try:
|
||||
answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass,
|
||||
tcp, backend=backend)
|
||||
answer = await resolver.resolve(
|
||||
name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
|
||||
)
|
||||
assert answer.rrset is not None
|
||||
if answer.rrset.name == name:
|
||||
return name
|
||||
# otherwise we were CNAMEd or DNAMEd and need to look higher
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
from typing import Union, Optional, List, Any, Dict
|
||||
from . import exception, rdataclass, name, rdatatype, asyncbackend
|
||||
|
||||
async def resolve(qname : str, rdtype : Union[int,str] = 0,
|
||||
rdclass : Union[int,str] = 0,
|
||||
tcp=False, source=None, raise_on_no_answer=True,
|
||||
source_port=0, lifetime : Optional[float]=None,
|
||||
search : Optional[bool]=None,
|
||||
backend : Optional[asyncbackend.Backend]=None):
|
||||
...
|
||||
async def resolve_address(self, ipaddr: str,
|
||||
*args: Any, **kwargs: Optional[Dict]):
|
||||
...
|
||||
|
||||
class Resolver:
|
||||
def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
|
||||
configure : Optional[bool] = True):
|
||||
self.nameservers : List[str]
|
||||
async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
|
||||
rdclass : Union[int,str] = rdataclass.IN,
|
||||
tcp : bool = False, source : Optional[str] = None,
|
||||
raise_on_no_answer=True, source_port : int = 0,
|
||||
lifetime : Optional[float]=None,
|
||||
search : Optional[bool]=None,
|
||||
backend : Optional[asyncbackend.Backend]=None):
|
||||
...
|
File diff suppressed because it is too large
Load diff
|
@ -1,21 +0,0 @@
|
|||
from typing import Union, Dict, Tuple, Optional
|
||||
from . import rdataset, rrset, exception, name, rdtypes, rdata, node
|
||||
import dns.rdtypes.ANY.DS as DS
|
||||
import dns.rdtypes.ANY.DNSKEY as DNSKEY
|
||||
|
||||
_have_pyca : bool
|
||||
|
||||
def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
|
||||
...
|
||||
|
||||
def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
|
||||
...
|
||||
|
||||
class ValidationFailure(exception.DNSException):
|
||||
...
|
||||
|
||||
def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
|
||||
...
|
||||
|
||||
def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
|
||||
...
|
71
lib/dns/dnssectypes.py
Normal file
71
lib/dns/dnssectypes.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""Common DNSSEC-related types."""
|
||||
|
||||
# This is a separate file to avoid import circularity between dns.dnssec and
|
||||
# the implementations of the DS and DNSKEY types.
|
||||
|
||||
import dns.enum
|
||||
|
||||
|
||||
class Algorithm(dns.enum.IntEnum):
|
||||
RSAMD5 = 1
|
||||
DH = 2
|
||||
DSA = 3
|
||||
ECC = 4
|
||||
RSASHA1 = 5
|
||||
DSANSEC3SHA1 = 6
|
||||
RSASHA1NSEC3SHA1 = 7
|
||||
RSASHA256 = 8
|
||||
RSASHA512 = 10
|
||||
ECCGOST = 12
|
||||
ECDSAP256SHA256 = 13
|
||||
ECDSAP384SHA384 = 14
|
||||
ED25519 = 15
|
||||
ED448 = 16
|
||||
INDIRECT = 252
|
||||
PRIVATEDNS = 253
|
||||
PRIVATEOID = 254
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 255
|
||||
|
||||
|
||||
class DSDigest(dns.enum.IntEnum):
|
||||
"""DNSSEC Delegation Signer Digest Algorithm"""
|
||||
|
||||
NULL = 0
|
||||
SHA1 = 1
|
||||
SHA256 = 2
|
||||
GOST = 3
|
||||
SHA384 = 4
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 255
|
||||
|
||||
|
||||
class NSEC3Hash(dns.enum.IntEnum):
|
||||
"""NSEC3 hash algorithm"""
|
||||
|
||||
SHA1 = 1
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 255
|
|
@ -17,15 +17,19 @@
|
|||
|
||||
"""DNS E.164 helpers."""
|
||||
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import dns.exception
|
||||
import dns.name
|
||||
import dns.resolver
|
||||
|
||||
#: The public E.164 domain.
|
||||
public_enum_domain = dns.name.from_text('e164.arpa.')
|
||||
public_enum_domain = dns.name.from_text("e164.arpa.")
|
||||
|
||||
|
||||
def from_e164(text, origin=public_enum_domain):
|
||||
def from_e164(
|
||||
text: str, origin: Optional[dns.name.Name] = public_enum_domain
|
||||
) -> dns.name.Name:
|
||||
"""Convert an E.164 number in textual form into a Name object whose
|
||||
value is the ENUM domain name for that number.
|
||||
|
||||
|
@ -42,10 +46,14 @@ def from_e164(text, origin=public_enum_domain):
|
|||
|
||||
parts = [d for d in text if d.isdigit()]
|
||||
parts.reverse()
|
||||
return dns.name.from_text('.'.join(parts), origin=origin)
|
||||
return dns.name.from_text(".".join(parts), origin=origin)
|
||||
|
||||
|
||||
def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
|
||||
def to_e164(
|
||||
name: dns.name.Name,
|
||||
origin: Optional[dns.name.Name] = public_enum_domain,
|
||||
want_plus_prefix: bool = True,
|
||||
) -> str:
|
||||
"""Convert an ENUM domain name into an E.164 number.
|
||||
|
||||
Note that dnspython does not have any information about preferred
|
||||
|
@ -69,15 +77,19 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
|
|||
name = name.relativize(origin)
|
||||
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
|
||||
if len(dlabels) != len(name.labels):
|
||||
raise dns.exception.SyntaxError('non-digit labels in ENUM domain name')
|
||||
raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
|
||||
dlabels.reverse()
|
||||
text = b''.join(dlabels)
|
||||
text = b"".join(dlabels)
|
||||
if want_plus_prefix:
|
||||
text = b'+' + text
|
||||
text = b"+" + text
|
||||
return text.decode()
|
||||
|
||||
|
||||
def query(number, domains, resolver=None):
|
||||
def query(
|
||||
number: str,
|
||||
domains: Iterable[Union[dns.name.Name, str]],
|
||||
resolver: Optional[dns.resolver.Resolver] = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Look for NAPTR RRs for the specified number in the specified domains.
|
||||
|
||||
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
|
||||
|
@ -98,7 +110,7 @@ def query(number, domains, resolver=None):
|
|||
domain = dns.name.from_text(domain)
|
||||
qname = dns.e164.from_e164(number, domain)
|
||||
try:
|
||||
return resolver.resolve(qname, 'NAPTR')
|
||||
return resolver.resolve(qname, "NAPTR")
|
||||
except dns.resolver.NXDOMAIN as e:
|
||||
e_nx += e
|
||||
raise e_nx
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
from typing import Optional, Iterable
|
||||
from . import name, resolver
|
||||
def from_e164(text : str, origin=name.Name(".")) -> name.Name:
|
||||
...
|
||||
|
||||
def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
|
||||
...
|
||||
|
||||
def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
|
||||
...
|
152
lib/dns/edns.py
152
lib/dns/edns.py
|
@ -17,6 +17,8 @@
|
|||
|
||||
"""EDNS Options"""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import math
|
||||
import socket
|
||||
import struct
|
||||
|
@ -24,6 +26,7 @@ import struct
|
|||
import dns.enum
|
||||
import dns.inet
|
||||
import dns.rdata
|
||||
import dns.wire
|
||||
|
||||
|
||||
class OptionType(dns.enum.IntEnum):
|
||||
|
@ -59,14 +62,14 @@ class Option:
|
|||
|
||||
"""Base class for all EDNS option types."""
|
||||
|
||||
def __init__(self, otype):
|
||||
def __init__(self, otype: Union[OptionType, str]):
|
||||
"""Initialize an option.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
*otype*, a ``dns.edns.OptionType``, is the option type.
|
||||
"""
|
||||
self.otype = OptionType.make(otype)
|
||||
|
||||
def to_wire(self, file=None):
|
||||
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
|
||||
"""Convert an option to wire format.
|
||||
|
||||
Returns a ``bytes`` or ``None``.
|
||||
|
@ -75,10 +78,10 @@ class Option:
|
|||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, otype, parser):
|
||||
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
*otype*, a ``dns.edns.OptionType``, is the option type.
|
||||
|
||||
*parser*, a ``dns.wire.Parser``, the parser, which should be
|
||||
restructed to the option length.
|
||||
|
@ -115,26 +118,22 @@ class Option:
|
|||
return self._cmp(other) != 0
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Option) or \
|
||||
self.otype != other.otype:
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) < 0
|
||||
|
||||
def __le__(self, other):
|
||||
if not isinstance(other, Option) or \
|
||||
self.otype != other.otype:
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) <= 0
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, Option) or \
|
||||
self.otype != other.otype:
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) >= 0
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, Option) or \
|
||||
self.otype != other.otype:
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) > 0
|
||||
|
||||
|
@ -142,7 +141,7 @@ class Option:
|
|||
return self.to_text()
|
||||
|
||||
|
||||
class GenericOption(Option):
|
||||
class GenericOption(Option): # lgtm[py/missing-equals]
|
||||
|
||||
"""Generic Option Class
|
||||
|
||||
|
@ -150,28 +149,31 @@ class GenericOption(Option):
|
|||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, otype, data):
|
||||
def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
|
||||
super().__init__(otype)
|
||||
self.data = dns.rdata.Rdata._as_bytes(data, True)
|
||||
|
||||
def to_wire(self, file=None):
|
||||
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
|
||||
if file:
|
||||
file.write(self.data)
|
||||
return None
|
||||
else:
|
||||
return self.data
|
||||
|
||||
def to_text(self):
|
||||
def to_text(self) -> str:
|
||||
return "Generic %d" % self.otype
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, otype, parser):
|
||||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
return cls(otype, parser.get_remaining())
|
||||
|
||||
|
||||
class ECSOption(Option):
|
||||
class ECSOption(Option): # lgtm[py/missing-equals]
|
||||
"""EDNS Client Subnet (ECS, RFC7871)"""
|
||||
|
||||
def __init__(self, address, srclen=None, scopelen=0):
|
||||
def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0):
|
||||
"""*address*, a ``str``, is the client address information.
|
||||
|
||||
*srclen*, an ``int``, the source prefix length, which is the
|
||||
|
@ -200,8 +202,9 @@ class ECSOption(Option):
|
|||
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
|
||||
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
|
||||
else: # pragma: no cover (this will never happen)
|
||||
raise ValueError('Bad address family')
|
||||
raise ValueError("Bad address family")
|
||||
|
||||
assert srclen is not None
|
||||
self.address = address
|
||||
self.srclen = srclen
|
||||
self.scopelen = scopelen
|
||||
|
@ -214,16 +217,14 @@ class ECSOption(Option):
|
|||
self.addrdata = addrdata[:nbytes]
|
||||
nbits = srclen % 8
|
||||
if nbits != 0:
|
||||
last = struct.pack('B',
|
||||
ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
|
||||
last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
|
||||
self.addrdata = self.addrdata[:-1] + last
|
||||
|
||||
def to_text(self):
|
||||
return "ECS {}/{} scope/{}".format(self.address, self.srclen,
|
||||
self.scopelen)
|
||||
def to_text(self) -> str:
|
||||
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
|
||||
|
||||
@staticmethod
|
||||
def from_text(text):
|
||||
def from_text(text: str) -> Option:
|
||||
"""Convert a string into a `dns.edns.ECSOption`
|
||||
|
||||
*text*, a `str`, the text form of the option.
|
||||
|
@ -246,7 +247,7 @@ class ECSOption(Option):
|
|||
>>> # it understands results from `dns.edns.ECSOption.to_text()`
|
||||
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
|
||||
"""
|
||||
optional_prefix = 'ECS'
|
||||
optional_prefix = "ECS"
|
||||
tokens = text.split()
|
||||
ecs_text = None
|
||||
if len(tokens) == 1:
|
||||
|
@ -257,47 +258,53 @@ class ECSOption(Option):
|
|||
ecs_text = tokens[1]
|
||||
else:
|
||||
raise ValueError('could not parse ECS from "{}"'.format(text))
|
||||
n_slashes = ecs_text.count('/')
|
||||
n_slashes = ecs_text.count("/")
|
||||
if n_slashes == 1:
|
||||
address, srclen = ecs_text.split('/')
|
||||
scope = 0
|
||||
address, tsrclen = ecs_text.split("/")
|
||||
tscope = "0"
|
||||
elif n_slashes == 2:
|
||||
address, srclen, scope = ecs_text.split('/')
|
||||
address, tsrclen, tscope = ecs_text.split("/")
|
||||
else:
|
||||
raise ValueError('could not parse ECS from "{}"'.format(text))
|
||||
try:
|
||||
scope = int(scope)
|
||||
scope = int(tscope)
|
||||
except ValueError:
|
||||
raise ValueError('invalid scope ' +
|
||||
'"{}": scope must be an integer'.format(scope))
|
||||
raise ValueError(
|
||||
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
|
||||
)
|
||||
try:
|
||||
srclen = int(srclen)
|
||||
srclen = int(tsrclen)
|
||||
except ValueError:
|
||||
raise ValueError('invalid srclen ' +
|
||||
'"{}": srclen must be an integer'.format(srclen))
|
||||
raise ValueError(
|
||||
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
|
||||
)
|
||||
return ECSOption(address, srclen, scope)
|
||||
|
||||
def to_wire(self, file=None):
|
||||
value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) +
|
||||
self.addrdata)
|
||||
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
|
||||
value = (
|
||||
struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
|
||||
)
|
||||
if file:
|
||||
file.write(value)
|
||||
return None
|
||||
else:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, otype, parser):
|
||||
family, src, scope = parser.get_struct('!HBB')
|
||||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
family, src, scope = parser.get_struct("!HBB")
|
||||
addrlen = int(math.ceil(src / 8.0))
|
||||
prefix = parser.get_bytes(addrlen)
|
||||
if family == 1:
|
||||
pad = 4 - addrlen
|
||||
addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad)
|
||||
addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
|
||||
elif family == 2:
|
||||
pad = 16 - addrlen
|
||||
addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad)
|
||||
addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
|
||||
else:
|
||||
raise ValueError('unsupported family')
|
||||
raise ValueError("unsupported family")
|
||||
|
||||
return cls(addr, src, scope)
|
||||
|
||||
|
@ -334,10 +341,10 @@ class EDECode(dns.enum.IntEnum):
|
|||
return 65535
|
||||
|
||||
|
||||
class EDEOption(Option):
|
||||
class EDEOption(Option): # lgtm[py/missing-equals]
|
||||
"""Extended DNS Error (EDE, RFC8914)"""
|
||||
|
||||
def __init__(self, code, text=None):
|
||||
def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
|
||||
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
|
||||
extended error.
|
||||
|
||||
|
@ -349,49 +356,50 @@ class EDEOption(Option):
|
|||
|
||||
self.code = EDECode.make(code)
|
||||
if text is not None and not isinstance(text, str):
|
||||
raise ValueError('text must be string or None')
|
||||
|
||||
self.code = code
|
||||
raise ValueError("text must be string or None")
|
||||
self.text = text
|
||||
|
||||
def to_text(self):
|
||||
output = f'EDE {self.code}'
|
||||
def to_text(self) -> str:
|
||||
output = f"EDE {self.code}"
|
||||
if self.text is not None:
|
||||
output += f': {self.text}'
|
||||
output += f": {self.text}"
|
||||
return output
|
||||
|
||||
def to_wire(self, file=None):
|
||||
value = struct.pack('!H', self.code)
|
||||
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
|
||||
value = struct.pack("!H", self.code)
|
||||
if self.text is not None:
|
||||
value += self.text.encode('utf8')
|
||||
value += self.text.encode("utf8")
|
||||
|
||||
if file:
|
||||
file.write(value)
|
||||
return None
|
||||
else:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, otype, parser):
|
||||
code = parser.get_uint16()
|
||||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
the_code = EDECode.make(parser.get_uint16())
|
||||
text = parser.get_remaining()
|
||||
|
||||
if text:
|
||||
if text[-1] == 0: # text MAY be null-terminated
|
||||
text = text[:-1]
|
||||
text = text.decode('utf8')
|
||||
btext = text.decode("utf8")
|
||||
else:
|
||||
text = None
|
||||
btext = None
|
||||
|
||||
return cls(code, text)
|
||||
return cls(the_code, btext)
|
||||
|
||||
|
||||
_type_to_class = {
|
||||
_type_to_class: Dict[OptionType, Any] = {
|
||||
OptionType.ECS: ECSOption,
|
||||
OptionType.EDE: EDEOption,
|
||||
}
|
||||
|
||||
|
||||
def get_option_class(otype):
|
||||
def get_option_class(otype: OptionType) -> Any:
|
||||
"""Return the class for the specified option type.
|
||||
|
||||
The GenericOption class is used if a more specific class is not
|
||||
|
@ -404,7 +412,9 @@ def get_option_class(otype):
|
|||
return cls
|
||||
|
||||
|
||||
def option_from_wire_parser(otype, parser):
|
||||
def option_from_wire_parser(
|
||||
otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
|
@ -414,12 +424,14 @@ def option_from_wire_parser(otype, parser):
|
|||
|
||||
Returns an instance of a subclass of ``dns.edns.Option``.
|
||||
"""
|
||||
cls = get_option_class(otype)
|
||||
otype = OptionType.make(otype)
|
||||
the_otype = OptionType.make(otype)
|
||||
cls = get_option_class(the_otype)
|
||||
return cls.from_wire_parser(otype, parser)
|
||||
|
||||
|
||||
def option_from_wire(otype, wire, current, olen):
|
||||
def option_from_wire(
|
||||
otype: Union[OptionType, str], wire: bytes, current: int, olen: int
|
||||
) -> Option:
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
|
@ -437,7 +449,8 @@ def option_from_wire(otype, wire, current, olen):
|
|||
with parser.restrict_to(olen):
|
||||
return option_from_wire_parser(otype, parser)
|
||||
|
||||
def register_type(implementation, otype):
|
||||
|
||||
def register_type(implementation: Any, otype: OptionType) -> None:
|
||||
"""Register the implementation of an option type.
|
||||
|
||||
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
|
||||
|
@ -447,6 +460,7 @@ def register_type(implementation, otype):
|
|||
|
||||
_type_to_class[otype] = implementation
|
||||
|
||||
|
||||
### BEGIN generated OptionType constants
|
||||
|
||||
NSID = OptionType.NSID
|
||||
|
|
|
@ -15,14 +15,13 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
try:
|
||||
import threading as _threading
|
||||
except ImportError: # pragma: no cover
|
||||
import dummy_threading as _threading # type: ignore
|
||||
|
||||
|
||||
class EntropyPool:
|
||||
|
@ -32,51 +31,51 @@ class EntropyPool:
|
|||
# leaving this code doesn't hurt anything as the library code
|
||||
# is used if present.
|
||||
|
||||
def __init__(self, seed=None):
|
||||
def __init__(self, seed: Optional[bytes] = None):
|
||||
self.pool_index = 0
|
||||
self.digest = None
|
||||
self.digest: Optional[bytearray] = None
|
||||
self.next_byte = 0
|
||||
self.lock = _threading.Lock()
|
||||
self.lock = threading.Lock()
|
||||
self.hash = hashlib.sha1()
|
||||
self.hash_len = 20
|
||||
self.pool = bytearray(b'\0' * self.hash_len)
|
||||
self.pool = bytearray(b"\0" * self.hash_len)
|
||||
if seed is not None:
|
||||
self._stir(bytearray(seed))
|
||||
self._stir(seed)
|
||||
self.seeded = True
|
||||
self.seed_pid = os.getpid()
|
||||
else:
|
||||
self.seeded = False
|
||||
self.seed_pid = 0
|
||||
|
||||
def _stir(self, entropy):
|
||||
def _stir(self, entropy: bytes) -> None:
|
||||
for c in entropy:
|
||||
if self.pool_index == self.hash_len:
|
||||
self.pool_index = 0
|
||||
b = c & 0xff
|
||||
b = c & 0xFF
|
||||
self.pool[self.pool_index] ^= b
|
||||
self.pool_index += 1
|
||||
|
||||
def stir(self, entropy):
|
||||
def stir(self, entropy: bytes) -> None:
|
||||
with self.lock:
|
||||
self._stir(entropy)
|
||||
|
||||
def _maybe_seed(self):
|
||||
def _maybe_seed(self) -> None:
|
||||
if not self.seeded or self.seed_pid != os.getpid():
|
||||
try:
|
||||
seed = os.urandom(16)
|
||||
except Exception: # pragma: no cover
|
||||
try:
|
||||
with open('/dev/urandom', 'rb', 0) as r:
|
||||
with open("/dev/urandom", "rb", 0) as r:
|
||||
seed = r.read(16)
|
||||
except Exception:
|
||||
seed = str(time.time())
|
||||
seed = str(time.time()).encode()
|
||||
self.seeded = True
|
||||
self.seed_pid = os.getpid()
|
||||
self.digest = None
|
||||
seed = bytearray(seed)
|
||||
self._stir(seed)
|
||||
|
||||
def random_8(self):
|
||||
def random_8(self) -> int:
|
||||
with self.lock:
|
||||
self._maybe_seed()
|
||||
if self.digest is None or self.next_byte == self.hash_len:
|
||||
|
@ -88,16 +87,16 @@ class EntropyPool:
|
|||
self.next_byte += 1
|
||||
return value
|
||||
|
||||
def random_16(self):
|
||||
def random_16(self) -> int:
|
||||
return self.random_8() * 256 + self.random_8()
|
||||
|
||||
def random_32(self):
|
||||
def random_32(self) -> int:
|
||||
return self.random_16() * 65536 + self.random_16()
|
||||
|
||||
def random_between(self, first, last):
|
||||
def random_between(self, first: int, last: int) -> int:
|
||||
size = last - first + 1
|
||||
if size > 4294967296:
|
||||
raise ValueError('too big')
|
||||
raise ValueError("too big")
|
||||
if size > 65536:
|
||||
rand = self.random_32
|
||||
max = 4294967295
|
||||
|
@ -109,20 +108,24 @@ class EntropyPool:
|
|||
max = 255
|
||||
return first + size * rand() // (max + 1)
|
||||
|
||||
|
||||
pool = EntropyPool()
|
||||
|
||||
system_random: Optional[Any]
|
||||
try:
|
||||
system_random = random.SystemRandom()
|
||||
except Exception: # pragma: no cover
|
||||
system_random = None
|
||||
|
||||
def random_16():
|
||||
|
||||
def random_16() -> int:
|
||||
if system_random is not None:
|
||||
return system_random.randrange(0, 65536)
|
||||
else:
|
||||
return pool.random_16()
|
||||
|
||||
def between(first, last):
|
||||
|
||||
def between(first: int, last: int) -> int:
|
||||
if system_random is not None:
|
||||
return system_random.randrange(first, last + 1)
|
||||
else:
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
from typing import Optional
|
||||
from random import SystemRandom
|
||||
|
||||
system_random : Optional[SystemRandom]
|
||||
|
||||
def random_16() -> int:
|
||||
pass
|
||||
|
||||
def between(first: int, last: int) -> int:
|
||||
pass
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import enum
|
||||
|
||||
|
||||
class IntEnum(enum.IntEnum):
|
||||
@classmethod
|
||||
def _check_value(cls, value):
|
||||
|
@ -32,9 +33,12 @@ class IntEnum(enum.IntEnum):
|
|||
return cls[text]
|
||||
except KeyError:
|
||||
pass
|
||||
value = cls._extra_from_text(text)
|
||||
if value:
|
||||
return value
|
||||
prefix = cls._prefix()
|
||||
if text.startswith(prefix) and text[len(prefix):].isdigit():
|
||||
value = int(text[len(prefix):])
|
||||
if text.startswith(prefix) and text[len(prefix) :].isdigit():
|
||||
value = int(text[len(prefix) :])
|
||||
cls._check_value(value)
|
||||
try:
|
||||
return cls(value)
|
||||
|
@ -46,9 +50,13 @@ class IntEnum(enum.IntEnum):
|
|||
def to_text(cls, value):
|
||||
cls._check_value(value)
|
||||
try:
|
||||
return cls(value).name
|
||||
text = cls(value).name
|
||||
except ValueError:
|
||||
return f"{cls._prefix()}{value}"
|
||||
text = None
|
||||
text = cls._extra_to_text(value, text)
|
||||
if text is None:
|
||||
text = f"{cls._prefix()}{value}"
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def make(cls, value):
|
||||
|
@ -83,7 +91,15 @@ class IntEnum(enum.IntEnum):
|
|||
|
||||
@classmethod
|
||||
def _prefix(cls):
|
||||
return ''
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _extra_from_text(cls, text): # pylint: disable=W0613
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
|
||||
return current_text
|
||||
|
||||
@classmethod
|
||||
def _unknown_exception_class(cls):
|
||||
|
|
|
@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will
|
|||
always be subclasses of ``DNSException``.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
|
||||
class DNSException(Exception):
|
||||
"""Abstract base class shared by all dnspython exceptions.
|
||||
|
||||
|
@ -44,14 +48,15 @@ class DNSException(Exception):
|
|||
and ``fmt`` class variables to get nice parametrized messages.
|
||||
"""
|
||||
|
||||
msg = None # non-parametrized message
|
||||
supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check)
|
||||
fmt = None # message parametrized with results from _fmt_kwargs
|
||||
msg: Optional[str] = None # non-parametrized message
|
||||
supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
|
||||
fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._check_params(*args, **kwargs)
|
||||
if kwargs:
|
||||
self.kwargs = self._check_kwargs(**kwargs)
|
||||
# This call to a virtual method from __init__ is ok in our usage
|
||||
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
|
||||
self.msg = str(self)
|
||||
else:
|
||||
self.kwargs = dict() # defined but empty for old mode exceptions
|
||||
|
@ -68,14 +73,15 @@ class DNSException(Exception):
|
|||
|
||||
For sanity we do not allow to mix old and new behavior."""
|
||||
if args or kwargs:
|
||||
assert bool(args) != bool(kwargs), \
|
||||
'keyword arguments are mutually exclusive with positional args'
|
||||
assert bool(args) != bool(
|
||||
kwargs
|
||||
), "keyword arguments are mutually exclusive with positional args"
|
||||
|
||||
def _check_kwargs(self, **kwargs):
|
||||
if kwargs:
|
||||
assert set(kwargs.keys()) == self.supp_kwargs, \
|
||||
'following set of keyword args is required: %s' % (
|
||||
self.supp_kwargs)
|
||||
assert (
|
||||
set(kwargs.keys()) == self.supp_kwargs
|
||||
), "following set of keyword args is required: %s" % (self.supp_kwargs)
|
||||
return kwargs
|
||||
|
||||
def _fmt_kwargs(self, **kwargs):
|
||||
|
@ -124,9 +130,15 @@ class TooBig(DNSException):
|
|||
|
||||
class Timeout(DNSException):
|
||||
"""The DNS operation timed out."""
|
||||
supp_kwargs = {'timeout'}
|
||||
|
||||
supp_kwargs = {"timeout"}
|
||||
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
|
||||
|
||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||
# idna_exception
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ExceptionWrapper:
|
||||
def __init__(self, exception_class):
|
||||
|
@ -136,7 +148,6 @@ class ExceptionWrapper:
|
|||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None and not isinstance(exc_val,
|
||||
self.exception_class):
|
||||
if exc_type is not None and not isinstance(exc_val, self.exception_class):
|
||||
raise self.exception_class(str(exc_val)) from exc_val
|
||||
return False
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
from typing import Set, Optional, Dict
|
||||
|
||||
class DNSException(Exception):
|
||||
supp_kwargs : Set[str]
|
||||
kwargs : Optional[Dict]
|
||||
fmt : Optional[str]
|
||||
|
||||
class SyntaxError(DNSException): ...
|
||||
class FormError(DNSException): ...
|
||||
class Timeout(DNSException): ...
|
||||
class TooBig(DNSException): ...
|
||||
class UnexpectedEnd(SyntaxError): ...
|
|
@ -17,10 +17,13 @@
|
|||
|
||||
"""DNS Message Flags."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import enum
|
||||
|
||||
# Standard DNS flags
|
||||
|
||||
|
||||
class Flag(enum.IntFlag):
|
||||
#: Query Response
|
||||
QR = 0x8000
|
||||
|
@ -40,12 +43,13 @@ class Flag(enum.IntFlag):
|
|||
|
||||
# EDNS flags
|
||||
|
||||
|
||||
class EDNSFlag(enum.IntFlag):
|
||||
#: DNSSEC answer OK
|
||||
DO = 0x8000
|
||||
|
||||
|
||||
def _from_text(text, enum_class):
|
||||
def _from_text(text: str, enum_class: Any) -> int:
|
||||
flags = 0
|
||||
tokens = text.split()
|
||||
for t in tokens:
|
||||
|
@ -53,15 +57,15 @@ def _from_text(text, enum_class):
|
|||
return flags
|
||||
|
||||
|
||||
def _to_text(flags, enum_class):
|
||||
def _to_text(flags: int, enum_class: Any) -> str:
|
||||
text_flags = []
|
||||
for k, v in enum_class.__members__.items():
|
||||
if flags & v != 0:
|
||||
text_flags.append(k)
|
||||
return ' '.join(text_flags)
|
||||
return " ".join(text_flags)
|
||||
|
||||
|
||||
def from_text(text):
|
||||
def from_text(text: str) -> int:
|
||||
"""Convert a space-separated list of flag text values into a flags
|
||||
value.
|
||||
|
||||
|
@ -71,7 +75,7 @@ def from_text(text):
|
|||
return _from_text(text, Flag)
|
||||
|
||||
|
||||
def to_text(flags):
|
||||
def to_text(flags: int) -> str:
|
||||
"""Convert a flags value into a space-separated list of flag text
|
||||
values.
|
||||
|
||||
|
@ -81,7 +85,7 @@ def to_text(flags):
|
|||
return _to_text(flags, Flag)
|
||||
|
||||
|
||||
def edns_from_text(text):
|
||||
def edns_from_text(text: str) -> int:
|
||||
"""Convert a space-separated list of EDNS flag text values into a EDNS
|
||||
flags value.
|
||||
|
||||
|
@ -91,7 +95,7 @@ def edns_from_text(text):
|
|||
return _from_text(text, EDNSFlag)
|
||||
|
||||
|
||||
def edns_to_text(flags):
|
||||
def edns_to_text(flags: int) -> str:
|
||||
"""Convert an EDNS flags value into a space-separated list of EDNS flag
|
||||
text values.
|
||||
|
||||
|
@ -100,6 +104,7 @@ def edns_to_text(flags):
|
|||
|
||||
return _to_text(flags, EDNSFlag)
|
||||
|
||||
|
||||
### BEGIN generated Flag constants
|
||||
|
||||
QR = Flag.QR
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
|
||||
"""DNS GENERATE range conversion."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import dns
|
||||
|
||||
def from_text(text):
|
||||
|
||||
def from_text(text: str) -> Tuple[int, int, int]:
|
||||
"""Convert the text form of a range in a ``$GENERATE`` statement to an
|
||||
integer.
|
||||
|
||||
|
@ -31,22 +34,22 @@ def from_text(text):
|
|||
start = -1
|
||||
stop = -1
|
||||
step = 1
|
||||
cur = ''
|
||||
cur = ""
|
||||
state = 0
|
||||
# state 0 1 2
|
||||
# x - y / z
|
||||
|
||||
if text and text[0] == '-':
|
||||
if text and text[0] == "-":
|
||||
raise dns.exception.SyntaxError("Start cannot be a negative number")
|
||||
|
||||
for c in text:
|
||||
if c == '-' and state == 0:
|
||||
if c == "-" and state == 0:
|
||||
start = int(cur)
|
||||
cur = ''
|
||||
cur = ""
|
||||
state = 1
|
||||
elif c == '/':
|
||||
elif c == "/":
|
||||
stop = int(cur)
|
||||
cur = ''
|
||||
cur = ""
|
||||
state = 2
|
||||
elif c.isdigit():
|
||||
cur += c
|
||||
|
@ -64,6 +67,6 @@ def from_text(text):
|
|||
assert step >= 1
|
||||
assert start >= 0
|
||||
if start > stop:
|
||||
raise dns.exception.SyntaxError('start must be <= stop')
|
||||
raise dns.exception.SyntaxError("start must be <= stop")
|
||||
|
||||
return (start, stop, step)
|
||||
|
|
|
@ -1,32 +1,25 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import collections.abc
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# pylint: disable=unused-import
|
||||
if sys.version_info >= (3, 7):
|
||||
odict = dict
|
||||
from dns._immutable_ctx import immutable
|
||||
else:
|
||||
# pragma: no cover
|
||||
from collections import OrderedDict as odict
|
||||
from dns._immutable_attr import immutable # noqa
|
||||
# pylint: enable=unused-import
|
||||
import collections.abc
|
||||
|
||||
from dns._immutable_ctx import immutable
|
||||
|
||||
|
||||
@immutable
|
||||
class Dict(collections.abc.Mapping):
|
||||
def __init__(self, dictionary, no_copy=False):
|
||||
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
|
||||
def __init__(self, dictionary: Any, no_copy: bool = False):
|
||||
"""Make an immutable dictionary from the specified dictionary.
|
||||
|
||||
If *no_copy* is `True`, then *dictionary* will be wrapped instead
|
||||
of copied. Only set this if you are sure there will be no external
|
||||
references to the dictionary.
|
||||
"""
|
||||
if no_copy and isinstance(dictionary, odict):
|
||||
if no_copy and isinstance(dictionary, dict):
|
||||
self._odict = dictionary
|
||||
else:
|
||||
self._odict = odict(dictionary)
|
||||
self._odict = dict(dictionary)
|
||||
self._hash = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
|
@ -37,7 +30,7 @@ class Dict(collections.abc.Mapping):
|
|||
h = 0
|
||||
for key in sorted(self._odict.keys()):
|
||||
h ^= hash(key)
|
||||
object.__setattr__(self, '_hash', h)
|
||||
object.__setattr__(self, "_hash", h)
|
||||
# this does return an int, but pylint doesn't figure that out
|
||||
return self._hash
|
||||
|
||||
|
@ -48,7 +41,7 @@ class Dict(collections.abc.Mapping):
|
|||
return iter(self._odict)
|
||||
|
||||
|
||||
def constify(o):
|
||||
def constify(o: Any) -> Any:
|
||||
"""
|
||||
Convert mutable types to immutable types.
|
||||
"""
|
||||
|
@ -63,7 +56,7 @@ def constify(o):
|
|||
if isinstance(o, list):
|
||||
return tuple(constify(elt) for elt in o)
|
||||
if isinstance(o, dict):
|
||||
cdict = odict()
|
||||
cdict = dict()
|
||||
for k, v in o.items():
|
||||
cdict[k] = constify(v)
|
||||
return Dict(cdict, True)
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
"""Generic Internet address helper functions."""
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import socket
|
||||
|
||||
import dns.ipv4
|
||||
|
@ -30,7 +32,7 @@ AF_INET = socket.AF_INET
|
|||
AF_INET6 = socket.AF_INET6
|
||||
|
||||
|
||||
def inet_pton(family, text):
|
||||
def inet_pton(family: int, text: str) -> bytes:
|
||||
"""Convert the textual form of a network address into its binary form.
|
||||
|
||||
*family* is an ``int``, the address family.
|
||||
|
@ -51,7 +53,7 @@ def inet_pton(family, text):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def inet_ntop(family, address):
|
||||
def inet_ntop(family: int, address: bytes) -> str:
|
||||
"""Convert the binary form of a network address into its textual form.
|
||||
|
||||
*family* is an ``int``, the address family.
|
||||
|
@ -72,7 +74,7 @@ def inet_ntop(family, address):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def af_for_address(text):
|
||||
def af_for_address(text: str) -> int:
|
||||
"""Determine the address family of a textual-form network address.
|
||||
|
||||
*text*, a ``str``, the textual address.
|
||||
|
@ -94,7 +96,7 @@ def af_for_address(text):
|
|||
raise ValueError
|
||||
|
||||
|
||||
def is_multicast(text):
|
||||
def is_multicast(text: str) -> bool:
|
||||
"""Is the textual-form network address a multicast address?
|
||||
|
||||
*text*, a ``str``, the textual address.
|
||||
|
@ -116,7 +118,7 @@ def is_multicast(text):
|
|||
raise ValueError
|
||||
|
||||
|
||||
def is_address(text):
|
||||
def is_address(text: str) -> bool:
|
||||
"""Is the specified string an IPv4 or IPv6 address?
|
||||
|
||||
*text*, a ``str``, the textual address.
|
||||
|
@ -135,7 +137,9 @@ def is_address(text):
|
|||
return False
|
||||
|
||||
|
||||
def low_level_address_tuple(high_tuple, af=None):
|
||||
def low_level_address_tuple(
|
||||
high_tuple: Tuple[str, int], af: Optional[int] = None
|
||||
) -> Any:
|
||||
"""Given a "high-level" address tuple, i.e.
|
||||
an (address, port) return the appropriate "low-level" address tuple
|
||||
suitable for use in socket calls.
|
||||
|
@ -143,7 +147,6 @@ def low_level_address_tuple(high_tuple, af=None):
|
|||
If an *af* other than ``None`` is provided, it is assumed the
|
||||
address in the high-level tuple is valid and has that af. If af
|
||||
is ``None``, then af_for_address will be called.
|
||||
|
||||
"""
|
||||
address, port = high_tuple
|
||||
if af is None:
|
||||
|
@ -151,13 +154,13 @@ def low_level_address_tuple(high_tuple, af=None):
|
|||
if af == AF_INET:
|
||||
return (address, port)
|
||||
elif af == AF_INET6:
|
||||
i = address.find('%')
|
||||
i = address.find("%")
|
||||
if i < 0:
|
||||
# no scope, shortcut!
|
||||
return (address, port, 0, 0)
|
||||
# try to avoid getaddrinfo()
|
||||
addrpart = address[:i]
|
||||
scope = address[i + 1:]
|
||||
scope = address[i + 1 :]
|
||||
if scope.isdigit():
|
||||
return (addrpart, port, 0, int(scope))
|
||||
try:
|
||||
|
@ -167,4 +170,4 @@ def low_level_address_tuple(high_tuple, af=None):
|
|||
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
|
||||
return tup
|
||||
else:
|
||||
raise NotImplementedError(f'unknown address family {af}')
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from typing import Union
|
||||
from socket import AddressFamily
|
||||
|
||||
AF_INET6 : Union[int, AddressFamily]
|
|
@ -17,11 +17,14 @@
|
|||
|
||||
"""IPv4 helper functions."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import struct
|
||||
|
||||
import dns.exception
|
||||
|
||||
def inet_ntoa(address):
|
||||
|
||||
def inet_ntoa(address: bytes) -> str:
|
||||
"""Convert an IPv4 address in binary form to text form.
|
||||
|
||||
*address*, a ``bytes``, the IPv4 address in binary form.
|
||||
|
@ -31,30 +34,32 @@ def inet_ntoa(address):
|
|||
|
||||
if len(address) != 4:
|
||||
raise dns.exception.SyntaxError
|
||||
return ('%u.%u.%u.%u' % (address[0], address[1],
|
||||
address[2], address[3]))
|
||||
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
|
||||
|
||||
def inet_aton(text):
|
||||
|
||||
def inet_aton(text: Union[str, bytes]) -> bytes:
|
||||
"""Convert an IPv4 address in text form to binary form.
|
||||
|
||||
*text*, a ``str``, the IPv4 address in textual form.
|
||||
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
|
||||
|
||||
Returns a ``bytes``.
|
||||
"""
|
||||
|
||||
if not isinstance(text, bytes):
|
||||
text = text.encode()
|
||||
parts = text.split(b'.')
|
||||
btext = text.encode()
|
||||
else:
|
||||
btext = text
|
||||
parts = btext.split(b".")
|
||||
if len(parts) != 4:
|
||||
raise dns.exception.SyntaxError
|
||||
for part in parts:
|
||||
if not part.isdigit():
|
||||
raise dns.exception.SyntaxError
|
||||
if len(part) > 1 and part[0] == ord('0'):
|
||||
if len(part) > 1 and part[0] == ord("0"):
|
||||
# No leading zeros
|
||||
raise dns.exception.SyntaxError
|
||||
try:
|
||||
b = [int(part) for part in parts]
|
||||
return struct.pack('BBBB', *b)
|
||||
return struct.pack("BBBB", *b)
|
||||
except Exception:
|
||||
raise dns.exception.SyntaxError
|
||||
|
|
|
@ -17,15 +17,18 @@
|
|||
|
||||
"""IPv6 helper functions."""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import re
|
||||
import binascii
|
||||
|
||||
import dns.exception
|
||||
import dns.ipv4
|
||||
|
||||
_leading_zero = re.compile(r'0+([0-9a-f]+)')
|
||||
_leading_zero = re.compile(r"0+([0-9a-f]+)")
|
||||
|
||||
def inet_ntoa(address):
|
||||
|
||||
def inet_ntoa(address: bytes) -> str:
|
||||
"""Convert an IPv6 address in binary form to text form.
|
||||
|
||||
*address*, a ``bytes``, the IPv6 address in binary form.
|
||||
|
@ -41,7 +44,7 @@ def inet_ntoa(address):
|
|||
i = 0
|
||||
l = len(hex)
|
||||
while i < l:
|
||||
chunk = hex[i:i + 4].decode()
|
||||
chunk = hex[i : i + 4].decode()
|
||||
# strip leading zeros. we do this with an re instead of
|
||||
# with lstrip() because lstrip() didn't support chars until
|
||||
# python 2.2.2
|
||||
|
@ -58,7 +61,7 @@ def inet_ntoa(address):
|
|||
start = -1
|
||||
last_was_zero = False
|
||||
for i in range(8):
|
||||
if chunks[i] != '0':
|
||||
if chunks[i] != "0":
|
||||
if last_was_zero:
|
||||
end = i
|
||||
current_len = end - start
|
||||
|
@ -76,27 +79,30 @@ def inet_ntoa(address):
|
|||
best_start = start
|
||||
best_len = current_len
|
||||
if best_len > 1:
|
||||
if best_start == 0 and \
|
||||
(best_len == 6 or
|
||||
best_len == 5 and chunks[5] == 'ffff'):
|
||||
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
|
||||
# We have an embedded IPv4 address
|
||||
if best_len == 6:
|
||||
prefix = '::'
|
||||
prefix = "::"
|
||||
else:
|
||||
prefix = '::ffff:'
|
||||
hex = prefix + dns.ipv4.inet_ntoa(address[12:])
|
||||
prefix = "::ffff:"
|
||||
thex = prefix + dns.ipv4.inet_ntoa(address[12:])
|
||||
else:
|
||||
hex = ':'.join(chunks[:best_start]) + '::' + \
|
||||
':'.join(chunks[best_start + best_len:])
|
||||
thex = (
|
||||
":".join(chunks[:best_start])
|
||||
+ "::"
|
||||
+ ":".join(chunks[best_start + best_len :])
|
||||
)
|
||||
else:
|
||||
hex = ':'.join(chunks)
|
||||
return hex
|
||||
thex = ":".join(chunks)
|
||||
return thex
|
||||
|
||||
_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$')
|
||||
_colon_colon_start = re.compile(br'::.*')
|
||||
_colon_colon_end = re.compile(br'.*::$')
|
||||
|
||||
def inet_aton(text, ignore_scope=False):
|
||||
_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
|
||||
_colon_colon_start = re.compile(rb"::.*")
|
||||
_colon_colon_end = re.compile(rb".*::$")
|
||||
|
||||
|
||||
def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
|
||||
"""Convert an IPv6 address in text form to binary form.
|
||||
|
||||
*text*, a ``str``, the IPv6 address in textual form.
|
||||
|
@ -111,82 +117,88 @@ def inet_aton(text, ignore_scope=False):
|
|||
# Our aim here is not something fast; we just want something that works.
|
||||
#
|
||||
if not isinstance(text, bytes):
|
||||
text = text.encode()
|
||||
btext = text.encode()
|
||||
else:
|
||||
btext = text
|
||||
|
||||
if ignore_scope:
|
||||
parts = text.split(b'%')
|
||||
parts = btext.split(b"%")
|
||||
l = len(parts)
|
||||
if l == 2:
|
||||
text = parts[0]
|
||||
btext = parts[0]
|
||||
elif l > 2:
|
||||
raise dns.exception.SyntaxError
|
||||
|
||||
if text == b'':
|
||||
if btext == b"":
|
||||
raise dns.exception.SyntaxError
|
||||
elif text.endswith(b':') and not text.endswith(b'::'):
|
||||
elif btext.endswith(b":") and not btext.endswith(b"::"):
|
||||
raise dns.exception.SyntaxError
|
||||
elif text.startswith(b':') and not text.startswith(b'::'):
|
||||
elif btext.startswith(b":") and not btext.startswith(b"::"):
|
||||
raise dns.exception.SyntaxError
|
||||
elif text == b'::':
|
||||
text = b'0::'
|
||||
elif btext == b"::":
|
||||
btext = b"0::"
|
||||
#
|
||||
# Get rid of the icky dot-quad syntax if we have it.
|
||||
#
|
||||
m = _v4_ending.match(text)
|
||||
m = _v4_ending.match(btext)
|
||||
if m is not None:
|
||||
b = dns.ipv4.inet_aton(m.group(2))
|
||||
text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
|
||||
b[0], b[1], b[2],
|
||||
b[3])).encode()
|
||||
btext = (
|
||||
"{}:{:02x}{:02x}:{:02x}{:02x}".format(
|
||||
m.group(1).decode(), b[0], b[1], b[2], b[3]
|
||||
)
|
||||
).encode()
|
||||
#
|
||||
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
|
||||
# turn '<whatever>::' into '<whatever>:'
|
||||
#
|
||||
m = _colon_colon_start.match(text)
|
||||
m = _colon_colon_start.match(btext)
|
||||
if m is not None:
|
||||
text = text[1:]
|
||||
btext = btext[1:]
|
||||
else:
|
||||
m = _colon_colon_end.match(text)
|
||||
m = _colon_colon_end.match(btext)
|
||||
if m is not None:
|
||||
text = text[:-1]
|
||||
btext = btext[:-1]
|
||||
#
|
||||
# Now canonicalize into 8 chunks of 4 hex digits each
|
||||
#
|
||||
chunks = text.split(b':')
|
||||
chunks = btext.split(b":")
|
||||
l = len(chunks)
|
||||
if l > 8:
|
||||
raise dns.exception.SyntaxError
|
||||
seen_empty = False
|
||||
canonical = []
|
||||
canonical: List[bytes] = []
|
||||
for c in chunks:
|
||||
if c == b'':
|
||||
if c == b"":
|
||||
if seen_empty:
|
||||
raise dns.exception.SyntaxError
|
||||
seen_empty = True
|
||||
for _ in range(0, 8 - l + 1):
|
||||
canonical.append(b'0000')
|
||||
canonical.append(b"0000")
|
||||
else:
|
||||
lc = len(c)
|
||||
if lc > 4:
|
||||
raise dns.exception.SyntaxError
|
||||
if lc != 4:
|
||||
c = (b'0' * (4 - lc)) + c
|
||||
c = (b"0" * (4 - lc)) + c
|
||||
canonical.append(c)
|
||||
if l < 8 and not seen_empty:
|
||||
raise dns.exception.SyntaxError
|
||||
text = b''.join(canonical)
|
||||
btext = b"".join(canonical)
|
||||
|
||||
#
|
||||
# Finally we can go to binary.
|
||||
#
|
||||
try:
|
||||
return binascii.unhexlify(text)
|
||||
return binascii.unhexlify(btext)
|
||||
except (binascii.Error, TypeError):
|
||||
raise dns.exception.SyntaxError
|
||||
|
||||
_mapped_prefix = b'\x00' * 10 + b'\xff\xff'
|
||||
|
||||
def is_mapped(address):
|
||||
_mapped_prefix = b"\x00" * 10 + b"\xff\xff"
|
||||
|
||||
|
||||
def is_mapped(address: bytes) -> bool:
|
||||
"""Is the specified address a mapped IPv4 address?
|
||||
|
||||
*address*, a ``bytes`` is an IPv6 address in binary form.
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,47 +0,0 @@
|
|||
from typing import Optional, Dict, List, Tuple, Union
|
||||
from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode
|
||||
import hmac
|
||||
|
||||
class Message:
|
||||
def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes:
|
||||
...
|
||||
def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int,
|
||||
covers=rdatatype.NONE, deleting : Optional[int]=None, create=False,
|
||||
force_unique=False) -> rrset.RRset:
|
||||
...
|
||||
def __init__(self, id : Optional[int] =None) -> None:
|
||||
self.id : int
|
||||
self.flags = 0
|
||||
self.sections : List[List[rrset.RRset]] = [[], [], [], []]
|
||||
self.opt : rrset.RRset = None
|
||||
self.request_payload = 0
|
||||
self.keyring = None
|
||||
self.tsig : rrset.RRset = None
|
||||
self.request_mac = b''
|
||||
self.xfr = False
|
||||
self.origin = None
|
||||
self.tsig_ctx = None
|
||||
self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
|
||||
|
||||
def is_response(self, other : Message) -> bool:
|
||||
...
|
||||
|
||||
def set_rcode(self, rcode : rcode.Rcode):
|
||||
...
|
||||
|
||||
def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
|
||||
...
|
||||
|
||||
def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
|
||||
tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
|
||||
question_only=False, one_rr_per_rrset=False,
|
||||
ignore_trailing=False) -> Message:
|
||||
...
|
||||
def make_response(query : Message, recursion_available=False, our_payload=8192,
|
||||
fudge=300) -> Message:
|
||||
...
|
||||
|
||||
def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None,
|
||||
want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None,
|
||||
request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message:
|
||||
...
|
395
lib/dns/name.py
395
lib/dns/name.py
|
@ -18,32 +18,61 @@
|
|||
"""DNS Names.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import copy
|
||||
import struct
|
||||
|
||||
import encodings.idna # type: ignore
|
||||
|
||||
try:
|
||||
import idna # type: ignore
|
||||
|
||||
have_idna_2008 = True
|
||||
except ImportError: # pragma: no cover
|
||||
have_idna_2008 = False
|
||||
|
||||
import dns.enum
|
||||
import dns.wire
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
|
||||
# fullcompare() result values
|
||||
|
||||
#: The compared names have no relationship to each other.
|
||||
NAMERELN_NONE = 0
|
||||
#: the first name is a superdomain of the second.
|
||||
NAMERELN_SUPERDOMAIN = 1
|
||||
#: The first name is a subdomain of the second.
|
||||
NAMERELN_SUBDOMAIN = 2
|
||||
#: The compared names are equal.
|
||||
NAMERELN_EQUAL = 3
|
||||
#: The compared names have a common ancestor.
|
||||
NAMERELN_COMMONANCESTOR = 4
|
||||
CompressType = Dict["Name", int]
|
||||
|
||||
|
||||
class NameRelation(dns.enum.IntEnum):
|
||||
"""Name relation result from fullcompare()."""
|
||||
|
||||
# This is an IntEnum for backwards compatibility in case anyone
|
||||
# has hardwired the constants.
|
||||
|
||||
#: The compared names have no relationship to each other.
|
||||
NONE = 0
|
||||
#: the first name is a superdomain of the second.
|
||||
SUPERDOMAIN = 1
|
||||
#: The first name is a subdomain of the second.
|
||||
SUBDOMAIN = 2
|
||||
#: The compared names are equal.
|
||||
EQUAL = 3
|
||||
#: The compared names have a common ancestor.
|
||||
COMMONANCESTOR = 4
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return cls.COMMONANCESTOR
|
||||
|
||||
@classmethod
|
||||
def _short_name(cls):
|
||||
return cls.__name__
|
||||
|
||||
|
||||
# Backwards compatibility
|
||||
NAMERELN_NONE = NameRelation.NONE
|
||||
NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN
|
||||
NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN
|
||||
NAMERELN_EQUAL = NameRelation.EQUAL
|
||||
NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR
|
||||
|
||||
|
||||
class EmptyLabel(dns.exception.SyntaxError):
|
||||
|
@ -84,6 +113,7 @@ class NoParent(dns.exception.DNSException):
|
|||
"""An attempt was made to get the parent of the root name
|
||||
or the empty name."""
|
||||
|
||||
|
||||
class NoIDNA2008(dns.exception.DNSException):
|
||||
"""IDNA 2008 processing was requested but the idna module is not
|
||||
available."""
|
||||
|
@ -92,9 +122,47 @@ class NoIDNA2008(dns.exception.DNSException):
|
|||
class IDNAException(dns.exception.DNSException):
|
||||
"""IDNA processing raised an exception."""
|
||||
|
||||
supp_kwargs = {'idna_exception'}
|
||||
supp_kwargs = {"idna_exception"}
|
||||
fmt = "IDNA processing exception: {idna_exception}"
|
||||
|
||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||
# idna_exception
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
_escaped = b'"().;\\@$'
|
||||
_escaped_text = '"().;\\@$'
|
||||
|
||||
|
||||
def _escapify(label: Union[bytes, str]) -> str:
|
||||
"""Escape the characters in label which need it.
|
||||
@returns: the escaped string
|
||||
@rtype: string"""
|
||||
if isinstance(label, bytes):
|
||||
# Ordinary DNS label mode. Escape special characters and values
|
||||
# < 0x20 or > 0x7f.
|
||||
text = ""
|
||||
for c in label:
|
||||
if c in _escaped:
|
||||
text += "\\" + chr(c)
|
||||
elif c > 0x20 and c < 0x7F:
|
||||
text += chr(c)
|
||||
else:
|
||||
text += "\\%03d" % c
|
||||
return text
|
||||
|
||||
# Unicode label mode. Escape only special characters and values < 0x20
|
||||
text = ""
|
||||
for uc in label:
|
||||
if uc in _escaped_text:
|
||||
text += "\\" + uc
|
||||
elif uc <= "\x20":
|
||||
text += "\\%03d" % ord(uc)
|
||||
else:
|
||||
text += uc
|
||||
return text
|
||||
|
||||
|
||||
class IDNACodec:
|
||||
"""Abstract base class for IDNA encoder/decoders."""
|
||||
|
@ -102,26 +170,28 @@ class IDNACodec:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def is_idna(self, label):
|
||||
return label.lower().startswith(b'xn--')
|
||||
def is_idna(self, label: bytes) -> bool:
|
||||
return label.lower().startswith(b"xn--")
|
||||
|
||||
def encode(self, label):
|
||||
def encode(self, label: str) -> bytes:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def decode(self, label):
|
||||
def decode(self, label: bytes) -> str:
|
||||
# We do not apply any IDNA policy on decode.
|
||||
if self.is_idna(label):
|
||||
try:
|
||||
label = label[4:].decode('punycode')
|
||||
slabel = label[4:].decode("punycode")
|
||||
return _escapify(slabel)
|
||||
except Exception as e:
|
||||
raise IDNAException(idna_exception=e)
|
||||
else:
|
||||
return _escapify(label)
|
||||
|
||||
|
||||
class IDNA2003Codec(IDNACodec):
|
||||
"""IDNA 2003 encoder/decoder."""
|
||||
|
||||
def __init__(self, strict_decode=False):
|
||||
def __init__(self, strict_decode: bool = False):
|
||||
"""Initialize the IDNA 2003 encoder/decoder.
|
||||
|
||||
*strict_decode* is a ``bool``. If `True`, then IDNA2003 checking
|
||||
|
@ -132,22 +202,22 @@ class IDNA2003Codec(IDNACodec):
|
|||
super().__init__()
|
||||
self.strict_decode = strict_decode
|
||||
|
||||
def encode(self, label):
|
||||
def encode(self, label: str) -> bytes:
|
||||
"""Encode *label*."""
|
||||
|
||||
if label == '':
|
||||
return b''
|
||||
if label == "":
|
||||
return b""
|
||||
try:
|
||||
return encodings.idna.ToASCII(label)
|
||||
except UnicodeError:
|
||||
raise LabelTooLong
|
||||
|
||||
def decode(self, label):
|
||||
def decode(self, label: bytes) -> str:
|
||||
"""Decode *label*."""
|
||||
if not self.strict_decode:
|
||||
return super().decode(label)
|
||||
if label == b'':
|
||||
return ''
|
||||
if label == b"":
|
||||
return ""
|
||||
try:
|
||||
return _escapify(encodings.idna.ToUnicode(label))
|
||||
except Exception as e:
|
||||
|
@ -155,16 +225,20 @@ class IDNA2003Codec(IDNACodec):
|
|||
|
||||
|
||||
class IDNA2008Codec(IDNACodec):
|
||||
"""IDNA 2008 encoder/decoder.
|
||||
"""
|
||||
"""IDNA 2008 encoder/decoder."""
|
||||
|
||||
def __init__(self, uts_46=False, transitional=False,
|
||||
allow_pure_ascii=False, strict_decode=False):
|
||||
def __init__(
|
||||
self,
|
||||
uts_46: bool = False,
|
||||
transitional: bool = False,
|
||||
allow_pure_ascii: bool = False,
|
||||
strict_decode: bool = False,
|
||||
):
|
||||
"""Initialize the IDNA 2008 encoder/decoder.
|
||||
|
||||
*uts_46* is a ``bool``. If True, apply Unicode IDNA
|
||||
compatibility processing as described in Unicode Technical
|
||||
Standard #46 (http://unicode.org/reports/tr46/).
|
||||
Standard #46 (https://unicode.org/reports/tr46/).
|
||||
If False, do not apply the mapping. The default is False.
|
||||
|
||||
*transitional* is a ``bool``: If True, use the
|
||||
|
@ -188,11 +262,11 @@ class IDNA2008Codec(IDNACodec):
|
|||
self.allow_pure_ascii = allow_pure_ascii
|
||||
self.strict_decode = strict_decode
|
||||
|
||||
def encode(self, label):
|
||||
if label == '':
|
||||
return b''
|
||||
def encode(self, label: str) -> bytes:
|
||||
if label == "":
|
||||
return b""
|
||||
if self.allow_pure_ascii and is_all_ascii(label):
|
||||
encoded = label.encode('ascii')
|
||||
encoded = label.encode("ascii")
|
||||
if len(encoded) > 63:
|
||||
raise LabelTooLong
|
||||
return encoded
|
||||
|
@ -203,16 +277,16 @@ class IDNA2008Codec(IDNACodec):
|
|||
label = idna.uts46_remap(label, False, self.transitional)
|
||||
return idna.alabel(label)
|
||||
except idna.IDNAError as e:
|
||||
if e.args[0] == 'Label too long':
|
||||
if e.args[0] == "Label too long":
|
||||
raise LabelTooLong
|
||||
else:
|
||||
raise IDNAException(idna_exception=e)
|
||||
|
||||
def decode(self, label):
|
||||
def decode(self, label: bytes) -> str:
|
||||
if not self.strict_decode:
|
||||
return super().decode(label)
|
||||
if label == b'':
|
||||
return ''
|
||||
if label == b"":
|
||||
return ""
|
||||
if not have_idna_2008:
|
||||
raise NoIDNA2008
|
||||
try:
|
||||
|
@ -223,8 +297,6 @@ class IDNA2008Codec(IDNACodec):
|
|||
except (idna.IDNAError, UnicodeError) as e:
|
||||
raise IDNAException(idna_exception=e)
|
||||
|
||||
_escaped = b'"().;\\@$'
|
||||
_escaped_text = '"().;\\@$'
|
||||
|
||||
IDNA_2003_Practical = IDNA2003Codec(False)
|
||||
IDNA_2003_Strict = IDNA2003Codec(True)
|
||||
|
@ -235,35 +307,8 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True)
|
|||
IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
|
||||
IDNA_2008 = IDNA_2008_Practical
|
||||
|
||||
def _escapify(label):
|
||||
"""Escape the characters in label which need it.
|
||||
@returns: the escaped string
|
||||
@rtype: string"""
|
||||
if isinstance(label, bytes):
|
||||
# Ordinary DNS label mode. Escape special characters and values
|
||||
# < 0x20 or > 0x7f.
|
||||
text = ''
|
||||
for c in label:
|
||||
if c in _escaped:
|
||||
text += '\\' + chr(c)
|
||||
elif c > 0x20 and c < 0x7F:
|
||||
text += chr(c)
|
||||
else:
|
||||
text += '\\%03d' % c
|
||||
return text
|
||||
|
||||
# Unicode label mode. Escape only special characters and values < 0x20
|
||||
text = ''
|
||||
for c in label:
|
||||
if c in _escaped_text:
|
||||
text += '\\' + c
|
||||
elif c <= '\x20':
|
||||
text += '\\%03d' % ord(c)
|
||||
else:
|
||||
text += c
|
||||
return text
|
||||
|
||||
def _validate_labels(labels):
|
||||
def _validate_labels(labels: Tuple[bytes, ...]) -> None:
|
||||
"""Check for empty labels in the middle of a label sequence,
|
||||
labels that are too long, and for too many labels.
|
||||
|
||||
|
@ -284,7 +329,7 @@ def _validate_labels(labels):
|
|||
total += ll + 1
|
||||
if ll > 63:
|
||||
raise LabelTooLong
|
||||
if i < 0 and label == b'':
|
||||
if i < 0 and label == b"":
|
||||
i = j
|
||||
j += 1
|
||||
if total > 255:
|
||||
|
@ -293,7 +338,7 @@ def _validate_labels(labels):
|
|||
raise EmptyLabel
|
||||
|
||||
|
||||
def _maybe_convert_to_binary(label):
|
||||
def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
|
||||
"""If label is ``str``, convert it to ``bytes``. If it is already
|
||||
``bytes`` just return it.
|
||||
|
||||
|
@ -316,14 +361,13 @@ class Name:
|
|||
of the class are immutable.
|
||||
"""
|
||||
|
||||
__slots__ = ['labels']
|
||||
__slots__ = ["labels"]
|
||||
|
||||
def __init__(self, labels):
|
||||
"""*labels* is any iterable whose values are ``str`` or ``bytes``.
|
||||
"""
|
||||
def __init__(self, labels: Iterable[Union[bytes, str]]):
|
||||
"""*labels* is any iterable whose values are ``str`` or ``bytes``."""
|
||||
|
||||
labels = [_maybe_convert_to_binary(x) for x in labels]
|
||||
self.labels = tuple(labels)
|
||||
blabels = [_maybe_convert_to_binary(x) for x in labels]
|
||||
self.labels = tuple(blabels)
|
||||
_validate_labels(self.labels)
|
||||
|
||||
def __copy__(self):
|
||||
|
@ -334,29 +378,29 @@ class Name:
|
|||
|
||||
def __getstate__(self):
|
||||
# Names can be pickled
|
||||
return {'labels': self.labels}
|
||||
return {"labels": self.labels}
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setattr__('labels', state['labels'])
|
||||
super().__setattr__("labels", state["labels"])
|
||||
_validate_labels(self.labels)
|
||||
|
||||
def is_absolute(self):
|
||||
def is_absolute(self) -> bool:
|
||||
"""Is the most significant label of this name the root label?
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
return len(self.labels) > 0 and self.labels[-1] == b''
|
||||
return len(self.labels) > 0 and self.labels[-1] == b""
|
||||
|
||||
def is_wild(self):
|
||||
def is_wild(self) -> bool:
|
||||
"""Is this name wild? (I.e. Is the least significant label '*'?)
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
return len(self.labels) > 0 and self.labels[0] == b'*'
|
||||
return len(self.labels) > 0 and self.labels[0] == b"*"
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""Return a case-insensitive hash of the name.
|
||||
|
||||
Returns an ``int``.
|
||||
|
@ -368,14 +412,14 @@ class Name:
|
|||
h += (h << 3) + c
|
||||
return h
|
||||
|
||||
def fullcompare(self, other):
|
||||
def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]:
|
||||
"""Compare two names, returning a 3-tuple
|
||||
``(relation, order, nlabels)``.
|
||||
|
||||
*relation* describes the relation ship between the names,
|
||||
and is one of: ``dns.name.NAMERELN_NONE``,
|
||||
``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``,
|
||||
``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``.
|
||||
and is one of: ``dns.name.NameRelation.NONE``,
|
||||
``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``,
|
||||
``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``.
|
||||
|
||||
*order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and ==
|
||||
0 if *self* == *other*. A relative name is always less than an
|
||||
|
@ -404,9 +448,9 @@ class Name:
|
|||
oabs = other.is_absolute()
|
||||
if sabs != oabs:
|
||||
if sabs:
|
||||
return (NAMERELN_NONE, 1, 0)
|
||||
return (NameRelation.NONE, 1, 0)
|
||||
else:
|
||||
return (NAMERELN_NONE, -1, 0)
|
||||
return (NameRelation.NONE, -1, 0)
|
||||
l1 = len(self.labels)
|
||||
l2 = len(other.labels)
|
||||
ldiff = l1 - l2
|
||||
|
@ -417,7 +461,7 @@ class Name:
|
|||
|
||||
order = 0
|
||||
nlabels = 0
|
||||
namereln = NAMERELN_NONE
|
||||
namereln = NameRelation.NONE
|
||||
while l > 0:
|
||||
l -= 1
|
||||
l1 -= 1
|
||||
|
@ -427,52 +471,52 @@ class Name:
|
|||
if label1 < label2:
|
||||
order = -1
|
||||
if nlabels > 0:
|
||||
namereln = NAMERELN_COMMONANCESTOR
|
||||
namereln = NameRelation.COMMONANCESTOR
|
||||
return (namereln, order, nlabels)
|
||||
elif label1 > label2:
|
||||
order = 1
|
||||
if nlabels > 0:
|
||||
namereln = NAMERELN_COMMONANCESTOR
|
||||
namereln = NameRelation.COMMONANCESTOR
|
||||
return (namereln, order, nlabels)
|
||||
nlabels += 1
|
||||
order = ldiff
|
||||
if ldiff < 0:
|
||||
namereln = NAMERELN_SUPERDOMAIN
|
||||
namereln = NameRelation.SUPERDOMAIN
|
||||
elif ldiff > 0:
|
||||
namereln = NAMERELN_SUBDOMAIN
|
||||
namereln = NameRelation.SUBDOMAIN
|
||||
else:
|
||||
namereln = NAMERELN_EQUAL
|
||||
namereln = NameRelation.EQUAL
|
||||
return (namereln, order, nlabels)
|
||||
|
||||
def is_subdomain(self, other):
|
||||
def is_subdomain(self, other: "Name") -> bool:
|
||||
"""Is self a subdomain of other?
|
||||
|
||||
Note that the notion of subdomain includes equality, e.g.
|
||||
"dnpython.org" is a subdomain of itself.
|
||||
"dnspython.org" is a subdomain of itself.
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
(nr, _, _) = self.fullcompare(other)
|
||||
if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
|
||||
if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_superdomain(self, other):
|
||||
def is_superdomain(self, other: "Name") -> bool:
|
||||
"""Is self a superdomain of other?
|
||||
|
||||
Note that the notion of superdomain includes equality, e.g.
|
||||
"dnpython.org" is a superdomain of itself.
|
||||
"dnspython.org" is a superdomain of itself.
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
(nr, _, _) = self.fullcompare(other)
|
||||
if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
|
||||
if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL:
|
||||
return True
|
||||
return False
|
||||
|
||||
def canonicalize(self):
|
||||
def canonicalize(self) -> "Name":
|
||||
"""Return a name which is equal to the current name, but is in
|
||||
DNSSEC canonical form.
|
||||
"""
|
||||
|
@ -516,12 +560,12 @@ class Name:
|
|||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return '<DNS name ' + self.__str__() + '>'
|
||||
return "<DNS name " + self.__str__() + ">"
|
||||
|
||||
def __str__(self):
|
||||
return self.to_text(False)
|
||||
|
||||
def to_text(self, omit_final_dot=False):
|
||||
def to_text(self, omit_final_dot: bool = False) -> str:
|
||||
"""Convert name to DNS text format.
|
||||
|
||||
*omit_final_dot* is a ``bool``. If True, don't emit the final
|
||||
|
@ -532,17 +576,19 @@ class Name:
|
|||
"""
|
||||
|
||||
if len(self.labels) == 0:
|
||||
return '@'
|
||||
if len(self.labels) == 1 and self.labels[0] == b'':
|
||||
return '.'
|
||||
return "@"
|
||||
if len(self.labels) == 1 and self.labels[0] == b"":
|
||||
return "."
|
||||
if omit_final_dot and self.is_absolute():
|
||||
l = self.labels[:-1]
|
||||
else:
|
||||
l = self.labels
|
||||
s = '.'.join(map(_escapify, l))
|
||||
s = ".".join(map(_escapify, l))
|
||||
return s
|
||||
|
||||
def to_unicode(self, omit_final_dot=False, idna_codec=None):
|
||||
def to_unicode(
|
||||
self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None
|
||||
) -> str:
|
||||
"""Convert name to Unicode text format.
|
||||
|
||||
IDN ACE labels are converted to Unicode.
|
||||
|
@ -561,18 +607,18 @@ class Name:
|
|||
"""
|
||||
|
||||
if len(self.labels) == 0:
|
||||
return '@'
|
||||
if len(self.labels) == 1 and self.labels[0] == b'':
|
||||
return '.'
|
||||
return "@"
|
||||
if len(self.labels) == 1 and self.labels[0] == b"":
|
||||
return "."
|
||||
if omit_final_dot and self.is_absolute():
|
||||
l = self.labels[:-1]
|
||||
else:
|
||||
l = self.labels
|
||||
if idna_codec is None:
|
||||
idna_codec = IDNA_2003_Practical
|
||||
return '.'.join([idna_codec.decode(x) for x in l])
|
||||
return ".".join([idna_codec.decode(x) for x in l])
|
||||
|
||||
def to_digestable(self, origin=None):
|
||||
def to_digestable(self, origin: Optional["Name"] = None) -> bytes:
|
||||
"""Convert name to a format suitable for digesting in hashes.
|
||||
|
||||
The name is canonicalized and converted to uncompressed wire
|
||||
|
@ -589,10 +635,17 @@ class Name:
|
|||
Returns a ``bytes``.
|
||||
"""
|
||||
|
||||
return self.to_wire(origin=origin, canonicalize=True)
|
||||
digest = self.to_wire(origin=origin, canonicalize=True)
|
||||
assert digest is not None
|
||||
return digest
|
||||
|
||||
def to_wire(self, file=None, compress=None, origin=None,
|
||||
canonicalize=False):
|
||||
def to_wire(
|
||||
self,
|
||||
file: Optional[Any] = None,
|
||||
compress: Optional[CompressType] = None,
|
||||
origin: Optional["Name"] = None,
|
||||
canonicalize: bool = False,
|
||||
) -> Optional[bytes]:
|
||||
"""Convert name to wire format, possibly compressing it.
|
||||
|
||||
*file* is the file where the name is emitted (typically an
|
||||
|
@ -638,6 +691,7 @@ class Name:
|
|||
out += label
|
||||
return bytes(out)
|
||||
|
||||
labels: Iterable[bytes]
|
||||
if not self.is_absolute():
|
||||
if origin is None or not origin.is_absolute():
|
||||
raise NeedAbsoluteNameOrOrigin
|
||||
|
@ -654,24 +708,25 @@ class Name:
|
|||
else:
|
||||
pos = None
|
||||
if pos is not None:
|
||||
value = 0xc000 + pos
|
||||
s = struct.pack('!H', value)
|
||||
value = 0xC000 + pos
|
||||
s = struct.pack("!H", value)
|
||||
file.write(s)
|
||||
break
|
||||
else:
|
||||
if compress is not None and len(n) > 1:
|
||||
pos = file.tell()
|
||||
if pos <= 0x3fff:
|
||||
if pos <= 0x3FFF:
|
||||
compress[n] = pos
|
||||
l = len(label)
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
if l > 0:
|
||||
if canonicalize:
|
||||
file.write(label.lower())
|
||||
else:
|
||||
file.write(label)
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""The length of the name (in labels).
|
||||
|
||||
Returns an ``int``.
|
||||
|
@ -688,7 +743,7 @@ class Name:
|
|||
def __sub__(self, other):
|
||||
return self.relativize(other)
|
||||
|
||||
def split(self, depth):
|
||||
def split(self, depth: int) -> Tuple["Name", "Name"]:
|
||||
"""Split a name into a prefix and suffix names at the specified depth.
|
||||
|
||||
*depth* is an ``int`` specifying the number of labels in the suffix
|
||||
|
@ -705,11 +760,10 @@ class Name:
|
|||
elif depth == l:
|
||||
return (dns.name.empty, self)
|
||||
elif depth < 0 or depth > l:
|
||||
raise ValueError(
|
||||
'depth must be >= 0 and <= the length of the name')
|
||||
return (Name(self[: -depth]), Name(self[-depth:]))
|
||||
raise ValueError("depth must be >= 0 and <= the length of the name")
|
||||
return (Name(self[:-depth]), Name(self[-depth:]))
|
||||
|
||||
def concatenate(self, other):
|
||||
def concatenate(self, other: "Name") -> "Name":
|
||||
"""Return a new name which is the concatenation of self and other.
|
||||
|
||||
Raises ``dns.name.AbsoluteConcatenation`` if the name is
|
||||
|
@ -724,7 +778,7 @@ class Name:
|
|||
labels.extend(list(other.labels))
|
||||
return Name(labels)
|
||||
|
||||
def relativize(self, origin):
|
||||
def relativize(self, origin: "Name") -> "Name":
|
||||
"""If the name is a subdomain of *origin*, return a new name which is
|
||||
the name relative to origin. Otherwise return the name.
|
||||
|
||||
|
@ -740,7 +794,7 @@ class Name:
|
|||
else:
|
||||
return self
|
||||
|
||||
def derelativize(self, origin):
|
||||
def derelativize(self, origin: "Name") -> "Name":
|
||||
"""If the name is a relative name, return a new name which is the
|
||||
concatenation of the name and origin. Otherwise return the name.
|
||||
|
||||
|
@ -756,7 +810,9 @@ class Name:
|
|||
else:
|
||||
return self
|
||||
|
||||
def choose_relativity(self, origin=None, relativize=True):
|
||||
def choose_relativity(
|
||||
self, origin: Optional["Name"] = None, relativize: bool = True
|
||||
) -> "Name":
|
||||
"""Return a name with the relativity desired by the caller.
|
||||
|
||||
If *origin* is ``None``, then the name is returned.
|
||||
|
@ -775,7 +831,7 @@ class Name:
|
|||
else:
|
||||
return self
|
||||
|
||||
def parent(self):
|
||||
def parent(self) -> "Name":
|
||||
"""Return the parent of the name.
|
||||
|
||||
For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``.
|
||||
|
@ -790,13 +846,17 @@ class Name:
|
|||
raise NoParent
|
||||
return Name(self.labels[1:])
|
||||
|
||||
|
||||
#: The root name, '.'
|
||||
root = Name([b''])
|
||||
root = Name([b""])
|
||||
|
||||
#: The empty name.
|
||||
empty = Name([])
|
||||
|
||||
def from_unicode(text, origin=root, idna_codec=None):
|
||||
|
||||
def from_unicode(
|
||||
text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None
|
||||
) -> Name:
|
||||
"""Convert unicode text into a Name object.
|
||||
|
||||
Labels are encoded in IDN ACE form according to rules specified by
|
||||
|
@ -819,17 +879,17 @@ def from_unicode(text, origin=root, idna_codec=None):
|
|||
if not (origin is None or isinstance(origin, Name)):
|
||||
raise ValueError("origin must be a Name or None")
|
||||
labels = []
|
||||
label = ''
|
||||
label = ""
|
||||
escaping = False
|
||||
edigits = 0
|
||||
total = 0
|
||||
if idna_codec is None:
|
||||
idna_codec = IDNA_2003
|
||||
if text == '@':
|
||||
text = ''
|
||||
if text == "@":
|
||||
text = ""
|
||||
if text:
|
||||
if text in ['.', '\u3002', '\uff0e', '\uff61']:
|
||||
return Name([b'']) # no Unicode "u" on this constant!
|
||||
if text in [".", "\u3002", "\uff0e", "\uff61"]:
|
||||
return Name([b""]) # no Unicode "u" on this constant!
|
||||
for c in text:
|
||||
if escaping:
|
||||
if edigits == 0:
|
||||
|
@ -848,12 +908,12 @@ def from_unicode(text, origin=root, idna_codec=None):
|
|||
if edigits == 3:
|
||||
escaping = False
|
||||
label += chr(total)
|
||||
elif c in ['.', '\u3002', '\uff0e', '\uff61']:
|
||||
elif c in [".", "\u3002", "\uff0e", "\uff61"]:
|
||||
if len(label) == 0:
|
||||
raise EmptyLabel
|
||||
labels.append(idna_codec.encode(label))
|
||||
label = ''
|
||||
elif c == '\\':
|
||||
label = ""
|
||||
elif c == "\\":
|
||||
escaping = True
|
||||
edigits = 0
|
||||
total = 0
|
||||
|
@ -864,22 +924,28 @@ def from_unicode(text, origin=root, idna_codec=None):
|
|||
if len(label) > 0:
|
||||
labels.append(idna_codec.encode(label))
|
||||
else:
|
||||
labels.append(b'')
|
||||
labels.append(b"")
|
||||
|
||||
if (len(labels) == 0 or labels[-1] != b'') and origin is not None:
|
||||
if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
|
||||
labels.extend(list(origin.labels))
|
||||
return Name(labels)
|
||||
|
||||
def is_all_ascii(text):
|
||||
|
||||
def is_all_ascii(text: str) -> bool:
|
||||
for c in text:
|
||||
if ord(c) > 0x7f:
|
||||
if ord(c) > 0x7F:
|
||||
return False
|
||||
return True
|
||||
|
||||
def from_text(text, origin=root, idna_codec=None):
|
||||
|
||||
def from_text(
|
||||
text: Union[bytes, str],
|
||||
origin: Optional[Name] = root,
|
||||
idna_codec: Optional[IDNACodec] = None,
|
||||
) -> Name:
|
||||
"""Convert text into a Name object.
|
||||
|
||||
*text*, a ``str``, is the text to convert into a name.
|
||||
*text*, a ``bytes`` or ``str``, is the text to convert into a name.
|
||||
|
||||
*origin*, a ``dns.name.Name``, specifies the origin to
|
||||
append to non-absolute names. The default is the root name.
|
||||
|
@ -903,23 +969,23 @@ def from_text(text, origin=root, idna_codec=None):
|
|||
#
|
||||
# then it's still "all ASCII" even though the domain name has
|
||||
# codepoints > 127.
|
||||
text = text.encode('ascii')
|
||||
text = text.encode("ascii")
|
||||
if not isinstance(text, bytes):
|
||||
raise ValueError("input to from_text() must be a string")
|
||||
if not (origin is None or isinstance(origin, Name)):
|
||||
raise ValueError("origin must be a Name or None")
|
||||
labels = []
|
||||
label = b''
|
||||
label = b""
|
||||
escaping = False
|
||||
edigits = 0
|
||||
total = 0
|
||||
if text == b'@':
|
||||
text = b''
|
||||
if text == b"@":
|
||||
text = b""
|
||||
if text:
|
||||
if text == b'.':
|
||||
return Name([b''])
|
||||
if text == b".":
|
||||
return Name([b""])
|
||||
for c in text:
|
||||
byte_ = struct.pack('!B', c)
|
||||
byte_ = struct.pack("!B", c)
|
||||
if escaping:
|
||||
if edigits == 0:
|
||||
if byte_.isdigit():
|
||||
|
@ -936,13 +1002,13 @@ def from_text(text, origin=root, idna_codec=None):
|
|||
edigits += 1
|
||||
if edigits == 3:
|
||||
escaping = False
|
||||
label += struct.pack('!B', total)
|
||||
elif byte_ == b'.':
|
||||
label += struct.pack("!B", total)
|
||||
elif byte_ == b".":
|
||||
if len(label) == 0:
|
||||
raise EmptyLabel
|
||||
labels.append(label)
|
||||
label = b''
|
||||
elif byte_ == b'\\':
|
||||
label = b""
|
||||
elif byte_ == b"\\":
|
||||
escaping = True
|
||||
edigits = 0
|
||||
total = 0
|
||||
|
@ -953,13 +1019,16 @@ def from_text(text, origin=root, idna_codec=None):
|
|||
if len(label) > 0:
|
||||
labels.append(label)
|
||||
else:
|
||||
labels.append(b'')
|
||||
if (len(labels) == 0 or labels[-1] != b'') and origin is not None:
|
||||
labels.append(b"")
|
||||
if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
|
||||
labels.extend(list(origin.labels))
|
||||
return Name(labels)
|
||||
|
||||
|
||||
def from_wire_parser(parser):
|
||||
# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
|
||||
|
||||
|
||||
def from_wire_parser(parser: "dns.wire.Parser") -> Name:
|
||||
"""Convert possibly compressed wire format into a Name.
|
||||
|
||||
*parser* is a dns.wire.Parser.
|
||||
|
@ -980,7 +1049,7 @@ def from_wire_parser(parser):
|
|||
if count < 64:
|
||||
labels.append(parser.get_bytes(count))
|
||||
elif count >= 192:
|
||||
current = (count & 0x3f) * 256 + parser.get_uint8()
|
||||
current = (count & 0x3F) * 256 + parser.get_uint8()
|
||||
if current >= biggest_pointer:
|
||||
raise BadPointer
|
||||
biggest_pointer = current
|
||||
|
@ -988,11 +1057,11 @@ def from_wire_parser(parser):
|
|||
else:
|
||||
raise BadLabelType
|
||||
count = parser.get_uint8()
|
||||
labels.append(b'')
|
||||
labels.append(b"")
|
||||
return Name(labels)
|
||||
|
||||
|
||||
def from_wire(message, current):
|
||||
def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
|
||||
"""Convert possibly compressed wire format into a Name.
|
||||
|
||||
*message* is a ``bytes`` containing an entire DNS message in DNS
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
from typing import Optional, Union, Tuple, Iterable, List
|
||||
|
||||
have_idna_2008: bool
|
||||
|
||||
class Name:
|
||||
def is_subdomain(self, o : Name) -> bool: ...
|
||||
def is_superdomain(self, o : Name) -> bool: ...
|
||||
def __init__(self, labels : Iterable[Union[bytes,str]]) -> None:
|
||||
self.labels : List[bytes]
|
||||
def is_absolute(self) -> bool: ...
|
||||
def is_wild(self) -> bool: ...
|
||||
def fullcompare(self, other) -> Tuple[int,int,int]: ...
|
||||
def canonicalize(self) -> Name: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __ne__(self, other) -> bool: ...
|
||||
def __lt__(self, other : Name) -> bool: ...
|
||||
def __le__(self, other : Name) -> bool: ...
|
||||
def __ge__(self, other : Name) -> bool: ...
|
||||
def __gt__(self, other : Name) -> bool: ...
|
||||
def to_text(self, omit_final_dot=False) -> str: ...
|
||||
def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ...
|
||||
def to_digestable(self, origin=None) -> bytes: ...
|
||||
def to_wire(self, file=None, compress=None, origin=None,
|
||||
canonicalize=False) -> Optional[bytes]: ...
|
||||
def __add__(self, other : Name) -> Name: ...
|
||||
def __sub__(self, other : Name) -> Name: ...
|
||||
def split(self, depth) -> List[Tuple[str,str]]: ...
|
||||
def concatenate(self, other : Name) -> Name: ...
|
||||
def relativize(self, origin) -> Name: ...
|
||||
def derelativize(self, origin) -> Name: ...
|
||||
def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ...
|
||||
def parent(self) -> Name: ...
|
||||
|
||||
class IDNACodec:
|
||||
pass
|
||||
|
||||
def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name:
|
||||
...
|
||||
|
||||
empty : Name
|
|
@ -27,7 +27,8 @@
|
|||
|
||||
"""DNS name dictionary"""
|
||||
|
||||
from collections.abc import MutableMapping
|
||||
# pylint seems to be confused about this one!
|
||||
from collections.abc import MutableMapping # pylint: disable=no-name-in-module
|
||||
|
||||
import dns.name
|
||||
|
||||
|
@ -62,7 +63,7 @@ class NameDict(MutableMapping):
|
|||
|
||||
def __setitem__(self, key, value):
|
||||
if not isinstance(key, dns.name.Name):
|
||||
raise ValueError('NameDict key must be a name')
|
||||
raise ValueError("NameDict key must be a name")
|
||||
self.__store[key] = value
|
||||
self.__update_max_depth(key)
|
||||
|
||||
|
|
120
lib/dns/node.py
120
lib/dns/node.py
|
@ -17,12 +17,17 @@
|
|||
|
||||
"""DNS nodes. A node is a set of rdatasets."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import enum
|
||||
import io
|
||||
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdataclass
|
||||
import dns.rdataset
|
||||
import dns.rdatatype
|
||||
import dns.rrset
|
||||
import dns.renderer
|
||||
|
||||
|
||||
|
@ -37,21 +42,23 @@ _neutral_types = {
|
|||
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
|
||||
}
|
||||
|
||||
|
||||
def _matches_type_or_its_signature(rdtypes, rdtype, covers):
|
||||
return rdtype in rdtypes or \
|
||||
(rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
|
||||
return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
|
||||
|
||||
|
||||
@enum.unique
|
||||
class NodeKind(enum.Enum):
|
||||
"""Rdatasets in nodes
|
||||
"""
|
||||
"""Rdatasets in nodes"""
|
||||
|
||||
REGULAR = 0 # a.k.a "other data"
|
||||
NEUTRAL = 1
|
||||
CNAME = 2
|
||||
|
||||
@classmethod
|
||||
def classify(cls, rdtype, covers):
|
||||
def classify(
|
||||
cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType
|
||||
) -> "NodeKind":
|
||||
if _matches_type_or_its_signature(_cname_types, rdtype, covers):
|
||||
return NodeKind.CNAME
|
||||
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
|
||||
|
@ -60,7 +67,7 @@ class NodeKind(enum.Enum):
|
|||
return NodeKind.REGULAR
|
||||
|
||||
@classmethod
|
||||
def classify_rdataset(cls, rdataset):
|
||||
def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind":
|
||||
return cls.classify(rdataset.rdtype, rdataset.covers)
|
||||
|
||||
|
||||
|
@ -81,19 +88,19 @@ class Node:
|
|||
deleted.
|
||||
"""
|
||||
|
||||
__slots__ = ['rdatasets']
|
||||
__slots__ = ["rdatasets"]
|
||||
|
||||
def __init__(self):
|
||||
# the set of rdatasets, represented as a list.
|
||||
self.rdatasets = []
|
||||
|
||||
def to_text(self, name, **kw):
|
||||
def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str:
|
||||
"""Convert a node to text format.
|
||||
|
||||
Each rdataset at the node is printed. Any keyword arguments
|
||||
to this method are passed on to the rdataset's to_text() method.
|
||||
|
||||
*name*, a ``dns.name.Name`` or ``str``, the owner name of the
|
||||
*name*, a ``dns.name.Name``, the owner name of the
|
||||
rdatasets.
|
||||
|
||||
Returns a ``str``.
|
||||
|
@ -103,12 +110,12 @@ class Node:
|
|||
s = io.StringIO()
|
||||
for rds in self.rdatasets:
|
||||
if len(rds) > 0:
|
||||
s.write(rds.to_text(name, **kw))
|
||||
s.write('\n')
|
||||
s.write(rds.to_text(name, **kw)) # type: ignore[arg-type]
|
||||
s.write("\n")
|
||||
return s.getvalue()[:-1]
|
||||
|
||||
def __repr__(self):
|
||||
return '<DNS node ' + str(id(self)) + '>'
|
||||
return "<DNS node " + str(id(self)) + ">"
|
||||
|
||||
def __eq__(self, other):
|
||||
#
|
||||
|
@ -144,27 +151,36 @@ class Node:
|
|||
if len(self.rdatasets) > 0:
|
||||
kind = NodeKind.classify_rdataset(rdataset)
|
||||
if kind == NodeKind.CNAME:
|
||||
self.rdatasets = [rds for rds in self.rdatasets if
|
||||
NodeKind.classify_rdataset(rds) !=
|
||||
NodeKind.REGULAR]
|
||||
self.rdatasets = [
|
||||
rds
|
||||
for rds in self.rdatasets
|
||||
if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR
|
||||
]
|
||||
elif kind == NodeKind.REGULAR:
|
||||
self.rdatasets = [rds for rds in self.rdatasets if
|
||||
NodeKind.classify_rdataset(rds) !=
|
||||
NodeKind.CNAME]
|
||||
self.rdatasets = [
|
||||
rds
|
||||
for rds in self.rdatasets
|
||||
if NodeKind.classify_rdataset(rds) != NodeKind.CNAME
|
||||
]
|
||||
# Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
|
||||
# edit self.rdatasets.
|
||||
self.rdatasets.append(rdataset)
|
||||
|
||||
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
|
||||
create=False):
|
||||
def find_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
create: bool = False,
|
||||
) -> dns.rdataset.Rdataset:
|
||||
"""Find an rdataset matching the specified properties in the
|
||||
current node.
|
||||
|
||||
*rdclass*, an ``int``, the class of the rdataset.
|
||||
*rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
|
||||
|
||||
*rdtype*, an ``int``, the type of the rdataset.
|
||||
*rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
|
||||
|
||||
*covers*, an ``int`` or ``None``, the covered type.
|
||||
*covers*, a ``dns.rdatatype.RdataType``, the covered type.
|
||||
Usually this value is ``dns.rdatatype.NONE``, but if the
|
||||
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
|
||||
then the covers value will be the rdata type the SIG/RRSIG
|
||||
|
@ -191,8 +207,13 @@ class Node:
|
|||
self._append_rdataset(rds)
|
||||
return rds
|
||||
|
||||
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
|
||||
create=False):
|
||||
def get_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
create: bool = False,
|
||||
) -> Optional[dns.rdataset.Rdataset]:
|
||||
"""Get an rdataset matching the specified properties in the
|
||||
current node.
|
||||
|
||||
|
@ -223,7 +244,12 @@ class Node:
|
|||
rds = None
|
||||
return rds
|
||||
|
||||
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
|
||||
def delete_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
) -> None:
|
||||
"""Delete the rdataset matching the specified properties in the
|
||||
current node.
|
||||
|
||||
|
@ -240,7 +266,7 @@ class Node:
|
|||
if rds is not None:
|
||||
self.rdatasets.remove(rds)
|
||||
|
||||
def replace_rdataset(self, replacement):
|
||||
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
|
||||
"""Replace an rdataset.
|
||||
|
||||
It is not an error if there is no rdataset matching *replacement*.
|
||||
|
@ -256,16 +282,17 @@ class Node:
|
|||
"""
|
||||
|
||||
if not isinstance(replacement, dns.rdataset.Rdataset):
|
||||
raise ValueError('replacement is not an rdataset')
|
||||
raise ValueError("replacement is not an rdataset")
|
||||
if isinstance(replacement, dns.rrset.RRset):
|
||||
# RRsets are not good replacements as the match() method
|
||||
# is not compatible.
|
||||
replacement = replacement.to_rdataset()
|
||||
self.delete_rdataset(replacement.rdclass, replacement.rdtype,
|
||||
replacement.covers)
|
||||
self.delete_rdataset(
|
||||
replacement.rdclass, replacement.rdtype, replacement.covers
|
||||
)
|
||||
self._append_rdataset(replacement)
|
||||
|
||||
def classify(self):
|
||||
def classify(self) -> NodeKind:
|
||||
"""Classify a node.
|
||||
|
||||
A node which contains a CNAME or RRSIG(CNAME) is a
|
||||
|
@ -286,7 +313,7 @@ class Node:
|
|||
return kind
|
||||
return NodeKind.NEUTRAL
|
||||
|
||||
def is_immutable(self):
|
||||
def is_immutable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
@ -298,23 +325,38 @@ class ImmutableNode(Node):
|
|||
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
|
||||
)
|
||||
|
||||
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
|
||||
create=False):
|
||||
def find_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
create: bool = False,
|
||||
) -> dns.rdataset.Rdataset:
|
||||
if create:
|
||||
raise TypeError("immutable")
|
||||
return super().find_rdataset(rdclass, rdtype, covers, False)
|
||||
|
||||
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
|
||||
create=False):
|
||||
def get_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
create: bool = False,
|
||||
) -> Optional[dns.rdataset.Rdataset]:
|
||||
if create:
|
||||
raise TypeError("immutable")
|
||||
return super().get_rdataset(rdclass, rdtype, covers, False)
|
||||
|
||||
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
|
||||
def delete_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
) -> None:
|
||||
raise TypeError("immutable")
|
||||
|
||||
def replace_rdataset(self, replacement):
|
||||
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
|
||||
raise TypeError("immutable")
|
||||
|
||||
def is_immutable(self):
|
||||
def is_immutable(self) -> bool:
|
||||
return True
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
from typing import List, Optional, Union
|
||||
from . import rdataset, rdatatype, name
|
||||
class Node:
|
||||
def __init__(self):
|
||||
self.rdatasets : List[rdataset.Rdataset]
|
||||
def to_text(self, name : Union[str,name.Name], **kw) -> str:
|
||||
...
|
||||
def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
|
||||
create=False) -> rdataset.Rdataset:
|
||||
...
|
||||
def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
|
||||
create=False) -> Optional[rdataset.Rdataset]:
|
||||
...
|
||||
def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE):
|
||||
...
|
||||
def replace_rdataset(self, replacement : rdataset.Rdataset) -> None:
|
||||
...
|
|
@ -20,6 +20,7 @@
|
|||
import dns.enum
|
||||
import dns.exception
|
||||
|
||||
|
||||
class Opcode(dns.enum.IntEnum):
|
||||
#: Query
|
||||
QUERY = 0
|
||||
|
@ -45,7 +46,7 @@ class UnknownOpcode(dns.exception.DNSException):
|
|||
"""An DNS opcode is unknown."""
|
||||
|
||||
|
||||
def from_text(text):
|
||||
def from_text(text: str) -> Opcode:
|
||||
"""Convert text into an opcode.
|
||||
|
||||
*text*, a ``str``, the textual opcode
|
||||
|
@ -58,7 +59,7 @@ def from_text(text):
|
|||
return Opcode.from_text(text)
|
||||
|
||||
|
||||
def from_flags(flags):
|
||||
def from_flags(flags: int) -> Opcode:
|
||||
"""Extract an opcode from DNS message flags.
|
||||
|
||||
*flags*, an ``int``, the DNS flags.
|
||||
|
@ -66,10 +67,10 @@ def from_flags(flags):
|
|||
Returns an ``int``.
|
||||
"""
|
||||
|
||||
return (flags & 0x7800) >> 11
|
||||
return Opcode((flags & 0x7800) >> 11)
|
||||
|
||||
|
||||
def to_flags(value):
|
||||
def to_flags(value: Opcode) -> int:
|
||||
"""Convert an opcode to a value suitable for ORing into DNS message
|
||||
flags.
|
||||
|
||||
|
@ -81,7 +82,7 @@ def to_flags(value):
|
|||
return (value << 11) & 0x7800
|
||||
|
||||
|
||||
def to_text(value):
|
||||
def to_text(value: Opcode) -> str:
|
||||
"""Convert an opcode to text.
|
||||
|
||||
*value*, an ``int`` the opcode value,
|
||||
|
@ -94,7 +95,7 @@ def to_text(value):
|
|||
return Opcode.to_text(value)
|
||||
|
||||
|
||||
def is_update(flags):
|
||||
def is_update(flags: int) -> bool:
|
||||
"""Is the opcode in flags UPDATE?
|
||||
|
||||
*flags*, an ``int``, the DNS message flags.
|
||||
|
@ -104,6 +105,7 @@ def is_update(flags):
|
|||
|
||||
return from_flags(flags) == Opcode.UPDATE
|
||||
|
||||
|
||||
### BEGIN generated Opcode constants
|
||||
|
||||
QUERY = Opcode.QUERY
|
||||
|
|
677
lib/dns/query.py
677
lib/dns/query.py
File diff suppressed because it is too large
Load diff
|
@ -1,64 +0,0 @@
|
|||
from typing import Optional, Union, Dict, Generator, Any
|
||||
from . import tsig, rdatatype, rdataclass, name, message
|
||||
from requests.sessions import Session
|
||||
|
||||
import socket
|
||||
|
||||
# If the ssl import works, then
|
||||
#
|
||||
# error: Name 'ssl' already defined (by an import)
|
||||
#
|
||||
# is expected and can be ignored.
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
class ssl: # type: ignore
|
||||
SSLContext : Dict = {}
|
||||
|
||||
have_doh: bool
|
||||
|
||||
def https(q : message.Message, where: str, timeout : Optional[float] = None,
|
||||
port : Optional[int] = 443, source : Optional[str] = None,
|
||||
source_port : Optional[int] = 0,
|
||||
session: Optional[Session] = None,
|
||||
path : Optional[str] = '/dns-query', post : Optional[bool] = True,
|
||||
bootstrap_address : Optional[str] = None,
|
||||
verify : Optional[bool] = True) -> message.Message:
|
||||
pass
|
||||
|
||||
def tcp(q : message.Message, where : str, timeout : float = None, port=53,
|
||||
af : Optional[int] = None, source : Optional[str] = None,
|
||||
source_port : Optional[int] = 0,
|
||||
one_rr_per_rrset : Optional[bool] = False,
|
||||
ignore_trailing : Optional[bool] = False,
|
||||
sock : Optional[socket.socket] = None) -> message.Message:
|
||||
pass
|
||||
|
||||
def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR,
|
||||
rdclass=rdataclass.IN,
|
||||
timeout : Optional[float] = None, port=53,
|
||||
keyring : Optional[Dict[name.Name, bytes]] = None,
|
||||
keyname : Union[str,name.Name]= None, relativize=True,
|
||||
lifetime : Optional[float] = None,
|
||||
source : Optional[str] = None, source_port=0, serial=0,
|
||||
use_udp : Optional[bool] = False,
|
||||
keyalgorithm=tsig.default_algorithm) \
|
||||
-> Generator[Any,Any,message.Message]:
|
||||
pass
|
||||
|
||||
def udp(q : message.Message, where : str, timeout : Optional[float] = None,
|
||||
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
|
||||
ignore_unexpected : Optional[bool] = False,
|
||||
one_rr_per_rrset : Optional[bool] = False,
|
||||
ignore_trailing : Optional[bool] = False,
|
||||
sock : Optional[socket.socket] = None) -> message.Message:
|
||||
pass
|
||||
|
||||
def tls(q : message.Message, where : str, timeout : Optional[float] = None,
|
||||
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
|
||||
one_rr_per_rrset : Optional[bool] = False,
|
||||
ignore_trailing : Optional[bool] = False,
|
||||
sock : Optional[socket.socket] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
server_hostname: Optional[str] = None) -> message.Message:
|
||||
pass
|
74
lib/dns/quic/__init__.py
Normal file
74
lib/dns/quic/__init__.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
try:
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
|
||||
import dns.asyncbackend
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
|
||||
from dns.quic._asyncio import (
|
||||
AsyncioQuicManager,
|
||||
AsyncioQuicConnection,
|
||||
AsyncioQuicStream,
|
||||
)
|
||||
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
|
||||
|
||||
have_quic = True
|
||||
|
||||
def null_factory(
|
||||
*args, # pylint: disable=unused-argument
|
||||
**kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
return NullContext(None)
|
||||
|
||||
def _asyncio_manager_factory(
|
||||
context, *args, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
return AsyncioQuicManager(*args, **kwargs)
|
||||
|
||||
# We have a context factory and a manager factory as for trio we need to have
|
||||
# a nursery.
|
||||
|
||||
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
|
||||
|
||||
try:
|
||||
import trio
|
||||
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
|
||||
TrioQuicManager,
|
||||
TrioQuicConnection,
|
||||
TrioQuicStream,
|
||||
)
|
||||
|
||||
def _trio_context_factory():
|
||||
return trio.open_nursery()
|
||||
|
||||
def _trio_manager_factory(context, *args, **kwargs):
|
||||
return TrioQuicManager(context, *args, **kwargs)
|
||||
|
||||
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def factories_for_backend(backend=None):
|
||||
if backend is None:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
return _async_factories[backend.name()]
|
||||
|
||||
except ImportError:
|
||||
have_quic = False
|
||||
|
||||
from typing import Any
|
||||
|
||||
class AsyncQuicStream: # type: ignore
|
||||
pass
|
||||
|
||||
class AsyncQuicConnection: # type: ignore
|
||||
async def make_stream(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
class SyncQuicStream: # type: ignore
|
||||
pass
|
||||
|
||||
class SyncQuicConnection: # type: ignore
|
||||
def make_stream(self) -> Any:
|
||||
raise NotImplementedError
|
206
lib/dns/quic/_asyncio.py
Normal file
206
lib/dns/quic/_asyncio.py
Normal file
|
@ -0,0 +1,206 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import ssl
|
||||
import struct
|
||||
import time
|
||||
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import aioquic.quic.events # type: ignore
|
||||
import dns.inet
|
||||
import dns.asyncbackend
|
||||
|
||||
from dns.quic._common import (
|
||||
BaseQuicStream,
|
||||
AsyncQuicConnection,
|
||||
AsyncQuicManager,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
)
|
||||
|
||||
|
||||
class AsyncioQuicStream(BaseQuicStream):
|
||||
def __init__(self, connection, stream_id):
|
||||
super().__init__(connection, stream_id)
|
||||
self._wake_up = asyncio.Condition()
|
||||
|
||||
async def _wait_for_wake_up(self):
|
||||
async with self._wake_up:
|
||||
await self._wake_up.wait()
|
||||
|
||||
async def wait_for(self, amount, expiration):
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
while True:
|
||||
if self._buffer.have(amount):
|
||||
return
|
||||
self._expecting = amount
|
||||
try:
|
||||
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
|
||||
except Exception:
|
||||
pass
|
||||
self._expecting = 0
|
||||
|
||||
async def receive(self, timeout=None):
|
||||
expiration = self._expiration_from_timeout(timeout)
|
||||
await self.wait_for(2, expiration)
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size, expiration)
|
||||
return self._buffer.get(size)
|
||||
|
||||
async def send(self, datagram, is_end=False):
|
||||
data = self._encapsulate(datagram)
|
||||
await self._connection.write(self._stream_id, data, is_end)
|
||||
|
||||
async def _add_input(self, data, is_end):
|
||||
if self._common_add_input(data, is_end):
|
||||
async with self._wake_up:
|
||||
self._wake_up.notify()
|
||||
|
||||
async def close(self):
|
||||
self._close()
|
||||
|
||||
# Streams are async context managers
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
async with self._wake_up:
|
||||
self._wake_up.notify()
|
||||
return False
|
||||
|
||||
|
||||
class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
def __init__(self, connection, address, port, source, source_port, manager=None):
|
||||
super().__init__(connection, address, port, source, source_port, manager)
|
||||
self._socket = None
|
||||
self._handshake_complete = asyncio.Event()
|
||||
self._socket_created = asyncio.Event()
|
||||
self._wake_timer = asyncio.Condition()
|
||||
self._receiver_task = None
|
||||
self._sender_task = None
|
||||
|
||||
async def _receiver(self):
|
||||
try:
|
||||
af = dns.inet.af_for_address(self._address)
|
||||
backend = dns.asyncbackend.get_backend("asyncio")
|
||||
self._socket = await backend.make_socket(
|
||||
af, socket.SOCK_DGRAM, 0, self._source, self._peer
|
||||
)
|
||||
self._socket_created.set()
|
||||
async with self._socket:
|
||||
while not self._done:
|
||||
(datagram, address) = await self._socket.recvfrom(
|
||||
QUIC_MAX_DATAGRAM, None
|
||||
)
|
||||
if address[0] != self._peer[0] or address[1] != self._peer[1]:
|
||||
continue
|
||||
self._connection.receive_datagram(
|
||||
datagram, self._peer[0], time.time()
|
||||
)
|
||||
# Wake up the timer in case the sender is sleeping, as there may be
|
||||
# stuff to send now.
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _wait_for_wake_timer(self):
|
||||
async with self._wake_timer:
|
||||
await self._wake_timer.wait()
|
||||
|
||||
async def _sender(self):
|
||||
await self._socket_created.wait()
|
||||
while not self._done:
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for (datagram, address) in datagrams:
|
||||
assert address == self._peer[0]
|
||||
await self._socket.sendto(datagram, self._peer, None)
|
||||
(expiration, interval) = self._get_timer_values()
|
||||
try:
|
||||
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
|
||||
except Exception:
|
||||
pass
|
||||
self._handle_timer(expiration)
|
||||
await self._handle_events()
|
||||
|
||||
async def _handle_events(self):
|
||||
count = 0
|
||||
while True:
|
||||
event = self._connection.next_event()
|
||||
if event is None:
|
||||
return
|
||||
if isinstance(event, aioquic.quic.events.StreamDataReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(
|
||||
event, aioquic.quic.events.ConnectionTerminated
|
||||
) or isinstance(event, aioquic.quic.events.StreamReset):
|
||||
self._done = True
|
||||
self._receiver_task.cancel()
|
||||
count += 1
|
||||
if count > 10:
|
||||
# yield
|
||||
count = 0
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def write(self, stream, data, is_end=False):
|
||||
self._connection.send_stream_data(stream, data, is_end)
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
|
||||
def run(self):
|
||||
if self._closed:
|
||||
return
|
||||
self._receiver_task = asyncio.Task(self._receiver())
|
||||
self._sender_task = asyncio.Task(self._sender())
|
||||
|
||||
async def make_stream(self):
|
||||
await self._handshake_complete.wait()
|
||||
stream_id = self._connection.get_next_available_stream_id(False)
|
||||
stream = AsyncioQuicStream(self, stream_id)
|
||||
self._streams[stream_id] = stream
|
||||
return stream
|
||||
|
||||
async def close(self):
|
||||
if not self._closed:
|
||||
self._manager.closed(self._peer[0], self._peer[1])
|
||||
self._closed = True
|
||||
self._connection.close()
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
try:
|
||||
await self._receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
try:
|
||||
await self._sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncioQuicManager(AsyncQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
||||
super().__init__(conf, verify_mode, AsyncioQuicConnection)
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
if start:
|
||||
connection.run()
|
||||
return connection
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Copy the itertor into a list as exiting things will mutate the connections
|
||||
# table.
|
||||
connections = list(self._connections.values())
|
||||
for connection in connections:
|
||||
await connection.close()
|
||||
return False
|
180
lib/dns/quic/_common.py
Normal file
180
lib/dns/quic/_common.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
|
||||
from typing import Any
|
||||
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import dns.inet
|
||||
|
||||
|
||||
QUIC_MAX_DATAGRAM = 2048
|
||||
|
||||
|
||||
class UnexpectedEOF(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Buffer:
|
||||
def __init__(self):
|
||||
self._buffer = b""
|
||||
self._seen_end = False
|
||||
|
||||
def put(self, data, is_end):
|
||||
if self._seen_end:
|
||||
return
|
||||
self._buffer += data
|
||||
if is_end:
|
||||
self._seen_end = True
|
||||
|
||||
def have(self, amount):
|
||||
if len(self._buffer) >= amount:
|
||||
return True
|
||||
if self._seen_end:
|
||||
raise UnexpectedEOF
|
||||
return False
|
||||
|
||||
def seen_end(self):
|
||||
return self._seen_end
|
||||
|
||||
def get(self, amount):
|
||||
assert self.have(amount)
|
||||
data = self._buffer[:amount]
|
||||
self._buffer = self._buffer[amount:]
|
||||
return data
|
||||
|
||||
|
||||
class BaseQuicStream:
|
||||
def __init__(self, connection, stream_id):
|
||||
self._connection = connection
|
||||
self._stream_id = stream_id
|
||||
self._buffer = Buffer()
|
||||
self._expecting = 0
|
||||
|
||||
def id(self):
|
||||
return self._stream_id
|
||||
|
||||
def _expiration_from_timeout(self, timeout):
|
||||
if timeout is not None:
|
||||
expiration = time.time() + timeout
|
||||
else:
|
||||
expiration = None
|
||||
return expiration
|
||||
|
||||
def _timeout_from_expiration(self, expiration):
|
||||
if expiration is not None:
|
||||
timeout = max(expiration - time.time(), 0.0)
|
||||
else:
|
||||
timeout = None
|
||||
return timeout
|
||||
|
||||
# Subclass must implement receive() as sync / async and which returns a message
|
||||
# or raises UnexpectedEOF.
|
||||
|
||||
def _encapsulate(self, datagram):
|
||||
l = len(datagram)
|
||||
return struct.pack("!H", l) + datagram
|
||||
|
||||
def _common_add_input(self, data, is_end):
|
||||
self._buffer.put(data, is_end)
|
||||
return self._expecting > 0 and self._buffer.have(self._expecting)
|
||||
|
||||
def _close(self):
|
||||
self._connection.close_stream(self._stream_id)
|
||||
self._buffer.put(b"", True) # send EOF in case we haven't seen it.
|
||||
|
||||
|
||||
class BaseQuicConnection:
|
||||
def __init__(
|
||||
self, connection, address, port, source=None, source_port=0, manager=None
|
||||
):
|
||||
self._done = False
|
||||
self._connection = connection
|
||||
self._address = address
|
||||
self._port = port
|
||||
self._closed = False
|
||||
self._manager = manager
|
||||
self._streams = {}
|
||||
self._af = dns.inet.af_for_address(address)
|
||||
self._peer = dns.inet.low_level_address_tuple((address, port))
|
||||
if source is None and source_port != 0:
|
||||
if self._af == socket.AF_INET:
|
||||
source = "0.0.0.0"
|
||||
elif self._af == socket.AF_INET6:
|
||||
source = "::"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if source:
|
||||
self._source = (source, source_port)
|
||||
else:
|
||||
self._source = None
|
||||
|
||||
def close_stream(self, stream_id):
|
||||
del self._streams[stream_id]
|
||||
|
||||
def _get_timer_values(self, closed_is_special=True):
|
||||
now = time.time()
|
||||
expiration = self._connection.get_timer()
|
||||
if expiration is None:
|
||||
expiration = now + 3600 # arbitrary "big" value
|
||||
interval = max(expiration - now, 0)
|
||||
if self._closed and closed_is_special:
|
||||
# lower sleep interval to avoid a race in the closing process
|
||||
# which can lead to higher latency closing due to sleeping when
|
||||
# we have events.
|
||||
interval = min(interval, 0.05)
|
||||
return (expiration, interval)
|
||||
|
||||
def _handle_timer(self, expiration):
|
||||
now = time.time()
|
||||
if expiration <= now:
|
||||
self._connection.handle_timer(now)
|
||||
|
||||
|
||||
class AsyncQuicConnection(BaseQuicConnection):
|
||||
async def make_stream(self) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class BaseQuicManager:
|
||||
def __init__(self, conf, verify_mode, connection_factory):
|
||||
self._connections = {}
|
||||
self._connection_factory = connection_factory
|
||||
if conf is None:
|
||||
verify_path = None
|
||||
if isinstance(verify_mode, str):
|
||||
verify_path = verify_mode
|
||||
verify_mode = True
|
||||
conf = aioquic.quic.configuration.QuicConfiguration(
|
||||
alpn_protocols=["doq", "doq-i03"],
|
||||
verify_mode=verify_mode,
|
||||
)
|
||||
if verify_path is not None:
|
||||
conf.load_verify_locations(verify_path)
|
||||
self._conf = conf
|
||||
|
||||
def _connect(self, address, port=853, source=None, source_port=0):
|
||||
connection = self._connections.get((address, port))
|
||||
if connection is not None:
|
||||
return (connection, False)
|
||||
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
|
||||
qconn.connect(address, time.time())
|
||||
connection = self._connection_factory(
|
||||
qconn, address, port, source, source_port, self
|
||||
)
|
||||
self._connections[(address, port)] = connection
|
||||
return (connection, True)
|
||||
|
||||
def closed(self, address, port):
|
||||
try:
|
||||
del self._connections[(address, port)]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncQuicManager(BaseQuicManager):
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
raise NotImplementedError
|
214
lib/dns/quic/_sync.py
Normal file
214
lib/dns/quic/_sync.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import socket
|
||||
import ssl
|
||||
import selectors
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import aioquic.quic.events # type: ignore
|
||||
import dns.inet
|
||||
|
||||
from dns.quic._common import (
|
||||
BaseQuicStream,
|
||||
BaseQuicConnection,
|
||||
BaseQuicManager,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
)
|
||||
|
||||
# Avoid circularity with dns.query
|
||||
if hasattr(selectors, "PollSelector"):
|
||||
_selector_class = selectors.PollSelector # type: ignore
|
||||
else:
|
||||
_selector_class = selectors.SelectSelector # type: ignore
|
||||
|
||||
|
||||
class SyncQuicStream(BaseQuicStream):
|
||||
def __init__(self, connection, stream_id):
|
||||
super().__init__(connection, stream_id)
|
||||
self._wake_up = threading.Condition()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def wait_for(self, amount, expiration):
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
while True:
|
||||
with self._lock:
|
||||
if self._buffer.have(amount):
|
||||
return
|
||||
self._expecting = amount
|
||||
with self._wake_up:
|
||||
self._wake_up.wait(timeout)
|
||||
self._expecting = 0
|
||||
|
||||
def receive(self, timeout=None):
|
||||
expiration = self._expiration_from_timeout(timeout)
|
||||
self.wait_for(2, expiration)
|
||||
with self._lock:
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
self.wait_for(size, expiration)
|
||||
with self._lock:
|
||||
return self._buffer.get(size)
|
||||
|
||||
def send(self, datagram, is_end=False):
|
||||
data = self._encapsulate(datagram)
|
||||
self._connection.write(self._stream_id, data, is_end)
|
||||
|
||||
def _add_input(self, data, is_end):
|
||||
if self._common_add_input(data, is_end):
|
||||
with self._wake_up:
|
||||
self._wake_up.notify()
|
||||
|
||||
def close(self):
|
||||
with self._lock:
|
||||
self._close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
with self._wake_up:
|
||||
self._wake_up.notify()
|
||||
return False
|
||||
|
||||
|
||||
class SyncQuicConnection(BaseQuicConnection):
|
||||
def __init__(self, connection, address, port, source, source_port, manager):
|
||||
super().__init__(connection, address, port, source, source_port, manager)
|
||||
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
|
||||
self._socket.connect(self._peer)
|
||||
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
|
||||
self._receive_wakeup.setblocking(False)
|
||||
self._socket.setblocking(False)
|
||||
if self._source is not None:
|
||||
try:
|
||||
self._socket.bind(
|
||||
dns.inet.low_level_address_tuple(self._source, self._af)
|
||||
)
|
||||
except Exception:
|
||||
self._socket.close()
|
||||
raise
|
||||
self._handshake_complete = threading.Event()
|
||||
self._worker_thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _read(self):
|
||||
count = 0
|
||||
while count < 10:
|
||||
count += 1
|
||||
try:
|
||||
datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
|
||||
except BlockingIOError:
|
||||
return
|
||||
with self._lock:
|
||||
self._connection.receive_datagram(datagram, self._peer[0], time.time())
|
||||
|
||||
def _drain_wakeup(self):
|
||||
while True:
|
||||
try:
|
||||
self._receive_wakeup.recv(32)
|
||||
except BlockingIOError:
|
||||
return
|
||||
|
||||
def _worker(self):
|
||||
sel = _selector_class()
|
||||
sel.register(self._socket, selectors.EVENT_READ, self._read)
|
||||
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
|
||||
while not self._done:
|
||||
(expiration, interval) = self._get_timer_values(False)
|
||||
items = sel.select(interval)
|
||||
for (key, _) in items:
|
||||
key.data()
|
||||
with self._lock:
|
||||
self._handle_timer(expiration)
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for (datagram, _) in datagrams:
|
||||
try:
|
||||
self._socket.send(datagram)
|
||||
except BlockingIOError:
|
||||
# we let QUIC handle any lossage
|
||||
pass
|
||||
self._handle_events()
|
||||
|
||||
def _handle_events(self):
|
||||
while True:
|
||||
with self._lock:
|
||||
event = self._connection.next_event()
|
||||
if event is None:
|
||||
return
|
||||
if isinstance(event, aioquic.quic.events.StreamDataReceived):
|
||||
with self._lock:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(
|
||||
event, aioquic.quic.events.ConnectionTerminated
|
||||
) or isinstance(event, aioquic.quic.events.StreamReset):
|
||||
with self._lock:
|
||||
self._done = True
|
||||
|
||||
def write(self, stream, data, is_end=False):
|
||||
with self._lock:
|
||||
self._connection.send_stream_data(stream, data, is_end)
|
||||
self._send_wakeup.send(b"\x01")
|
||||
|
||||
def run(self):
|
||||
if self._closed:
|
||||
return
|
||||
self._worker_thread = threading.Thread(target=self._worker)
|
||||
self._worker_thread.start()
|
||||
|
||||
def make_stream(self):
|
||||
self._handshake_complete.wait()
|
||||
with self._lock:
|
||||
stream_id = self._connection.get_next_available_stream_id(False)
|
||||
stream = SyncQuicStream(self, stream_id)
|
||||
self._streams[stream_id] = stream
|
||||
return stream
|
||||
|
||||
def close_stream(self, stream_id):
|
||||
with self._lock:
|
||||
super().close_stream(stream_id)
|
||||
|
||||
def close(self):
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
return
|
||||
self._manager.closed(self._peer[0], self._peer[1])
|
||||
self._closed = True
|
||||
self._connection.close()
|
||||
self._send_wakeup.send(b"\x01")
|
||||
self._worker_thread.join()
|
||||
|
||||
|
||||
class SyncQuicManager(BaseQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
||||
super().__init__(conf, verify_mode, SyncQuicConnection)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
with self._lock:
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
if start:
|
||||
connection.run()
|
||||
return connection
|
||||
|
||||
def closed(self, address, port):
|
||||
with self._lock:
|
||||
super().closed(address, port)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Copy the itertor into a list as exiting things will mutate the connections
|
||||
# table.
|
||||
connections = list(self._connections.values())
|
||||
for connection in connections:
|
||||
connection.close()
|
||||
return False
|
170
lib/dns/quic/_trio.py
Normal file
170
lib/dns/quic/_trio.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import socket
|
||||
import ssl
|
||||
import struct
|
||||
import time
|
||||
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import aioquic.quic.events # type: ignore
|
||||
import trio
|
||||
|
||||
import dns.inet
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.quic._common import (
|
||||
BaseQuicStream,
|
||||
AsyncQuicConnection,
|
||||
AsyncQuicManager,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
)
|
||||
|
||||
|
||||
class TrioQuicStream(BaseQuicStream):
|
||||
def __init__(self, connection, stream_id):
|
||||
super().__init__(connection, stream_id)
|
||||
self._wake_up = trio.Condition()
|
||||
|
||||
async def wait_for(self, amount):
|
||||
while True:
|
||||
if self._buffer.have(amount):
|
||||
return
|
||||
self._expecting = amount
|
||||
async with self._wake_up:
|
||||
await self._wake_up.wait()
|
||||
self._expecting = 0
|
||||
|
||||
async def receive(self, timeout=None):
|
||||
if timeout is None:
|
||||
context = NullContext(None)
|
||||
else:
|
||||
context = trio.move_on_after(timeout)
|
||||
with context:
|
||||
await self.wait_for(2)
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size)
|
||||
return self._buffer.get(size)
|
||||
|
||||
async def send(self, datagram, is_end=False):
|
||||
data = self._encapsulate(datagram)
|
||||
await self._connection.write(self._stream_id, data, is_end)
|
||||
|
||||
async def _add_input(self, data, is_end):
|
||||
if self._common_add_input(data, is_end):
|
||||
async with self._wake_up:
|
||||
self._wake_up.notify()
|
||||
|
||||
async def close(self):
|
||||
self._close()
|
||||
|
||||
# Streams are async context managers
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
async with self._wake_up:
|
||||
self._wake_up.notify()
|
||||
return False
|
||||
|
||||
|
||||
class TrioQuicConnection(AsyncQuicConnection):
|
||||
def __init__(self, connection, address, port, source, source_port, manager=None):
|
||||
super().__init__(connection, address, port, source, source_port, manager)
|
||||
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
|
||||
if self._source:
|
||||
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
|
||||
self._handshake_complete = trio.Event()
|
||||
self._run_done = trio.Event()
|
||||
self._worker_scope = None
|
||||
|
||||
async def _worker(self):
|
||||
await self._socket.connect(self._peer)
|
||||
while not self._done:
|
||||
(expiration, interval) = self._get_timer_values(False)
|
||||
with trio.CancelScope(
|
||||
deadline=trio.current_time() + interval
|
||||
) as self._worker_scope:
|
||||
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
|
||||
self._connection.receive_datagram(datagram, self._peer[0], time.time())
|
||||
self._worker_scope = None
|
||||
self._handle_timer(expiration)
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for (datagram, _) in datagrams:
|
||||
await self._socket.send(datagram)
|
||||
await self._handle_events()
|
||||
|
||||
async def _handle_events(self):
|
||||
count = 0
|
||||
while True:
|
||||
event = self._connection.next_event()
|
||||
if event is None:
|
||||
return
|
||||
if isinstance(event, aioquic.quic.events.StreamDataReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(
|
||||
event, aioquic.quic.events.ConnectionTerminated
|
||||
) or isinstance(event, aioquic.quic.events.StreamReset):
|
||||
self._done = True
|
||||
self._socket.close()
|
||||
count += 1
|
||||
if count > 10:
|
||||
# yield
|
||||
count = 0
|
||||
await trio.sleep(0)
|
||||
|
||||
async def write(self, stream, data, is_end=False):
|
||||
self._connection.send_stream_data(stream, data, is_end)
|
||||
if self._worker_scope is not None:
|
||||
self._worker_scope.cancel()
|
||||
|
||||
async def run(self):
|
||||
if self._closed:
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._worker)
|
||||
self._run_done.set()
|
||||
|
||||
async def make_stream(self):
|
||||
await self._handshake_complete.wait()
|
||||
stream_id = self._connection.get_next_available_stream_id(False)
|
||||
stream = TrioQuicStream(self, stream_id)
|
||||
self._streams[stream_id] = stream
|
||||
return stream
|
||||
|
||||
async def close(self):
|
||||
if not self._closed:
|
||||
self._manager.closed(self._peer[0], self._peer[1])
|
||||
self._closed = True
|
||||
self._connection.close()
|
||||
if self._worker_scope is not None:
|
||||
self._worker_scope.cancel()
|
||||
await self._run_done.wait()
|
||||
|
||||
|
||||
class TrioQuicManager(AsyncQuicManager):
|
||||
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
||||
super().__init__(conf, verify_mode, TrioQuicConnection)
|
||||
self._nursery = nursery
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
if start:
|
||||
self._nursery.start_soon(connection.run)
|
||||
return connection
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Copy the itertor into a list as exiting things will mutate the connections
|
||||
# table.
|
||||
connections = list(self._connections.values())
|
||||
for connection in connections:
|
||||
await connection.close()
|
||||
return False
|
|
@ -17,9 +17,12 @@
|
|||
|
||||
"""DNS Result Codes."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import dns.enum
|
||||
import dns.exception
|
||||
|
||||
|
||||
class Rcode(dns.enum.IntEnum):
|
||||
#: No error
|
||||
NOERROR = 0
|
||||
|
@ -77,20 +80,20 @@ class UnknownRcode(dns.exception.DNSException):
|
|||
"""A DNS rcode is unknown."""
|
||||
|
||||
|
||||
def from_text(text):
|
||||
def from_text(text: str) -> Rcode:
|
||||
"""Convert text into an rcode.
|
||||
|
||||
*text*, a ``str``, the textual rcode or an integer in textual form.
|
||||
|
||||
Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown.
|
||||
|
||||
Returns an ``int``.
|
||||
Returns a ``dns.rcode.Rcode``.
|
||||
"""
|
||||
|
||||
return Rcode.from_text(text)
|
||||
|
||||
|
||||
def from_flags(flags, ednsflags):
|
||||
def from_flags(flags: int, ednsflags: int) -> Rcode:
|
||||
"""Return the rcode value encoded by flags and ednsflags.
|
||||
|
||||
*flags*, an ``int``, the DNS flags field.
|
||||
|
@ -99,17 +102,17 @@ def from_flags(flags, ednsflags):
|
|||
|
||||
Raises ``ValueError`` if rcode is < 0 or > 4095
|
||||
|
||||
Returns an ``int``.
|
||||
Returns a ``dns.rcode.Rcode``.
|
||||
"""
|
||||
|
||||
value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0)
|
||||
return value
|
||||
value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0)
|
||||
return Rcode.make(value)
|
||||
|
||||
|
||||
def to_flags(value):
|
||||
def to_flags(value: Rcode) -> Tuple[int, int]:
|
||||
"""Return a (flags, ednsflags) tuple which encodes the rcode.
|
||||
|
||||
*value*, an ``int``, the rcode.
|
||||
*value*, a ``dns.rcode.Rcode``, the rcode.
|
||||
|
||||
Raises ``ValueError`` if rcode is < 0 or > 4095.
|
||||
|
||||
|
@ -117,16 +120,16 @@ def to_flags(value):
|
|||
"""
|
||||
|
||||
if value < 0 or value > 4095:
|
||||
raise ValueError('rcode must be >= 0 and <= 4095')
|
||||
v = value & 0xf
|
||||
ev = (value & 0xff0) << 20
|
||||
raise ValueError("rcode must be >= 0 and <= 4095")
|
||||
v = value & 0xF
|
||||
ev = (value & 0xFF0) << 20
|
||||
return (v, ev)
|
||||
|
||||
|
||||
def to_text(value, tsig=False):
|
||||
def to_text(value: Rcode, tsig: bool = False) -> str:
|
||||
"""Convert rcode into text.
|
||||
|
||||
*value*, an ``int``, the rcode.
|
||||
*value*, a ``dns.rcode.Rcode``, the rcode.
|
||||
|
||||
Raises ``ValueError`` if rcode is < 0 or > 4095.
|
||||
|
||||
|
@ -134,9 +137,10 @@ def to_text(value, tsig=False):
|
|||
"""
|
||||
|
||||
if tsig and value == Rcode.BADVERS:
|
||||
return 'BADSIG'
|
||||
return "BADSIG"
|
||||
return Rcode.to_text(value)
|
||||
|
||||
|
||||
### BEGIN generated Rcode constants
|
||||
|
||||
NOERROR = Rcode.NOERROR
|
||||
|
|
372
lib/dns/rdata.py
372
lib/dns/rdata.py
|
@ -17,6 +17,8 @@
|
|||
|
||||
"""DNS rdata."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from importlib import import_module
|
||||
import base64
|
||||
import binascii
|
||||
|
@ -55,21 +57,22 @@ class NoRelativeRdataOrdering(dns.exception.DNSException):
|
|||
"""
|
||||
|
||||
|
||||
def _wordbreak(data, chunksize=_chunksize, separator=b' '):
|
||||
def _wordbreak(data, chunksize=_chunksize, separator=b" "):
|
||||
"""Break a binary string into chunks of chunksize characters separated by
|
||||
a space.
|
||||
"""
|
||||
|
||||
if not chunksize:
|
||||
return data.decode()
|
||||
return separator.join([data[i:i + chunksize]
|
||||
for i
|
||||
in range(0, len(data), chunksize)]).decode()
|
||||
return separator.join(
|
||||
[data[i : i + chunksize] for i in range(0, len(data), chunksize)]
|
||||
).decode()
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
|
||||
|
||||
def _hexify(data, chunksize=_chunksize, separator=b" ", **kw):
|
||||
"""Convert a binary string into its hex encoding, broken up into chunks
|
||||
of chunksize characters separated by a separator.
|
||||
"""
|
||||
|
@ -77,17 +80,19 @@ def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
|
|||
return _wordbreak(binascii.hexlify(data), chunksize, separator)
|
||||
|
||||
|
||||
def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw):
|
||||
def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw):
|
||||
"""Convert a binary string into its base64 encoding, broken up into chunks
|
||||
of chunksize characters separated by a separator.
|
||||
"""
|
||||
|
||||
return _wordbreak(base64.b64encode(data), chunksize, separator)
|
||||
|
||||
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
__escaped = b'"\\'
|
||||
|
||||
|
||||
def _escapify(qstring):
|
||||
"""Escape the characters in a quoted string which need it."""
|
||||
|
||||
|
@ -96,14 +101,14 @@ def _escapify(qstring):
|
|||
if not isinstance(qstring, bytearray):
|
||||
qstring = bytearray(qstring)
|
||||
|
||||
text = ''
|
||||
text = ""
|
||||
for c in qstring:
|
||||
if c in __escaped:
|
||||
text += '\\' + chr(c)
|
||||
text += "\\" + chr(c)
|
||||
elif c >= 0x20 and c < 0x7F:
|
||||
text += chr(c)
|
||||
else:
|
||||
text += '\\%03d' % c
|
||||
text += "\\%03d" % c
|
||||
return text
|
||||
|
||||
|
||||
|
@ -114,9 +119,10 @@ def _truncate_bitmap(what):
|
|||
|
||||
for i in range(len(what) - 1, -1, -1):
|
||||
if what[i] != 0:
|
||||
return what[0: i + 1]
|
||||
return what[0 : i + 1]
|
||||
return what[0:1]
|
||||
|
||||
|
||||
# So we don't have to edit all the rdata classes...
|
||||
_constify = dns.immutable.constify
|
||||
|
||||
|
@ -125,7 +131,7 @@ _constify = dns.immutable.constify
|
|||
class Rdata:
|
||||
"""Base class for all DNS rdata types."""
|
||||
|
||||
__slots__ = ['rdclass', 'rdtype', 'rdcomment']
|
||||
__slots__ = ["rdclass", "rdtype", "rdcomment"]
|
||||
|
||||
def __init__(self, rdclass, rdtype):
|
||||
"""Initialize an rdata.
|
||||
|
@ -140,8 +146,9 @@ class Rdata:
|
|||
self.rdcomment = None
|
||||
|
||||
def _get_all_slots(self):
|
||||
return itertools.chain.from_iterable(getattr(cls, '__slots__', [])
|
||||
for cls in self.__class__.__mro__)
|
||||
return itertools.chain.from_iterable(
|
||||
getattr(cls, "__slots__", []) for cls in self.__class__.__mro__
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
# We used to try to do a tuple of all slots here, but it
|
||||
|
@ -160,12 +167,12 @@ class Rdata:
|
|||
def __setstate__(self, state):
|
||||
for slot, val in state.items():
|
||||
object.__setattr__(self, slot, val)
|
||||
if not hasattr(self, 'rdcomment'):
|
||||
if not hasattr(self, "rdcomment"):
|
||||
# Pickled rdata from 2.0.x might not have a rdcomment, so add
|
||||
# it if needed.
|
||||
object.__setattr__(self, 'rdcomment', None)
|
||||
object.__setattr__(self, "rdcomment", None)
|
||||
|
||||
def covers(self):
|
||||
def covers(self) -> dns.rdatatype.RdataType:
|
||||
"""Return the type a Rdata covers.
|
||||
|
||||
DNS SIG/RRSIG rdatas apply to a specific type; this type is
|
||||
|
@ -174,12 +181,12 @@ class Rdata:
|
|||
creating rdatasets, allowing the rdataset to contain only RRSIGs
|
||||
of a particular type, e.g. RRSIG(NS).
|
||||
|
||||
Returns an ``int``.
|
||||
Returns a ``dns.rdatatype.RdataType``.
|
||||
"""
|
||||
|
||||
return dns.rdatatype.NONE
|
||||
|
||||
def extended_rdatatype(self):
|
||||
def extended_rdatatype(self) -> int:
|
||||
"""Return a 32-bit type value, the least significant 16 bits of
|
||||
which are the ordinary DNS type, and the upper 16 bits of which are
|
||||
the "covered" type, if any.
|
||||
|
@ -189,7 +196,12 @@ class Rdata:
|
|||
|
||||
return self.covers() << 16 | self.rdtype
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
def to_text(
|
||||
self,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
**kw: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Convert an rdata to text format.
|
||||
|
||||
Returns a ``str``.
|
||||
|
@ -197,11 +209,22 @@ class Rdata:
|
|||
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
def _to_wire(
|
||||
self,
|
||||
file: Optional[Any],
|
||||
compress: Optional[dns.name.CompressType] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
canonicalize: bool = False,
|
||||
) -> bytes:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def to_wire(self, file=None, compress=None, origin=None,
|
||||
canonicalize=False):
|
||||
def to_wire(
|
||||
self,
|
||||
file: Optional[Any] = None,
|
||||
compress: Optional[dns.name.CompressType] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
canonicalize: bool = False,
|
||||
) -> bytes:
|
||||
"""Convert an rdata to wire format.
|
||||
|
||||
Returns a ``bytes`` or ``None``.
|
||||
|
@ -214,15 +237,18 @@ class Rdata:
|
|||
self._to_wire(f, compress, origin, canonicalize)
|
||||
return f.getvalue()
|
||||
|
||||
def to_generic(self, origin=None):
|
||||
def to_generic(
|
||||
self, origin: Optional[dns.name.Name] = None
|
||||
) -> "dns.rdata.GenericRdata":
|
||||
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
|
||||
|
||||
Returns a ``dns.rdata.GenericRdata``.
|
||||
"""
|
||||
return dns.rdata.GenericRdata(self.rdclass, self.rdtype,
|
||||
self.to_wire(origin=origin))
|
||||
return dns.rdata.GenericRdata(
|
||||
self.rdclass, self.rdtype, self.to_wire(origin=origin)
|
||||
)
|
||||
|
||||
def to_digestable(self, origin=None):
|
||||
def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes:
|
||||
"""Convert rdata to a format suitable for digesting in hashes. This
|
||||
is also the DNSSEC canonical form.
|
||||
|
||||
|
@ -234,12 +260,19 @@ class Rdata:
|
|||
def __repr__(self):
|
||||
covers = self.covers()
|
||||
if covers == dns.rdatatype.NONE:
|
||||
ctext = ''
|
||||
ctext = ""
|
||||
else:
|
||||
ctext = '(' + dns.rdatatype.to_text(covers) + ')'
|
||||
return '<DNS ' + dns.rdataclass.to_text(self.rdclass) + ' ' + \
|
||||
dns.rdatatype.to_text(self.rdtype) + ctext + ' rdata: ' + \
|
||||
str(self) + '>'
|
||||
ctext = "(" + dns.rdatatype.to_text(covers) + ")"
|
||||
return (
|
||||
"<DNS "
|
||||
+ dns.rdataclass.to_text(self.rdclass)
|
||||
+ " "
|
||||
+ dns.rdatatype.to_text(self.rdtype)
|
||||
+ ctext
|
||||
+ " rdata: "
|
||||
+ str(self)
|
||||
+ ">"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_text()
|
||||
|
@ -320,27 +353,39 @@ class Rdata:
|
|||
return not self.__eq__(other)
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Rdata) or \
|
||||
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
|
||||
if (
|
||||
not isinstance(other, Rdata)
|
||||
or self.rdclass != other.rdclass
|
||||
or self.rdtype != other.rdtype
|
||||
):
|
||||
|
||||
return NotImplemented
|
||||
return self._cmp(other) < 0
|
||||
|
||||
def __le__(self, other):
|
||||
if not isinstance(other, Rdata) or \
|
||||
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
|
||||
if (
|
||||
not isinstance(other, Rdata)
|
||||
or self.rdclass != other.rdclass
|
||||
or self.rdtype != other.rdtype
|
||||
):
|
||||
return NotImplemented
|
||||
return self._cmp(other) <= 0
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, Rdata) or \
|
||||
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
|
||||
if (
|
||||
not isinstance(other, Rdata)
|
||||
or self.rdclass != other.rdclass
|
||||
or self.rdtype != other.rdtype
|
||||
):
|
||||
return NotImplemented
|
||||
return self._cmp(other) >= 0
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, Rdata) or \
|
||||
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
|
||||
if (
|
||||
not isinstance(other, Rdata)
|
||||
or self.rdclass != other.rdclass
|
||||
or self.rdtype != other.rdtype
|
||||
):
|
||||
return NotImplemented
|
||||
return self._cmp(other) > 0
|
||||
|
||||
|
@ -348,15 +393,28 @@ class Rdata:
|
|||
return hash(self.to_digestable(dns.name.root))
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
tok: dns.tokenizer.Tokenizer,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
relativize_to: Optional[dns.name.Name] = None,
|
||||
) -> "Rdata":
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
def from_wire_parser(
|
||||
cls,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
parser: dns.wire.Parser,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
) -> "Rdata":
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def replace(self, **kwargs):
|
||||
def replace(self, **kwargs: Any) -> "Rdata":
|
||||
"""
|
||||
Create a new Rdata instance based on the instance replace was
|
||||
invoked on. It is possible to pass different parameters to
|
||||
|
@ -369,19 +427,25 @@ class Rdata:
|
|||
"""
|
||||
|
||||
# Get the constructor parameters.
|
||||
parameters = inspect.signature(self.__init__).parameters
|
||||
parameters = inspect.signature(self.__init__).parameters # type: ignore
|
||||
|
||||
# Ensure that all of the arguments correspond to valid fields.
|
||||
# Don't allow rdclass or rdtype to be changed, though.
|
||||
for key in kwargs:
|
||||
if key == 'rdcomment':
|
||||
if key == "rdcomment":
|
||||
continue
|
||||
if key not in parameters:
|
||||
raise AttributeError("'{}' object has no attribute '{}'"
|
||||
.format(self.__class__.__name__, key))
|
||||
if key in ('rdclass', 'rdtype'):
|
||||
raise AttributeError("Cannot overwrite '{}' attribute '{}'"
|
||||
.format(self.__class__.__name__, key))
|
||||
raise AttributeError(
|
||||
"'{}' object has no attribute '{}'".format(
|
||||
self.__class__.__name__, key
|
||||
)
|
||||
)
|
||||
if key in ("rdclass", "rdtype"):
|
||||
raise AttributeError(
|
||||
"Cannot overwrite '{}' attribute '{}'".format(
|
||||
self.__class__.__name__, key
|
||||
)
|
||||
)
|
||||
|
||||
# Construct the parameter list. For each field, use the value in
|
||||
# kwargs if present, and the current value otherwise.
|
||||
|
@ -391,9 +455,9 @@ class Rdata:
|
|||
rd = self.__class__(*args)
|
||||
# The comment is not set in the constructor, so give it special
|
||||
# handling.
|
||||
rdcomment = kwargs.get('rdcomment', self.rdcomment)
|
||||
rdcomment = kwargs.get("rdcomment", self.rdcomment)
|
||||
if rdcomment is not None:
|
||||
object.__setattr__(rd, 'rdcomment', rdcomment)
|
||||
object.__setattr__(rd, "rdcomment", rdcomment)
|
||||
return rd
|
||||
|
||||
# Type checking and conversion helpers. These are class methods as
|
||||
|
@ -408,18 +472,26 @@ class Rdata:
|
|||
return dns.rdatatype.RdataType.make(value)
|
||||
|
||||
@classmethod
|
||||
def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True):
|
||||
def _as_bytes(
|
||||
cls,
|
||||
value: Any,
|
||||
encode: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
empty_ok: bool = True,
|
||||
) -> bytes:
|
||||
if encode and isinstance(value, str):
|
||||
value = value.encode()
|
||||
bvalue = value.encode()
|
||||
elif isinstance(value, bytearray):
|
||||
value = bytes(value)
|
||||
elif not isinstance(value, bytes):
|
||||
raise ValueError('not bytes')
|
||||
if max_length is not None and len(value) > max_length:
|
||||
raise ValueError('too long')
|
||||
if not empty_ok and len(value) == 0:
|
||||
raise ValueError('empty bytes not allowed')
|
||||
return value
|
||||
bvalue = bytes(value)
|
||||
elif isinstance(value, bytes):
|
||||
bvalue = value
|
||||
else:
|
||||
raise ValueError("not bytes")
|
||||
if max_length is not None and len(bvalue) > max_length:
|
||||
raise ValueError("too long")
|
||||
if not empty_ok and len(bvalue) == 0:
|
||||
raise ValueError("empty bytes not allowed")
|
||||
return bvalue
|
||||
|
||||
@classmethod
|
||||
def _as_name(cls, value):
|
||||
|
@ -429,49 +501,49 @@ class Rdata:
|
|||
if isinstance(value, str):
|
||||
return dns.name.from_text(value)
|
||||
elif not isinstance(value, dns.name.Name):
|
||||
raise ValueError('not a name')
|
||||
raise ValueError("not a name")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _as_uint8(cls, value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError('not an integer')
|
||||
raise ValueError("not an integer")
|
||||
if value < 0 or value > 255:
|
||||
raise ValueError('not a uint8')
|
||||
raise ValueError("not a uint8")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _as_uint16(cls, value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError('not an integer')
|
||||
raise ValueError("not an integer")
|
||||
if value < 0 or value > 65535:
|
||||
raise ValueError('not a uint16')
|
||||
raise ValueError("not a uint16")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _as_uint32(cls, value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError('not an integer')
|
||||
raise ValueError("not an integer")
|
||||
if value < 0 or value > 4294967295:
|
||||
raise ValueError('not a uint32')
|
||||
raise ValueError("not a uint32")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _as_uint48(cls, value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError('not an integer')
|
||||
raise ValueError("not an integer")
|
||||
if value < 0 or value > 281474976710655:
|
||||
raise ValueError('not a uint48')
|
||||
raise ValueError("not a uint48")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _as_int(cls, value, low=None, high=None):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError('not an integer')
|
||||
raise ValueError("not an integer")
|
||||
if low is not None and value < low:
|
||||
raise ValueError('value too small')
|
||||
raise ValueError("value too small")
|
||||
if high is not None and value > high:
|
||||
raise ValueError('value too large')
|
||||
raise ValueError("value too large")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
|
@ -483,7 +555,7 @@ class Rdata:
|
|||
elif isinstance(value, bytes):
|
||||
return dns.ipv4.inet_ntoa(value)
|
||||
else:
|
||||
raise ValueError('not an IPv4 address')
|
||||
raise ValueError("not an IPv4 address")
|
||||
|
||||
@classmethod
|
||||
def _as_ipv6_address(cls, value):
|
||||
|
@ -494,14 +566,14 @@ class Rdata:
|
|||
elif isinstance(value, bytes):
|
||||
return dns.ipv6.inet_ntoa(value)
|
||||
else:
|
||||
raise ValueError('not an IPv6 address')
|
||||
raise ValueError("not an IPv6 address")
|
||||
|
||||
@classmethod
|
||||
def _as_bool(cls, value):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
else:
|
||||
raise ValueError('not a boolean')
|
||||
raise ValueError("not a boolean")
|
||||
|
||||
@classmethod
|
||||
def _as_ttl(cls, value):
|
||||
|
@ -510,7 +582,7 @@ class Rdata:
|
|||
elif isinstance(value, str):
|
||||
return dns.ttl.from_text(value)
|
||||
else:
|
||||
raise ValueError('not a TTL')
|
||||
raise ValueError("not a TTL")
|
||||
|
||||
@classmethod
|
||||
def _as_tuple(cls, value, as_value):
|
||||
|
@ -532,6 +604,7 @@ class Rdata:
|
|||
return items
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class GenericRdata(Rdata):
|
||||
|
||||
"""Generic Rdata Class
|
||||
|
@ -540,28 +613,32 @@ class GenericRdata(Rdata):
|
|||
implementation. It implements the DNS "unknown RRs" scheme.
|
||||
"""
|
||||
|
||||
__slots__ = ['data']
|
||||
__slots__ = ["data"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, data):
|
||||
super().__init__(rdclass, rdtype)
|
||||
object.__setattr__(self, 'data', data)
|
||||
self.data = data
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return r'\# %d ' % len(self.data) + _hexify(self.data, **kw)
|
||||
def to_text(
|
||||
self,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
**kw: Dict[str, Any]
|
||||
) -> str:
|
||||
return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
token = tok.get()
|
||||
if not token.is_identifier() or token.value != r'\#':
|
||||
raise dns.exception.SyntaxError(
|
||||
r'generic rdata does not start with \#')
|
||||
if not token.is_identifier() or token.value != r"\#":
|
||||
raise dns.exception.SyntaxError(r"generic rdata does not start with \#")
|
||||
length = tok.get_int()
|
||||
hex = tok.concatenate_remaining_identifiers(True).encode()
|
||||
data = binascii.unhexlify(hex)
|
||||
if len(data) != length:
|
||||
raise dns.exception.SyntaxError(
|
||||
'generic rdata hex data has wrong length')
|
||||
raise dns.exception.SyntaxError("generic rdata hex data has wrong length")
|
||||
return cls(rdclass, rdtype, data)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
|
@ -571,8 +648,12 @@ class GenericRdata(Rdata):
|
|||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
return cls(rdclass, rdtype, parser.get_remaining())
|
||||
|
||||
_rdata_classes = {}
|
||||
_module_prefix = 'dns.rdtypes'
|
||||
|
||||
_rdata_classes: Dict[
|
||||
Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any
|
||||
] = {}
|
||||
_module_prefix = "dns.rdtypes"
|
||||
|
||||
|
||||
def get_rdata_class(rdclass, rdtype):
|
||||
cls = _rdata_classes.get((rdclass, rdtype))
|
||||
|
@ -581,16 +662,16 @@ def get_rdata_class(rdclass, rdtype):
|
|||
if not cls:
|
||||
rdclass_text = dns.rdataclass.to_text(rdclass)
|
||||
rdtype_text = dns.rdatatype.to_text(rdtype)
|
||||
rdtype_text = rdtype_text.replace('-', '_')
|
||||
rdtype_text = rdtype_text.replace("-", "_")
|
||||
try:
|
||||
mod = import_module('.'.join([_module_prefix,
|
||||
rdclass_text, rdtype_text]))
|
||||
mod = import_module(
|
||||
".".join([_module_prefix, rdclass_text, rdtype_text])
|
||||
)
|
||||
cls = getattr(mod, rdtype_text)
|
||||
_rdata_classes[(rdclass, rdtype)] = cls
|
||||
except ImportError:
|
||||
try:
|
||||
mod = import_module('.'.join([_module_prefix,
|
||||
'ANY', rdtype_text]))
|
||||
mod = import_module(".".join([_module_prefix, "ANY", rdtype_text]))
|
||||
cls = getattr(mod, rdtype_text)
|
||||
_rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
|
||||
_rdata_classes[(rdclass, rdtype)] = cls
|
||||
|
@ -602,8 +683,15 @@ def get_rdata_class(rdclass, rdtype):
|
|||
return cls
|
||||
|
||||
|
||||
def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None, idna_codec=None):
|
||||
def from_text(
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
tok: Union[dns.tokenizer.Tokenizer, str],
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
relativize_to: Optional[dns.name.Name] = None,
|
||||
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||
) -> Rdata:
|
||||
"""Build an rdata object from text format.
|
||||
|
||||
This function attempts to dynamically load a class which
|
||||
|
@ -617,9 +705,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
|
|||
If *tok* is a ``str``, then a tokenizer is created and the string
|
||||
is used as its input.
|
||||
|
||||
*rdclass*, an ``int``, the rdataclass.
|
||||
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
|
||||
|
||||
*rdtype*, an ``int``, the rdatatype.
|
||||
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
|
||||
|
||||
*tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
|
||||
|
||||
|
@ -651,17 +739,18 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
|
|||
# peek at first token
|
||||
token = tok.get()
|
||||
tok.unget(token)
|
||||
if token.is_identifier() and \
|
||||
token.value == r'\#':
|
||||
if token.is_identifier() and token.value == r"\#":
|
||||
#
|
||||
# Known type using the generic syntax. Extract the
|
||||
# wire form from the generic syntax, and then run
|
||||
# from_wire on it.
|
||||
#
|
||||
grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
|
||||
relativize, relativize_to)
|
||||
rdata = from_wire(rdclass, rdtype, grdata.data, 0,
|
||||
len(grdata.data), origin)
|
||||
grdata = GenericRdata.from_text(
|
||||
rdclass, rdtype, tok, origin, relativize, relativize_to
|
||||
)
|
||||
rdata = from_wire(
|
||||
rdclass, rdtype, grdata.data, 0, len(grdata.data), origin
|
||||
)
|
||||
#
|
||||
# If this comparison isn't equal, then there must have been
|
||||
# compressed names in the wire format, which is an error,
|
||||
|
@ -669,19 +758,27 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
|
|||
#
|
||||
rwire = rdata.to_wire()
|
||||
if rwire != grdata.data:
|
||||
raise dns.exception.SyntaxError('compressed data in '
|
||||
'generic syntax form '
|
||||
'of known rdatatype')
|
||||
raise dns.exception.SyntaxError(
|
||||
"compressed data in "
|
||||
"generic syntax form "
|
||||
"of known rdatatype"
|
||||
)
|
||||
if rdata is None:
|
||||
rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
|
||||
relativize_to)
|
||||
rdata = cls.from_text(
|
||||
rdclass, rdtype, tok, origin, relativize, relativize_to
|
||||
)
|
||||
token = tok.get_eol_as_token()
|
||||
if token.comment is not None:
|
||||
object.__setattr__(rdata, 'rdcomment', token.comment)
|
||||
object.__setattr__(rdata, "rdcomment", token.comment)
|
||||
return rdata
|
||||
|
||||
|
||||
def from_wire_parser(rdclass, rdtype, parser, origin=None):
|
||||
def from_wire_parser(
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
parser: dns.wire.Parser,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
) -> Rdata:
|
||||
"""Build an rdata object from wire format
|
||||
|
||||
This function attempts to dynamically load a class which
|
||||
|
@ -692,9 +789,9 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
|
|||
Once a class is chosen, its from_wire() class method is called
|
||||
with the parameters to this function.
|
||||
|
||||
*rdclass*, an ``int``, the rdataclass.
|
||||
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
|
||||
|
||||
*rdtype*, an ``int``, the rdatatype.
|
||||
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
|
||||
|
||||
*parser*, a ``dns.wire.Parser``, the parser, which should be
|
||||
restricted to the rdata length.
|
||||
|
@ -712,7 +809,14 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
|
|||
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
|
||||
|
||||
|
||||
def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
|
||||
def from_wire(
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
wire: bytes,
|
||||
current: int,
|
||||
rdlen: int,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
) -> Rdata:
|
||||
"""Build an rdata object from wire format
|
||||
|
||||
This function attempts to dynamically load a class which
|
||||
|
@ -746,13 +850,21 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
|
|||
|
||||
class RdatatypeExists(dns.exception.DNSException):
|
||||
"""DNS rdatatype already exists."""
|
||||
supp_kwargs = {'rdclass', 'rdtype'}
|
||||
fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \
|
||||
"already exists."
|
||||
|
||||
supp_kwargs = {"rdclass", "rdtype"}
|
||||
fmt = (
|
||||
"The rdata type with class {rdclass:d} and rdtype {rdtype:d} "
|
||||
+ "already exists."
|
||||
)
|
||||
|
||||
|
||||
def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
|
||||
rdclass=dns.rdataclass.IN):
|
||||
def register_type(
|
||||
implementation: Any,
|
||||
rdtype: int,
|
||||
rdtype_text: str,
|
||||
is_singleton: bool = False,
|
||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||
) -> None:
|
||||
"""Dynamically register a module to handle an rdatatype.
|
||||
|
||||
*implementation*, a module implementing the type in the usual dnspython
|
||||
|
@ -769,14 +881,16 @@ def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
|
|||
it applies to all classes.
|
||||
"""
|
||||
|
||||
existing_cls = get_rdata_class(rdclass, rdtype)
|
||||
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
existing_cls = get_rdata_class(rdclass, the_rdtype)
|
||||
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
|
||||
try:
|
||||
if dns.rdatatype.RdataType(rdtype).name != rdtype_text:
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
|
||||
if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
|
||||
except ValueError:
|
||||
pass
|
||||
_rdata_classes[(rdclass, rdtype)] = getattr(implementation,
|
||||
rdtype_text.replace('-', '_'))
|
||||
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
|
||||
_rdata_classes[(rdclass, the_rdtype)] = getattr(
|
||||
implementation, rdtype_text.replace("-", "_")
|
||||
)
|
||||
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
from typing import Dict, Tuple, Any, Optional, BinaryIO
|
||||
from .name import Name, IDNACodec
|
||||
class Rdata:
|
||||
def __init__(self):
|
||||
self.address : str
|
||||
def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]:
|
||||
...
|
||||
@classmethod
|
||||
def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True):
|
||||
...
|
||||
_rdata_modules : Dict[Tuple[Any,Rdata],Any]
|
||||
|
||||
def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None,
|
||||
relativize : bool = True, relativize_to : Optional[Name] = None,
|
||||
idna_codec : Optional[IDNACodec] = None):
|
||||
...
|
||||
|
||||
def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None):
|
||||
...
|
|
@ -20,8 +20,10 @@
|
|||
import dns.enum
|
||||
import dns.exception
|
||||
|
||||
|
||||
class RdataClass(dns.enum.IntEnum):
|
||||
"""DNS Rdata Class"""
|
||||
|
||||
RESERVED0 = 0
|
||||
IN = 1
|
||||
INTERNET = IN
|
||||
|
@ -56,7 +58,7 @@ class UnknownRdataclass(dns.exception.DNSException):
|
|||
"""A DNS class is unknown."""
|
||||
|
||||
|
||||
def from_text(text):
|
||||
def from_text(text: str) -> RdataClass:
|
||||
"""Convert text into a DNS rdata class value.
|
||||
|
||||
The input text can be a defined DNS RR class mnemonic or
|
||||
|
@ -68,13 +70,13 @@ def from_text(text):
|
|||
|
||||
Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535.
|
||||
|
||||
Returns an ``int``.
|
||||
Returns a ``dns.rdataclass.RdataClass``.
|
||||
"""
|
||||
|
||||
return RdataClass.from_text(text)
|
||||
|
||||
|
||||
def to_text(value):
|
||||
def to_text(value: RdataClass) -> str:
|
||||
"""Convert a DNS rdata class value to text.
|
||||
|
||||
If the value has a known mnemonic, it will be used, otherwise the
|
||||
|
@ -88,18 +90,19 @@ def to_text(value):
|
|||
return RdataClass.to_text(value)
|
||||
|
||||
|
||||
def is_metaclass(rdclass):
|
||||
def is_metaclass(rdclass: RdataClass) -> bool:
|
||||
"""True if the specified class is a metaclass.
|
||||
|
||||
The currently defined metaclasses are ANY and NONE.
|
||||
|
||||
*rdclass* is an ``int``.
|
||||
*rdclass* is a ``dns.rdataclass.RdataClass``.
|
||||
"""
|
||||
|
||||
if rdclass in _metaclasses:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
### BEGIN generated RdataClass constants
|
||||
|
||||
RESERVED0 = RdataClass.RESERVED0
|
||||
|
|
|
@ -17,16 +17,20 @@
|
|||
|
||||
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
|
||||
|
||||
from typing import Any, cast, Collection, Dict, List, Optional, Union
|
||||
|
||||
import io
|
||||
import random
|
||||
import struct
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdatatype
|
||||
import dns.rdataclass
|
||||
import dns.rdata
|
||||
import dns.set
|
||||
import dns.ttl
|
||||
|
||||
# define SimpleSet here for backwards compatibility
|
||||
SimpleSet = dns.set.Set
|
||||
|
@ -45,24 +49,30 @@ class Rdataset(dns.set.Set):
|
|||
|
||||
"""A DNS rdataset."""
|
||||
|
||||
__slots__ = ['rdclass', 'rdtype', 'covers', 'ttl']
|
||||
__slots__ = ["rdclass", "rdtype", "covers", "ttl"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0):
|
||||
def __init__(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
ttl: int = 0,
|
||||
):
|
||||
"""Create a new rdataset of the specified class and type.
|
||||
|
||||
*rdclass*, an ``int``, the rdataclass.
|
||||
*rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
|
||||
|
||||
*rdtype*, an ``int``, the rdatatype.
|
||||
*rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype.
|
||||
|
||||
*covers*, an ``int``, the covered rdatatype.
|
||||
*covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype.
|
||||
|
||||
*ttl*, an ``int``, the TTL.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.rdclass = rdclass
|
||||
self.rdtype = rdtype
|
||||
self.covers = covers
|
||||
self.rdtype: dns.rdatatype.RdataType = rdtype
|
||||
self.covers: dns.rdatatype.RdataType = covers
|
||||
self.ttl = ttl
|
||||
|
||||
def _clone(self):
|
||||
|
@ -73,7 +83,7 @@ class Rdataset(dns.set.Set):
|
|||
obj.ttl = self.ttl
|
||||
return obj
|
||||
|
||||
def update_ttl(self, ttl):
|
||||
def update_ttl(self, ttl: int) -> None:
|
||||
"""Perform TTL minimization.
|
||||
|
||||
Set the TTL of the rdataset to be the lesser of the set's current
|
||||
|
@ -88,7 +98,9 @@ class Rdataset(dns.set.Set):
|
|||
elif ttl < self.ttl:
|
||||
self.ttl = ttl
|
||||
|
||||
def add(self, rd, ttl=None): # pylint: disable=arguments-differ
|
||||
def add( # pylint: disable=arguments-differ,arguments-renamed
|
||||
self, rd: dns.rdata.Rdata, ttl: Optional[int] = None
|
||||
) -> None:
|
||||
"""Add the specified rdata to the rdataset.
|
||||
|
||||
If the optional *ttl* parameter is supplied, then
|
||||
|
@ -115,8 +127,7 @@ class Rdataset(dns.set.Set):
|
|||
raise IncompatibleTypes
|
||||
if ttl is not None:
|
||||
self.update_ttl(ttl)
|
||||
if self.rdtype == dns.rdatatype.RRSIG or \
|
||||
self.rdtype == dns.rdatatype.SIG:
|
||||
if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG:
|
||||
covers = rd.covers()
|
||||
if len(self) == 0 and self.covers == dns.rdatatype.NONE:
|
||||
self.covers = covers
|
||||
|
@ -147,19 +158,26 @@ class Rdataset(dns.set.Set):
|
|||
def _rdata_repr(self):
|
||||
def maybe_truncate(s):
|
||||
if len(s) > 100:
|
||||
return s[:100] + '...'
|
||||
return s[:100] + "..."
|
||||
return s
|
||||
return '[%s]' % ', '.join('<%s>' % maybe_truncate(str(rr))
|
||||
for rr in self)
|
||||
|
||||
return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
|
||||
|
||||
def __repr__(self):
|
||||
if self.covers == 0:
|
||||
ctext = ''
|
||||
ctext = ""
|
||||
else:
|
||||
ctext = '(' + dns.rdatatype.to_text(self.covers) + ')'
|
||||
return '<DNS ' + dns.rdataclass.to_text(self.rdclass) + ' ' + \
|
||||
dns.rdatatype.to_text(self.rdtype) + ctext + \
|
||||
' rdataset: ' + self._rdata_repr() + '>'
|
||||
ctext = "(" + dns.rdatatype.to_text(self.covers) + ")"
|
||||
return (
|
||||
"<DNS "
|
||||
+ dns.rdataclass.to_text(self.rdclass)
|
||||
+ " "
|
||||
+ dns.rdatatype.to_text(self.rdtype)
|
||||
+ ctext
|
||||
+ " rdataset: "
|
||||
+ self._rdata_repr()
|
||||
+ ">"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_text()
|
||||
|
@ -167,17 +185,26 @@ class Rdataset(dns.set.Set):
|
|||
def __eq__(self, other):
|
||||
if not isinstance(other, Rdataset):
|
||||
return False
|
||||
if self.rdclass != other.rdclass or \
|
||||
self.rdtype != other.rdtype or \
|
||||
self.covers != other.covers:
|
||||
if (
|
||||
self.rdclass != other.rdclass
|
||||
or self.rdtype != other.rdtype
|
||||
or self.covers != other.covers
|
||||
):
|
||||
return False
|
||||
return super().__eq__(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def to_text(self, name=None, origin=None, relativize=True,
|
||||
override_rdclass=None, want_comments=False, **kw):
|
||||
def to_text(
|
||||
self,
|
||||
name: Optional[dns.name.Name] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
|
||||
want_comments: bool = False,
|
||||
**kw: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Convert the rdataset into DNS zone file format.
|
||||
|
||||
See ``dns.name.Name.choose_relativity`` for more information
|
||||
|
@ -206,10 +233,10 @@ class Rdataset(dns.set.Set):
|
|||
if name is not None:
|
||||
name = name.choose_relativity(origin, relativize)
|
||||
ntext = str(name)
|
||||
pad = ' '
|
||||
pad = " "
|
||||
else:
|
||||
ntext = ''
|
||||
pad = ''
|
||||
ntext = ""
|
||||
pad = ""
|
||||
s = io.StringIO()
|
||||
if override_rdclass is not None:
|
||||
rdclass = override_rdclass
|
||||
|
@ -221,28 +248,46 @@ class Rdataset(dns.set.Set):
|
|||
# some dynamic updates, so we don't need to print out the TTL
|
||||
# (which is meaningless anyway).
|
||||
#
|
||||
s.write('{}{}{} {}\n'.format(ntext, pad,
|
||||
s.write(
|
||||
"{}{}{} {}\n".format(
|
||||
ntext,
|
||||
pad,
|
||||
dns.rdataclass.to_text(rdclass),
|
||||
dns.rdatatype.to_text(self.rdtype)))
|
||||
dns.rdatatype.to_text(self.rdtype),
|
||||
)
|
||||
)
|
||||
else:
|
||||
for rd in self:
|
||||
extra = ''
|
||||
extra = ""
|
||||
if want_comments:
|
||||
if rd.rdcomment:
|
||||
extra = f' ;{rd.rdcomment}'
|
||||
s.write('%s%s%d %s %s %s%s\n' %
|
||||
(ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass),
|
||||
extra = f" ;{rd.rdcomment}"
|
||||
s.write(
|
||||
"%s%s%d %s %s %s%s\n"
|
||||
% (
|
||||
ntext,
|
||||
pad,
|
||||
self.ttl,
|
||||
dns.rdataclass.to_text(rdclass),
|
||||
dns.rdatatype.to_text(self.rdtype),
|
||||
rd.to_text(origin=origin, relativize=relativize,
|
||||
**kw),
|
||||
extra))
|
||||
rd.to_text(origin=origin, relativize=relativize, **kw),
|
||||
extra,
|
||||
)
|
||||
)
|
||||
#
|
||||
# We strip off the final \n for the caller's convenience in printing
|
||||
#
|
||||
return s.getvalue()[:-1]
|
||||
|
||||
def to_wire(self, name, file, compress=None, origin=None,
|
||||
override_rdclass=None, want_shuffle=True):
|
||||
def to_wire(
|
||||
self,
|
||||
name: dns.name.Name,
|
||||
file: Any,
|
||||
compress: Optional[dns.name.CompressType] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
|
||||
want_shuffle: bool = True,
|
||||
) -> int:
|
||||
"""Convert the rdataset to wire format.
|
||||
|
||||
*name*, a ``dns.name.Name`` is the owner name to use.
|
||||
|
@ -279,6 +324,7 @@ class Rdataset(dns.set.Set):
|
|||
file.write(stuff)
|
||||
return 1
|
||||
else:
|
||||
l: Union[Rdataset, List[dns.rdata.Rdata]]
|
||||
if want_shuffle:
|
||||
l = list(self)
|
||||
random.shuffle(l)
|
||||
|
@ -286,8 +332,7 @@ class Rdataset(dns.set.Set):
|
|||
l = self
|
||||
for rd in l:
|
||||
name.to_wire(file, compress, origin)
|
||||
stuff = struct.pack("!HHIH", self.rdtype, rdclass,
|
||||
self.ttl, 0)
|
||||
stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
|
||||
file.write(stuff)
|
||||
start = file.tell()
|
||||
rd.to_wire(file, compress, origin)
|
||||
|
@ -299,17 +344,20 @@ class Rdataset(dns.set.Set):
|
|||
file.seek(0, io.SEEK_END)
|
||||
return len(self)
|
||||
|
||||
def match(self, rdclass, rdtype, covers):
|
||||
def match(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType,
|
||||
) -> bool:
|
||||
"""Returns ``True`` if this rdataset matches the specified class,
|
||||
type, and covers.
|
||||
"""
|
||||
if self.rdclass == rdclass and \
|
||||
self.rdtype == rdtype and \
|
||||
self.covers == covers:
|
||||
if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def processing_order(self):
|
||||
def processing_order(self) -> List[dns.rdata.Rdata]:
|
||||
"""Return rdatas in a valid processing order according to the type's
|
||||
specification. For example, MX records are in preference order from
|
||||
lowest to highest preferences, with items of the same preference
|
||||
|
@ -325,51 +373,56 @@ class Rdataset(dns.set.Set):
|
|||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class ImmutableRdataset(Rdataset):
|
||||
class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
|
||||
|
||||
"""An immutable DNS rdataset."""
|
||||
|
||||
_clone_class = Rdataset
|
||||
|
||||
def __init__(self, rdataset):
|
||||
def __init__(self, rdataset: Rdataset):
|
||||
"""Create an immutable rdataset from the specified rdataset."""
|
||||
|
||||
super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
|
||||
rdataset.ttl)
|
||||
super().__init__(
|
||||
rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl
|
||||
)
|
||||
self.items = dns.immutable.Dict(rdataset.items)
|
||||
|
||||
def update_ttl(self, ttl):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def add(self, rd, ttl=None):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def union_update(self, other):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def intersection_update(self, other):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def update(self, other):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def __delitem__(self, i):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def __ior__(self, other):
|
||||
raise TypeError('immutable')
|
||||
# lgtm complains about these not raising ArithmeticError, but there is
|
||||
# precedent for overrides of these methods in other classes to raise
|
||||
# TypeError, and it seems like the better exception.
|
||||
|
||||
def __iand__(self, other):
|
||||
raise TypeError('immutable')
|
||||
def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method]
|
||||
raise TypeError("immutable")
|
||||
|
||||
def __iadd__(self, other):
|
||||
raise TypeError('immutable')
|
||||
def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method]
|
||||
raise TypeError("immutable")
|
||||
|
||||
def __isub__(self, other):
|
||||
raise TypeError('immutable')
|
||||
def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method]
|
||||
raise TypeError("immutable")
|
||||
|
||||
def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method]
|
||||
raise TypeError("immutable")
|
||||
|
||||
def clear(self):
|
||||
raise TypeError('immutable')
|
||||
raise TypeError("immutable")
|
||||
|
||||
def __copy__(self):
|
||||
return ImmutableRdataset(super().copy())
|
||||
|
@ -386,9 +439,20 @@ class ImmutableRdataset(Rdataset):
|
|||
def difference(self, other):
|
||||
return ImmutableRdataset(super().difference(other))
|
||||
|
||||
def symmetric_difference(self, other):
|
||||
return ImmutableRdataset(super().symmetric_difference(other))
|
||||
|
||||
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
|
||||
origin=None, relativize=True, relativize_to=None):
|
||||
|
||||
def from_text_list(
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
ttl: int,
|
||||
text_rdatas: Collection[str],
|
||||
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
relativize_to: Optional[dns.name.Name] = None,
|
||||
) -> Rdataset:
|
||||
"""Create an rdataset with the specified class, type, and TTL, and with
|
||||
the specified list of rdatas in text format.
|
||||
|
||||
|
@ -407,28 +471,34 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
|
|||
Returns a ``dns.rdataset.Rdataset`` object.
|
||||
"""
|
||||
|
||||
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
r = Rdataset(rdclass, rdtype)
|
||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
r = Rdataset(the_rdclass, the_rdtype)
|
||||
r.update_ttl(ttl)
|
||||
for t in text_rdatas:
|
||||
rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
|
||||
relativize_to, idna_codec)
|
||||
rd = dns.rdata.from_text(
|
||||
r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec
|
||||
)
|
||||
r.add(rd)
|
||||
return r
|
||||
|
||||
|
||||
def from_text(rdclass, rdtype, ttl, *text_rdatas):
|
||||
def from_text(
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
ttl: int,
|
||||
*text_rdatas: Any,
|
||||
) -> Rdataset:
|
||||
"""Create an rdataset with the specified class, type, and TTL, and with
|
||||
the specified rdatas in text format.
|
||||
|
||||
Returns a ``dns.rdataset.Rdataset`` object.
|
||||
"""
|
||||
|
||||
return from_text_list(rdclass, rdtype, ttl, text_rdatas)
|
||||
return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
|
||||
|
||||
|
||||
def from_rdata_list(ttl, rdatas):
|
||||
def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
|
||||
"""Create an rdataset with the specified TTL, and with
|
||||
the specified list of rdata objects.
|
||||
|
||||
|
@ -443,14 +513,15 @@ def from_rdata_list(ttl, rdatas):
|
|||
r = Rdataset(rd.rdclass, rd.rdtype)
|
||||
r.update_ttl(ttl)
|
||||
r.add(rd)
|
||||
assert r is not None
|
||||
return r
|
||||
|
||||
|
||||
def from_rdata(ttl, *rdatas):
|
||||
def from_rdata(ttl: int, *rdatas: Any) -> Rdataset:
|
||||
"""Create an rdataset with the specified TTL, and with
|
||||
the specified rdata objects.
|
||||
|
||||
Returns a ``dns.rdataset.Rdataset`` object.
|
||||
"""
|
||||
|
||||
return from_rdata_list(ttl, rdatas)
|
||||
return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
from typing import Optional, Dict, List, Union
|
||||
from io import BytesIO
|
||||
from . import exception, name, set, rdatatype, rdata, rdataset
|
||||
|
||||
class DifferingCovers(exception.DNSException):
|
||||
"""An attempt was made to add a DNS SIG/RRSIG whose covered type
|
||||
is not the same as that of the other rdatas in the rdataset."""
|
||||
|
||||
|
||||
class IncompatibleTypes(exception.DNSException):
|
||||
"""An attempt was made to add DNS RR data of an incompatible type."""
|
||||
|
||||
|
||||
class Rdataset(set.Set):
|
||||
def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0):
|
||||
self.rdclass : int = rdclass
|
||||
self.rdtype : int = rdtype
|
||||
self.covers : int = covers
|
||||
self.ttl : int = ttl
|
||||
|
||||
def update_ttl(self, ttl : int) -> None:
|
||||
...
|
||||
|
||||
def add(self, rd : rdata.Rdata, ttl : Optional[int] =None):
|
||||
...
|
||||
|
||||
def union_update(self, other : Rdataset):
|
||||
...
|
||||
|
||||
def intersection_update(self, other : Rdataset):
|
||||
...
|
||||
|
||||
def update(self, other : Rdataset):
|
||||
...
|
||||
|
||||
def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True,
|
||||
override_rdclass : Optional[int] =None, **kw) -> bytes:
|
||||
...
|
||||
|
||||
def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None,
|
||||
override_rdclass : Optional[int] = None, want_shuffle=True) -> int:
|
||||
...
|
||||
|
||||
def match(self, rdclass : int, rdtype : int, covers : int) -> bool:
|
||||
...
|
||||
|
||||
|
||||
def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset:
|
||||
...
|
||||
|
||||
def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset:
|
||||
...
|
||||
|
||||
def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
|
||||
...
|
||||
|
||||
def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
|
||||
...
|
|
@ -17,11 +17,15 @@
|
|||
|
||||
"""DNS Rdata Types."""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import dns.enum
|
||||
import dns.exception
|
||||
|
||||
|
||||
class RdataType(dns.enum.IntEnum):
|
||||
"""DNS Rdata Type"""
|
||||
|
||||
TYPE0 = 0
|
||||
NONE = 0
|
||||
A = 1
|
||||
|
@ -116,24 +120,47 @@ class RdataType(dns.enum.IntEnum):
|
|||
def _prefix(cls):
|
||||
return "TYPE"
|
||||
|
||||
@classmethod
|
||||
def _extra_from_text(cls, text):
|
||||
if text.find("-") >= 0:
|
||||
try:
|
||||
return cls[text.replace("-", "_")]
|
||||
except KeyError:
|
||||
pass
|
||||
return _registered_by_text.get(text)
|
||||
|
||||
@classmethod
|
||||
def _extra_to_text(cls, value, current_text):
|
||||
if current_text is None:
|
||||
return _registered_by_value.get(value)
|
||||
if current_text.find("_") >= 0:
|
||||
return current_text.replace("_", "-")
|
||||
return current_text
|
||||
|
||||
@classmethod
|
||||
def _unknown_exception_class(cls):
|
||||
return UnknownRdatatype
|
||||
|
||||
_registered_by_text = {}
|
||||
_registered_by_value = {}
|
||||
|
||||
_registered_by_text: Dict[str, RdataType] = {}
|
||||
_registered_by_value: Dict[RdataType, str] = {}
|
||||
|
||||
_metatypes = {RdataType.OPT}
|
||||
|
||||
_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME,
|
||||
RdataType.NSEC, RdataType.CNAME}
|
||||
_singletons = {
|
||||
RdataType.SOA,
|
||||
RdataType.NXT,
|
||||
RdataType.DNAME,
|
||||
RdataType.NSEC,
|
||||
RdataType.CNAME,
|
||||
}
|
||||
|
||||
|
||||
class UnknownRdatatype(dns.exception.DNSException):
|
||||
"""DNS resource record type is unknown."""
|
||||
|
||||
|
||||
def from_text(text):
|
||||
def from_text(text: str) -> RdataType:
|
||||
"""Convert text into a DNS rdata type value.
|
||||
|
||||
The input text can be a defined DNS RR type mnemonic or
|
||||
|
@ -145,20 +172,13 @@ def from_text(text):
|
|||
|
||||
Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535.
|
||||
|
||||
Returns an ``int``.
|
||||
Returns a ``dns.rdatatype.RdataType``.
|
||||
"""
|
||||
|
||||
text = text.upper().replace('-', '_')
|
||||
try:
|
||||
return RdataType.from_text(text)
|
||||
except UnknownRdatatype:
|
||||
registered_type = _registered_by_text.get(text)
|
||||
if registered_type:
|
||||
return registered_type
|
||||
raise
|
||||
|
||||
|
||||
def to_text(value):
|
||||
def to_text(value: RdataType) -> str:
|
||||
"""Convert a DNS rdata type value to text.
|
||||
|
||||
If the value has a known mnemonic, it will be used, otherwise the
|
||||
|
@ -169,18 +189,13 @@ def to_text(value):
|
|||
Returns a ``str``.
|
||||
"""
|
||||
|
||||
text = RdataType.to_text(value)
|
||||
if text.startswith("TYPE"):
|
||||
registered_text = _registered_by_value.get(value)
|
||||
if registered_text:
|
||||
text = registered_text
|
||||
return text.replace('_', '-')
|
||||
return RdataType.to_text(value)
|
||||
|
||||
|
||||
def is_metatype(rdtype):
|
||||
def is_metatype(rdtype: RdataType) -> bool:
|
||||
"""True if the specified type is a metatype.
|
||||
|
||||
*rdtype* is an ``int``.
|
||||
*rdtype* is a ``dns.rdatatype.RdataType``.
|
||||
|
||||
The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA,
|
||||
MAILB, ANY, and OPT.
|
||||
|
@ -191,7 +206,7 @@ def is_metatype(rdtype):
|
|||
return (256 > rdtype >= 128) or rdtype in _metatypes
|
||||
|
||||
|
||||
def is_singleton(rdtype):
|
||||
def is_singleton(rdtype: RdataType) -> bool:
|
||||
"""Is the specified type a singleton type?
|
||||
|
||||
Singleton types can only have a single rdata in an rdataset, or a single
|
||||
|
@ -209,11 +224,14 @@ def is_singleton(rdtype):
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def register_type(rdtype, rdtype_text, is_singleton=False):
|
||||
def register_type(
|
||||
rdtype: RdataType, rdtype_text: str, is_singleton: bool = False
|
||||
) -> None:
|
||||
"""Dynamically register an rdatatype.
|
||||
|
||||
*rdtype*, an ``int``, the rdatatype to register.
|
||||
*rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register.
|
||||
|
||||
*rdtype_text*, a ``str``, the textual form of the rdatatype.
|
||||
|
||||
|
@ -226,6 +244,7 @@ def register_type(rdtype, rdtype_text, is_singleton=False):
|
|||
if is_singleton:
|
||||
_singletons.add(rdtype)
|
||||
|
||||
|
||||
### BEGIN generated RdataType constants
|
||||
|
||||
TYPE0 = RdataType.TYPE0
|
||||
|
|
|
@ -23,7 +23,7 @@ import dns.rdtypes.util
|
|||
|
||||
|
||||
class Relay(dns.rdtypes.util.Gateway):
|
||||
name = 'AMTRELAY relay'
|
||||
name = "AMTRELAY relay"
|
||||
|
||||
@property
|
||||
def relay(self):
|
||||
|
@ -37,10 +37,11 @@ class AMTRELAY(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 8777
|
||||
|
||||
__slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay']
|
||||
__slots__ = ["precedence", "discovery_optional", "relay_type", "relay"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, precedence, discovery_optional,
|
||||
relay_type, relay):
|
||||
def __init__(
|
||||
self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay
|
||||
):
|
||||
super().__init__(rdclass, rdtype)
|
||||
relay = Relay(relay_type, relay)
|
||||
self.precedence = self._as_uint8(precedence)
|
||||
|
@ -50,37 +51,42 @@ class AMTRELAY(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
|
||||
return '%d %d %d %s' % (self.precedence, self.discovery_optional,
|
||||
self.relay_type, relay)
|
||||
return "%d %d %d %s" % (
|
||||
self.precedence,
|
||||
self.discovery_optional,
|
||||
self.relay_type,
|
||||
relay,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
precedence = tok.get_uint8()
|
||||
discovery_optional = tok.get_uint8()
|
||||
if discovery_optional > 1:
|
||||
raise dns.exception.SyntaxError('expecting 0 or 1')
|
||||
raise dns.exception.SyntaxError("expecting 0 or 1")
|
||||
discovery_optional = bool(discovery_optional)
|
||||
relay_type = tok.get_uint8()
|
||||
if relay_type > 0x7f:
|
||||
raise dns.exception.SyntaxError('expecting an integer <= 127')
|
||||
relay = Relay.from_text(relay_type, tok, origin, relativize,
|
||||
relativize_to)
|
||||
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
|
||||
relay.relay)
|
||||
if relay_type > 0x7F:
|
||||
raise dns.exception.SyntaxError("expecting an integer <= 127")
|
||||
relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to)
|
||||
return cls(
|
||||
rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
|
||||
)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
relay_type = self.relay_type | (self.discovery_optional << 7)
|
||||
header = struct.pack("!BB", self.precedence, relay_type)
|
||||
file.write(header)
|
||||
Relay(self.relay_type, self.relay).to_wire(file, compress, origin,
|
||||
canonicalize)
|
||||
Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
(precedence, relay_type) = parser.get_struct('!BB')
|
||||
(precedence, relay_type) = parser.get_struct("!BB")
|
||||
discovery_optional = bool(relay_type >> 7)
|
||||
relay_type &= 0x7f
|
||||
relay_type &= 0x7F
|
||||
relay = Relay.from_wire_parser(relay_type, parser, origin)
|
||||
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
|
||||
relay.relay)
|
||||
return cls(
|
||||
rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
|
||||
)
|
||||
|
|
|
@ -30,7 +30,7 @@ class CAA(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 6844
|
||||
|
||||
__slots__ = ['flags', 'tag', 'value']
|
||||
__slots__ = ["flags", "tag", "value"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, flags, tag, value):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -41,23 +41,26 @@ class CAA(dns.rdata.Rdata):
|
|||
self.value = self._as_bytes(value)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return '%u %s "%s"' % (self.flags,
|
||||
return '%u %s "%s"' % (
|
||||
self.flags,
|
||||
dns.rdata._escapify(self.tag),
|
||||
dns.rdata._escapify(self.value))
|
||||
dns.rdata._escapify(self.value),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
flags = tok.get_uint8()
|
||||
tag = tok.get_string().encode()
|
||||
value = tok.get_string().encode()
|
||||
return cls(rdclass, rdtype, flags, tag, value)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
file.write(struct.pack('!B', self.flags))
|
||||
file.write(struct.pack("!B", self.flags))
|
||||
l = len(self.tag)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.tag)
|
||||
file.write(self.value)
|
||||
|
||||
|
|
|
@ -15,13 +15,19 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dnskeybase
|
||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||
import dns.immutable
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401
|
||||
from dns.rdtypes.dnskeybase import (
|
||||
SEP,
|
||||
REVOKE,
|
||||
ZONE,
|
||||
) # noqa: F401 lgtm[py/unused-import]
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
|
||||
|
||||
|
|
|
@ -20,34 +20,34 @@ import base64
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.dnssec
|
||||
import dns.dnssectypes
|
||||
import dns.rdata
|
||||
import dns.tokenizer
|
||||
|
||||
_ctype_by_value = {
|
||||
1: 'PKIX',
|
||||
2: 'SPKI',
|
||||
3: 'PGP',
|
||||
4: 'IPKIX',
|
||||
5: 'ISPKI',
|
||||
6: 'IPGP',
|
||||
7: 'ACPKIX',
|
||||
8: 'IACPKIX',
|
||||
253: 'URI',
|
||||
254: 'OID',
|
||||
1: "PKIX",
|
||||
2: "SPKI",
|
||||
3: "PGP",
|
||||
4: "IPKIX",
|
||||
5: "ISPKI",
|
||||
6: "IPGP",
|
||||
7: "ACPKIX",
|
||||
8: "IACPKIX",
|
||||
253: "URI",
|
||||
254: "OID",
|
||||
}
|
||||
|
||||
_ctype_by_name = {
|
||||
'PKIX': 1,
|
||||
'SPKI': 2,
|
||||
'PGP': 3,
|
||||
'IPKIX': 4,
|
||||
'ISPKI': 5,
|
||||
'IPGP': 6,
|
||||
'ACPKIX': 7,
|
||||
'IACPKIX': 8,
|
||||
'URI': 253,
|
||||
'OID': 254,
|
||||
"PKIX": 1,
|
||||
"SPKI": 2,
|
||||
"PGP": 3,
|
||||
"IPKIX": 4,
|
||||
"ISPKI": 5,
|
||||
"IPGP": 6,
|
||||
"ACPKIX": 7,
|
||||
"IACPKIX": 8,
|
||||
"URI": 253,
|
||||
"OID": 254,
|
||||
}
|
||||
|
||||
|
||||
|
@ -72,10 +72,11 @@ class CERT(dns.rdata.Rdata):
|
|||
|
||||
# see RFC 4398
|
||||
|
||||
__slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate']
|
||||
__slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm,
|
||||
certificate):
|
||||
def __init__(
|
||||
self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate
|
||||
):
|
||||
super().__init__(rdclass, rdtype)
|
||||
self.certificate_type = self._as_uint16(certificate_type)
|
||||
self.key_tag = self._as_uint16(key_tag)
|
||||
|
@ -84,24 +85,28 @@ class CERT(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
certificate_type = _ctype_to_text(self.certificate_type)
|
||||
return "%s %d %s %s" % (certificate_type, self.key_tag,
|
||||
dns.dnssec.algorithm_to_text(self.algorithm),
|
||||
dns.rdata._base64ify(self.certificate, **kw))
|
||||
return "%s %d %s %s" % (
|
||||
certificate_type,
|
||||
self.key_tag,
|
||||
dns.dnssectypes.Algorithm.to_text(self.algorithm),
|
||||
dns.rdata._base64ify(self.certificate, **kw),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
certificate_type = _ctype_from_text(tok.get_string())
|
||||
key_tag = tok.get_uint16()
|
||||
algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
|
||||
algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
|
||||
b64 = tok.concatenate_remaining_identifiers().encode()
|
||||
certificate = base64.b64decode(b64)
|
||||
return cls(rdclass, rdtype, certificate_type, key_tag,
|
||||
algorithm, certificate)
|
||||
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
prefix = struct.pack("!HHB", self.certificate_type, self.key_tag,
|
||||
self.algorithm)
|
||||
prefix = struct.pack(
|
||||
"!HHB", self.certificate_type, self.key_tag, self.algorithm
|
||||
)
|
||||
file.write(prefix)
|
||||
file.write(self.certificate)
|
||||
|
||||
|
@ -109,5 +114,4 @@ class CERT(dns.rdata.Rdata):
|
|||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
(certificate_type, key_tag, algorithm) = parser.get_struct("!HHB")
|
||||
certificate = parser.get_remaining()
|
||||
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm,
|
||||
certificate)
|
||||
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
|
||||
|
|
|
@ -27,7 +27,7 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class Bitmap(dns.rdtypes.util.Bitmap):
|
||||
type_name = 'CSYNC'
|
||||
type_name = "CSYNC"
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
@ -35,7 +35,7 @@ class CSYNC(dns.rdata.Rdata):
|
|||
|
||||
"""CSYNC record"""
|
||||
|
||||
__slots__ = ['serial', 'flags', 'windows']
|
||||
__slots__ = ["serial", "flags", "windows"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, serial, flags, windows):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -47,18 +47,19 @@ class CSYNC(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
text = Bitmap(self.windows).to_text()
|
||||
return '%d %d%s' % (self.serial, self.flags, text)
|
||||
return "%d %d%s" % (self.serial, self.flags, text)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
serial = tok.get_uint32()
|
||||
flags = tok.get_uint16()
|
||||
bitmap = Bitmap.from_text(tok)
|
||||
return cls(rdclass, rdtype, serial, flags, bitmap)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
file.write(struct.pack('!IH', self.serial, self.flags))
|
||||
file.write(struct.pack("!IH", self.serial, self.flags))
|
||||
Bitmap(self.windows).to_wire(file)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -15,13 +15,19 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dnskeybase
|
||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||
import dns.immutable
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401
|
||||
from dns.rdtypes.dnskeybase import (
|
||||
SEP,
|
||||
REVOKE,
|
||||
ZONE,
|
||||
) # noqa: F401 lgtm[py/unused-import]
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
|
||||
|
||||
|
|
|
@ -26,19 +26,19 @@ import dns.tokenizer
|
|||
def _validate_float_string(what):
|
||||
if len(what) == 0:
|
||||
raise dns.exception.FormError
|
||||
if what[0] == b'-'[0] or what[0] == b'+'[0]:
|
||||
if what[0] == b"-"[0] or what[0] == b"+"[0]:
|
||||
what = what[1:]
|
||||
if what.isdigit():
|
||||
return
|
||||
try:
|
||||
(left, right) = what.split(b'.')
|
||||
(left, right) = what.split(b".")
|
||||
except ValueError:
|
||||
raise dns.exception.FormError
|
||||
if left == b'' and right == b'':
|
||||
if left == b"" and right == b"":
|
||||
raise dns.exception.FormError
|
||||
if not left == b'' and not left.decode().isdigit():
|
||||
if not left == b"" and not left.decode().isdigit():
|
||||
raise dns.exception.FormError
|
||||
if not right == b'' and not right.decode().isdigit():
|
||||
if not right == b"" and not right.decode().isdigit():
|
||||
raise dns.exception.FormError
|
||||
|
||||
|
||||
|
@ -49,18 +49,15 @@ class GPOS(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 1712
|
||||
|
||||
__slots__ = ['latitude', 'longitude', 'altitude']
|
||||
__slots__ = ["latitude", "longitude", "altitude"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, latitude, longitude, altitude):
|
||||
super().__init__(rdclass, rdtype)
|
||||
if isinstance(latitude, float) or \
|
||||
isinstance(latitude, int):
|
||||
if isinstance(latitude, float) or isinstance(latitude, int):
|
||||
latitude = str(latitude)
|
||||
if isinstance(longitude, float) or \
|
||||
isinstance(longitude, int):
|
||||
if isinstance(longitude, float) or isinstance(longitude, int):
|
||||
longitude = str(longitude)
|
||||
if isinstance(altitude, float) or \
|
||||
isinstance(altitude, int):
|
||||
if isinstance(altitude, float) or isinstance(altitude, int):
|
||||
altitude = str(altitude)
|
||||
latitude = self._as_bytes(latitude, True, 255)
|
||||
longitude = self._as_bytes(longitude, True, 255)
|
||||
|
@ -73,19 +70,20 @@ class GPOS(dns.rdata.Rdata):
|
|||
self.altitude = altitude
|
||||
flat = self.float_latitude
|
||||
if flat < -90.0 or flat > 90.0:
|
||||
raise dns.exception.FormError('bad latitude')
|
||||
raise dns.exception.FormError("bad latitude")
|
||||
flong = self.float_longitude
|
||||
if flong < -180.0 or flong > 180.0:
|
||||
raise dns.exception.FormError('bad longitude')
|
||||
raise dns.exception.FormError("bad longitude")
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return '{} {} {}'.format(self.latitude.decode(),
|
||||
self.longitude.decode(),
|
||||
self.altitude.decode())
|
||||
return "{} {} {}".format(
|
||||
self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
latitude = tok.get_string()
|
||||
longitude = tok.get_string()
|
||||
altitude = tok.get_string()
|
||||
|
@ -94,15 +92,15 @@ class GPOS(dns.rdata.Rdata):
|
|||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
l = len(self.latitude)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.latitude)
|
||||
l = len(self.longitude)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.longitude)
|
||||
l = len(self.altitude)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.altitude)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -30,7 +30,7 @@ class HINFO(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 1035
|
||||
|
||||
__slots__ = ['cpu', 'os']
|
||||
__slots__ = ["cpu", "os"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, cpu, os):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -38,12 +38,14 @@ class HINFO(dns.rdata.Rdata):
|
|||
self.os = self._as_bytes(os, True, 255)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu),
|
||||
dns.rdata._escapify(self.os))
|
||||
return '"{}" "{}"'.format(
|
||||
dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
cpu = tok.get_string(max_length=255)
|
||||
os = tok.get_string(max_length=255)
|
||||
return cls(rdclass, rdtype, cpu, os)
|
||||
|
@ -51,11 +53,11 @@ class HINFO(dns.rdata.Rdata):
|
|||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
l = len(self.cpu)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.cpu)
|
||||
l = len(self.os)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.os)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -32,7 +32,7 @@ class HIP(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 5205
|
||||
|
||||
__slots__ = ['hit', 'algorithm', 'key', 'servers']
|
||||
__slots__ = ["hit", "algorithm", "key", "servers"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -43,18 +43,19 @@ class HIP(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
hit = binascii.hexlify(self.hit).decode()
|
||||
key = base64.b64encode(self.key).replace(b'\n', b'').decode()
|
||||
text = ''
|
||||
key = base64.b64encode(self.key).replace(b"\n", b"").decode()
|
||||
text = ""
|
||||
servers = []
|
||||
for server in self.servers:
|
||||
servers.append(server.choose_relativity(origin, relativize))
|
||||
if len(servers) > 0:
|
||||
text += (' ' + ' '.join((x.to_unicode() for x in servers)))
|
||||
return '%u %s %s%s' % (self.algorithm, hit, key, text)
|
||||
text += " " + " ".join((x.to_unicode() for x in servers))
|
||||
return "%u %s %s%s" % (self.algorithm, hit, key, text)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
algorithm = tok.get_uint8()
|
||||
hit = binascii.unhexlify(tok.get_string().encode())
|
||||
key = base64.b64decode(tok.get_string().encode())
|
||||
|
@ -75,7 +76,7 @@ class HIP(dns.rdata.Rdata):
|
|||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
(lh, algorithm, lk) = parser.get_struct('!BBH')
|
||||
(lh, algorithm, lk) = parser.get_struct("!BBH")
|
||||
hit = parser.get_bytes(lh)
|
||||
key = parser.get_bytes(lk)
|
||||
servers = []
|
||||
|
|
|
@ -30,7 +30,7 @@ class ISDN(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 1183
|
||||
|
||||
__slots__ = ['address', 'subaddress']
|
||||
__slots__ = ["address", "subaddress"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, address, subaddress):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -39,31 +39,33 @@ class ISDN(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
if self.subaddress:
|
||||
return '"{}" "{}"'.format(dns.rdata._escapify(self.address),
|
||||
dns.rdata._escapify(self.subaddress))
|
||||
return '"{}" "{}"'.format(
|
||||
dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
|
||||
)
|
||||
else:
|
||||
return '"%s"' % dns.rdata._escapify(self.address)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
address = tok.get_string()
|
||||
tokens = tok.get_remaining(max_tokens=1)
|
||||
if len(tokens) >= 1:
|
||||
subaddress = tokens[0].unescape().value
|
||||
else:
|
||||
subaddress = ''
|
||||
subaddress = ""
|
||||
return cls(rdclass, rdtype, address, subaddress)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
l = len(self.address)
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.address)
|
||||
l = len(self.subaddress)
|
||||
if l > 0:
|
||||
assert l < 256
|
||||
file.write(struct.pack('!B', l))
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(self.subaddress)
|
||||
|
||||
@classmethod
|
||||
|
@ -72,5 +74,5 @@ class ISDN(dns.rdata.Rdata):
|
|||
if parser.remaining() > 0:
|
||||
subaddress = parser.get_counted_bytes()
|
||||
else:
|
||||
subaddress = b''
|
||||
subaddress = b""
|
||||
return cls(rdclass, rdtype, address, subaddress)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import struct
|
||||
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
@ -12,7 +13,7 @@ class L32(dns.rdata.Rdata):
|
|||
|
||||
# see: rfc6742.txt
|
||||
|
||||
__slots__ = ['preference', 'locator32']
|
||||
__slots__ = ["preference", "locator32"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, preference, locator32):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -20,17 +21,18 @@ class L32(dns.rdata.Rdata):
|
|||
self.locator32 = self._as_ipv4_address(locator32)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return f'{self.preference} {self.locator32}'
|
||||
return f"{self.preference} {self.locator32}"
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
preference = tok.get_uint16()
|
||||
nodeid = tok.get_identifier()
|
||||
return cls(rdclass, rdtype, preference, nodeid)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
file.write(struct.pack('!H', self.preference))
|
||||
file.write(struct.pack("!H", self.preference))
|
||||
file.write(dns.ipv4.inet_aton(self.locator32))
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -13,33 +13,33 @@ class L64(dns.rdata.Rdata):
|
|||
|
||||
# see: rfc6742.txt
|
||||
|
||||
__slots__ = ['preference', 'locator64']
|
||||
__slots__ = ["preference", "locator64"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, preference, locator64):
|
||||
super().__init__(rdclass, rdtype)
|
||||
self.preference = self._as_uint16(preference)
|
||||
if isinstance(locator64, bytes):
|
||||
if len(locator64) != 8:
|
||||
raise ValueError('invalid locator64')
|
||||
self.locator64 = dns.rdata._hexify(locator64, 4, b':')
|
||||
raise ValueError("invalid locator64")
|
||||
self.locator64 = dns.rdata._hexify(locator64, 4, b":")
|
||||
else:
|
||||
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':')
|
||||
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":")
|
||||
self.locator64 = locator64
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return f'{self.preference} {self.locator64}'
|
||||
return f"{self.preference} {self.locator64}"
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
preference = tok.get_uint16()
|
||||
locator64 = tok.get_identifier()
|
||||
return cls(rdclass, rdtype, preference, locator64)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
file.write(struct.pack('!H', self.preference))
|
||||
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64,
|
||||
4, 4, ':'))
|
||||
file.write(struct.pack("!H", self.preference))
|
||||
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":"))
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
|
|
|
@ -93,15 +93,15 @@ def _decode_size(what, desc):
|
|||
|
||||
def _check_coordinate_list(value, low, high):
|
||||
if value[0] < low or value[0] > high:
|
||||
raise ValueError(f'not in range [{low}, {high}]')
|
||||
raise ValueError(f"not in range [{low}, {high}]")
|
||||
if value[1] < 0 or value[1] > 59:
|
||||
raise ValueError('bad minutes value')
|
||||
raise ValueError("bad minutes value")
|
||||
if value[2] < 0 or value[2] > 59:
|
||||
raise ValueError('bad seconds value')
|
||||
raise ValueError("bad seconds value")
|
||||
if value[3] < 0 or value[3] > 999:
|
||||
raise ValueError('bad milliseconds value')
|
||||
raise ValueError("bad milliseconds value")
|
||||
if value[4] != 1 and value[4] != -1:
|
||||
raise ValueError('bad hemisphere value')
|
||||
raise ValueError("bad hemisphere value")
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
@ -111,12 +111,26 @@ class LOC(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 1876
|
||||
|
||||
__slots__ = ['latitude', 'longitude', 'altitude', 'size',
|
||||
'horizontal_precision', 'vertical_precision']
|
||||
__slots__ = [
|
||||
"latitude",
|
||||
"longitude",
|
||||
"altitude",
|
||||
"size",
|
||||
"horizontal_precision",
|
||||
"vertical_precision",
|
||||
]
|
||||
|
||||
def __init__(self, rdclass, rdtype, latitude, longitude, altitude,
|
||||
size=_default_size, hprec=_default_hprec,
|
||||
vprec=_default_vprec):
|
||||
def __init__(
|
||||
self,
|
||||
rdclass,
|
||||
rdtype,
|
||||
latitude,
|
||||
longitude,
|
||||
altitude,
|
||||
size=_default_size,
|
||||
hprec=_default_hprec,
|
||||
vprec=_default_vprec,
|
||||
):
|
||||
"""Initialize a LOC record instance.
|
||||
|
||||
The parameters I{latitude} and I{longitude} may be either a 4-tuple
|
||||
|
@ -145,34 +159,44 @@ class LOC(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
if self.latitude[4] > 0:
|
||||
lat_hemisphere = 'N'
|
||||
lat_hemisphere = "N"
|
||||
else:
|
||||
lat_hemisphere = 'S'
|
||||
lat_hemisphere = "S"
|
||||
if self.longitude[4] > 0:
|
||||
long_hemisphere = 'E'
|
||||
long_hemisphere = "E"
|
||||
else:
|
||||
long_hemisphere = 'W'
|
||||
long_hemisphere = "W"
|
||||
text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % (
|
||||
self.latitude[0], self.latitude[1],
|
||||
self.latitude[2], self.latitude[3], lat_hemisphere,
|
||||
self.longitude[0], self.longitude[1], self.longitude[2],
|
||||
self.longitude[3], long_hemisphere,
|
||||
self.altitude / 100.0
|
||||
self.latitude[0],
|
||||
self.latitude[1],
|
||||
self.latitude[2],
|
||||
self.latitude[3],
|
||||
lat_hemisphere,
|
||||
self.longitude[0],
|
||||
self.longitude[1],
|
||||
self.longitude[2],
|
||||
self.longitude[3],
|
||||
long_hemisphere,
|
||||
self.altitude / 100.0,
|
||||
)
|
||||
|
||||
# do not print default values
|
||||
if self.size != _default_size or \
|
||||
self.horizontal_precision != _default_hprec or \
|
||||
self.vertical_precision != _default_vprec:
|
||||
if (
|
||||
self.size != _default_size
|
||||
or self.horizontal_precision != _default_hprec
|
||||
or self.vertical_precision != _default_vprec
|
||||
):
|
||||
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
|
||||
self.size / 100.0, self.horizontal_precision / 100.0,
|
||||
self.vertical_precision / 100.0
|
||||
self.size / 100.0,
|
||||
self.horizontal_precision / 100.0,
|
||||
self.vertical_precision / 100.0,
|
||||
)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
latitude = [0, 0, 0, 0, 1]
|
||||
longitude = [0, 0, 0, 0, 1]
|
||||
size = _default_size
|
||||
|
@ -184,16 +208,14 @@ class LOC(dns.rdata.Rdata):
|
|||
if t.isdigit():
|
||||
latitude[1] = int(t)
|
||||
t = tok.get_string()
|
||||
if '.' in t:
|
||||
(seconds, milliseconds) = t.split('.')
|
||||
if "." in t:
|
||||
(seconds, milliseconds) = t.split(".")
|
||||
if not seconds.isdigit():
|
||||
raise dns.exception.SyntaxError(
|
||||
'bad latitude seconds value')
|
||||
raise dns.exception.SyntaxError("bad latitude seconds value")
|
||||
latitude[2] = int(seconds)
|
||||
l = len(milliseconds)
|
||||
if l == 0 or l > 3 or not milliseconds.isdigit():
|
||||
raise dns.exception.SyntaxError(
|
||||
'bad latitude milliseconds value')
|
||||
raise dns.exception.SyntaxError("bad latitude milliseconds value")
|
||||
if l == 1:
|
||||
m = 100
|
||||
elif l == 2:
|
||||
|
@ -205,26 +227,24 @@ class LOC(dns.rdata.Rdata):
|
|||
elif t.isdigit():
|
||||
latitude[2] = int(t)
|
||||
t = tok.get_string()
|
||||
if t == 'S':
|
||||
if t == "S":
|
||||
latitude[4] = -1
|
||||
elif t != 'N':
|
||||
raise dns.exception.SyntaxError('bad latitude hemisphere value')
|
||||
elif t != "N":
|
||||
raise dns.exception.SyntaxError("bad latitude hemisphere value")
|
||||
|
||||
longitude[0] = tok.get_int()
|
||||
t = tok.get_string()
|
||||
if t.isdigit():
|
||||
longitude[1] = int(t)
|
||||
t = tok.get_string()
|
||||
if '.' in t:
|
||||
(seconds, milliseconds) = t.split('.')
|
||||
if "." in t:
|
||||
(seconds, milliseconds) = t.split(".")
|
||||
if not seconds.isdigit():
|
||||
raise dns.exception.SyntaxError(
|
||||
'bad longitude seconds value')
|
||||
raise dns.exception.SyntaxError("bad longitude seconds value")
|
||||
longitude[2] = int(seconds)
|
||||
l = len(milliseconds)
|
||||
if l == 0 or l > 3 or not milliseconds.isdigit():
|
||||
raise dns.exception.SyntaxError(
|
||||
'bad longitude milliseconds value')
|
||||
raise dns.exception.SyntaxError("bad longitude milliseconds value")
|
||||
if l == 1:
|
||||
m = 100
|
||||
elif l == 2:
|
||||
|
@ -236,31 +256,31 @@ class LOC(dns.rdata.Rdata):
|
|||
elif t.isdigit():
|
||||
longitude[2] = int(t)
|
||||
t = tok.get_string()
|
||||
if t == 'W':
|
||||
if t == "W":
|
||||
longitude[4] = -1
|
||||
elif t != 'E':
|
||||
raise dns.exception.SyntaxError('bad longitude hemisphere value')
|
||||
elif t != "E":
|
||||
raise dns.exception.SyntaxError("bad longitude hemisphere value")
|
||||
|
||||
t = tok.get_string()
|
||||
if t[-1] == 'm':
|
||||
t = t[0: -1]
|
||||
if t[-1] == "m":
|
||||
t = t[0:-1]
|
||||
altitude = float(t) * 100.0 # m -> cm
|
||||
|
||||
tokens = tok.get_remaining(max_tokens=3)
|
||||
if len(tokens) >= 1:
|
||||
value = tokens[0].unescape().value
|
||||
if value[-1] == 'm':
|
||||
value = value[0: -1]
|
||||
if value[-1] == "m":
|
||||
value = value[0:-1]
|
||||
size = float(value) * 100.0 # m -> cm
|
||||
if len(tokens) >= 2:
|
||||
value = tokens[1].unescape().value
|
||||
if value[-1] == 'm':
|
||||
value = value[0: -1]
|
||||
if value[-1] == "m":
|
||||
value = value[0:-1]
|
||||
hprec = float(value) * 100.0 # m -> cm
|
||||
if len(tokens) >= 3:
|
||||
value = tokens[2].unescape().value
|
||||
if value[-1] == 'm':
|
||||
value = value[0: -1]
|
||||
if value[-1] == "m":
|
||||
value = value[0:-1]
|
||||
vprec = float(value) * 100.0 # m -> cm
|
||||
|
||||
# Try encoding these now so we raise if they are bad
|
||||
|
@ -268,32 +288,43 @@ class LOC(dns.rdata.Rdata):
|
|||
_encode_size(hprec, "horizontal precision")
|
||||
_encode_size(vprec, "vertical precision")
|
||||
|
||||
return cls(rdclass, rdtype, latitude, longitude, altitude,
|
||||
size, hprec, vprec)
|
||||
return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
milliseconds = (self.latitude[0] * 3600000 +
|
||||
self.latitude[1] * 60000 +
|
||||
self.latitude[2] * 1000 +
|
||||
self.latitude[3]) * self.latitude[4]
|
||||
milliseconds = (
|
||||
self.latitude[0] * 3600000
|
||||
+ self.latitude[1] * 60000
|
||||
+ self.latitude[2] * 1000
|
||||
+ self.latitude[3]
|
||||
) * self.latitude[4]
|
||||
latitude = 0x80000000 + milliseconds
|
||||
milliseconds = (self.longitude[0] * 3600000 +
|
||||
self.longitude[1] * 60000 +
|
||||
self.longitude[2] * 1000 +
|
||||
self.longitude[3]) * self.longitude[4]
|
||||
milliseconds = (
|
||||
self.longitude[0] * 3600000
|
||||
+ self.longitude[1] * 60000
|
||||
+ self.longitude[2] * 1000
|
||||
+ self.longitude[3]
|
||||
) * self.longitude[4]
|
||||
longitude = 0x80000000 + milliseconds
|
||||
altitude = int(self.altitude) + 10000000
|
||||
size = _encode_size(self.size, "size")
|
||||
hprec = _encode_size(self.horizontal_precision, "horizontal precision")
|
||||
vprec = _encode_size(self.vertical_precision, "vertical precision")
|
||||
wire = struct.pack("!BBBBIII", 0, size, hprec, vprec, latitude,
|
||||
longitude, altitude)
|
||||
wire = struct.pack(
|
||||
"!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude
|
||||
)
|
||||
file.write(wire)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
(version, size, hprec, vprec, latitude, longitude, altitude) = \
|
||||
parser.get_struct("!BBBBIII")
|
||||
(
|
||||
version,
|
||||
size,
|
||||
hprec,
|
||||
vprec,
|
||||
latitude,
|
||||
longitude,
|
||||
altitude,
|
||||
) = parser.get_struct("!BBBBIII")
|
||||
if version != 0:
|
||||
raise dns.exception.FormError("LOC version not zero")
|
||||
if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
|
||||
|
@ -312,8 +343,7 @@ class LOC(dns.rdata.Rdata):
|
|||
size = _decode_size(size, "size")
|
||||
hprec = _decode_size(hprec, "horizontal precision")
|
||||
vprec = _decode_size(vprec, "vertical precision")
|
||||
return cls(rdclass, rdtype, latitude, longitude, altitude,
|
||||
size, hprec, vprec)
|
||||
return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
|
||||
|
||||
@property
|
||||
def float_latitude(self):
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import struct
|
||||
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
@ -12,7 +13,7 @@ class LP(dns.rdata.Rdata):
|
|||
|
||||
# see: rfc6742.txt
|
||||
|
||||
__slots__ = ['preference', 'fqdn']
|
||||
__slots__ = ["preference", "fqdn"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, preference, fqdn):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -21,17 +22,18 @@ class LP(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
fqdn = self.fqdn.choose_relativity(origin, relativize)
|
||||
return '%d %s' % (self.preference, fqdn)
|
||||
return "%d %s" % (self.preference, fqdn)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
preference = tok.get_uint16()
|
||||
fqdn = tok.get_name(origin, relativize, relativize_to)
|
||||
return cls(rdclass, rdtype, preference, fqdn)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
file.write(struct.pack('!H', self.preference))
|
||||
file.write(struct.pack("!H", self.preference))
|
||||
self.fqdn.to_wire(file, compress, origin, canonicalize)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -13,32 +13,33 @@ class NID(dns.rdata.Rdata):
|
|||
|
||||
# see: rfc6742.txt
|
||||
|
||||
__slots__ = ['preference', 'nodeid']
|
||||
__slots__ = ["preference", "nodeid"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, preference, nodeid):
|
||||
super().__init__(rdclass, rdtype)
|
||||
self.preference = self._as_uint16(preference)
|
||||
if isinstance(nodeid, bytes):
|
||||
if len(nodeid) != 8:
|
||||
raise ValueError('invalid nodeid')
|
||||
self.nodeid = dns.rdata._hexify(nodeid, 4, b':')
|
||||
raise ValueError("invalid nodeid")
|
||||
self.nodeid = dns.rdata._hexify(nodeid, 4, b":")
|
||||
else:
|
||||
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':')
|
||||
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":")
|
||||
self.nodeid = nodeid
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return f'{self.preference} {self.nodeid}'
|
||||
return f"{self.preference} {self.nodeid}"
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
preference = tok.get_uint16()
|
||||
nodeid = tok.get_identifier()
|
||||
return cls(rdclass, rdtype, preference, nodeid)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
file.write(struct.pack('!H', self.preference))
|
||||
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':'))
|
||||
file.write(struct.pack("!H", self.preference))
|
||||
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":"))
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
|
|
|
@ -25,7 +25,7 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class Bitmap(dns.rdtypes.util.Bitmap):
|
||||
type_name = 'NSEC'
|
||||
type_name = "NSEC"
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
@ -33,7 +33,7 @@ class NSEC(dns.rdata.Rdata):
|
|||
|
||||
"""NSEC record"""
|
||||
|
||||
__slots__ = ['next', 'windows']
|
||||
__slots__ = ["next", "windows"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, next, windows):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -45,11 +45,12 @@ class NSEC(dns.rdata.Rdata):
|
|||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
next = self.next.choose_relativity(origin, relativize)
|
||||
text = Bitmap(self.windows).to_text()
|
||||
return '{}{}'.format(next, text)
|
||||
return "{}{}".format(next, text)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
next = tok.get_name(origin, relativize, relativize_to)
|
||||
windows = Bitmap.from_text(tok)
|
||||
return cls(rdclass, rdtype, next, windows)
|
||||
|
|
|
@ -26,10 +26,12 @@ import dns.rdatatype
|
|||
import dns.rdtypes.util
|
||||
|
||||
|
||||
b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV',
|
||||
b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567')
|
||||
b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567',
|
||||
b'0123456789ABCDEFGHIJKLMNOPQRSTUV')
|
||||
b32_hex_to_normal = bytes.maketrans(
|
||||
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
|
||||
)
|
||||
b32_normal_to_hex = bytes.maketrans(
|
||||
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV"
|
||||
)
|
||||
|
||||
# hash algorithm constants
|
||||
SHA1 = 1
|
||||
|
@ -40,7 +42,7 @@ OPTOUT = 1
|
|||
|
||||
@dns.immutable.immutable
|
||||
class Bitmap(dns.rdtypes.util.Bitmap):
|
||||
type_name = 'NSEC3'
|
||||
type_name = "NSEC3"
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
@ -48,10 +50,11 @@ class NSEC3(dns.rdata.Rdata):
|
|||
|
||||
"""NSEC3 record"""
|
||||
|
||||
__slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows']
|
||||
__slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt,
|
||||
next, windows):
|
||||
def __init__(
|
||||
self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows
|
||||
):
|
||||
super().__init__(rdclass, rdtype)
|
||||
self.algorithm = self._as_uint8(algorithm)
|
||||
self.flags = self._as_uint8(flags)
|
||||
|
@ -63,38 +66,41 @@ class NSEC3(dns.rdata.Rdata):
|
|||
self.windows = tuple(windows.windows)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
next = base64.b32encode(self.next).translate(
|
||||
b32_normal_to_hex).lower().decode()
|
||||
if self.salt == b'':
|
||||
salt = '-'
|
||||
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
|
||||
if self.salt == b"":
|
||||
salt = "-"
|
||||
else:
|
||||
salt = binascii.hexlify(self.salt).decode()
|
||||
text = Bitmap(self.windows).to_text()
|
||||
return '%u %u %u %s %s%s' % (self.algorithm, self.flags,
|
||||
self.iterations, salt, next, text)
|
||||
return "%u %u %u %s %s%s" % (
|
||||
self.algorithm,
|
||||
self.flags,
|
||||
self.iterations,
|
||||
salt,
|
||||
next,
|
||||
text,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
algorithm = tok.get_uint8()
|
||||
flags = tok.get_uint8()
|
||||
iterations = tok.get_uint16()
|
||||
salt = tok.get_string()
|
||||
if salt == '-':
|
||||
salt = b''
|
||||
if salt == "-":
|
||||
salt = b""
|
||||
else:
|
||||
salt = binascii.unhexlify(salt.encode('ascii'))
|
||||
next = tok.get_string().encode(
|
||||
'ascii').upper().translate(b32_hex_to_normal)
|
||||
salt = binascii.unhexlify(salt.encode("ascii"))
|
||||
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
|
||||
next = base64.b32decode(next)
|
||||
bitmap = Bitmap.from_text(tok)
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
|
||||
bitmap)
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
l = len(self.salt)
|
||||
file.write(struct.pack("!BBHB", self.algorithm, self.flags,
|
||||
self.iterations, l))
|
||||
file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
|
||||
file.write(self.salt)
|
||||
l = len(self.next)
|
||||
file.write(struct.pack("!B", l))
|
||||
|
@ -103,9 +109,8 @@ class NSEC3(dns.rdata.Rdata):
|
|||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
(algorithm, flags, iterations) = parser.get_struct('!BBH')
|
||||
(algorithm, flags, iterations) = parser.get_struct("!BBH")
|
||||
salt = parser.get_counted_bytes()
|
||||
next = parser.get_counted_bytes()
|
||||
bitmap = Bitmap.from_wire_parser(parser)
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
|
||||
bitmap)
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
|
||||
|
|
|
@ -28,7 +28,7 @@ class NSEC3PARAM(dns.rdata.Rdata):
|
|||
|
||||
"""NSEC3PARAM record"""
|
||||
|
||||
__slots__ = ['algorithm', 'flags', 'iterations', 'salt']
|
||||
__slots__ = ["algorithm", "flags", "iterations", "salt"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -38,34 +38,33 @@ class NSEC3PARAM(dns.rdata.Rdata):
|
|||
self.salt = self._as_bytes(salt, True, 255)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
if self.salt == b'':
|
||||
salt = '-'
|
||||
if self.salt == b"":
|
||||
salt = "-"
|
||||
else:
|
||||
salt = binascii.hexlify(self.salt).decode()
|
||||
return '%u %u %u %s' % (self.algorithm, self.flags, self.iterations,
|
||||
salt)
|
||||
return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
algorithm = tok.get_uint8()
|
||||
flags = tok.get_uint8()
|
||||
iterations = tok.get_uint16()
|
||||
salt = tok.get_string()
|
||||
if salt == '-':
|
||||
salt = ''
|
||||
if salt == "-":
|
||||
salt = ""
|
||||
else:
|
||||
salt = binascii.unhexlify(salt.encode())
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
l = len(self.salt)
|
||||
file.write(struct.pack("!BBHB", self.algorithm, self.flags,
|
||||
self.iterations, l))
|
||||
file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
|
||||
file.write(self.salt)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
(algorithm, flags, iterations) = parser.get_struct('!BBH')
|
||||
(algorithm, flags, iterations) = parser.get_struct("!BBH")
|
||||
salt = parser.get_counted_bytes()
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
|
||||
|
|
|
@ -22,6 +22,7 @@ import dns.immutable
|
|||
import dns.rdata
|
||||
import dns.tokenizer
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class OPENPGPKEY(dns.rdata.Rdata):
|
||||
|
||||
|
@ -37,8 +38,9 @@ class OPENPGPKEY(dns.rdata.Rdata):
|
|||
return dns.rdata._base64ify(self.key, chunksize=None, **kw)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
b64 = tok.concatenate_remaining_identifiers().encode()
|
||||
key = base64.b64decode(b64)
|
||||
return cls(rdclass, rdtype, key)
|
||||
|
|
|
@ -26,12 +26,13 @@ import dns.rdata
|
|||
# We don't implement from_text, and that's ok.
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class OPT(dns.rdata.Rdata):
|
||||
|
||||
"""OPT record"""
|
||||
|
||||
__slots__ = ['options']
|
||||
__slots__ = ["options"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, options):
|
||||
"""Initialize an OPT rdata.
|
||||
|
@ -45,10 +46,12 @@ class OPT(dns.rdata.Rdata):
|
|||
"""
|
||||
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
||||
def as_option(option):
|
||||
if not isinstance(option, dns.edns.Option):
|
||||
raise ValueError('option is not a dns.edns.option')
|
||||
raise ValueError("option is not a dns.edns.option")
|
||||
return option
|
||||
|
||||
self.options = self._as_tuple(options, as_option)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
|
@ -58,13 +61,13 @@ class OPT(dns.rdata.Rdata):
|
|||
file.write(owire)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return ' '.join(opt.to_text() for opt in self.options)
|
||||
return " ".join(opt.to_text() for opt in self.options)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
options = []
|
||||
while parser.remaining() > 0:
|
||||
(otype, olen) = parser.get_struct('!HH')
|
||||
(otype, olen) = parser.get_struct("!HH")
|
||||
with parser.restrict_to(olen):
|
||||
opt = dns.edns.option_from_wire_parser(otype, parser)
|
||||
options.append(opt)
|
||||
|
|
|
@ -28,7 +28,7 @@ class RP(dns.rdata.Rdata):
|
|||
|
||||
# see: RFC 1183
|
||||
|
||||
__slots__ = ['mbox', 'txt']
|
||||
__slots__ = ["mbox", "txt"]
|
||||
|
||||
def __init__(self, rdclass, rdtype, mbox, txt):
|
||||
super().__init__(rdclass, rdtype)
|
||||
|
@ -41,8 +41,9 @@ class RP(dns.rdata.Rdata):
|
|||
return "{} {}".format(str(mbox), str(txt))
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
|
||||
relativize_to=None):
|
||||
def from_text(
|
||||
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
|
||||
):
|
||||
mbox = tok.get_name(origin, relativize, relativize_to)
|
||||
txt = tok.get_name(origin, relativize, relativize_to)
|
||||
return cls(rdclass, rdtype, mbox, txt)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue