diff --git a/modules/syn_scan/syn_scan.go b/modules/syn_scan/syn_scan.go index 0033e650..73bdd446 100644 --- a/modules/syn_scan/syn_scan.go +++ b/modules/syn_scan/syn_scan.go @@ -3,7 +3,6 @@ package syn_scan import ( "fmt" "net" - "strconv" "sync" "sync/atomic" "time" @@ -11,10 +10,10 @@ import ( "github.com/bettercap/bettercap/packets" "github.com/bettercap/bettercap/session" - "github.com/malfunkt/iprange" - "github.com/evilsocket/islazy/async" - "github.com/evilsocket/islazy/str" + + "github.com/google/gopacket" + "github.com/google/gopacket/pcap" ) const synSourcePort = 666 @@ -33,6 +32,8 @@ type SynScanner struct { addresses []net.IP startPort int endPort int + handle *pcap.Handle + packets chan gopacket.Packet progressEvery time.Duration stats synScannerStats waitGroup *sync.WaitGroup @@ -97,44 +98,6 @@ func NewSynScanner(s *session.Session) *SynScanner { return mod } -func (mod *SynScanner) parseTargets(arg string) error { - if list, err := iprange.Parse(arg); err != nil { - return fmt.Errorf("error while parsing IP range '%s': %s", arg, err) - } else { - mod.addresses = list.Expand() - } - return nil -} - -func (mod *SynScanner) parsePorts(args []string) (err error) { - argc := len(args) - mod.stats.totProbes = 0 - mod.stats.doneProbes = 0 - mod.startPort = 1 - mod.endPort = 65535 - - if argc > 1 && str.Trim(args[1]) != "" { - if mod.startPort, err = strconv.Atoi(str.Trim(args[1])); err != nil { - return fmt.Errorf("invalid start port %s: %s", args[1], err) - } else if mod.startPort > 65535 { - mod.startPort = 65535 - } - mod.endPort = mod.startPort - } - - if argc > 2 && str.Trim(args[2]) != "" { - if mod.endPort, err = strconv.Atoi(str.Trim(args[2])); err != nil { - return fmt.Errorf("invalid end port %s: %s", args[2], err) - } - } - - if mod.endPort < mod.startPort { - return fmt.Errorf("end port %d is greater than start port %d", mod.endPort, mod.startPort) - } - - return -} - func (mod *SynScanner) Name() string { return "syn.scan" } @@ -147,7 +110,14 @@ func (mod *SynScanner) Author() string { return "Simone Margaritelli " } -func (mod *SynScanner) Configure() error { +func (mod *SynScanner) Configure() (err error) { + if mod.Running() { + return session.ErrAlreadyStarted(mod.Name()) + } else if mod.handle, err = pcap.OpenLive(mod.Session.Interface.Name(), 65536, true, pcap.BlockForever); err != nil { + return err + } else if err = mod.handle.SetBPFFilter(fmt.Sprintf("tcp dst port %d", synSourcePort)); err != nil { + return err + } return nil } @@ -180,10 +150,12 @@ func (mod *SynScanner) showProgress() error { func (mod *SynScanner) Stop() error { mod.Info("stopping ...") return mod.SetRunning(false, func() { + mod.packets <- nil mod.waitGroup.Wait() mod.showProgress() mod.addresses = []net.IP{} mod.State.Store("progress", 0.0) + mod.State.Store("scanning", &mod.addresses) }) } @@ -219,15 +191,22 @@ func (mod *SynScanner) scanWorker(job async.Job) { } func (mod *SynScanner) synScan() error { + if err := mod.Configure(); err != nil { + return err + } + mod.SetRunning(true, func() { + mod.waitGroup.Add(1) + defer mod.waitGroup.Done() + defer mod.SetRunning(false, func() { mod.addresses = []net.IP{} mod.State.Store("progress", 0.0) + mod.State.Store("scanning", &mod.addresses) + mod.packets <- nil + mod.handle.Close() }) - mod.waitGroup.Add(1) - defer mod.waitGroup.Done() - mod.stats.openPorts = 0 mod.stats.numPorts = uint64(mod.endPort - mod.startPort + 1) mod.stats.started = time.Now() @@ -247,9 +226,20 @@ func (mod *SynScanner) synScan() error { mod.State.Store("progress", 0.0) - // set the collector - mod.Session.Queue.OnPacket(mod.onPacket) - defer mod.Session.Queue.OnPacket(nil) + // start the collector + go func() { + mod.waitGroup.Add(1) + defer mod.waitGroup.Done() + + src := gopacket.NewPacketSource(mod.handle, mod.handle.LinkType()) + mod.packets = src.Packets() + for packet := range mod.packets { + if !mod.Running() { + break + } + mod.onPacket(packet) + } + }() // start to show progress every second go func() { diff --git a/modules/syn_scan/syn_scan_parsers.go b/modules/syn_scan/syn_scan_parsers.go new file mode 100644 index 00000000..25a8c24d --- /dev/null +++ b/modules/syn_scan/syn_scan_parsers.go @@ -0,0 +1,47 @@ +package syn_scan + +import ( + "fmt" + "strconv" + + "github.com/evilsocket/islazy/str" + "github.com/malfunkt/iprange" +) + +func (mod *SynScanner) parseTargets(arg string) error { + if list, err := iprange.Parse(arg); err != nil { + return fmt.Errorf("error while parsing IP range '%s': %s", arg, err) + } else { + mod.addresses = list.Expand() + } + return nil +} + +func (mod *SynScanner) parsePorts(args []string) (err error) { + argc := len(args) + mod.stats.totProbes = 0 + mod.stats.doneProbes = 0 + mod.startPort = 1 + mod.endPort = 65535 + + if argc > 1 && str.Trim(args[1]) != "" { + if mod.startPort, err = strconv.Atoi(str.Trim(args[1])); err != nil { + return fmt.Errorf("invalid start port %s: %s", args[1], err) + } else if mod.startPort > 65535 { + mod.startPort = 65535 + } + mod.endPort = mod.startPort + } + + if argc > 2 && str.Trim(args[2]) != "" { + if mod.endPort, err = strconv.Atoi(str.Trim(args[2])); err != nil { + return fmt.Errorf("invalid end port %s: %s", args[2], err) + } + } + + if mod.endPort < mod.startPort { + return fmt.Errorf("end port %d is greater than start port %d", mod.endPort, mod.startPort) + } + + return +} diff --git a/modules/syn_scan/syn_scan_reader.go b/modules/syn_scan/syn_scan_reader.go index a5a00706..b64a63b9 100644 --- a/modules/syn_scan/syn_scan_reader.go +++ b/modules/syn_scan/syn_scan_reader.go @@ -1,7 +1,6 @@ package syn_scan import ( - "net" "sync/atomic" "github.com/bettercap/bettercap/network" @@ -19,15 +18,6 @@ type OpenPort struct { Port int `json:"port"` } -func (mod *SynScanner) isAddressInRange(ip net.IP) bool { - for _, a := range mod.addresses { - if a.Equal(ip) { - return true - } - } - return false -} - func (mod *SynScanner) onPacket(pkt gopacket.Packet) { var eth layers.Ethernet var ip layers.IPv4 @@ -46,7 +36,7 @@ func (mod *SynScanner) onPacket(pkt gopacket.Packet) { return } - if mod.isAddressInRange(ip.SrcIP) && tcp.DstPort == synSourcePort && tcp.SYN && tcp.ACK { + if tcp.DstPort == synSourcePort && tcp.SYN && tcp.ACK { atomic.AddUint64(&mod.stats.openPorts, 1) from := ip.SrcIP.String() diff --git a/packets/queue.go b/packets/queue.go index 1a7eeaf0..23400d2c 100644 --- a/packets/queue.go +++ b/packets/queue.go @@ -33,8 +33,6 @@ type Stats struct { Errors uint64 `json:"errors"` } -type PacketCallback func(pkt gopacket.Packet) - type Queue struct { sync.RWMutex @@ -49,7 +47,6 @@ type Queue struct { source *gopacket.PacketSource srcChannel chan gopacket.Packet writes *sync.WaitGroup - pktCb PacketCallback active bool } @@ -69,7 +66,6 @@ func NewQueue(iface *network.Endpoint) (q *Queue, err error) { writes: &sync.WaitGroup{}, iface: iface, active: !iface.IsMonitor(), - pktCb: nil, } if q.active { @@ -107,21 +103,6 @@ func (q *Queue) MarshalJSON() ([]byte, error) { return json.Marshal(doc) } -func (q *Queue) OnPacket(cb PacketCallback) { - q.Lock() - defer q.Unlock() - q.pktCb = cb -} - -func (q *Queue) onPacketCallback(pkt gopacket.Packet) { - q.RLock() - defer q.RUnlock() - - if q.pktCb != nil { - q.pktCb(pkt) - } -} - func (q *Queue) trackProtocols(pkt gopacket.Packet) { // gather protocols stats pktLayers := pkt.Layers() @@ -206,7 +187,6 @@ func (q *Queue) worker() { pktSize := uint64(len(pkt.Data())) q.TrackPacket(pktSize) - q.onPacketCallback(pkt) // decode eth and ipv4 layers leth := pkt.Layer(layers.LayerTypeEthernet)