diff --git a/modules/api_rest.go b/modules/api_rest.go index 25fc3398..28c504f7 100644 --- a/modules/api_rest.go +++ b/modules/api_rest.go @@ -10,19 +10,34 @@ import ( "github.com/bettercap/bettercap/log" "github.com/bettercap/bettercap/session" "github.com/bettercap/bettercap/tls" + + "github.com/gorilla/websocket" ) type RestAPI struct { session.SessionModule - server *http.Server - certFile string - keyFile string + server *http.Server + username string + password string + certFile string + keyFile string + useWebsocket bool + upgrader websocket.Upgrader + eventListener <-chan session.Event + quit chan bool } func NewRestAPI(s *session.Session) *RestAPI { api := &RestAPI{ SessionModule: session.NewSessionModule("api.rest", s), server: &http.Server{}, + quit: make(chan bool), + useWebsocket: false, + eventListener: s.Events.Listen(), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, } api.AddParam(session.NewStringParameter("api.rest.address", @@ -54,6 +69,10 @@ func NewRestAPI(s *session.Session) *RestAPI { "", "API TLS key")) + api.AddParam(session.NewBoolParameter("api.rest.websocket", + "false", + "If true the /api/events route will be available as a websocket endpoint instead of HTTPS.")) + api.AddHandler(session.NewModuleHandler("api.rest on", "", "Start REST API server.", func(args []string) error { @@ -106,9 +125,11 @@ func (api *RestAPI) Configure() error { return err } else if api.keyFile, err = core.ExpandPath(api.keyFile); err != nil { return err - } else if err, ApiUsername = api.StringParam("api.rest.username"); err != nil { + } else if err, api.username = api.StringParam("api.rest.username"); err != nil { return err - } else if err, ApiPassword = api.StringParam("api.rest.password"); err != nil { + } else if err, api.password = api.StringParam("api.rest.password"); err != nil { + return err + } else if err, api.useWebsocket = api.BoolParam("api.rest.websocket"); err != nil { return err } else if core.Exists(api.certFile) == false || core.Exists(api.keyFile) == false { log.Info("Generating TLS key to %s", api.keyFile) @@ -125,8 +146,8 @@ func (api *RestAPI) Configure() error { router := http.NewServeMux() - router.HandleFunc("/api/session", SessionRoute) - router.HandleFunc("/api/events", EventsRoute) + router.HandleFunc("/api/session", api.sessionRoute) + router.HandleFunc("/api/events", api.eventsRoute) api.server.Handler = router @@ -153,6 +174,10 @@ func (api *RestAPI) Start() error { func (api *RestAPI) Stop() error { return api.SetRunning(false, func() { + go func() { + api.quit <- true + }() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() api.server.Shutdown(ctx) diff --git a/modules/api_rest_controller.go b/modules/api_rest_controller.go index 05457ae4..53480ea8 100644 --- a/modules/api_rest_controller.go +++ b/modules/api_rest_controller.go @@ -5,13 +5,22 @@ import ( "encoding/json" "net/http" "strconv" + "strings" + "time" + "github.com/bettercap/bettercap/log" "github.com/bettercap/bettercap/session" + + "github.com/gorilla/websocket" ) -var ( - ApiUsername = "" - ApiPassword = "" +const ( + // Time allowed to write an event to the client. + writeWait = 10 * time.Second + // Time allowed to read the next pong message from the client. + pongWait = 60 * time.Second + // Send pings to client with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 ) type CommandRequest struct { @@ -23,18 +32,9 @@ type APIResponse struct { Message string `json:"msg"` } -func checkAuth(r *http.Request) bool { - user, pass, _ := r.BasicAuth() - // timing attack my ass - if subtle.ConstantTimeCompare([]byte(user), []byte(ApiUsername)) != 1 { - return false - } else if subtle.ConstantTimeCompare([]byte(pass), []byte(ApiPassword)) != 1 { - return false - } - return true -} +func setAuthFailed(w http.ResponseWriter, r *http.Request) { + log.Warning("Unauthorized authentication attempt from %s", r.RemoteAddr) -func setAuthFailed(w http.ResponseWriter) { w.Header().Set("WWW-Authenticate", `Basic realm="auth"`) w.WriteHeader(401) w.Write([]byte("Unauthorized")) @@ -52,11 +52,22 @@ func toJSON(w http.ResponseWriter, o interface{}) { json.NewEncoder(w).Encode(o) } -func showSession(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) checkAuth(r *http.Request) bool { + user, pass, _ := r.BasicAuth() + // timing attack my ass + if subtle.ConstantTimeCompare([]byte(user), []byte(api.username)) != 1 { + return false + } else if subtle.ConstantTimeCompare([]byte(pass), []byte(api.password)) != 1 { + return false + } + return true +} + +func (api *RestAPI) showSession(w http.ResponseWriter, r *http.Request) { toJSON(w, session.I) } -func runSessionCommand(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) runSessionCommand(w http.ResponseWriter, r *http.Request) { var err error var cmd CommandRequest @@ -71,56 +82,151 @@ func runSessionCommand(w http.ResponseWriter, r *http.Request) { } } -func showEvents(w http.ResponseWriter, r *http.Request) { - var err error +func (api *RestAPI) streamEvent(ws *websocket.Conn, event session.Event) error { + msg, err := json.Marshal(event) + if err != nil { + log.Error("Error while creating websocket message: %s", err) + return err + } - events := session.I.Events.Sorted() - nmax := len(events) - n := nmax - - q := r.URL.Query() - vals := q["n"] - if len(vals) > 0 { - n, err = strconv.Atoi(q["n"][0]) - if err == nil { - if n > nmax { - n = nmax - } - } else { - n = nmax + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.TextMessage, msg); err != nil { + if !strings.Contains(err.Error(), "closed connection") { + log.Error("Error while writing websocket message: %s", err) + return err } } - toJSON(w, events[0:n]) + return nil } -func clearEvents(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) sendPing(ws *websocket.Conn) error { + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + log.Error("Error while writing websocket ping message: %s", err) + return err + } + return nil +} + +func (api *RestAPI) streamWriter(ws *websocket.Conn, w http.ResponseWriter, r *http.Request) { + defer ws.Close() + + // first we stream what we already have + events := session.I.Events.Sorted() + n := len(events) + if n > 0 { + log.Debug("Sending %d events.", n) + for _, event := range events { + if err := api.streamEvent(ws, event); err != nil { + return + } + } + } + + session.I.Events.Clear() + + log.Debug("Listening for events and streaming to ws endpoint ...") + + pingTicker := time.NewTicker(pingPeriod) + + for { + select { + case <-pingTicker.C: + if err := api.sendPing(ws); err != nil { + return + } + case event := <-api.eventListener: + if err := api.streamEvent(ws, event); err != nil { + return + } + case <-api.quit: + log.Info("Stopping websocket events streamer ...") + return + } + } +} + +func (api *RestAPI) streamReader(ws *websocket.Conn) { + defer ws.Close() + ws.SetReadLimit(512) + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, _, err := ws.ReadMessage() + if err != nil { + log.Debug("Closing websocket reader.") + break + } + } +} + +func (api *RestAPI) showEvents(w http.ResponseWriter, r *http.Request) { + var err error + + if api.useWebsocket { + ws, err := api.upgrader.Upgrade(w, r, nil) + if err != nil { + if _, ok := err.(websocket.HandshakeError); !ok { + log.Error("Error while updating api.rest connection to websocket: %s", err) + } + return + } + + log.Debug("Websocket streaming started for %s", r.RemoteAddr) + + go api.streamWriter(ws, w, r) + api.streamReader(ws) + } else { + + events := session.I.Events.Sorted() + nmax := len(events) + n := nmax + + q := r.URL.Query() + vals := q["n"] + if len(vals) > 0 { + n, err = strconv.Atoi(q["n"][0]) + if err == nil { + if n > nmax { + n = nmax + } + } else { + n = nmax + } + } + + toJSON(w, events[0:n]) + } +} + +func (api *RestAPI) clearEvents(w http.ResponseWriter, r *http.Request) { session.I.Events.Clear() } -func SessionRoute(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) sessionRoute(w http.ResponseWriter, r *http.Request) { setSecurityHeaders(w) - if checkAuth(r) == false { - setAuthFailed(w) + if api.checkAuth(r) == false { + setAuthFailed(w, r) } else if r.Method == "GET" { - showSession(w, r) + api.showSession(w, r) } else if r.Method == "POST" { - runSessionCommand(w, r) + api.runSessionCommand(w, r) } else { http.Error(w, "Bad Request", 400) } } -func EventsRoute(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) eventsRoute(w http.ResponseWriter, r *http.Request) { setSecurityHeaders(w) - if checkAuth(r) == false { - setAuthFailed(w) + if api.checkAuth(r) == false { + setAuthFailed(w, r) } else if r.Method == "GET" { - showEvents(w, r) + api.showEvents(w, r) } else if r.Method == "DELETE" { - clearEvents(w, r) + api.clearEvents(w, r) } else { http.Error(w, "Bad Request", 400) } diff --git a/modules/events_stream.go b/modules/events_stream.go index 70e93b73..99d009e2 100644 --- a/modules/events_stream.go +++ b/modules/events_stream.go @@ -12,10 +12,11 @@ import ( type EventsStream struct { session.SessionModule - ignoreList *IgnoreList - waitFor string - waitChan chan *session.Event - quit chan bool + ignoreList *IgnoreList + waitFor string + waitChan chan *session.Event + eventListener <-chan session.Event + quit chan bool } func NewEventsStream(s *session.Session) *EventsStream { @@ -124,10 +125,12 @@ func (s *EventsStream) Configure() error { func (s *EventsStream) Start() error { return s.SetRunning(true, func() { + + s.eventListener = s.Session.Events.Listen() for { var e session.Event select { - case e = <-s.Session.Events.NewEvents: + case e = <-s.eventListener: if e.Tag == s.waitFor { s.waitFor = "" s.waitChan <- &e diff --git a/network/ble_device.go b/network/ble_device.go index 50bf097a..40a72199 100644 --- a/network/ble_device.go +++ b/network/ble_device.go @@ -4,6 +4,7 @@ package network import ( + "encoding/json" "time" "github.com/bettercap/gatt" @@ -11,10 +12,18 @@ import ( type BLEDevice struct { LastSeen time.Time - Device gatt.Peripheral Vendor string - Advertisement *gatt.Advertisement RSSI int + Device gatt.Peripheral + Advertisement *gatt.Advertisement +} + +type bleDeviceJSON struct { + LastSeen time.Time `json:"last_seen"` + Name string `json:"name"` + MAC string `json:"mac"` + Vendor string `json:"vendor"` + RSSI int `json:"rssi"` } func NewBLEDevice(p gatt.Peripheral, a *gatt.Advertisement, rssi int) *BLEDevice { @@ -26,3 +35,15 @@ func NewBLEDevice(p gatt.Peripheral, a *gatt.Advertisement, rssi int) *BLEDevice RSSI: rssi, } } + +func (d *BLEDevice) MarshalJSON() ([]byte, error) { + doc := bleDeviceJSON{ + LastSeen: d.LastSeen, + Name: d.Device.Name(), + MAC: d.Device.ID(), + Vendor: d.Vendor, + RSSI: d.RSSI, + } + + return json.Marshal(doc) +} diff --git a/session/events.go b/session/events.go index 26e719f5..24fe1899 100644 --- a/session/events.go +++ b/session/events.go @@ -39,21 +39,29 @@ func (e Event) Label() string { type EventPool struct { sync.Mutex - NewEvents chan Event debug bool silent bool events []Event + listeners []chan Event } func NewEventPool(debug bool, silent bool) *EventPool { return &EventPool{ - NewEvents: make(chan Event, 0xff), debug: debug, silent: silent, events: make([]Event, 0), + listeners: make([]chan Event, 0), } } +func (p *EventPool) Listen() <-chan Event { + p.Lock() + defer p.Unlock() + l := make(chan Event, 1) + p.listeners = append(p.listeners, l) + return l +} + func (p *EventPool) SetSilent(s bool) { p.Lock() defer p.Unlock() @@ -69,9 +77,17 @@ func (p *EventPool) SetDebug(d bool) { func (p *EventPool) Add(tag string, data interface{}) { p.Lock() defer p.Unlock() + e := NewEvent(tag, data) p.events = append([]Event{e}, p.events...) - p.NewEvents <- e + + // broadcast the event to every listener + for _, l := range p.listeners { + select { + case l <- e: + default: + } + } } func (p *EventPool) Log(level int, format string, args ...interface{}) {