diff --git a/modules/net_probe_udp.go b/modules/net_probe_udp.go index 14e299a9..aa6a70d4 100644 --- a/modules/net_probe_udp.go +++ b/modules/net_probe_udp.go @@ -19,9 +19,9 @@ func (p *Prober) sendProbeUDP(from net.IP, from_hw net.HardwareAddr, ip net.IP) wrote, _ := con.Write([]byte{0x00}) if wrote > 0 { - p.Session.Queue.Stats.Lock() - p.Session.Queue.Stats.Sent += uint64(wrote) - p.Session.Queue.Stats.Unlock() + p.Session.Queue.TrackSent(uint64(wrote)) + } else { + p.Session.Queue.TrackError() } } } diff --git a/modules/wifi.go b/modules/wifi.go index 7eefcc40..2e1daabd 100644 --- a/modules/wifi.go +++ b/modules/wifi.go @@ -273,17 +273,6 @@ func (w *WiFiModule) updateStats(dot11 *layers.Dot11, packet gopacket.Packet) { } } -func (w *WiFiModule) trackPacket(pkt gopacket.Packet) { - pktSize := uint64(len(pkt.Data())) - - w.Session.Queue.Stats.Lock() - - w.Session.Queue.Stats.PktReceived++ - w.Session.Queue.Stats.Received += pktSize - - w.Session.Queue.Stats.Unlock() -} - func (w *WiFiModule) Start() error { if err := w.Configure(); err != nil { return err @@ -306,13 +295,11 @@ func (w *WiFiModule) Start() error { for packet := range w.pktSourceChan { if w.Running() == false { break - } - - if packet == nil { + } else if packet == nil { continue } - w.trackPacket(packet) + w.Session.Queue.TrackPacket(uint64(len(packet.Data()))) // perform initial dot11 parsing and layers validation if ok, radiotap, dot11 := packets.Dot11Parse(packet); ok == true { diff --git a/modules/wifi_deauth.go b/modules/wifi_deauth.go index 602701a0..7f48ef86 100644 --- a/modules/wifi_deauth.go +++ b/modules/wifi_deauth.go @@ -13,14 +13,9 @@ import ( func (w *WiFiModule) injectPacket(data []byte) { if err := w.handle.WritePacketData(data); err != nil { log.Error("Could not inject WiFi packet: %s", err) - - w.Session.Queue.Stats.Lock() - w.Session.Queue.Stats.Errors++ - w.Session.Queue.Stats.Unlock() + w.Session.Queue.TrackError() } else { - w.Session.Queue.Stats.Lock() - w.Session.Queue.Stats.Sent += uint64(len(data)) - w.Session.Queue.Stats.Unlock() + w.Session.Queue.TrackSent(uint64(len(data))) } // let the network card breath a little time.Sleep(10 * time.Millisecond) diff --git a/packets/queue.go b/packets/queue.go index 5d49d4ab..2b4c2749 100644 --- a/packets/queue.go +++ b/packets/queue.go @@ -141,6 +141,28 @@ func (q *Queue) trackActivity(eth *layers.Ethernet, ip4 *layers.IPv4, address ne } } +func (q *Queue) TrackPacket(size uint64) { + q.Stats.Lock() + defer q.Stats.Unlock() + + q.Stats.PktReceived++ + q.Stats.Received += size +} + +func (q *Queue) TrackSent(size uint64) { + q.Stats.Lock() + defer q.Stats.Unlock() + + q.Stats.Sent += size +} + +func (q *Queue) TrackError() { + q.Stats.Lock() + defer q.Stats.Unlock() + + q.Stats.Errors++ +} + func (q *Queue) worker() { for pkt := range q.srcChannel { if q.active == false { @@ -151,13 +173,7 @@ func (q *Queue) worker() { pktSize := uint64(len(pkt.Data())) - q.Stats.Lock() - - q.Stats.PktReceived++ - q.Stats.Received += pktSize - - q.Stats.Unlock() - + q.TrackPacket(pktSize) q.onPacketCallback(pkt) // decode eth and ipv4 layers @@ -200,14 +216,10 @@ func (q *Queue) Send(raw []byte) error { defer q.writes.Done() if err := q.handle.WritePacketData(raw); err != nil { - q.Stats.Lock() - q.Stats.Errors++ - q.Stats.Unlock() + q.TrackError() return err } else { - q.Stats.Lock() - q.Stats.Sent += uint64(len(raw)) - q.Stats.Unlock() + q.TrackSent(uint64(len(raw))) } return nil