diff --git a/modules/dns_spoof.go b/modules/dns_spoof.go index f870cc59..c70def22 100644 --- a/modules/dns_spoof.go +++ b/modules/dns_spoof.go @@ -14,15 +14,12 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcap" - - "github.com/gobwas/glob" ) type DNSSpoofer struct { session.SessionModule Handle *pcap.Handle - Domains []glob.Glob - Address net.IP + Hosts Hosts All bool waitGroup *sync.WaitGroup pktSourceChan chan gopacket.Packet @@ -33,13 +30,18 @@ func NewDNSSpoofer(s *session.Session) *DNSSpoofer { SessionModule: session.NewSessionModule("dns.spoof", s), Handle: nil, All: false, - Domains: make([]glob.Glob, 0), + Hosts: Hosts{}, waitGroup: &sync.WaitGroup{}, } + spoof.AddParam(session.NewStringParameter("dns.spoof.hosts", + "", + "", + "If not empty, this hosts file will be used to map domains to IP addresses.")) + spoof.AddParam(session.NewStringParameter("dns.spoof.domains", - "*", - ``, + "", + "", "Comma separated values of domain names to spoof.")) spoof.AddParam(session.NewStringParameter("dns.spoof.address", @@ -80,43 +82,46 @@ func (s DNSSpoofer) Author() string { func (s *DNSSpoofer) Configure() error { var err error - var addr string + var hostsFile string var domains []string + var address net.IP if s.Running() { return session.ErrAlreadyStarted - } - - if s.Handle, err = pcap.OpenLive(s.Session.Interface.Name(), 65536, true, pcap.BlockForever); err != nil { + } else if s.Handle, err = pcap.OpenLive(s.Session.Interface.Name(), 65536, true, pcap.BlockForever); err != nil { return err - } - - err = s.Handle.SetBPFFilter("udp") - if err != nil { + } else if err = s.Handle.SetBPFFilter("udp"); err != nil { return err - } - - if err, s.All = s.BoolParam("dns.spoof.all"); err != nil { + } else if err, s.All = s.BoolParam("dns.spoof.all"); err != nil { return err - } - - if err, domains = s.ListParam("dns.spoof.domains"); err != nil { + } else if err, address = s.IPParam("dns.spoof.address"); err != nil { + return err + } else if err, domains = s.ListParam("dns.spoof.domains"); err != nil { + return err + } else if err, hostsFile = s.StringParam("dns.spoof.hosts"); err != nil { return err } for _, domain := range domains { - if expr, err := glob.Compile(domain); err != nil { - return fmt.Errorf("'%s' is not a valid domain glob expression: %s", domain, err) + s.Hosts = append(s.Hosts, NewHostEntry(domain, address)) + } + + if hostsFile != "" { + log.Info("loading hosts from file %s ...", hostsFile) + if err, hosts := HostsFromFile(hostsFile); err != nil { + return fmt.Errorf("error reading hosts from file %s: %v", hostsFile, err) } else { - s.Domains = append(s.Domains, expr) + s.Hosts = append(s.Hosts, hosts...) } } - if err, addr = s.StringParam("dns.spoof.address"); err != nil { - return err + if len(s.Hosts) == 0 { + return fmt.Errorf("at least dns.spoof.hosts or dns.spoof.domains must be filled") } - s.Address = net.ParseIP(addr) + for _, entry := range s.Hosts { + log.Info("[%s] %s -> %s", core.Green("dns.spoof"), entry.Host, entry.Address) + } if !s.Session.Firewall.IsForwardingEnabled() { log.Info("Enabling forwarding.") @@ -126,8 +131,8 @@ func (s *DNSSpoofer) Configure() error { return nil } -func (s *DNSSpoofer) dnsReply(pkt gopacket.Packet, peth *layers.Ethernet, pudp *layers.UDP, domain string, req *layers.DNS, target net.HardwareAddr) { - redir := fmt.Sprintf("(->%s)", s.Address) +func (s *DNSSpoofer) dnsReply(pkt gopacket.Packet, peth *layers.Ethernet, pudp *layers.UDP, domain string, address net.IP, req *layers.DNS, target net.HardwareAddr) { + redir := fmt.Sprintf("(->%s)", address.String()) who := target.String() if t, found := s.Session.Lan.Get(target.String()); found { @@ -177,7 +182,7 @@ func (s *DNSSpoofer) dnsReply(pkt gopacket.Packet, peth *layers.Ethernet, pudp * Type: q.Type, Class: q.Class, TTL: 1024, - IP: s.Address, + IP: address, }) } @@ -242,15 +247,6 @@ func (s *DNSSpoofer) dnsReply(pkt gopacket.Packet, peth *layers.Ethernet, pudp * } } -func (s *DNSSpoofer) shouldSpoof(domain string) bool { - for _, expr := range s.Domains { - if expr.Match(domain) { - return true - } - } - return false -} - func (s *DNSSpoofer) onPacket(pkt gopacket.Packet) { typeEth := pkt.Layer(layers.LayerTypeEthernet) typeUDP := pkt.Layer(layers.LayerTypeUDP) @@ -265,8 +261,8 @@ func (s *DNSSpoofer) onPacket(pkt gopacket.Packet) { udp := typeUDP.(*layers.UDP) for _, q := range dns.Questions { qName := string(q.Name) - if s.shouldSpoof(qName) { - s.dnsReply(pkt, eth, udp, qName, dns, eth.SrcMAC) + if address := s.Hosts.Resolve(qName); address != nil { + s.dnsReply(pkt, eth, udp, qName, address, dns, eth.SrcMAC) break } else { log.Debug("skipping domain %s", qName) diff --git a/modules/dns_spoof_hosts.go b/modules/dns_spoof_hosts.go new file mode 100644 index 00000000..8e8ac198 --- /dev/null +++ b/modules/dns_spoof_hosts.go @@ -0,0 +1,83 @@ +package modules + +import ( + "bufio" + "fmt" + "net" + "os" + "regexp" + "strings" + + "github.com/bettercap/bettercap/core" + + "github.com/gobwas/glob" +) + +var hostsSplitter = regexp.MustCompile(`\s+`) + +type HostEntry struct { + Host string + Suffix string + Expr glob.Glob + Address net.IP +} + +func (e HostEntry) Matches(host string) bool { + return e.Host == host || strings.HasSuffix(host, e.Suffix) || (e.Expr != nil && e.Expr.Match(host)) +} + +type Hosts []HostEntry + +func NewHostEntry(host string, address net.IP) HostEntry { + entry := HostEntry{ + Host: host, + Address: address, + } + + if host[0] == '.' { + entry.Suffix = host + } else { + entry.Suffix = "." + host + } + + if expr, err := glob.Compile(host); err == nil { + entry.Expr = expr + } + + return entry +} + +func HostsFromFile(filename string) (err error, entries []HostEntry) { + input, err := os.Open(filename) + if err != nil { + return + } + defer input.Close() + + scanner := bufio.NewScanner(input) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + line := core.Trim(scanner.Text()) + if line == "" || line[0] == '#' { + continue + } + if parts := hostsSplitter.Split(line, 2); len(parts) == 2 { + address := net.ParseIP(parts[0]) + domain := parts[1] + entries = append(entries, NewHostEntry(domain, address)) + } else { + return fmt.Errorf("'%s' invalid hosts line", line), nil + } + } + + return +} + +func (h Hosts) Resolve(host string) net.IP { + for _, entry := range h { + if entry.Matches(host) { + return entry.Address + } + } + return nil +} diff --git a/session/events.go b/session/events.go index 6de9defa..5b424344 100644 --- a/session/events.go +++ b/session/events.go @@ -62,8 +62,8 @@ func (p *EventPool) Listen() <-chan Event { // make sure, without blocking, the new listener // will receive all the queued events go func() { - for _, e := range p.events { - l <- e + for i := len(p.events) - 1; i >= 0; i-- { + l <- p.events[i] } }() diff --git a/session/module.go b/session/module.go index a7006d7b..b57b7b59 100644 --- a/session/module.go +++ b/session/module.go @@ -2,6 +2,7 @@ package session import ( "fmt" + "net" "strings" "sync" "time" @@ -86,6 +87,14 @@ func (m SessionModule) StringParam(name string) (error, string) { } } +func (m SessionModule) IPParam(name string) (error, net.IP) { + if err, v := m.StringParam(name); err != nil { + return err, nil + } else { + return nil, net.ParseIP(v) + } +} + func (m SessionModule) IntParam(name string) (error, int) { if p, found := m.params[name]; found { if err, v := p.Get(m.Session); err != nil {