fix: using sync.Map to avoid race conditions on the packets.Queue

This commit is contained in:
evilsocket 2019-03-17 13:12:31 +01:00
parent b676d68b4c
commit 64a5ce2b58
No known key found for this signature in database
GPG key ID: 1564D7F30393A456
4 changed files with 34 additions and 71 deletions

View file

@ -4,6 +4,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"github.com/bettercap/bettercap/network"
@ -25,8 +26,6 @@ type Traffic struct {
}
type Stats struct {
sync.RWMutex
Sent uint64
Received uint64
PktReceived uint64
@ -39,10 +38,9 @@ type Queue struct {
sync.RWMutex
Activities chan Activity `json:"-"`
Stats Stats
Protos map[string]uint64
Traffic map[string]*Traffic
Stats Stats
Protos sync.Map
Traffic sync.Map
iface *network.Endpoint
handle *pcap.Handle
@ -55,8 +53,8 @@ type Queue struct {
func NewQueue(iface *network.Endpoint) (q *Queue, err error) {
q = &Queue{
Protos: make(map[string]uint64),
Traffic: make(map[string]*Traffic),
Protos: sync.Map{},
Traffic: sync.Map{},
Activities: make(chan Activity),
writes: &sync.WaitGroup{},
@ -102,14 +100,12 @@ func (q *Queue) trackProtocols(pkt gopacket.Packet) {
continue
}
q.Lock()
name := proto.String()
if _, found := q.Protos[name]; !found {
q.Protos[name] = 1
if v, found := q.Protos.Load(name); !found {
q.Protos.Store(name, 1)
} else {
q.Protos[name]++
q.Protos.Store(name, v.(int)+1)
}
q.Unlock()
}
}
@ -122,46 +118,34 @@ func (q *Queue) trackActivity(eth *layers.Ethernet, ip4 *layers.IPv4, address ne
Source: isSent,
}
q.Lock()
defer q.Unlock()
// initialize or update stats
addr := address.String()
if _, found := q.Traffic[addr]; !found {
if v, found := q.Traffic.Load(addr); !found {
if isSent {
q.Traffic[addr] = &Traffic{Sent: pktSize}
q.Traffic.Store(addr, &Traffic{Sent: pktSize})
} else {
q.Traffic[addr] = &Traffic{Received: pktSize}
q.Traffic.Store(addr, &Traffic{Received: pktSize})
}
} else {
if isSent {
q.Traffic[addr].Sent += pktSize
v.(*Traffic).Sent += pktSize
} else {
q.Traffic[addr].Received += pktSize
v.(*Traffic).Received += pktSize
}
}
}
func (q *Queue) TrackPacket(size uint64) {
q.Stats.Lock()
defer q.Stats.Unlock()
q.Stats.PktReceived++
q.Stats.Received += size
atomic.AddUint64(&q.Stats.PktReceived, 1)
atomic.AddUint64(&q.Stats.Received, size)
}
func (q *Queue) TrackSent(size uint64) {
q.Stats.Lock()
defer q.Stats.Unlock()
q.Stats.Sent += size
atomic.AddUint64(&q.Stats.Sent, size)
}
func (q *Queue) TrackError() {
q.Stats.Lock()
defer q.Stats.Unlock()
q.Stats.Errors++
atomic.AddUint64(&q.Stats.Errors, 1)
}
func (q *Queue) getPacketMeta(pkt gopacket.Packet) map[string]string {