mirror of
https://github.com/bettercap/bettercap
synced 2025-07-30 03:29:57 -07:00
644 lines
15 KiB
Go
644 lines
15 KiB
Go
package net_recon
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/bettercap/bettercap/v2/modules/utils"
|
|
"github.com/bettercap/bettercap/v2/network"
|
|
"github.com/bettercap/bettercap/v2/packets"
|
|
"github.com/bettercap/bettercap/v2/session"
|
|
"github.com/evilsocket/islazy/data"
|
|
)
|
|
|
|
// Mock ArpUpdate function
|
|
var mockArpUpdateFunc func(string) (network.ArpTable, error)
|
|
|
|
// Override the network.ArpUpdate function for testing
|
|
func mockArpUpdate(iface string) (network.ArpTable, error) {
|
|
if mockArpUpdateFunc != nil {
|
|
return mockArpUpdateFunc(iface)
|
|
}
|
|
return make(network.ArpTable), nil
|
|
}
|
|
|
|
// MockLAN implements a mock version of the LAN interface
|
|
type MockLAN struct {
|
|
sync.RWMutex
|
|
hosts map[string]*network.Endpoint
|
|
wasMissed map[string]bool
|
|
addedHosts []string
|
|
removedHosts []string
|
|
}
|
|
|
|
func NewMockLAN() *MockLAN {
|
|
return &MockLAN{
|
|
hosts: make(map[string]*network.Endpoint),
|
|
wasMissed: make(map[string]bool),
|
|
addedHosts: []string{},
|
|
removedHosts: []string{},
|
|
}
|
|
}
|
|
|
|
func (m *MockLAN) AddIfNew(ip, mac string) {
|
|
m.Lock()
|
|
defer m.Unlock()
|
|
|
|
if _, exists := m.hosts[mac]; !exists {
|
|
m.hosts[mac] = &network.Endpoint{
|
|
IpAddress: ip,
|
|
HwAddress: mac,
|
|
FirstSeen: time.Now(),
|
|
LastSeen: time.Now(),
|
|
}
|
|
m.addedHosts = append(m.addedHosts, mac)
|
|
}
|
|
}
|
|
|
|
func (m *MockLAN) Remove(ip, mac string) {
|
|
m.Lock()
|
|
defer m.Unlock()
|
|
|
|
if _, exists := m.hosts[mac]; exists {
|
|
delete(m.hosts, mac)
|
|
m.removedHosts = append(m.removedHosts, mac)
|
|
}
|
|
}
|
|
|
|
func (m *MockLAN) Clear() {
|
|
m.Lock()
|
|
defer m.Unlock()
|
|
|
|
m.hosts = make(map[string]*network.Endpoint)
|
|
m.wasMissed = make(map[string]bool)
|
|
m.addedHosts = []string{}
|
|
m.removedHosts = []string{}
|
|
}
|
|
|
|
func (m *MockLAN) EachHost(cb func(mac string, e *network.Endpoint)) {
|
|
m.RLock()
|
|
defer m.RUnlock()
|
|
|
|
for mac, host := range m.hosts {
|
|
cb(mac, host)
|
|
}
|
|
}
|
|
|
|
func (m *MockLAN) List() []*network.Endpoint {
|
|
m.RLock()
|
|
defer m.RUnlock()
|
|
|
|
list := make([]*network.Endpoint, 0, len(m.hosts))
|
|
for _, host := range m.hosts {
|
|
list = append(list, host)
|
|
}
|
|
return list
|
|
}
|
|
|
|
func (m *MockLAN) WasMissed(mac string) bool {
|
|
m.RLock()
|
|
defer m.RUnlock()
|
|
|
|
return m.wasMissed[mac]
|
|
}
|
|
|
|
func (m *MockLAN) Get(mac string) *network.Endpoint {
|
|
m.RLock()
|
|
defer m.RUnlock()
|
|
|
|
return m.hosts[mac]
|
|
}
|
|
|
|
// Create a mock session for testing
|
|
func createMockSession() *session.Session {
|
|
iface := &network.Endpoint{
|
|
IpAddress: "192.168.1.100",
|
|
HwAddress: "aa:bb:cc:dd:ee:ff",
|
|
Hostname: "eth0",
|
|
}
|
|
iface.SetIP("192.168.1.100")
|
|
iface.SetBits(24)
|
|
|
|
gateway := &network.Endpoint{
|
|
IpAddress: "192.168.1.1",
|
|
HwAddress: "11:22:33:44:55:66",
|
|
}
|
|
|
|
// Create environment
|
|
env, _ := session.NewEnvironment("")
|
|
|
|
sess := &session.Session{
|
|
Interface: iface,
|
|
Gateway: gateway,
|
|
StartedAt: time.Now(),
|
|
Active: true,
|
|
Env: env,
|
|
Queue: &packets.Queue{
|
|
Traffic: sync.Map{},
|
|
Stats: packets.Stats{},
|
|
},
|
|
Modules: make(session.ModuleList, 0),
|
|
}
|
|
|
|
// Initialize the Events field with a mock EventPool
|
|
sess.Events = session.NewEventPool(false, false)
|
|
|
|
return sess
|
|
}
|
|
|
|
func TestNewDiscovery(t *testing.T) {
|
|
sess := createMockSession()
|
|
mod := NewDiscovery(sess)
|
|
|
|
if mod == nil {
|
|
t.Fatal("NewDiscovery returned nil")
|
|
}
|
|
|
|
if mod.Name() != "net.recon" {
|
|
t.Errorf("expected module name 'net.recon', got '%s'", mod.Name())
|
|
}
|
|
|
|
if mod.Author() != "Simone Margaritelli <evilsocket@gmail.com>" {
|
|
t.Errorf("unexpected author: %s", mod.Author())
|
|
}
|
|
|
|
if mod.selector == nil {
|
|
t.Error("selector should be initialized")
|
|
}
|
|
}
|
|
|
|
func TestRunDiff(t *testing.T) {
|
|
// Test the basic diff functionality with a simpler approach
|
|
tests := []struct {
|
|
name string
|
|
initialHosts map[string]string // IP -> MAC
|
|
arpTable network.ArpTable
|
|
expectedAdded []string
|
|
expectedRemoved []string
|
|
}{
|
|
{
|
|
name: "no changes",
|
|
initialHosts: map[string]string{
|
|
"192.168.1.10": "aa:aa:aa:aa:aa:aa",
|
|
"192.168.1.20": "bb:bb:bb:bb:bb:bb",
|
|
},
|
|
arpTable: network.ArpTable{
|
|
"192.168.1.10": "aa:aa:aa:aa:aa:aa",
|
|
"192.168.1.20": "bb:bb:bb:bb:bb:bb",
|
|
},
|
|
expectedAdded: []string{},
|
|
expectedRemoved: []string{},
|
|
},
|
|
{
|
|
name: "new host discovered",
|
|
initialHosts: map[string]string{
|
|
"192.168.1.10": "aa:aa:aa:aa:aa:aa",
|
|
},
|
|
arpTable: network.ArpTable{
|
|
"192.168.1.10": "aa:aa:aa:aa:aa:aa",
|
|
"192.168.1.20": "bb:bb:bb:bb:bb:bb",
|
|
},
|
|
expectedAdded: []string{"bb:bb:bb:bb:bb:bb"},
|
|
expectedRemoved: []string{},
|
|
},
|
|
{
|
|
name: "host disappeared",
|
|
initialHosts: map[string]string{
|
|
"192.168.1.10": "aa:aa:aa:aa:aa:aa",
|
|
"192.168.1.20": "bb:bb:bb:bb:bb:bb",
|
|
},
|
|
arpTable: network.ArpTable{
|
|
"192.168.1.10": "aa:aa:aa:aa:aa:aa",
|
|
},
|
|
expectedAdded: []string{},
|
|
expectedRemoved: []string{"bb:bb:bb:bb:bb:bb"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
sess := createMockSession()
|
|
|
|
// Track callbacks
|
|
addedHosts := []string{}
|
|
removedHosts := []string{}
|
|
|
|
newCb := func(e *network.Endpoint) {
|
|
addedHosts = append(addedHosts, e.HwAddress)
|
|
}
|
|
|
|
lostCb := func(e *network.Endpoint) {
|
|
removedHosts = append(removedHosts, e.HwAddress)
|
|
}
|
|
|
|
aliases, _ := data.NewUnsortedKV("", 0)
|
|
sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, newCb, lostCb)
|
|
|
|
mod := &Discovery{
|
|
SessionModule: session.NewSessionModule("net.recon", sess),
|
|
}
|
|
|
|
// Add initial hosts
|
|
for ip, mac := range tt.initialHosts {
|
|
sess.Lan.AddIfNew(ip, mac)
|
|
}
|
|
|
|
// Reset tracking
|
|
addedHosts = []string{}
|
|
removedHosts = []string{}
|
|
|
|
// Add interface and gateway to ARP table to avoid them being removed
|
|
finalArpTable := make(network.ArpTable)
|
|
for k, v := range tt.arpTable {
|
|
finalArpTable[k] = v
|
|
}
|
|
finalArpTable[sess.Interface.IpAddress] = sess.Interface.HwAddress
|
|
finalArpTable[sess.Gateway.IpAddress] = sess.Gateway.HwAddress
|
|
|
|
// Run the diff multiple times to trigger actual removal (TTL countdown)
|
|
for i := 0; i < network.LANDefaultttl+1; i++ {
|
|
mod.runDiff(finalArpTable)
|
|
}
|
|
|
|
// Check results
|
|
if len(addedHosts) != len(tt.expectedAdded) {
|
|
t.Errorf("expected %d added hosts, got %d. Added: %v", len(tt.expectedAdded), len(addedHosts), addedHosts)
|
|
}
|
|
|
|
if len(removedHosts) != len(tt.expectedRemoved) {
|
|
t.Errorf("expected %d removed hosts, got %d. Removed: %v", len(tt.expectedRemoved), len(removedHosts), removedHosts)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConfigure(t *testing.T) {
|
|
sess := createMockSession()
|
|
mod := NewDiscovery(sess)
|
|
|
|
err := mod.Configure()
|
|
if err != nil {
|
|
t.Errorf("Configure() returned error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStartStop(t *testing.T) {
|
|
sess := createMockSession()
|
|
aliases, _ := data.NewUnsortedKV("", 0)
|
|
sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
|
|
|
|
mod := NewDiscovery(sess)
|
|
|
|
// Test starting the module
|
|
err := mod.Start()
|
|
if err != nil {
|
|
t.Errorf("Start() returned error: %v", err)
|
|
}
|
|
|
|
if !mod.Running() {
|
|
t.Error("module should be running after Start()")
|
|
}
|
|
|
|
// Let it run briefly
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Test stopping the module
|
|
err = mod.Stop()
|
|
if err != nil {
|
|
t.Errorf("Stop() returned error: %v", err)
|
|
}
|
|
|
|
if mod.Running() {
|
|
t.Error("module should not be running after Stop()")
|
|
}
|
|
}
|
|
|
|
func TestShowMethods(t *testing.T) {
|
|
// Skip this test as it requires a full session with readline
|
|
t.Skip("Skipping TestShowMethods as it requires readline initialization")
|
|
}
|
|
|
|
func TestDoSelection(t *testing.T) {
|
|
sess := createMockSession()
|
|
aliases, _ := data.NewUnsortedKV("", 0)
|
|
sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
|
|
|
|
// Add test endpoints
|
|
sess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
|
|
sess.Lan.AddIfNew("192.168.1.20", "bb:bb:bb:bb:bb:bb")
|
|
sess.Lan.AddIfNew("192.168.1.30", "cc:cc:cc:cc:cc:cc")
|
|
|
|
// Get endpoints and set additional properties
|
|
if e, found := sess.Lan.Get("aa:aa:aa:aa:aa:aa"); found {
|
|
e.Hostname = "host1"
|
|
e.Vendor = "Vendor1"
|
|
}
|
|
|
|
if e, found := sess.Lan.Get("bb:bb:bb:bb:bb:bb"); found {
|
|
e.Alias = "mydevice"
|
|
e.Vendor = "Vendor2"
|
|
}
|
|
|
|
mod := NewDiscovery(sess)
|
|
mod.selector = utils.ViewSelectorFor(&mod.SessionModule, "net.show",
|
|
[]string{"ip", "mac", "seen", "sent", "rcvd"}, "ip asc")
|
|
|
|
tests := []struct {
|
|
name string
|
|
arg string
|
|
expectedCount int
|
|
expectedIPs []string
|
|
}{
|
|
{
|
|
name: "select all",
|
|
arg: "",
|
|
expectedCount: 3,
|
|
},
|
|
{
|
|
name: "select by IP",
|
|
arg: "192.168.1.10",
|
|
expectedCount: 1,
|
|
expectedIPs: []string{"192.168.1.10"},
|
|
},
|
|
{
|
|
name: "select by MAC",
|
|
arg: "aa:aa:aa:aa:aa:aa",
|
|
expectedCount: 1,
|
|
expectedIPs: []string{"192.168.1.10"},
|
|
},
|
|
{
|
|
name: "select multiple by comma",
|
|
arg: "192.168.1.10,192.168.1.20",
|
|
expectedCount: 2,
|
|
expectedIPs: []string{"192.168.1.10", "192.168.1.20"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err, targets := mod.doSelection(tt.arg)
|
|
if err != nil {
|
|
t.Errorf("doSelection returned error: %v", err)
|
|
}
|
|
|
|
if len(targets) != tt.expectedCount {
|
|
t.Errorf("expected %d targets, got %d", tt.expectedCount, len(targets))
|
|
}
|
|
|
|
if tt.expectedIPs != nil {
|
|
for _, expectedIP := range tt.expectedIPs {
|
|
found := false
|
|
for _, target := range targets {
|
|
if target.IpAddress == expectedIP {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("expected to find IP %s in targets", expectedIP)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandlers(t *testing.T) {
|
|
sess := createMockSession()
|
|
aliases, _ := data.NewUnsortedKV("", 0)
|
|
sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
|
|
|
|
mod := NewDiscovery(sess)
|
|
|
|
handlers := []struct {
|
|
name string
|
|
handler string
|
|
args []string
|
|
setup func()
|
|
validate func() error
|
|
}{
|
|
{
|
|
name: "net.clear",
|
|
handler: "net.clear",
|
|
args: []string{},
|
|
setup: func() {
|
|
sess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
|
|
},
|
|
validate: func() error {
|
|
// Check if hosts were cleared
|
|
hosts := sess.Lan.List()
|
|
if len(hosts) != 0 {
|
|
return fmt.Errorf("expected empty hosts after clear, got %d", len(hosts))
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range handlers {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.setup != nil {
|
|
tt.setup()
|
|
}
|
|
|
|
// Find and execute the handler
|
|
found := false
|
|
for _, h := range mod.Handlers() {
|
|
if h.Name == tt.handler {
|
|
found = true
|
|
err := h.Exec(tt.args)
|
|
if err != nil {
|
|
t.Errorf("handler %s returned error: %v", tt.handler, err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("handler %s not found", tt.handler)
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
if err := tt.validate(); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetRow(t *testing.T) {
|
|
sess := createMockSession()
|
|
aliases, _ := data.NewUnsortedKV("", 0)
|
|
sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
|
|
|
|
mod := NewDiscovery(sess)
|
|
|
|
// Test endpoint with metadata
|
|
endpoint := &network.Endpoint{
|
|
IpAddress: "192.168.1.10",
|
|
HwAddress: "aa:aa:aa:aa:aa:aa",
|
|
Hostname: "testhost",
|
|
Vendor: "Test Vendor",
|
|
FirstSeen: time.Now().Add(-time.Hour),
|
|
LastSeen: time.Now(),
|
|
Meta: network.NewMeta(),
|
|
}
|
|
endpoint.Meta.Set("key1", "value1")
|
|
endpoint.Meta.Set("key2", "value2")
|
|
|
|
// Test without meta
|
|
rows := mod.getRow(endpoint, false)
|
|
if len(rows) != 1 {
|
|
t.Errorf("expected 1 row without meta, got %d", len(rows))
|
|
}
|
|
if len(rows[0]) != 7 {
|
|
t.Errorf("expected 7 columns, got %d", len(rows[0]))
|
|
}
|
|
|
|
// Test with meta
|
|
rows = mod.getRow(endpoint, true)
|
|
if len(rows) != 2 { // One main row + one meta row per metadata entry
|
|
t.Errorf("expected 2 rows with meta, got %d", len(rows))
|
|
}
|
|
|
|
// Test interface endpoint
|
|
ifaceEndpoint := sess.Interface
|
|
rows = mod.getRow(ifaceEndpoint, false)
|
|
if len(rows) != 1 {
|
|
t.Errorf("expected 1 row for interface, got %d", len(rows))
|
|
}
|
|
|
|
// Test gateway endpoint
|
|
gatewayEndpoint := sess.Gateway
|
|
rows = mod.getRow(gatewayEndpoint, false)
|
|
if len(rows) != 1 {
|
|
t.Errorf("expected 1 row for gateway, got %d", len(rows))
|
|
}
|
|
}
|
|
|
|
func TestDoFilter(t *testing.T) {
|
|
sess := createMockSession()
|
|
mod := NewDiscovery(sess)
|
|
mod.selector = utils.ViewSelectorFor(&mod.SessionModule, "net.show",
|
|
[]string{"ip", "mac", "seen", "sent", "rcvd"}, "ip asc")
|
|
|
|
// Test that doFilter behavior matches the actual implementation
|
|
// When Expression is nil, it returns true (no filtering)
|
|
// When Expression is set, it matches against any of the fields
|
|
|
|
tests := []struct {
|
|
name string
|
|
filter string
|
|
endpoint *network.Endpoint
|
|
shouldMatch bool
|
|
}{
|
|
{
|
|
name: "no filter",
|
|
filter: "",
|
|
endpoint: &network.Endpoint{
|
|
IpAddress: "192.168.1.10",
|
|
Meta: network.NewMeta(),
|
|
},
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "ip filter match",
|
|
filter: "192.168",
|
|
endpoint: &network.Endpoint{
|
|
IpAddress: "192.168.1.10",
|
|
Meta: network.NewMeta(),
|
|
},
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "mac filter match",
|
|
filter: "aa:bb",
|
|
endpoint: &network.Endpoint{
|
|
IpAddress: "192.168.1.10",
|
|
HwAddress: "aa:bb:cc:dd:ee:ff",
|
|
Meta: network.NewMeta(),
|
|
},
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "hostname filter match",
|
|
filter: "myhost",
|
|
endpoint: &network.Endpoint{
|
|
IpAddress: "192.168.1.10",
|
|
Hostname: "myhost.local",
|
|
Meta: network.NewMeta(),
|
|
},
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "no match - testing unique string",
|
|
filter: "xyz123nomatch",
|
|
endpoint: &network.Endpoint{
|
|
IpAddress: "192.168.1.10",
|
|
Ip6Address: "",
|
|
HwAddress: "aa:bb:cc:dd:ee:ff",
|
|
Hostname: "host.local",
|
|
Alias: "",
|
|
Vendor: "",
|
|
Meta: network.NewMeta(),
|
|
},
|
|
shouldMatch: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Reset selector for each test
|
|
// Set the parameter value that Update() will read
|
|
sess.Env.Set("net.show.filter", tt.filter)
|
|
mod.selector.Expression = nil
|
|
|
|
// Update will read from the parameter
|
|
err := mod.selector.Update()
|
|
if err != nil {
|
|
t.Fatalf("selector.Update() failed: %v", err)
|
|
}
|
|
|
|
result := mod.doFilter(tt.endpoint)
|
|
if result != tt.shouldMatch {
|
|
if mod.selector.Expression != nil {
|
|
t.Errorf("expected doFilter to return %v, got %v. Regex: %s", tt.shouldMatch, result, mod.selector.Expression.String())
|
|
} else {
|
|
t.Errorf("expected doFilter to return %v, got %v. Expression is nil", tt.shouldMatch, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Benchmark the runDiff method
|
|
func BenchmarkRunDiff(b *testing.B) {
|
|
sess := createMockSession()
|
|
aliases, _ := data.NewUnsortedKV("", 0)
|
|
sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
|
|
|
|
mod := &Discovery{
|
|
SessionModule: session.NewSessionModule("net.recon", sess),
|
|
}
|
|
|
|
// Create a large ARP table
|
|
arpTable := make(network.ArpTable)
|
|
for i := 0; i < 100; i++ {
|
|
ip := fmt.Sprintf("192.168.1.%d", i)
|
|
mac := fmt.Sprintf("aa:bb:cc:dd:%02x:%02x", i/256, i%256)
|
|
arpTable[ip] = mac
|
|
|
|
// Add half to the existing LAN
|
|
if i < 50 {
|
|
sess.Lan.AddIfNew(ip, mac)
|
|
}
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
mod.runDiff(arpTable)
|
|
}
|
|
}
|