diff --git a/modules/net_probe/net_probe_mdns.go b/modules/net_probe/net_probe_mdns.go index ce830731..cb863e1a 100644 --- a/modules/net_probe/net_probe_mdns.go +++ b/modules/net_probe/net_probe_mdns.go @@ -5,8 +5,6 @@ import ( "io/ioutil" "log" "net" - "sync" - "time" "github.com/bettercap/bettercap/packets" @@ -38,58 +36,50 @@ func (mod *Prober) sendProbeMDNS(from net.IP, from_hw net.HardwareAddr) { } } -func (mod *Prober) mdnsProber() { - mod.waitGroup.Add(1) - defer mod.waitGroup.Done() +func (mod *Prober) mdnsListener(c chan *mdns.ServiceEntry) { + mod.Debug("mdns listener started") + defer mod.Debug("mdns listener stopped") + for entry := range c { + if host := mod.Session.Lan.GetByIp(entry.AddrV4.String()); host != nil { + meta := make(map[string]string) + + meta["mdns:name"] = entry.Name + meta["mdns:hostname"] = entry.Host + meta["mdns:ipv4"] = entry.AddrV4.String() + + if entry.AddrV6 != nil { + meta["mdns:ipv6"] = entry.AddrV6.String() + } + + meta["mdns:port"] = fmt.Sprintf("%d", entry.Port) + + host.OnMeta(meta) + } else { + mod.Debug("got mdns entry for known ip %s", entry.AddrV4) + } + } +} + +func (mod *Prober) mdnsProber() { mod.Debug("mdns prober started") defer mod.Debug("mdns.prober stopped") + mod.waitGroup.Add(1) + defer mod.waitGroup.Done() + log.SetOutput(ioutil.Discard) ch := make(chan *mdns.ServiceEntry) - wait := sync.WaitGroup{} - defer close(ch) - go func(c chan *mdns.ServiceEntry) { - mod.Debug("mdns channel read started") - defer mod.Debug("mdns channel read stopped") - - for entry := range c { - if host := mod.Session.Lan.GetByIp(entry.AddrV4.String()); host != nil { - meta := make(map[string]string) - - meta["mdns:name"] = entry.Name - meta["mdns:hostname"] = entry.Host - meta["mdns:ipv4"] = entry.AddrV4.String() - - if entry.AddrV6 != nil { - meta["mdns:ipv6"] = entry.AddrV6.String() - } - - meta["mdns:port"] = fmt.Sprintf("%d", entry.Port) - - host.OnMeta(meta) - } else { - mod.Debug("got mdns entry for known ip %s", entry.AddrV4) - } - } - }(ch) + go mod.mdnsListener(ch) for mod.Running() { for _, svc := range services { - go func(svc string, w *sync.WaitGroup) { - w.Add(1) - defer w.Done() - - params := mdns.DefaultParams(svc) - params.Entries = ch - params.Timeout = time.Duration(5) * time.Second - - mdns.Query(params) - }(svc, &wait) + if mod.Running() { + mdns.Lookup(svc, ch) + } } - wait.Wait() } }