Merge pull request #799 from bonedaddy/wifi#lock-optimize

WiFi Network Locking Optimizations
This commit is contained in:
Simone Margaritelli 2021-01-06 10:51:36 +01:00 committed by GitHub
commit e6ecd6504f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 43 deletions

View file

@ -43,7 +43,7 @@ type APNewCallback func(ap *AccessPoint)
type APLostCallback func(ap *AccessPoint) type APLostCallback func(ap *AccessPoint)
type WiFi struct { type WiFi struct {
sync.Mutex sync.RWMutex
aliases *data.UnsortedKV aliases *data.UnsortedKV
aps map[string]*AccessPoint aps map[string]*AccessPoint
@ -67,8 +67,12 @@ func NewWiFi(iface *Endpoint, aliases *data.UnsortedKV, newcb APNewCallback, los
} }
func (w *WiFi) MarshalJSON() ([]byte, error) { func (w *WiFi) MarshalJSON() ([]byte, error) {
w.RLock()
defer w.RUnlock()
doc := wifiJSON{ doc := wifiJSON{
AccessPoints: make([]*AccessPoint, 0), // we know the length so preallocate to reduce memory allocations
AccessPoints: make([]*AccessPoint, 0, len(w.aps)),
} }
for _, ap := range w.aps { for _, ap := range w.aps {
@ -88,10 +92,11 @@ func (w *WiFi) EachAccessPoint(cb func(mac string, ap *AccessPoint)) {
} }
func (w *WiFi) Stations() (list []*Station) { func (w *WiFi) Stations() (list []*Station) {
w.Lock() w.RLock()
defer w.Unlock() defer w.RUnlock()
list = make([]*Station, 0, len(w.aps))
list = make([]*Station, 0)
for _, ap := range w.aps { for _, ap := range w.aps {
list = append(list, ap.Station) list = append(list, ap.Station)
} }
@ -99,10 +104,11 @@ func (w *WiFi) Stations() (list []*Station) {
} }
func (w *WiFi) List() (list []*AccessPoint) { func (w *WiFi) List() (list []*AccessPoint) {
w.Lock() w.RLock()
defer w.Unlock() defer w.RUnlock()
list = make([]*AccessPoint, 0, len(w.aps))
list = make([]*AccessPoint, 0)
for _, ap := range w.aps { for _, ap := range w.aps {
list = append(list, ap) list = append(list, ap)
} }
@ -167,8 +173,8 @@ func (w *WiFi) AddIfNew(ssid, mac string, frequency int, rssi int8) (*AccessPoin
} }
func (w *WiFi) Get(mac string) (*AccessPoint, bool) { func (w *WiFi) Get(mac string) (*AccessPoint, bool) {
w.Lock() w.RLock()
defer w.Unlock() defer w.RUnlock()
mac = NormalizeMac(mac) mac = NormalizeMac(mac)
ap, found := w.aps[mac] ap, found := w.aps[mac]
@ -176,8 +182,8 @@ func (w *WiFi) Get(mac string) (*AccessPoint, bool) {
} }
func (w *WiFi) GetClient(mac string) (*Station, bool) { func (w *WiFi) GetClient(mac string) (*Station, bool) {
w.Lock() w.RLock()
defer w.Unlock() defer w.RUnlock()
mac = NormalizeMac(mac) mac = NormalizeMac(mac)
for _, ap := range w.aps { for _, ap := range w.aps {
@ -196,8 +202,8 @@ func (w *WiFi) Clear() {
} }
func (w *WiFi) NumHandshakes() int { func (w *WiFi) NumHandshakes() int {
w.Lock() w.RLock()
defer w.Unlock() defer w.RUnlock()
sum := 0 sum := 0
for _, ap := range w.aps { for _, ap := range w.aps {
@ -212,9 +218,6 @@ func (w *WiFi) NumHandshakes() int {
} }
func (w *WiFi) SaveHandshakesTo(fileName string, linkType layers.LinkType) error { func (w *WiFi) SaveHandshakesTo(fileName string, linkType layers.LinkType) error {
w.Lock()
defer w.Unlock()
// check if folder exists first // check if folder exists first
dirName := filepath.Dir(fileName) dirName := filepath.Dir(fileName)
if _, err := os.Stat(dirName); err != nil { if _, err := os.Stat(dirName); err != nil {
@ -238,6 +241,9 @@ func (w *WiFi) SaveHandshakesTo(fileName string, linkType layers.LinkType) error
} }
} }
w.RLock()
defer w.RUnlock()
for _, ap := range w.aps { for _, ap := range w.aps {
for _, station := range ap.Clients() { for _, station := range ap.Clients() {
// if half (which includes also complete) or has pmkid // if half (which includes also complete) or has pmkid

View file

@ -10,7 +10,7 @@ import (
type AccessPoint struct { type AccessPoint struct {
*Station *Station
sync.Mutex sync.RWMutex
aliases *data.UnsortedKV aliases *data.UnsortedKV
clients map[string]*Station clients map[string]*Station
@ -32,12 +32,12 @@ func NewAccessPoint(essid, bssid string, frequency int, rssi int8, aliases *data
} }
func (ap *AccessPoint) MarshalJSON() ([]byte, error) { func (ap *AccessPoint) MarshalJSON() ([]byte, error) {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
doc := apJSON{ doc := apJSON{
Station: ap.Station, Station: ap.Station,
Clients: make([]*Station, 0), Clients: make([]*Station, 0, len(ap.clients)),
Handshake: ap.withKeyMaterial, Handshake: ap.withKeyMaterial,
} }
@ -49,8 +49,8 @@ func (ap *AccessPoint) MarshalJSON() ([]byte, error) {
} }
func (ap *AccessPoint) Get(bssid string) (*Station, bool) { func (ap *AccessPoint) Get(bssid string) (*Station, bool) {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
bssid = NormalizeMac(bssid) bssid = NormalizeMac(bssid)
if s, found := ap.clients[bssid]; found { if s, found := ap.clients[bssid]; found {
@ -97,16 +97,16 @@ func (ap *AccessPoint) AddClientIfNew(bssid string, frequency int, rssi int8) (*
} }
func (ap *AccessPoint) NumClients() int { func (ap *AccessPoint) NumClients() int {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
return len(ap.clients) return len(ap.clients)
} }
func (ap *AccessPoint) Clients() (list []*Station) { func (ap *AccessPoint) Clients() (list []*Station) {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
list = make([]*Station, 0) list = make([]*Station, 0, len(ap.clients))
for _, c := range ap.clients { for _, c := range ap.clients {
list = append(list, c) list = append(list, c)
} }
@ -130,15 +130,15 @@ func (ap *AccessPoint) WithKeyMaterial(state bool) {
} }
func (ap *AccessPoint) HasKeyMaterial() bool { func (ap *AccessPoint) HasKeyMaterial() bool {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
return ap.withKeyMaterial return ap.withKeyMaterial
} }
func (ap *AccessPoint) NumHandshakes() int { func (ap *AccessPoint) NumHandshakes() int {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
sum := 0 sum := 0
@ -156,8 +156,8 @@ func (ap *AccessPoint) HasHandshakes() bool {
} }
func (ap *AccessPoint) HasPMKID() bool { func (ap *AccessPoint) HasPMKID() bool {
ap.Lock() ap.RLock()
defer ap.Unlock() defer ap.RUnlock()
for _, c := range ap.clients { for _, c := range ap.clients {
if c.Handshake.HasPMKID() { if c.Handshake.HasPMKID() {

View file

@ -8,7 +8,7 @@ import (
) )
type Handshake struct { type Handshake struct {
sync.Mutex sync.RWMutex
Beacon gopacket.Packet Beacon gopacket.Packet
Challenges []gopacket.Packet Challenges []gopacket.Packet
@ -80,8 +80,8 @@ func (h *Handshake) AddFrame(n int, pkt gopacket.Packet) {
} }
func (h *Handshake) Complete() bool { func (h *Handshake) Complete() bool {
h.Lock() h.RLock()
defer h.Unlock() defer h.RUnlock()
nChal := len(h.Challenges) nChal := len(h.Challenges)
nResp := len(h.Responses) nResp := len(h.Responses)
@ -91,8 +91,8 @@ func (h *Handshake) Complete() bool {
} }
func (h *Handshake) Half() bool { func (h *Handshake) Half() bool {
h.Lock() h.RLock()
defer h.Unlock() defer h.RUnlock()
/* /*
* You can use every combination of the handshake to crack the net: * You can use every combination of the handshake to crack the net:
@ -110,14 +110,14 @@ func (h *Handshake) Half() bool {
} }
func (h *Handshake) HasPMKID() bool { func (h *Handshake) HasPMKID() bool {
h.Lock() h.RLock()
defer h.Unlock() defer h.RUnlock()
return h.hasPMKID return h.hasPMKID
} }
func (h *Handshake) NumUnsaved() int { func (h *Handshake) NumUnsaved() int {
h.Lock() h.RLock()
defer h.Unlock() defer h.RUnlock()
return len(h.unsaved) return len(h.unsaved)
} }