diff --git a/network/wifi.go b/network/wifi.go index 0fc57d35..0a452d20 100644 --- a/network/wifi.go +++ b/network/wifi.go @@ -43,7 +43,7 @@ type APNewCallback func(ap *AccessPoint) type APLostCallback func(ap *AccessPoint) type WiFi struct { - sync.Mutex + sync.RWMutex aliases *data.UnsortedKV aps map[string]*AccessPoint @@ -67,8 +67,12 @@ func NewWiFi(iface *Endpoint, aliases *data.UnsortedKV, newcb APNewCallback, los } func (w *WiFi) MarshalJSON() ([]byte, error) { + w.RLock() + defer w.RUnlock() + doc := wifiJSON{ - AccessPoints: make([]*AccessPoint, 0), + // we know the length so preallocate to reduce memory allocations + AccessPoints: make([]*AccessPoint, len(w.aps)), } for _, ap := range w.aps { @@ -88,10 +92,11 @@ func (w *WiFi) EachAccessPoint(cb func(mac string, ap *AccessPoint)) { } func (w *WiFi) Stations() (list []*Station) { - w.Lock() - defer w.Unlock() + w.RLock() + defer w.RUnlock() + + list = make([]*Station, len(w.aps)) - list = make([]*Station, 0) for _, ap := range w.aps { list = append(list, ap.Station) } @@ -99,10 +104,11 @@ func (w *WiFi) Stations() (list []*Station) { } func (w *WiFi) List() (list []*AccessPoint) { - w.Lock() - defer w.Unlock() + w.RLock() + defer w.RUnlock() + + list = make([]*AccessPoint, len(w.aps)) - list = make([]*AccessPoint, 0) for _, ap := range w.aps { list = append(list, ap) } @@ -167,8 +173,8 @@ func (w *WiFi) AddIfNew(ssid, mac string, frequency int, rssi int8) (*AccessPoin } func (w *WiFi) Get(mac string) (*AccessPoint, bool) { - w.Lock() - defer w.Unlock() + w.RLock() + defer w.RUnlock() mac = NormalizeMac(mac) ap, found := w.aps[mac] @@ -176,8 +182,8 @@ func (w *WiFi) Get(mac string) (*AccessPoint, bool) { } func (w *WiFi) GetClient(mac string) (*Station, bool) { - w.Lock() - defer w.Unlock() + w.RLock() + defer w.RUnlock() mac = NormalizeMac(mac) for _, ap := range w.aps { @@ -196,8 +202,8 @@ func (w *WiFi) Clear() { } func (w *WiFi) NumHandshakes() int { - w.Lock() - defer w.Unlock() + w.RLock() + defer w.RUnlock() sum := 0 for _, ap := range w.aps { @@ -212,8 +218,8 @@ func (w *WiFi) NumHandshakes() int { } func (w *WiFi) SaveHandshakesTo(fileName string, linkType layers.LinkType) error { - w.Lock() - defer w.Unlock() + w.RLock() + defer w.RUnlock() // check if folder exists first dirName := filepath.Dir(fileName) diff --git a/network/wifi_ap.go b/network/wifi_ap.go index 58143cc0..fc889036 100644 --- a/network/wifi_ap.go +++ b/network/wifi_ap.go @@ -10,7 +10,7 @@ import ( type AccessPoint struct { *Station - sync.Mutex + sync.RWMutex aliases *data.UnsortedKV clients map[string]*Station @@ -32,12 +32,12 @@ func NewAccessPoint(essid, bssid string, frequency int, rssi int8, aliases *data } func (ap *AccessPoint) MarshalJSON() ([]byte, error) { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() doc := apJSON{ Station: ap.Station, - Clients: make([]*Station, 0), + Clients: make([]*Station, len(ap.clients)), Handshake: ap.withKeyMaterial, } @@ -49,8 +49,8 @@ func (ap *AccessPoint) MarshalJSON() ([]byte, error) { } func (ap *AccessPoint) Get(bssid string) (*Station, bool) { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() bssid = NormalizeMac(bssid) if s, found := ap.clients[bssid]; found { @@ -97,16 +97,16 @@ func (ap *AccessPoint) AddClientIfNew(bssid string, frequency int, rssi int8) (* } func (ap *AccessPoint) NumClients() int { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() return len(ap.clients) } func (ap *AccessPoint) Clients() (list []*Station) { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() - list = make([]*Station, 0) + list = make([]*Station, len(ap.clients)) for _, c := range ap.clients { list = append(list, c) } @@ -130,15 +130,15 @@ func (ap *AccessPoint) WithKeyMaterial(state bool) { } func (ap *AccessPoint) HasKeyMaterial() bool { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() return ap.withKeyMaterial } func (ap *AccessPoint) NumHandshakes() int { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() sum := 0 @@ -156,8 +156,8 @@ func (ap *AccessPoint) HasHandshakes() bool { } func (ap *AccessPoint) HasPMKID() bool { - ap.Lock() - defer ap.Unlock() + ap.RLock() + defer ap.RUnlock() for _, c := range ap.clients { if c.Handshake.HasPMKID() { diff --git a/network/wifi_handshake.go b/network/wifi_handshake.go index 341a4f9e..7181baec 100644 --- a/network/wifi_handshake.go +++ b/network/wifi_handshake.go @@ -8,7 +8,7 @@ import ( ) type Handshake struct { - sync.Mutex + sync.RWMutex Beacon gopacket.Packet Challenges []gopacket.Packet @@ -80,8 +80,8 @@ func (h *Handshake) AddFrame(n int, pkt gopacket.Packet) { } func (h *Handshake) Complete() bool { - h.Lock() - defer h.Unlock() + h.RLock() + defer h.RUnlock() nChal := len(h.Challenges) nResp := len(h.Responses) @@ -91,8 +91,8 @@ func (h *Handshake) Complete() bool { } func (h *Handshake) Half() bool { - h.Lock() - defer h.Unlock() + h.RLock() + defer h.RUnlock() /* * You can use every combination of the handshake to crack the net: @@ -110,14 +110,14 @@ func (h *Handshake) Half() bool { } func (h *Handshake) HasPMKID() bool { - h.Lock() - defer h.Unlock() + h.RLock() + defer h.RUnlock() return h.hasPMKID } func (h *Handshake) NumUnsaved() int { - h.Lock() - defer h.Unlock() + h.RLock() + defer h.RUnlock() return len(h.unsaved) }