diff --git a/modules/net_recon/net_show.go b/modules/net_recon/net_show.go index ff793fdb..630f76cf 100644 --- a/modules/net_recon/net_show.go +++ b/modules/net_recon/net_show.go @@ -61,8 +61,11 @@ func (mod *Discovery) getRow(e *network.Endpoint, withMeta bool) [][]string { var traffic *packets.Traffic var found bool - if traffic, found = mod.Session.Queue.Traffic[e.IpAddress]; !found { + var v interface{} + if v, found = mod.Session.Queue.Traffic.Load(e.IpAddress); !found { traffic = &packets.Traffic{} + } else { + traffic = v.(*packets.Traffic) } seen := e.LastSeen.Format("15:04:05") @@ -203,9 +206,6 @@ func (mod *Discovery) colNames(hasMeta bool) []string { } func (mod *Discovery) showStatusBar() { - mod.Session.Queue.Stats.RLock() - defer mod.Session.Queue.Stats.RUnlock() - parts := []string{ fmt.Sprintf("%s %s", tui.Red("↑"), humanize.Bytes(mod.Session.Queue.Stats.Sent)), fmt.Sprintf("%s %s", tui.Green("↓"), humanize.Bytes(mod.Session.Queue.Stats.Received)), diff --git a/modules/net_recon/net_show_sort.go b/modules/net_recon/net_show_sort.go index 10ad08e5..2109543f 100644 --- a/modules/net_recon/net_show_sort.go +++ b/modules/net_recon/net_show_sort.go @@ -41,24 +41,19 @@ func (a BySeenSorter) Less(i, j int) bool { return a[i].LastSeen.Before(a[j].Las type BySentSorter []*network.Endpoint +func trafficOf(ip string) *packets.Traffic { + if v, found := session.I.Queue.Traffic.Load(ip); !found { + return &packets.Traffic{} + } else { + return v.(*packets.Traffic) + } +} + func (a BySentSorter) Len() int { return len(a) } func (a BySentSorter) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a BySentSorter) Less(i, j int) bool { - session.I.Queue.Lock() - defer session.I.Queue.Unlock() - - var found bool = false - var aTraffic *packets.Traffic = nil - var bTraffic *packets.Traffic = nil - - if aTraffic, found = session.I.Queue.Traffic[a[i].IpAddress]; !found { - aTraffic = &packets.Traffic{} - } - - if bTraffic, found = session.I.Queue.Traffic[a[j].IpAddress]; !found { - bTraffic = &packets.Traffic{} - } - + aTraffic := trafficOf(a[i].IpAddress) + bTraffic := trafficOf(a[j].IpAddress) return bTraffic.Sent > aTraffic.Sent } @@ -67,20 +62,7 @@ type ByRcvdSorter []*network.Endpoint func (a ByRcvdSorter) Len() int { return len(a) } func (a ByRcvdSorter) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a ByRcvdSorter) Less(i, j int) bool { - session.I.Queue.Lock() - defer session.I.Queue.Unlock() - - var found bool = false - var aTraffic *packets.Traffic = nil - var bTraffic *packets.Traffic = nil - - if aTraffic, found = session.I.Queue.Traffic[a[i].IpAddress]; !found { - aTraffic = &packets.Traffic{} - } - - if bTraffic, found = session.I.Queue.Traffic[a[j].IpAddress]; !found { - bTraffic = &packets.Traffic{} - } - + aTraffic := trafficOf(a[i].IpAddress) + bTraffic := trafficOf(a[j].IpAddress) return bTraffic.Received > aTraffic.Received } diff --git a/modules/wifi/wifi_show.go b/modules/wifi/wifi_show.go index 8e45403f..6517c6b2 100644 --- a/modules/wifi/wifi_show.go +++ b/modules/wifi/wifi_show.go @@ -296,9 +296,6 @@ func (mod *WiFiModule) colNames(nrows int) []string { } func (mod *WiFiModule) showStatusBar() { - mod.Session.Queue.Stats.RLock() - defer mod.Session.Queue.Stats.RUnlock() - parts := []string{ fmt.Sprintf("%s (ch. %d)", mod.iface.Name(), network.GetInterfaceChannel(mod.iface.Name())), fmt.Sprintf("%s %s", tui.Red("↑"), humanize.Bytes(mod.Session.Queue.Stats.Sent)), diff --git a/packets/queue.go b/packets/queue.go index 94aef1a1..a3a1b44e 100644 --- a/packets/queue.go +++ b/packets/queue.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "github.com/bettercap/bettercap/network" @@ -25,8 +26,6 @@ type Traffic struct { } type Stats struct { - sync.RWMutex - Sent uint64 Received uint64 PktReceived uint64 @@ -39,10 +38,9 @@ type Queue struct { sync.RWMutex Activities chan Activity `json:"-"` - - Stats Stats - Protos map[string]uint64 - Traffic map[string]*Traffic + Stats Stats + Protos sync.Map + Traffic sync.Map iface *network.Endpoint handle *pcap.Handle @@ -55,8 +53,8 @@ type Queue struct { func NewQueue(iface *network.Endpoint) (q *Queue, err error) { q = &Queue{ - Protos: make(map[string]uint64), - Traffic: make(map[string]*Traffic), + Protos: sync.Map{}, + Traffic: sync.Map{}, Activities: make(chan Activity), writes: &sync.WaitGroup{}, @@ -102,14 +100,12 @@ func (q *Queue) trackProtocols(pkt gopacket.Packet) { continue } - q.Lock() name := proto.String() - if _, found := q.Protos[name]; !found { - q.Protos[name] = 1 + if v, found := q.Protos.Load(name); !found { + q.Protos.Store(name, 1) } else { - q.Protos[name]++ + q.Protos.Store(name, v.(int)+1) } - q.Unlock() } } @@ -122,46 +118,34 @@ func (q *Queue) trackActivity(eth *layers.Ethernet, ip4 *layers.IPv4, address ne Source: isSent, } - q.Lock() - defer q.Unlock() - // initialize or update stats addr := address.String() - if _, found := q.Traffic[addr]; !found { + if v, found := q.Traffic.Load(addr); !found { if isSent { - q.Traffic[addr] = &Traffic{Sent: pktSize} + q.Traffic.Store(addr, &Traffic{Sent: pktSize}) } else { - q.Traffic[addr] = &Traffic{Received: pktSize} + q.Traffic.Store(addr, &Traffic{Received: pktSize}) } } else { if isSent { - q.Traffic[addr].Sent += pktSize + v.(*Traffic).Sent += pktSize } else { - q.Traffic[addr].Received += pktSize + v.(*Traffic).Received += pktSize } } } func (q *Queue) TrackPacket(size uint64) { - q.Stats.Lock() - defer q.Stats.Unlock() - - q.Stats.PktReceived++ - q.Stats.Received += size + atomic.AddUint64(&q.Stats.PktReceived, 1) + atomic.AddUint64(&q.Stats.Received, size) } func (q *Queue) TrackSent(size uint64) { - q.Stats.Lock() - defer q.Stats.Unlock() - - q.Stats.Sent += size + atomic.AddUint64(&q.Stats.Sent, size) } func (q *Queue) TrackError() { - q.Stats.Lock() - defer q.Stats.Unlock() - - q.Stats.Errors++ + atomic.AddUint64(&q.Stats.Errors, 1) } func (q *Queue) getPacketMeta(pkt gopacket.Packet) map[string]string {