it works!

This commit is contained in:
Simone Margaritelli 2024-09-19 16:33:44 +02:00
commit 91d360327a
12 changed files with 446 additions and 1356 deletions

View file

@ -11,13 +11,13 @@ import (
func (mod *EventsStream) viewMDNSEvent(output io.Writer, e session.Event) {
event := e.Data.(mdns.ServiceDiscoveryEvent)
fmt.Fprintf(output, "[%s] [%s] service %s detected for %s (%s):%d : %s\n",
fmt.Fprintf(output, "[%s] [%s] service %s detected for %s (%s):%d with %d records\n",
e.Time.Format(mod.timeFormat),
tui.Green(e.Tag),
tui.Bold(event.Service.Name),
event.Service.AddrV4.String(),
tui.Dim(event.Service.Host),
tui.Bold(event.Service.ServiceInstanceName()),
event.Service.AddrIPv4,
tui.Dim(event.Service.HostName),
event.Service.Port,
event.Service.Info,
len(event.Service.Text),
)
}

View file

@ -1,475 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MIT
package mdns
import (
"context"
"fmt"
"log"
"net"
"strings"
"sync/atomic"
"time"
"github.com/miekg/dns"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// ServiceEntry is returned after we query for a service
type ServiceEntry struct {
Name string
Host string
AddrV4 net.IP
AddrV6 net.IP // @Deprecated
AddrV6IPAddr *net.IPAddr
Port int
Info string
InfoFields []string
Addr net.IP // @Deprecated
hasTXT bool
sent bool
}
// complete is used to check if we have all the info we need
func (s *ServiceEntry) complete() bool {
return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT
}
// QueryParam is used to customize how a Lookup is performed
type QueryParam struct {
Module *MDNSModule
Service string // Service to lookup
Domain string // Lookup domain, default "local"
Timeout time.Duration // Lookup timeout, default 1 second
Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
DisableIPv4 bool // Whether to disable usage of IPv4 for MDNS operations. Does not affect discovered addresses.
DisableIPv6 bool // Whether to disable usage of IPv6 for MDNS operations. Does not affect discovered addresses.
Logger *log.Logger // Optionally provide a *log.Logger to better manage log output.
}
// DefaultParams is used to return a default set of QueryParam's
func DefaultParams(service string) *QueryParam {
return &QueryParam{
Service: service,
Domain: "local",
Timeout: time.Second,
Entries: make(chan *ServiceEntry),
WantUnicastResponse: false, // TODO(reddaly): Change this default.
DisableIPv4: false,
DisableIPv6: false,
}
}
// Query looks up a given service, in a domain, waiting at most
// for a timeout before finishing the query. The results are streamed
// to a channel. Sends will not block, so clients should make sure to
// either read or buffer.
func Query(params *QueryParam) error {
return QueryContext(context.Background(), params)
}
// QueryContext looks up a given service, in a domain, waiting at most
// for a timeout before finishing the query. The results are streamed
// to a channel. Sends will not block, so clients should make sure to
// either read or buffer. QueryContext will attempt to stop the query
// on cancellation.
func QueryContext(ctx context.Context, params *QueryParam) error {
if params.Logger == nil {
params.Logger = log.Default()
}
// Create a new client
client, err := newClient(!params.DisableIPv4, !params.DisableIPv6, params.Logger)
if err != nil {
return err
}
defer client.Close()
go func() {
select {
case <-ctx.Done():
client.Close()
case <-client.closedCh:
return
}
}()
// Set the multicast interface
if params.Interface != nil {
if err := client.setInterface(params.Interface); err != nil {
return err
}
}
// Ensure defaults are set
if params.Domain == "" {
params.Domain = "local"
}
if params.Timeout == 0 {
params.Timeout = time.Second
}
// Run the query
return client.query(params)
}
// Lookup is the same as Query, however it uses all the default parameters
func Lookup(service string, entries chan<- *ServiceEntry) error {
params := DefaultParams(service)
params.Entries = entries
return Query(params)
}
// Client provides a query interface that can be used to
// search for service providers using mDNS
type client struct {
use_ipv4 bool
use_ipv6 bool
ipv4UnicastConn *net.UDPConn
ipv6UnicastConn *net.UDPConn
ipv4MulticastConn *net.UDPConn
ipv6MulticastConn *net.UDPConn
closed int32
closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
log *log.Logger
}
// NewClient creates a new mdns Client that can be used to query
// for records
func newClient(v4 bool, v6 bool, logger *log.Logger) (*client, error) {
if !v4 && !v6 {
return nil, fmt.Errorf("Must enable at least one of IPv4 and IPv6 querying")
}
// TODO(reddaly): At least attempt to bind to the port required in the spec.
// Create a IPv4 listener
var uconn4 *net.UDPConn
var uconn6 *net.UDPConn
var mconn4 *net.UDPConn
var mconn6 *net.UDPConn
var err error
// Establish unicast connections
if v4 {
uconn4, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
logger.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
}
}
if v6 {
uconn6, err = net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
logger.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
}
}
if uconn4 == nil && uconn6 == nil {
return nil, fmt.Errorf("failed to bind to any unicast udp port")
}
// Establish multicast connections
if v4 {
mconn4, err = net.ListenMulticastUDP("udp4", nil, ipv4Addr)
if err != nil {
logger.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
}
}
if v6 {
mconn6, err = net.ListenMulticastUDP("udp6", nil, ipv6Addr)
if err != nil {
logger.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
}
}
if mconn4 == nil && mconn6 == nil {
return nil, fmt.Errorf("failed to bind to any multicast udp port")
}
// Check that unicast and multicast connections have been made for IPv4 and IPv6
// and disable the respective protocol if not.
if uconn4 == nil || mconn4 == nil {
logger.Printf("[INFO] mdns: Failed to listen to both unicast and multicast on IPv4")
uconn4 = nil
mconn4 = nil
v4 = false
}
if uconn6 == nil || mconn6 == nil {
logger.Printf("[INFO] mdns: Failed to listen to both unicast and multicast on IPv6")
uconn6 = nil
mconn6 = nil
v6 = false
}
if !v4 && !v6 {
return nil, fmt.Errorf("at least one of IPv4 and IPv6 must be enabled for querying")
}
c := &client{
use_ipv4: v4,
use_ipv6: v6,
ipv4MulticastConn: mconn4,
ipv6MulticastConn: mconn6,
ipv4UnicastConn: uconn4,
ipv6UnicastConn: uconn6,
closedCh: make(chan struct{}),
log: logger,
}
return c, nil
}
// Close is used to cleanup the client
func (c *client) Close() error {
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
// something else already closed it
return nil
}
c.log.Printf("[INFO] mdns: Closing client %v", *c)
close(c.closedCh)
if c.ipv4UnicastConn != nil {
c.ipv4UnicastConn.Close()
}
if c.ipv6UnicastConn != nil {
c.ipv6UnicastConn.Close()
}
if c.ipv4MulticastConn != nil {
c.ipv4MulticastConn.Close()
}
if c.ipv6MulticastConn != nil {
c.ipv6MulticastConn.Close()
}
return nil
}
// setInterface is used to set the query interface, uses system
// default if not provided
func (c *client) setInterface(iface *net.Interface) error {
if c.use_ipv4 {
p := ipv4.NewPacketConn(c.ipv4UnicastConn)
if err := p.SetMulticastInterface(iface); err != nil {
return err
}
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
if err := p.SetMulticastInterface(iface); err != nil {
return err
}
}
if c.use_ipv6 {
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
return err
}
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
if err := p2.SetMulticastInterface(iface); err != nil {
return err
}
}
return nil
}
// msgAddr carries the message and source address from recv to message processing.
type msgAddr struct {
msg *dns.Msg
src *net.UDPAddr
}
// query is used to perform a lookup and stream results
func (c *client) query(params *QueryParam) error {
// Create the service name
serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
// Start listening for response packets
msgCh := make(chan *msgAddr, 32)
if c.use_ipv4 {
go c.recv(c.ipv4UnicastConn, msgCh)
go c.recv(c.ipv4MulticastConn, msgCh)
}
if c.use_ipv6 {
go c.recv(c.ipv6UnicastConn, msgCh)
go c.recv(c.ipv6MulticastConn, msgCh)
}
// Send the query
m := new(dns.Msg)
m.SetQuestion(serviceAddr, dns.TypePTR)
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
// Section
//
// In the Question Section of a Multicast DNS query, the top bit of the qclass
// field is used to indicate that unicast responses are preferred for this
// particular question. (See Section 5.4.)
if params.WantUnicastResponse {
m.Question[0].Qclass |= 1 << 15
}
m.RecursionDesired = true
if err := c.sendQuery(m); err != nil {
return err
}
// Map the in-progress responses
inprogress := make(map[string]*ServiceEntry)
// Listen until we reach the timeout
finish := time.After(params.Timeout)
for {
select {
case resp := <-msgCh:
var inp *ServiceEntry
for _, answer := range append(resp.msg.Answer, resp.msg.Extra...) {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr)
case *dns.SRV:
// Check for a target mismatch
if rr.Target != rr.Hdr.Name {
alias(inprogress, rr.Hdr.Name, rr.Target)
}
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name)
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name)
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name)
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA // @Deprecated
inp.AddrV6IPAddr = &net.IPAddr{IP: rr.AAAA}
// link-local IPv6 addresses must be qualified with a zone (interface). Zone is
// specific to this machine/network-namespace and so won't be carried in the
// mDNS message itself. We borrow the zone from the source address of the UDP
// packet, as the link-local address should be valid on that interface.
if rr.AAAA.IsLinkLocalUnicast() || rr.AAAA.IsLinkLocalMulticast() {
inp.AddrV6IPAddr.Zone = resp.src.Zone
}
}
if inp == nil {
params.Module.Debug("no inp for %v", answer)
continue
}
// Check if this entry is complete
if inp.complete() {
if inp.sent {
continue
}
inp.sent = true
select {
case params.Entries <- inp:
default:
}
} else {
// Fire off a node specific query
params.Module.Debug("sending query for service %s", inp.Name)
m := new(dns.Msg)
m.SetQuestion(inp.Name, dns.TypePTR)
m.RecursionDesired = true
if err := c.sendQuery(m); err != nil {
params.Module.Error("failed to query instance %s: %v", inp.Name, err)
}
time.Sleep(time.Duration(1) * time.Millisecond)
}
}
case <-finish:
return nil
}
}
}
// sendQuery is used to multicast a query out
func (c *client) sendQuery(q *dns.Msg) error {
buf, err := q.Pack()
if err != nil {
return err
}
if c.ipv4UnicastConn != nil {
_, err = c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
if err != nil {
return err
}
}
if c.ipv6UnicastConn != nil {
_, err = c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
if err != nil {
return err
}
}
return nil
}
// recv is used to receive until we get a shutdown
func (c *client) recv(l *net.UDPConn, msgCh chan *msgAddr) {
if l == nil {
return
}
buf := make([]byte, 65536)
for atomic.LoadInt32(&c.closed) == 0 {
n, addr, err := l.ReadFromUDP(buf)
if atomic.LoadInt32(&c.closed) == 1 {
return
}
if err != nil {
c.log.Printf("[ERR] mdns: Failed to read packet: %v", err)
continue
}
msg := new(dns.Msg)
if err := msg.Unpack(buf[:n]); err != nil {
c.log.Printf("[ERR] mdns: Failed to unpack packet: %v", err)
continue
}
select {
case msgCh <- &msgAddr{
msg: msg,
src: addr,
}:
case <-c.closedCh:
return
}
}
}
// ensureName is used to ensure the named node is in progress
func ensureName(inprogress map[string]*ServiceEntry, name string) *ServiceEntry {
if inp, ok := inprogress[name]; ok {
return inp
}
inp := &ServiceEntry{
Name: name,
}
inprogress[name] = inp
return inp
}
// alias is used to setup an alias between two entries
func alias(inprogress map[string]*ServiceEntry, src, dst string) {
srcEntry := ensureName(inprogress, src)
inprogress[dst] = srcEntry
}

View file

@ -4,34 +4,57 @@ import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"strings"
"github.com/miekg/dns"
"github.com/evilsocket/islazy/tui"
"github.com/grandcat/zeroconf"
yaml "gopkg.in/yaml.v3"
)
/*
type multiService struct {
mod *MDNSModule
services []*MDNSService
}
func (m multiService) Records(q dns.Question) []dns.RR {
records := make([]dns.RR, 0)
for _, svc := range m.services {
records = append(records, svc.Records(q)...)
m.mod.Debug("QUESTION: %+v", q)
if strings.HasPrefix(q.Name, "_services._dns-sd._udp.") {
for _, svc := range m.services {
records = append(records, svc.Records(q)...)
}
} else {
for _, svc := range m.services {
if svcRecords := svc.Records(q); len(svcRecords) > 0 {
records = svcRecords
break
}
}
}
if num := len(records); num == 0 {
m.mod.Debug("unhandled service %+v", q)
} else {
m.mod.Info("responding to query %s with %d records", tui.Green(q.Name), num)
if q.Name == "_services._dns-sd._udp.local." {
for _, r := range records {
m.mod.Info(" %+v", r)
}
}
}
return records
}
*/
type Advertiser struct {
Filename string
Mapping map[string]ServiceEntry
Service multiService
Server *Server
Mapping map[string]zeroconf.ServiceEntry
Servers map[string]*zeroconf.Server
}
func (mod *MDNSModule) startAdvertiser(fileName string) error {
@ -44,7 +67,7 @@ func (mod *MDNSModule) startAdvertiser(fileName string) error {
return fmt.Errorf("could not read %s: %v", fileName, err)
}
mapping := make(map[string]ServiceEntry)
mapping := make(map[string]zeroconf.ServiceEntry)
if err = yaml.Unmarshal(data, &mapping); err != nil {
return fmt.Errorf("could not deserialize %s: %v", fileName, err)
}
@ -53,54 +76,41 @@ func (mod *MDNSModule) startAdvertiser(fileName string) error {
if err != nil {
return fmt.Errorf("could not get hostname: %v", err)
}
if !strings.HasSuffix(hostName, ".") {
hostName += "."
}
mod.Info("loaded %d services from %s, advertising with: host=%s ipv4=%s ipv6=%s",
ifName := mod.Session.Interface.Name()
/*
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("error getting interface %s: %v", ifName, err)
}
*/
mod.Info("loaded %d services from %s, advertising with host=%s iface=%s ipv4=%s ipv6=%s",
len(mapping),
fileName,
hostName,
ifName,
mod.Session.Interface.IpAddress,
mod.Session.Interface.Ip6Address)
advertiser := &Advertiser{
Filename: fileName,
Mapping: mapping,
Service: multiService{
services: make([]*MDNSService, 0),
},
Servers: make(map[string]*zeroconf.Server),
}
for _, svcData := range mapping {
svcParts := strings.SplitN(svcData.Name, ".", 2)
svcInstance := svcParts[0]
svcService := strings.Replace(svcParts[1], ".local.", "", 1)
// TODO: patch UUID
service, err := NewMDNSService(
mod,
svcInstance,
svcService,
"local.",
hostName,
svcData.Port,
[]net.IP{
mod.Session.Interface.IP,
mod.Session.Interface.IPv6,
},
svcData.InfoFields)
for key, svc := range mapping {
server, err := zeroconf.Register(svc.Instance, svc.Service, svc.Domain, svc.Port, svc.Text, nil)
if err != nil {
return fmt.Errorf("could not create service %s: %v", svcData.Name, err)
return fmt.Errorf("could not create service %s: %v", svc.Instance, err)
}
advertiser.Service.services = append(advertiser.Service.services, service)
}
mod.Info("advertising service %s", tui.Yellow(svc.Service))
if advertiser.Server, err = NewServer(mod, &Config{Zone: advertiser.Service}); err != nil {
return fmt.Errorf("could not create server: %v", err)
advertiser.Servers[key] = server
}
mod.advertiser = advertiser
@ -117,7 +127,12 @@ func (mod *MDNSModule) stopAdvertiser() error {
mod.Info("stopping %d services ...", len(mod.advertiser.Mapping))
mod.advertiser.Server.Shutdown()
for key, server := range mod.advertiser.Servers {
mod.Info("stopping %s ...", key)
server.Shutdown()
}
mod.Info("all services stopped")
mod.advertiser = nil
return nil

View file

@ -1,31 +1,34 @@
package mdns
import (
"context"
"fmt"
"strings"
"time"
"github.com/bettercap/bettercap/v2/modules/syn_scan"
"github.com/bettercap/bettercap/v2/network"
"github.com/bettercap/bettercap/v2/session"
"github.com/evilsocket/islazy/str"
"github.com/evilsocket/islazy/tui"
"github.com/grandcat/zeroconf"
)
type MDNSModule struct {
session.SessionModule
advertiser *Advertiser
discoChannel chan *ServiceEntry
mapping map[string]map[string]*ServiceEntry
advertiser *Advertiser
rootContext context.Context
rootCancel context.CancelFunc
resolvers map[string]*zeroconf.Resolver
mapping map[string]map[string]*zeroconf.ServiceEntry
}
func NewMDNSModule(s *session.Session) *MDNSModule {
mod := &MDNSModule{
SessionModule: session.NewSessionModule("mdns", s),
discoChannel: make(chan *ServiceEntry),
mapping: make(map[string]map[string]*ServiceEntry),
advertiser: nil,
mapping: make(map[string]map[string]*zeroconf.ServiceEntry),
resolvers: make(map[string]*zeroconf.Resolver),
}
mod.SessionModule.Requires("net.recon")
@ -108,42 +111,47 @@ func (mod *MDNSModule) Configure() (err error) {
return session.ErrAlreadyStarted(mod.Name())
}
if mod.discoChannel != nil {
close(mod.discoChannel)
if mod.rootContext != nil {
mod.rootCancel()
}
mod.discoChannel = make(chan *ServiceEntry)
mod.mapping = make(map[string]map[string]*ServiceEntry)
mod.mapping = make(map[string]map[string]*zeroconf.ServiceEntry)
mod.resolvers = make(map[string]*zeroconf.Resolver)
mod.rootContext, mod.rootCancel = context.WithCancel(context.Background())
return
}
type ServiceDiscoveryEvent struct {
Service ServiceEntry `json:"service"`
Endpoint *network.Endpoint `json:"endpoint"`
Service zeroconf.ServiceEntry `json:"service"`
Endpoint *network.Endpoint `json:"endpoint"`
}
func (mod *MDNSModule) updateEndpointMeta(address string, endpoint *network.Endpoint, svc *ServiceEntry) {
func (mod *MDNSModule) updateEndpointMeta(address string, endpoint *network.Endpoint, svc *zeroconf.ServiceEntry) {
mod.Debug("found endpoint %s for address %s", endpoint.HwAddress, address)
// TODO: this is shit and needs to be refactored
// update mdns metadata
meta := make(map[string]string)
svcType := strings.SplitN(svc.Name, ".", 2)[1]
svcType := svc.Service
meta[fmt.Sprintf("mdns:%s:name", svcType)] = svc.Name
meta[fmt.Sprintf("mdns:%s:hostname", svcType)] = svc.Host
meta[fmt.Sprintf("mdns:%s:name", svcType)] = svc.ServiceName()
meta[fmt.Sprintf("mdns:%s:hostname", svcType)] = svc.HostName
if svc.AddrV4 != nil {
meta[fmt.Sprintf("mdns:%s:ipv4", svcType)] = svc.AddrV4.String()
// TODO: include all
if len(svc.AddrIPv4) > 0 {
meta[fmt.Sprintf("mdns:%s:ipv4", svcType)] = svc.AddrIPv4[0].String()
}
if svc.AddrV6 != nil {
meta[fmt.Sprintf("mdns:%s:ipv6", svcType)] = svc.AddrV6.String()
if len(svc.AddrIPv6) > 0 {
meta[fmt.Sprintf("mdns:%s:ipv6", svcType)] = svc.AddrIPv6[0].String()
}
meta[fmt.Sprintf("mdns:%s:port", svcType)] = fmt.Sprintf("%d", svc.Port)
for _, field := range svc.InfoFields {
for _, field := range svc.Text {
field = str.Trim(field)
if len(field) == 0 {
continue
@ -180,86 +188,122 @@ func (mod *MDNSModule) updateEndpointMeta(address string, endpoint *network.Endp
endpoint.Meta.Set("ports", ports)
}
func (mod *MDNSModule) onServiceDiscovered(svc *ServiceEntry) {
mod.Debug("discovered service %s (%s) [%v / %v]:%d", tui.Green(svc.Name), tui.Dim(svc.Host), svc.AddrV4, svc.AddrV6, svc.Port)
func (mod *MDNSModule) onServiceDiscovered(svc *zeroconf.ServiceEntry) {
mod.Debug("%++v", *svc)
if svc.Service == "_services._dns-sd._udp" && len(svc.AddrIPv4) == 0 && len(svc.AddrIPv6) == 0 {
svcName := strings.Replace(svc.Instance, ".local", "", 1)
if _, found := mod.resolvers[svcName]; !found {
mod.Debug("discovered service %s", tui.Green(svcName))
if err := mod.startResolver(svcName); err != nil {
mod.Error("%v", err)
}
}
return
}
mod.Debug("discovered instance %s (%s) [%v / %v]:%d",
tui.Green(svc.ServiceInstanceName()),
tui.Dim(svc.HostName),
svc.AddrIPv4,
svc.AddrIPv6,
svc.Port)
event := ServiceDiscoveryEvent{
Service: *svc,
Endpoint: nil,
}
addresses := []string{}
if svc.AddrV4 != nil {
addresses = append(addresses, svc.AddrV4.String())
}
if svc.AddrV6 != nil {
addresses = append(addresses, svc.AddrV6.String())
}
addresses := append(svc.AddrIPv4, svc.AddrIPv6...)
for _, address := range addresses {
for _, ip := range addresses {
address := ip.String()
if event.Endpoint = mod.Session.Lan.GetByIp(address); event.Endpoint != nil {
// update endpoint metadata
mod.updateEndpointMeta(address, event.Endpoint, svc)
// update internal module mapping
if ipServices, found := mod.mapping[address]; found {
ipServices[svc.Name] = svc
ipServices[svc.ServiceInstanceName()] = svc
} else {
mod.mapping[address] = map[string]*ServiceEntry{
svc.Name: svc,
mod.mapping[address] = map[string]*zeroconf.ServiceEntry{
svc.ServiceInstanceName(): svc,
}
}
break
} else {
mod.Warning("got mdns entry for unknown ip %s", svc.AddrV4)
}
}
if event.Endpoint == nil {
// TODO: this is probably an IPv6 only record, try to somehow check which known IPv4 it is
mod.Debug("got mdns entry for unknown ip: %++v", *svc)
}
session.I.Events.Add("mdns.service", event)
session.I.Refresh()
}
func (mod *MDNSModule) startResolver(service string) error {
mod.Debug("starting resolver for service %s", tui.Yellow(service))
resolver, err := zeroconf.NewResolver(nil)
if err != nil {
return err
}
// start listening
channel := make(chan *zeroconf.ServiceEntry)
go func() {
for entry := range channel {
mod.onServiceDiscovered(entry)
}
}()
// start browsing
go func() {
err = resolver.Browse(mod.rootContext, service, "local.", channel)
if err != nil {
mod.Error("%v", err)
}
mod.Debug("resolver for service %s stopped", tui.Yellow(service))
}()
mod.resolvers[service] = resolver
return nil
}
func (mod *MDNSModule) Start() (err error) {
if err = mod.Configure(); err != nil {
return err
}
// start the discovery
service := "_services._dns-sd._udp"
params := DefaultParams(service)
params.Module = mod
params.Service = service
params.Domain = "local"
params.Entries = mod.discoChannel
params.DisableIPv6 = true // https://github.com/hashicorp/mdns/issues/35
params.Timeout = time.Duration(10) * time.Minute
go func() {
mod.Info("starting query routine ...")
if err := Query(params); err != nil {
mod.Error("service discovery query: %v", err)
}
mod.Info("stopping query routine ...")
}()
// start the root discovery
if err = mod.startResolver("_services._dns-sd._udp"); err != nil {
return err
}
return mod.SetRunning(true, func() {
mod.Info("mDNS service discovery started")
mod.Info("service discovery started")
for entry := range mod.discoChannel {
mod.onServiceDiscovered(entry)
}
<-mod.rootContext.Done()
mod.Info("mDNS service discovery stopped")
mod.Info("service discovery stopped")
})
}
func (mod *MDNSModule) Stop() error {
return mod.SetRunning(false, func() {
if mod.discoChannel != nil {
mod.Info("closing mDNS discovery channel")
close(mod.discoChannel)
mod.discoChannel = nil
if mod.rootCancel != nil {
mod.Debug("stopping mDNS discovery")
mod.rootCancel()
<-mod.rootContext.Done()
mod.Debug("stopped")
mod.rootContext = nil
mod.rootCancel = nil
}
})
}

View file

@ -6,11 +6,12 @@ import (
"github.com/evilsocket/islazy/str"
"github.com/evilsocket/islazy/tui"
"github.com/grandcat/zeroconf"
)
type entry struct {
ip string
services map[string]*ServiceEntry
services map[string]*zeroconf.ServiceEntry
}
func (mod *MDNSModule) show(filter string, withData bool) error {
@ -30,7 +31,7 @@ func (mod *MDNSModule) show(filter string, withData bool) error {
for _, entry := range entries {
if endpoint := mod.Session.Lan.GetByIp(entry.ip); endpoint != nil {
fmt.Fprintf(mod.Session.Events.Stdout, "* %s (%s)\n", endpoint.IpAddress, tui.Dim(endpoint.Vendor))
fmt.Fprintf(mod.Session.Events.Stdout, "* %s (%s)\n", tui.Bold(endpoint.IpAddress), tui.Dim(endpoint.Vendor))
} else {
fmt.Fprintf(mod.Session.Events.Stdout, "* %s\n", tui.Bold(entry.ip))
}
@ -38,16 +39,16 @@ func (mod *MDNSModule) show(filter string, withData bool) error {
for name, svc := range entry.services {
fmt.Fprintf(mod.Session.Events.Stdout, " %s (%s) [%v / %v]:%s\n",
tui.Green(name),
tui.Dim(svc.Host),
svc.AddrV4,
svc.AddrV6,
tui.Dim(svc.HostName),
svc.AddrIPv4,
svc.AddrIPv6,
tui.Red(fmt.Sprintf("%d", svc.Port)),
)
numFields := len(svc.InfoFields)
numFields := len(svc.Text)
if withData {
if numFields > 0 {
for _, field := range svc.InfoFields {
for _, field := range svc.Text {
if field = str.Trim(field); len(field) > 0 {
fmt.Fprintf(mod.Session.Events.Stdout, " %s\n", field)
}

View file

@ -1,306 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MIT
package mdns
import (
"fmt"
"net"
"strings"
"sync/atomic"
"github.com/miekg/dns"
)
const (
ipv4mdns = "224.0.0.251"
ipv6mdns = "ff02::fb"
mdnsPort = 5353
forceUnicastResponses = false
)
var (
ipv4Addr = &net.UDPAddr{
IP: net.ParseIP(ipv4mdns),
Port: mdnsPort,
}
ipv6Addr = &net.UDPAddr{
IP: net.ParseIP(ipv6mdns),
Port: mdnsPort,
}
)
// Config is used to configure the mDNS server
type Config struct {
// Zone must be provided to support responding to queries
Zone Zone
// Iface if provided binds the multicast listener to the given
// interface. If not provided, the system default multicase interface
// is used.
Iface *net.Interface
// LogEmptyResponses indicates the server should print an informative message
// when there is an mDNS query for which the server has no response.
LogEmptyResponses bool
}
// mDNS server is used to listen for mDNS queries and respond if we
// have a matching local record
type Server struct {
mod *MDNSModule
config *Config
ipv4List *net.UDPConn
ipv6List *net.UDPConn
shutdown int32
shutdownCh chan struct{}
}
// NewServer is used to create a new mDNS server from a config
func NewServer(mod *MDNSModule, config *Config) (*Server, error) {
// Create the listeners
ipv4List, err := net.ListenMulticastUDP("udp4", config.Iface, ipv4Addr)
if err != nil {
return nil, err
}
ipv6List, _ := net.ListenMulticastUDP("udp6", config.Iface, ipv6Addr)
// Check if we have any listener
if ipv4List == nil && ipv6List == nil {
return nil, fmt.Errorf("no multicast listeners could be started")
}
s := &Server{
mod: mod,
config: config,
ipv4List: ipv4List,
ipv6List: ipv6List,
shutdownCh: make(chan struct{}),
}
if ipv4List != nil {
mod.Info("starting ipv4 receiver for %v", s.ipv4List)
go s.recv(s.ipv4List)
}
if ipv6List != nil {
mod.Info("starting ipv6 receiver for %v", s.ipv6List)
go s.recv(s.ipv6List)
}
return s, nil
}
// Shutdown is used to shutdown the listener
func (s *Server) Shutdown() error {
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
// something else already closed us
return nil
}
close(s.shutdownCh)
if s.ipv4List != nil {
s.ipv4List.Close()
}
if s.ipv6List != nil {
s.ipv6List.Close()
}
return nil
}
// recv is a long running routine to receive packets from an interface
func (s *Server) recv(c *net.UDPConn) {
if c == nil {
return
}
buf := make([]byte, 65536)
for atomic.LoadInt32(&s.shutdown) == 0 {
s.mod.Debug("receiving from %v ...", c)
n, from, err := c.ReadFrom(buf)
if err != nil {
s.mod.Error("error while receiving datagram: %v", err)
continue
}
if err := s.parsePacket(buf[:n], from); err != nil {
s.mod.Debug("failed to handle query: %v", err)
}
}
}
// parsePacket is used to parse an incoming packet
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
var msg dns.Msg
if err := msg.Unpack(packet); err != nil {
s.mod.Error("failed to unpack packet: %v", err)
return err
}
return s.handleQuery(&msg, from)
}
// handleQuery is used to handle an incoming query
func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
if query.Opcode != dns.OpcodeQuery {
// "In both multicast query and multicast response messages, the OPCODE MUST
// be zero on transmission (only standard queries are currently supported
// over multicast). Multicast DNS messages received with an OPCODE other
// than zero MUST be silently ignored." Note: OpcodeQuery == 0
return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query)
}
if query.Rcode != 0 {
// "In both multicast query and multicast response messages, the Response
// Code MUST be zero on transmission. Multicast DNS messages received with
// non-zero Response Codes MUST be silently ignored."
return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query)
}
// TODO(reddaly): Handle "TC (Truncated) Bit":
// In query messages, if the TC bit is set, it means that additional
// Known-Answer records may be following shortly. A responder SHOULD
// record this fact, and wait for those additional Known-Answer records,
// before deciding whether to respond. If the TC bit is clear, it means
// that the querying host has no additional Known Answers.
if query.Truncated {
return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query)
}
var unicastAnswer, multicastAnswer []dns.RR
// Handle each question
for _, q := range query.Question {
mrecs, urecs := s.handleQuestion(q)
multicastAnswer = append(multicastAnswer, mrecs...)
unicastAnswer = append(unicastAnswer, urecs...)
}
// See section 18 of RFC 6762 for rules about DNS headers.
resp := func(unicast bool) *dns.Msg {
// 18.1: ID (Query Identifier)
// 0 for multicast response, query.Id for unicast response
id := uint16(0)
if unicast {
id = query.Id
}
var answer []dns.RR
if unicast {
answer = unicastAnswer
} else {
answer = multicastAnswer
}
if len(answer) == 0 {
return nil
}
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: id,
// 18.2: QR (Query/Response) Bit - must be set to 1 in response.
Response: true,
// 18.3: OPCODE - must be zero in response (OpcodeQuery == 0)
Opcode: dns.OpcodeQuery,
// 18.4: AA (Authoritative Answer) Bit - must be set to 1
Authoritative: true,
// The following fields must all be set to 0:
// 18.5: TC (TRUNCATED) Bit
// 18.6: RD (Recursion Desired) Bit
// 18.7: RA (Recursion Available) Bit
// 18.8: Z (Zero) Bit
// 18.9: AD (Authentic Data) Bit
// 18.10: CD (Checking Disabled) Bit
// 18.11: RCODE (Response Code)
},
// 18.12 pertains to questions (handled by handleQuestion)
// 18.13 pertains to resource records (handled by handleQuestion)
// 18.14: Name Compression - responses should be compressed (though see
// caveats in the RFC), so set the Compress bit (part of the dns library
// API, not part of the DNS packet) to true.
Compress: true,
Answer: answer,
}
}
if s.config.LogEmptyResponses && len(multicastAnswer) == 0 && len(unicastAnswer) == 0 {
questions := make([]string, len(query.Question))
for i, q := range query.Question {
questions[i] = q.Name
}
s.mod.Warning("no responses for query with questions: %s", strings.Join(questions, ", "))
}
if mresp := resp(false); mresp != nil {
if err := s.sendResponse(mresp, from, false); err != nil {
return fmt.Errorf("mdns: error sending multicast response: %v", err)
}
}
if uresp := resp(true); uresp != nil {
if err := s.sendResponse(uresp, from, true); err != nil {
return fmt.Errorf("mdns: error sending unicast response: %v", err)
}
}
return nil
}
// handleQuestion is used to handle an incoming question
//
// The response to a question may be transmitted over multicast, unicast, or
// both. The return values are DNS records for each transmission type.
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
records := s.config.Zone.Records(q)
if len(records) == 0 {
return nil, nil
}
s.mod.Info("%+v :", q)
for _, rec := range records {
s.mod.Info(" %+v", rec)
}
// Handle unicast and multicast responses.
// TODO(reddaly): The decision about sending over unicast vs. multicast is not
// yet fully compliant with RFC 6762. For example, the unicast bit should be
// ignored if the records in question are close to TTL expiration. For now,
// we just use the unicast bit to make the decision, as per the spec:
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
// Section
//
// In the Question Section of a Multicast DNS query, the top bit of the
// qclass field is used to indicate that unicast responses are preferred
// for this particular question. (See Section 5.4.)
if q.Qclass&(1<<15) != 0 || forceUnicastResponses {
return nil, records
}
return records, nil
}
// sendResponse is used to send a response packet
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr, unicast bool) error {
s.mod.Debug("sending response=%v from=%v", *resp, from)
// TODO(reddaly): Respect the unicast argument, and allow sending responses
// over multicast.
buf, err := resp.Pack()
if err != nil {
return err
}
// Determine the socket to send from
addr := from.(*net.UDPAddr)
if addr.IP.To4() != nil {
_, err = s.ipv4List.WriteToUDP(buf, addr)
return err
} else {
_, err = s.ipv6List.WriteToUDP(buf, addr)
return err
}
}

View file

@ -1,317 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MIT
package mdns
import (
"fmt"
"net"
"os"
"strings"
"github.com/miekg/dns"
)
const (
// defaultTTL is the default TTL value in returned DNS records in seconds.
defaultTTL = 120
)
// Zone is the interface used to integrate with the server and
// to serve records dynamically
type Zone interface {
// Records returns DNS records in response to a DNS question.
Records(q dns.Question) []dns.RR
}
// MDNSService is used to export a named service by implementing a Zone
type MDNSService struct {
mod *MDNSModule
Instance string // Instance name (e.g. "hostService name")
Service string // Service name (e.g. "_http._tcp.")
Domain string // If blank, assumes "local"
HostName string // Host machine DNS name (e.g. "mymachine.net.")
Port int // Service Port
IPs []net.IP // IP addresses for the service's host
TXT []string // Service TXT records
serviceAddr string // Fully qualified service address
instanceAddr string // Fully qualified instance address
enumAddr string // _services._dns-sd._udp.<domain>
}
// validateFQDN returns an error if the passed string is not a fully qualified
// hdomain name (more specifically, a hostname).
func validateFQDN(s string) error {
if len(s) == 0 {
return fmt.Errorf("FQDN must not be blank")
}
if s[len(s)-1] != '.' {
return fmt.Errorf("FQDN must end in period: %s", s)
}
// TODO(reddaly): Perform full validation.
return nil
}
// NewMDNSService returns a new instance of MDNSService.
//
// If domain, hostName, or ips is set to the zero value, then a default value
// will be inferred from the operating system.
//
// TODO(reddaly): This interface may need to change to account for "unique
// record" conflict rules of the mDNS protocol. Upon startup, the server should
// check to ensure that the instance name does not conflict with other instance
// names, and, if required, select a new name. There may also be conflicting
// hostName A/AAAA records.
func NewMDNSService(mod *MDNSModule, instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) {
// Sanity check inputs
if instance == "" {
return nil, fmt.Errorf("missing service instance name")
}
if service == "" {
return nil, fmt.Errorf("missing service name")
}
if port == 0 {
return nil, fmt.Errorf("missing service port")
}
// Set default domain
if domain == "" {
domain = "local."
}
if err := validateFQDN(domain); err != nil {
return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err)
}
// Get host information if no host is specified.
if hostName == "" {
var err error
hostName, err = os.Hostname()
if err != nil {
return nil, fmt.Errorf("could not determine host: %v", err)
}
hostName = fmt.Sprintf("%s.", hostName)
}
if err := validateFQDN(hostName); err != nil {
return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err)
}
if len(ips) == 0 {
var err error
ips, err = net.LookupIP(hostName)
if err != nil {
// Try appending the host domain suffix and lookup again
// (required for Linux-based hosts)
tmpHostName := fmt.Sprintf("%s%s", hostName, domain)
ips, err = net.LookupIP(tmpHostName)
if err != nil {
return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName)
}
}
}
for _, ip := range ips {
if ip.To4() == nil && ip.To16() == nil {
return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip)
}
}
mod.Debug("serviceAddr=%s.%s.", trimDot(service), trimDot(domain))
mod.Debug("instanceAddr=%s.%s.%s.", instance, trimDot(service), trimDot(domain))
mod.Debug("enumAddr=_services._dns-sd._udp.%s.", trimDot(domain))
return &MDNSService{
mod: mod,
Instance: instance,
Service: service,
Domain: domain,
HostName: hostName,
Port: port,
IPs: ips,
TXT: txt,
serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)),
instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)),
enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)),
}, nil
}
// trimDot is used to trim the dots from the start or end of a string
func trimDot(s string) string {
return strings.Trim(s, ".")
}
// Records returns DNS records in response to a DNS question.
func (m *MDNSService) Records(q dns.Question) []dns.RR {
switch q.Name {
case m.enumAddr:
return m.serviceEnum(q)
case m.serviceAddr:
return m.serviceRecords(q)
case m.instanceAddr:
return m.instanceRecords(q)
case m.HostName:
if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA {
return m.instanceRecords(q)
}
fallthrough
default:
return nil
}
}
func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR {
switch q.Qtype {
case dns.TypeANY:
fallthrough
case dns.TypePTR:
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Ptr: m.serviceAddr,
}
return []dns.RR{rr}
default:
return nil
}
}
// serviceRecords is called when the query matches the service name
func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
switch q.Qtype {
case dns.TypeANY:
fallthrough
case dns.TypePTR:
// Build a PTR response for the service
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Ptr: m.instanceAddr,
}
servRec := []dns.RR{rr}
// Get the instance records
instRecs := m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeANY,
})
// Return the service record with the instance records
return append(servRec, instRecs...)
default:
return nil
}
}
// serviceRecords is called when the query matches the instance name
func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
switch q.Qtype {
case dns.TypeANY:
// Get the SRV, which includes A and AAAA
recs := m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeSRV,
})
// Add the TXT record
recs = append(recs, m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeTXT,
})...)
return recs
case dns.TypeA:
var rr []dns.RR
for _, ip := range m.IPs {
if ip4 := ip.To4(); ip4 != nil {
rr = append(rr, &dns.A{
Hdr: dns.RR_Header{
Name: m.HostName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
A: ip4,
})
}
}
return rr
case dns.TypeAAAA:
var rr []dns.RR
for _, ip := range m.IPs {
if ip.To4() != nil {
// TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and
// putinto AAAA records, but the current logic puts ipv4-encodable
// addresses into the A records exclusively. Perhaps this should be
// configurable?
continue
}
if ip16 := ip.To16(); ip16 != nil {
rr = append(rr, &dns.AAAA{
Hdr: dns.RR_Header{
Name: m.HostName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
AAAA: ip16,
})
}
}
return rr
case dns.TypeSRV:
// Create the SRV Record
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Priority: 10,
Weight: 1,
Port: uint16(m.Port),
Target: m.HostName,
}
recs := []dns.RR{srv}
// Add the A record
recs = append(recs, m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeA,
})...)
// Add the AAAA record
recs = append(recs, m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeAAAA,
})...)
return recs
case dns.TypeTXT:
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Txt: m.TXT,
}
return []dns.RR{txt}
}
return nil
}