adding vendor folder

This commit is contained in:
evilsocket 2018-03-23 15:25:11 +01:00
parent 49c65021ea
commit c304ca4696
No known key found for this signature in database
GPG key ID: 1564D7F30393A456
1145 changed files with 369961 additions and 2 deletions

13
vendor/github.com/inconshreveable/go-vhost/LICENSE generated vendored Normal file
View file

@ -0,0 +1,13 @@
Copyright 2014 Alan Shreve
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

110
vendor/github.com/inconshreveable/go-vhost/README.md generated vendored Normal file
View file

@ -0,0 +1,110 @@
# go-vhost
go-vhost is a simple library that lets you implement virtual hosting functionality for different protocols (HTTP and TLS so far). go-vhost has a high-level and a low-level interface. The high-level interface lets you wrap existing net.Listeners with "muxer" objects. You can then Listen() on a muxer for a particular virtual host name of interest which will return to you a net.Listener for just connections with the virtual hostname of interest.
The lower-level go-vhost interface are just functions which extract the name/routing information for the given protocol and return an object implementing net.Conn which works as if no bytes had been consumed.
### [API Documentation](https://godoc.org/github.com/inconshreveable/go-vhost)
### Usage
```go
l, _ := net.Listen("tcp", *listen)
// start multiplexing on it
mux, _ := vhost.NewHTTPMuxer(l, muxTimeout)
// listen for connections to different domains
for _, v := range virtualHosts {
vhost := v
// vhost.Name is a virtual hostname like "foo.example.com"
muxListener, _ := mux.Listen(vhost.Name())
go func(vh virtualHost, ml net.Listener) {
for {
conn, _ := ml.Accept()
go vh.Handle(conn)
}
}(vhost, muxListener)
}
for {
conn, err := mux.NextError()
switch err.(type) {
case vhost.BadRequest:
log.Printf("got a bad request!")
conn.Write([]byte("bad request"))
case vhost.NotFound:
log.Printf("got a connection for an unknown vhost")
conn.Write([]byte("vhost not found"))
case vhost.Closed:
log.Printf("closed conn: %s", err)
default:
if conn != nil {
conn.Write([]byte("server error"))
}
}
if conn != nil {
conn.Close()
}
}
```
### Low-level API usage
```go
// accept a new connection
conn, _ := listener.Accept()
// parse out the HTTP request and the Host header
if vhostConn, err = vhost.HTTP(conn); err != nil {
panic("Not a valid http connection!")
}
fmt.Printf("Target Host: ", vhostConn.Host())
// Target Host: example.com
// vhostConn contains the entire request as if no bytes had been consumed
bytes, _ := ioutil.ReadAll(vhostConn)
fmt.Printf("%s", bytes)
// GET / HTTP/1.1
// Host: example.com
// User-Agent: ...
// ...
```
### Advanced introspection
The entire HTTP request headers are available for inspection in case you want to mux on something besides the Host header:
```go
// parse out the HTTP request and the Host header
if vhostConn, err = vhost.HTTP(conn); err != nil {
panic("Not a valid http connection!")
}
httpVersion := vhost.Request.MinorVersion
customRouting := vhost.Request.Header["X-Custom-Routing-Header"]
```
Likewise for TLS, you can look at detailed information about the ClientHello message:
```go
if vhostConn, err = vhost.TLS(conn); err != nil {
panic("Not a valid TLS connection!")
}
cipherSuites := vhost.ClientHelloMsg.CipherSuites
sessionId := vhost.ClientHelloMsg.SessionId
```
##### Memory reduction with Free
After you're done muxing, you probably don't need to inspect the header data anymore, so you can make it available for garbage collection:
```go
// look up the upstream host
upstreamHost := hostMapping[vhostConn.Host()]
// free up the muxing data
vhostConn.Free()
// vhostConn.Host() == ""
// vhostConn.Request == nil (HTTP)
// vhostConn.ClientHelloMsg == nil (TLS)
```

42
vendor/github.com/inconshreveable/go-vhost/http.go generated vendored Normal file
View file

@ -0,0 +1,42 @@
package vhost
import (
"bufio"
"net"
"net/http"
)
type HTTPConn struct {
*sharedConn
Request *http.Request
}
// HTTP parses the head of the first HTTP request on conn and returns
// a new, unread connection with metadata for virtual host muxing
func HTTP(conn net.Conn) (httpConn *HTTPConn, err error) {
c, rd := newShared(conn)
httpConn = &HTTPConn{sharedConn: c}
if httpConn.Request, err = http.ReadRequest(bufio.NewReader(rd)); err != nil {
return
}
// You probably don't need access to the request body and this makes the API
// simpler by allowing you to call Free() optionally
httpConn.Request.Body.Close()
return
}
// Free sets Request to nil so that it can be garbage collected
func (c *HTTPConn) Free() {
c.Request = nil
}
func (c *HTTPConn) Host() string {
if c.Request == nil {
return ""
}
return c.Request.Host
}

View file

@ -0,0 +1,45 @@
package vhost
import (
"net"
"net/http"
"testing"
)
func TestHTTPHost(t *testing.T) {
var testHostname string = "foo.example.com"
l, err := net.Listen("tcp", "127.0.0.1:12345")
if err != nil {
panic(err)
}
defer l.Close()
go func() {
conn, err := net.Dial("tcp", "127.0.0.1:12345")
if err != nil {
panic(err)
}
defer conn.Close()
req, err := http.NewRequest("GET", "http://"+testHostname+"/bar", nil)
if err != nil {
panic(err)
}
if err = req.Write(conn); err != nil {
panic(err)
}
}()
conn, err := l.Accept()
if err != nil {
panic(err)
}
c, err := HTTP(conn)
if err != nil {
panic(err)
}
if c.Host() != testHostname {
t.Errorf("Connection Host() is %s, expected %s", c.Host(), testHostname)
}
}

View file

@ -0,0 +1,11 @@
package vhost
import (
"net"
)
type Conn interface {
net.Conn
Host() string
Free()
}

337
vendor/github.com/inconshreveable/go-vhost/mux.go generated vendored Normal file
View file

@ -0,0 +1,337 @@
package vhost
import (
"fmt"
"net"
"strings"
"sync"
"time"
)
var (
normalize = strings.ToLower
isClosed = func(err error) bool {
netErr, ok := err.(net.Error)
if ok {
return netErr.Temporary()
}
return false
}
)
// NotFound is returned when a vhost is not found
type NotFound struct {
error
}
// BadRequest is returned when extraction of the vhost name fails
type BadRequest struct {
error
}
// Closed is returned when the underlying connection is closed
type Closed struct {
error
}
type (
// this is the function you apply to a net.Conn to get
// a new virtual-host multiplexed connection
muxFn func(net.Conn) (Conn, error)
// an error encountered when multiplexing a connection
muxErr struct {
err error
conn net.Conn
}
)
type VhostMuxer struct {
listener net.Listener // listener on which we mux connections
muxTimeout time.Duration // a connection fails if it doesn't send enough data to mux after this timeout
vhostFn muxFn // new connections are multiplexed by applying this function
muxErrors chan muxErr // all muxing errors are sent over this channel
registry map[string]*Listener // registry of name -> listener
sync.RWMutex // protects the registry
}
func NewVhostMuxer(listener net.Listener, vhostFn muxFn, muxTimeout time.Duration) (*VhostMuxer, error) {
mux := &VhostMuxer{
listener: listener,
muxTimeout: muxTimeout,
vhostFn: vhostFn,
muxErrors: make(chan muxErr),
registry: make(map[string]*Listener),
}
go mux.run()
return mux, nil
}
// Listen begins multiplexing the underlying connection to send new
// connections for the given name over the returned listener.
func (m *VhostMuxer) Listen(name string) (net.Listener, error) {
name = normalize(name)
vhost := &Listener{
name: name,
mux: m,
accept: make(chan Conn),
}
if err := m.set(name, vhost); err != nil {
return nil, err
}
return vhost, nil
}
// NextError returns the next error encountered while mux'ing a connection.
// The net.Conn may be nil if the wrapped listener returned an error from Accept()
func (m *VhostMuxer) NextError() (net.Conn, error) {
muxErr := <-m.muxErrors
return muxErr.conn, muxErr.err
}
// Close closes the underlying listener
func (m *VhostMuxer) Close() {
m.listener.Close()
}
// run is the VhostMuxer's main loop for accepting new connections from the wrapped listener
func (m *VhostMuxer) run() {
for {
conn, err := m.listener.Accept()
if err != nil {
if isClosed(err) {
m.sendError(nil, Closed{err})
return
} else {
m.sendError(nil, err)
continue
}
}
go m.handle(conn)
}
}
// handle muxes a connection accepted from the listener
func (m *VhostMuxer) handle(conn net.Conn) {
defer func() {
// recover from failures
if r := recover(); r != nil {
m.sendError(conn, fmt.Errorf("NameMux.handle failed with error %v", r))
}
}()
// Make sure we detect dead connections while we decide how to multiplex
if err := conn.SetDeadline(time.Now().Add(m.muxTimeout)); err != nil {
m.sendError(conn, fmt.Errorf("Failed to set deadline: %v", err))
return
}
// extract the name
vconn, err := m.vhostFn(conn)
if err != nil {
m.sendError(conn, BadRequest{fmt.Errorf("Failed to extract vhost name: %v", err)})
return
}
// normalize the name
host := normalize(vconn.Host())
// look up the correct listener
l, ok := m.get(host)
if !ok {
m.sendError(vconn, NotFound{fmt.Errorf("Host not found: %v", host)})
return
}
if err = vconn.SetDeadline(time.Time{}); err != nil {
m.sendError(vconn, fmt.Errorf("Failed unset connection deadline: %v", err))
return
}
l.accept <- vconn
}
func (m *VhostMuxer) sendError(conn net.Conn, err error) {
m.muxErrors <- muxErr{conn: conn, err: err}
}
func (m *VhostMuxer) get(name string) (l *Listener, ok bool) {
m.RLock()
defer m.RUnlock()
l, ok = m.registry[name]
if !ok {
// look for a matching wildcard
parts := strings.Split(name, ".")
for i := 0; i < len(parts)-1; i++ {
parts[i] = "*"
name = strings.Join(parts[i:], ".")
l, ok = m.registry[name]
if ok {
break
}
}
}
return
}
func (m *VhostMuxer) set(name string, l *Listener) error {
m.Lock()
defer m.Unlock()
if _, exists := m.registry[name]; exists {
return fmt.Errorf("name %s is already bound", name)
}
m.registry[name] = l
return nil
}
func (m *VhostMuxer) del(name string) {
m.Lock()
defer m.Unlock()
delete(m.registry, name)
}
const (
serverError = `HTTP/1.0 500 Internal Server Error
Content-Length: 22
Internal Server Error
`
notFound = `HTTP/1.0 404 Not Found
Content-Length: 14
404 not found
`
badRequest = `HTTP/1.0 400 Bad Request
Content-Length: 12
Bad Request
`
)
type HTTPMuxer struct {
*VhostMuxer
}
// HandleErrors handles muxing errors by calling .NextError(). You must
// invoke this function if you do not want to handle the errors yourself.
func (m *HTTPMuxer) HandleErrors() {
for {
m.HandleError(m.NextError())
}
}
func (m *HTTPMuxer) HandleError(conn net.Conn, err error) {
switch err.(type) {
case Closed:
return
case NotFound:
conn.Write([]byte(notFound))
case BadRequest:
conn.Write([]byte(badRequest))
default:
if conn != nil {
conn.Write([]byte(serverError))
}
}
if conn != nil {
conn.Close()
}
}
// NewHTTPMuxer begins muxing HTTP connections on the given listener by inspecting
// the HTTP Host header in new connections.
func NewHTTPMuxer(listener net.Listener, muxTimeout time.Duration) (*HTTPMuxer, error) {
fn := func(c net.Conn) (Conn, error) { return HTTP(c) }
mux, err := NewVhostMuxer(listener, fn, muxTimeout)
return &HTTPMuxer{mux}, err
}
type TLSMuxer struct {
*VhostMuxer
}
// HandleErrors is the default error handler for TLS muxers. At the moment, it simply
// closes connections which are invalid or destined for virtual host names that it is
// not listening for.
// You must invoke this function if you do not want to handle the errors yourself.
func (m *TLSMuxer) HandleErrors() {
for {
conn, err := m.NextError()
if conn == nil {
if _, ok := err.(Closed); ok {
return
} else {
continue
}
} else {
// XXX: respond with valid TLS close messages
conn.Close()
}
}
}
func (m *TLSMuxer) Listen(name string) (net.Listener, error) {
// TLS SNI never includes the port
host, _, err := net.SplitHostPort(name)
if err != nil {
host = name
}
return m.VhostMuxer.Listen(host)
}
// NewTLSMuxer begins muxing TLS connections by inspecting the SNI extension.
func NewTLSMuxer(listener net.Listener, muxTimeout time.Duration) (*TLSMuxer, error) {
fn := func(c net.Conn) (Conn, error) { return TLS(c) }
mux, err := NewVhostMuxer(listener, fn, muxTimeout)
return &TLSMuxer{mux}, err
}
// Listener is returned by a call to Listen() on a muxer. A Listener
// only receives connections that were made to the name passed into the muxer's
// Listen call.
//
// Listener implements the net.Listener interface, so you can Accept() new
// connections and Close() it when finished. When you Close() a Listener,
// the parent muxer will stop listening for connections to the Listener's name.
type Listener struct {
name string
mux *VhostMuxer
accept chan Conn
}
// Accept returns the next mux'd connection for this listener and blocks
// until one is available.
func (l *Listener) Accept() (net.Conn, error) {
conn, ok := <-l.accept
if !ok {
return nil, fmt.Errorf("Listener closed")
}
return conn, nil
}
// Close stops the parent muxer from listening for connections to the mux'd
// virtual host name.
func (l *Listener) Close() error {
l.mux.del(l.name)
close(l.accept)
return nil
}
// Addr returns the address of the bound listener used by the parent muxer.
func (l *Listener) Addr() net.Addr {
// XXX: include name in address?
return l.mux.listener.Addr()
}
// Name returns the name of the virtual host this listener receives connections on.
func (l *Listener) Name() string {
return l.name
}

195
vendor/github.com/inconshreveable/go-vhost/mux_test.go generated vendored Normal file
View file

@ -0,0 +1,195 @@
package vhost
import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"testing"
"time"
)
// TestErrors ensures that error types for this package are implemented properly
func TestErrors(t *testing.T) {
// test case for https://github.com/inconshreveable/go-vhost/pull/2
// create local err vars of error interface type
var notFoundErr error
var badRequestErr error
var closedErr error
// stuff local error types in to interface values
notFoundErr = NotFound{fmt.Errorf("test NotFound")}
badRequestErr = BadRequest{fmt.Errorf("test BadRequest")}
closedErr = Closed{fmt.Errorf("test Closed")}
// assert the types
switch errType := notFoundErr.(type) {
case NotFound:
default:
t.Fatalf("expected NotFound, got: %s", errType)
}
switch errType := badRequestErr.(type) {
case BadRequest:
default:
t.Fatalf("expected BadRequest, got: %s", errType)
}
switch errType := closedErr.(type) {
case Closed:
default:
t.Fatalf("expected Closed, got: %s", errType)
}
}
func localListener(t *testing.T) (net.Listener, string) {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
return l, strconv.Itoa(l.Addr().(*net.TCPAddr).Port)
}
func TestHTTPMux(t *testing.T) {
l, port := localListener(t)
mux, err := NewHTTPMuxer(l, time.Second)
if err != nil {
t.Fatalf("failed to start muxer: %v", err)
}
go mux.HandleErrors()
muxed, err := mux.Listen("example.com")
if err != nil {
t.Fatalf("failed to listen on muxer: %v", muxed)
}
go http.Serve(muxed, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
}))
msg := "test"
url := "http://localhost:" + port
resp, err := http.Post(url, "text/plain", strings.NewReader(msg))
if err != nil {
t.Fatalf("failed to post: %v", err)
}
if resp.StatusCode != 404 {
t.Fatalf("sent incorrect host header, expected 404 but got %d", resp.StatusCode)
}
req, err := http.NewRequest("POST", url, strings.NewReader(msg))
if err != nil {
t.Fatalf("failed to construct HTTP request: %v", err)
}
req.Host = "example.com"
req.Header.Set("Content-Type", "text/plain")
resp, err = new(http.Client).Do(req)
if err != nil {
t.Fatalf("failed to make HTTP request", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read: %v", err)
}
got := string(body)
if got != msg {
t.Fatalf("unexpected resposne. got: %v, expected: %v", got, msg)
}
}
func testMux(t *testing.T, listen, dial string) {
muxFn := func(c net.Conn) (Conn, error) {
return fakeConn{c, dial}, nil
}
fakel := make(fakeListener, 1)
mux, err := NewVhostMuxer(fakel, muxFn, time.Second)
if err != nil {
t.Fatalf("failed to start vhost muxer: %v", err)
}
l, err := mux.Listen(listen)
if err != nil {
t.Fatalf("failed to listen for %s", err)
}
done := make(chan struct{})
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("failed to accept connection: %v", err)
return
}
got := conn.(Conn).Host()
expected := dial
if got != expected {
t.Fatalf("got connection with unexpected host. got: %s, expected: %s", got, expected)
return
}
close(done)
}()
go func() {
_, err := mux.NextError()
if err != nil {
t.Fatalf("muxing error: %v", err)
}
}()
fakel <- struct{}{}
select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("test timed out: dial: %s listen: %s", dial, listen)
}
}
func TestMuxingPatterns(t *testing.T) {
var tests = []struct {
listen string
dial string
}{
{"example.com", "example.com"},
{"sub.example.com", "sub.example.com"},
{"*.example.com", "sub.example.com"},
{"*.example.com", "nested.sub.example.com"},
}
for _, test := range tests {
testMux(t, test.listen, test.dial)
}
}
type fakeConn struct {
net.Conn
host string
}
func (c fakeConn) SetDeadline(d time.Time) error { return nil }
func (c fakeConn) Host() string { return c.host }
func (c fakeConn) Free() {}
type fakeNetConn struct {
net.Conn
}
func (fakeNetConn) SetDeadline(time.Time) error { return nil }
type fakeListener chan struct{}
func (l fakeListener) Accept() (net.Conn, error) {
for _ = range l {
return fakeNetConn{nil}, nil
}
select {}
}
func (fakeListener) Addr() net.Addr { return nil }
func (fakeListener) Close() error { return nil }

52
vendor/github.com/inconshreveable/go-vhost/shared.go generated vendored Normal file
View file

@ -0,0 +1,52 @@
package vhost
import (
"bytes"
"io"
"net"
"sync"
)
const (
initVhostBufSize = 1024 // allocate 1 KB up front to try to avoid resizing
)
type sharedConn struct {
sync.Mutex
net.Conn // the raw connection
vhostBuf *bytes.Buffer // all of the initial data that has to be read in order to vhost a connection is saved here
}
func newShared(conn net.Conn) (*sharedConn, io.Reader) {
c := &sharedConn{
Conn: conn,
vhostBuf: bytes.NewBuffer(make([]byte, 0, initVhostBufSize)),
}
return c, io.TeeReader(conn, c.vhostBuf)
}
func (c *sharedConn) Read(p []byte) (n int, err error) {
c.Lock()
if c.vhostBuf == nil {
c.Unlock()
return c.Conn.Read(p)
}
n, err = c.vhostBuf.Read(p)
// end of the request buffer
if err == io.EOF {
// let the request buffer get garbage collected
// and make sure we don't read from it again
c.vhostBuf = nil
// continue reading from the connection
var n2 int
n2, err = c.Conn.Read(p[n:])
// update total read
n += n2
}
c.Unlock()
return
}

View file

@ -0,0 +1,64 @@
package vhost
import (
"bytes"
"io"
"net"
"reflect"
"testing"
)
func TestHeaderPreserved(t *testing.T) {
var msg string = "TestHeaderPreserved message! Hello world!"
var headerLen int = 15
l, err := net.Listen("tcp", "127.0.0.1:12345")
if err != nil {
panic(err)
}
defer l.Close()
go func() {
conn, err := net.Dial("tcp", "127.0.0.1:12345")
if err != nil {
panic(err)
}
if _, err := conn.Write([]byte(msg)); err != nil {
panic(err)
}
if err = conn.Close(); err != nil {
panic(err)
}
}()
conn, err := l.Accept()
if err != nil {
panic(err)
}
// create a shared connection object
c, rd := newShared(conn)
// read out a "header"
p := make([]byte, headerLen)
_, err = io.ReadFull(rd, p)
if err != nil {
panic(err)
}
// make sure we got the header
expectedHeader := []byte(msg[:headerLen])
if !reflect.DeepEqual(p, expectedHeader) {
t.Errorf("Read header bytes %s, expected %s", p, expectedHeader)
return
}
// read out the entire connection. make sure it includes the header
buf := bytes.NewBuffer([]byte{})
io.Copy(buf, c)
expected := []byte(msg)
if !reflect.DeepEqual(buf.Bytes(), expected) {
t.Errorf("Read full connection bytes %s, expected %s", buf.Bytes(), expected)
}
}

434
vendor/github.com/inconshreveable/go-vhost/tls.go generated vendored Normal file
View file

@ -0,0 +1,434 @@
// Portions of the TLS code are:
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TLS virtual hosting
package vhost
import (
"bytes"
"errors"
"io"
"net"
"strconv"
)
const (
maxPlaintext = 16384 // maximum plaintext payload length
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
)
type alert uint8
const (
alertUnexpectedMessage alert = 10
alertRecordOverflow alert = 22
alertInternalError alert = 80
)
var alertText = map[alert]string{
alertUnexpectedMessage: "unexpected message",
alertRecordOverflow: "record overflow",
alertInternalError: "internal error",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
func (e alert) Error() string {
return e.String()
}
// TLS record types.
type recordType uint8
const (
recordTypeHandshake recordType = 22
)
// TLS handshake message types.
const (
typeClientHello uint8 = 1
)
// TLS extension numbers
var (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10
extensionSupportedPoints uint16 = 11
extensionSessionTicket uint16 = 35
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
)
// TLS CertificateStatusType (RFC 3546)
const (
statusTypeOCSP uint8 = 1
)
// A Conn represents a secured connection.
// It implements the net.Conn interface.
type TLSConn struct {
*sharedConn
ClientHelloMsg *ClientHelloMsg
}
// TLS parses the ClientHello message on conn and returns
// a new, unread connection with metadata for virtual host muxing
func TLS(conn net.Conn) (tlsConn *TLSConn, err error) {
c, rd := newShared(conn)
tlsConn = &TLSConn{sharedConn: c}
if tlsConn.ClientHelloMsg, err = readClientHello(rd); err != nil {
return
}
return
}
func (c *TLSConn) Host() string {
if c.ClientHelloMsg == nil {
return ""
}
return c.ClientHelloMsg.ServerName
}
func (c *TLSConn) Free() {
c.ClientHelloMsg = nil
}
// A block is a simple data buffer.
type block struct {
data []byte
off int // index for Read
}
// resize resizes block to be n bytes, growing if necessary.
func (b *block) resize(n int) {
if n > cap(b.data) {
b.reserve(n)
}
b.data = b.data[0:n]
}
// reserve makes sure that block contains a capacity of at least n bytes.
func (b *block) reserve(n int) {
if cap(b.data) >= n {
return
}
m := cap(b.data)
if m == 0 {
m = 1024
}
for m < n {
m *= 2
}
data := make([]byte, len(b.data), m)
copy(data, b.data)
b.data = data
}
// readFromUntil reads from r into b until b contains at least n bytes
// or else returns an error.
func (b *block) readFromUntil(r io.Reader, n int) error {
// quick case
if len(b.data) >= n {
return nil
}
// read until have enough.
b.reserve(n)
for {
m, err := r.Read(b.data[len(b.data):cap(b.data)])
b.data = b.data[0 : len(b.data)+m]
if len(b.data) >= n {
break
}
if err != nil {
return err
}
}
return nil
}
func (b *block) Read(p []byte) (n int, err error) {
n = copy(p, b.data[b.off:])
b.off += n
return
}
// newBlock allocates a new block
func newBlock() *block {
return new(block)
}
// splitBlock splits a block after the first n bytes,
// returning a block with those n bytes and a
// block with the remainder. the latter may be nil.
func splitBlock(b *block, n int) (*block, *block) {
if len(b.data) <= n {
return b, nil
}
bb := newBlock()
bb.resize(len(b.data) - n)
copy(bb.data, b.data[n:])
b.data = b.data[0:n]
return b, bb
}
// readHandshake reads the next handshake message from
// the record layer.
func readClientHello(rd io.Reader) (*ClientHelloMsg, error) {
var nextBlock *block // raw input, right off the wire
var hand bytes.Buffer // handshake data waiting to be read
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
readRecord := func() error {
// Caller must be in sync with connection:
// handshake data if handshake not yet completed,
// else application data. (We don't support renegotiation.)
if nextBlock == nil {
nextBlock = newBlock()
}
b := nextBlock
// Read header, payload.
if err := b.readFromUntil(rd, recordHeaderLen); err != nil {
return err
}
typ := recordType(b.data[0])
// No valid TLS record has a type of 0x80, however SSLv2 handshakes
// start with a uint16 length where the MSB is set and the first record
// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
// an SSLv2 client.
if typ == 0x80 {
return errors.New("tls: unsupported SSLv2 handshake received")
}
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if n > maxCiphertext {
return alertRecordOverflow
}
// First message, be extra suspicious:
// this might not be a TLS client.
// Bail out before reading a full 'body', if possible.
// The current max version is 3.1.
// If the version is >= 16.0, it's probably not real.
// Similarly, a clientHello message encodes in
// well under a kilobyte. If the length is >= 12 kB,
// it's probably not real.
if (typ != recordTypeHandshake) || vers >= 0x1000 || n >= 0x3000 {
return alertUnexpectedMessage
}
if err := b.readFromUntil(rd, recordHeaderLen+n); err != nil {
return err
}
// Process message.
b, nextBlock = splitBlock(b, recordHeaderLen+n)
b.off = recordHeaderLen
data := b.data[b.off:]
if len(data) > maxPlaintext {
return alertRecordOverflow
}
hand.Write(data)
return nil
}
if err := readRecord(); err != nil {
return nil, err
}
data := hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
return nil, alertInternalError
}
for hand.Len() < 4+n {
if err := readRecord(); err != nil {
return nil, err
}
}
data = hand.Next(4 + n)
if data[0] != typeClientHello {
return nil, alertUnexpectedMessage
}
msg := new(ClientHelloMsg)
if !msg.unmarshal(data) {
return nil, alertUnexpectedMessage
}
return msg, nil
}
type ClientHelloMsg struct {
Raw []byte
Vers uint16
Random []byte
SessionId []byte
CipherSuites []uint16
CompressionMethods []uint8
NextProtoNeg bool
ServerName string
OcspStapling bool
SupportedCurves []uint16
SupportedPoints []uint8
TicketSupported bool
SessionTicket []uint8
}
func (m *ClientHelloMsg) unmarshal(data []byte) bool {
if len(data) < 42 {
return false
}
m.Raw = data
m.Vers = uint16(data[4])<<8 | uint16(data[5])
m.Random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return false
}
m.SessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
if len(data) < 2 {
return false
}
// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
// they are uint16s, the number must be even.
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
return false
}
numCipherSuites := cipherSuiteLen / 2
m.CipherSuites = make([]uint16, numCipherSuites)
for i := 0; i < numCipherSuites; i++ {
m.CipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
return false
}
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
return false
}
m.CompressionMethods = data[1 : 1+compressionMethodsLen]
data = data[1+compressionMethodsLen:]
m.NextProtoNeg = false
m.ServerName = ""
m.OcspStapling = false
m.TicketSupported = false
m.SessionTicket = nil
if len(data) == 0 {
// ClientHello is optionally followed by extension data
return true
}
if len(data) < 2 {
return false
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
return false
}
for len(data) != 0 {
if len(data) < 4 {
return false
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return false
}
switch extension {
case extensionServerName:
if length < 2 {
return false
}
numNames := int(data[0])<<8 | int(data[1])
d := data[2:]
for i := 0; i < numNames; i++ {
if len(d) < 3 {
return false
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
return false
}
if nameType == 0 {
m.ServerName = string(d[0:nameLen])
break
}
d = d[nameLen:]
}
case extensionNextProtoNeg:
if length > 0 {
return false
}
m.NextProtoNeg = true
case extensionStatusRequest:
m.OcspStapling = length > 0 && data[0] == statusTypeOCSP
case extensionSupportedCurves:
// http://tools.ietf.org/html/rfc4492#section-5.5.1
if length < 2 {
return false
}
l := int(data[0])<<8 | int(data[1])
if l%2 == 1 || length != l+2 {
return false
}
numCurves := l / 2
m.SupportedCurves = make([]uint16, numCurves)
d := data[2:]
for i := 0; i < numCurves; i++ {
m.SupportedCurves[i] = uint16(d[0])<<8 | uint16(d[1])
d = d[2:]
}
case extensionSupportedPoints:
// http://tools.ietf.org/html/rfc4492#section-5.5.2
if length < 1 {
return false
}
l := int(data[0])
if length != l+1 {
return false
}
m.SupportedPoints = make([]uint8, l)
copy(m.SupportedPoints, data[1:])
case extensionSessionTicket:
// http://tools.ietf.org/html/rfc5077#section-3.2
m.TicketSupported = true
m.SessionTicket = data[:length]
}
data = data[length:]
}
return true
}

39
vendor/github.com/inconshreveable/go-vhost/tls_test.go generated vendored Normal file
View file

@ -0,0 +1,39 @@
package vhost
import (
"crypto/tls"
"net"
"testing"
)
func TestSNI(t *testing.T) {
var testHostname string = "foo.example.com"
l, err := net.Listen("tcp", "127.0.0.1:12345")
if err != nil {
panic(err)
}
defer l.Close()
go func() {
conf := &tls.Config{ServerName: testHostname}
conn, err := tls.Dial("tcp", "127.0.0.1:12345", conf)
if err != nil {
panic(err)
}
conn.Close()
}()
conn, err := l.Accept()
if err != nil {
panic(err)
}
c, err := TLS(conn)
if err != nil {
panic(err)
}
if c.Host() != testHostname {
t.Errorf("Connection Host() is %s, expected %s", c.Host(), testHostname)
}
}