diff --git a/modules/api_rest.go b/modules/api_rest.go index 62f3eabf..3033a488 100644 --- a/modules/api_rest.go +++ b/modules/api_rest.go @@ -10,23 +10,33 @@ import ( "github.com/bettercap/bettercap/log" "github.com/bettercap/bettercap/session" "github.com/bettercap/bettercap/tls" + + "github.com/gorilla/websocket" ) type RestAPI struct { session.SessionModule - server *http.Server - username string - password string - certFile string - keyFile string - useWebsocket bool + server *http.Server + username string + password string + certFile string + keyFile string + useWebsocket bool + upgrader websocket.Upgrader + eventListener <-chan session.Event + quit chan bool } func NewRestAPI(s *session.Session) *RestAPI { api := &RestAPI{ SessionModule: session.NewSessionModule("api.rest", s), server: &http.Server{}, + quit: make(chan bool), useWebsocket: false, + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, } api.AddParam(session.NewStringParameter("api.rest.address", @@ -163,6 +173,10 @@ func (api *RestAPI) Start() error { func (api *RestAPI) Stop() error { return api.SetRunning(false, func() { + go func() { + api.quit <- true + }() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() api.server.Shutdown(ctx) diff --git a/modules/api_rest_controller.go b/modules/api_rest_controller.go index 717e1f74..d6eee401 100644 --- a/modules/api_rest_controller.go +++ b/modules/api_rest_controller.go @@ -5,8 +5,21 @@ import ( "encoding/json" "net/http" "strconv" + "time" + "github.com/bettercap/bettercap/log" "github.com/bettercap/bettercap/session" + + "github.com/gorilla/websocket" +) + +const ( + // Time allowed to write an event to the client. + writeWait = 10 * time.Second + // Time allowed to read the next pong message from the client. + pongWait = 60 * time.Second + // Send pings to client with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 ) type CommandRequest struct { @@ -18,7 +31,9 @@ type APIResponse struct { Message string `json:"msg"` } -func setAuthFailed(w http.ResponseWriter) { +func setAuthFailed(w http.ResponseWriter, r *http.Request) { + log.Warning("Unauthorized authentication attempt from %s", r.RemoteAddr) + w.Header().Set("WWW-Authenticate", `Basic realm="auth"`) w.WriteHeader(401) w.Write([]byte("Unauthorized")) @@ -66,27 +81,119 @@ func (api *RestAPI) runSessionCommand(w http.ResponseWriter, r *http.Request) { } } -func (api *RestAPI) showEvents(w http.ResponseWriter, r *http.Request) { - var err error +func (api *RestAPI) streamWriter(ws *websocket.Conn, w http.ResponseWriter, r *http.Request) { + defer ws.Close() + // first we stream what we already have events := session.I.Events.Sorted() - nmax := len(events) - n := nmax - - q := r.URL.Query() - vals := q["n"] - if len(vals) > 0 { - n, err = strconv.Atoi(q["n"][0]) - if err == nil { - if n > nmax { - n = nmax + n := len(events) + if n > 0 { + log.Info("Sending %d events.", n) + for _, event := range events { + msg, err := json.Marshal(event) + if err != nil { + log.Error("Error while creating websocket message: %s", err) + return + } + + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.TextMessage, msg); err != nil { + log.Error("Error while writing websocket message: %s", err) + return } - } else { - n = nmax } } - toJSON(w, events[0:n]) + session.I.Events.Clear() + + log.Info("Listening for events and streaming to ws endpoint ...") + + api.eventListener = api.Session.Events.Listen() + + pingTicker := time.NewTicker(pingPeriod) + + for { + select { + case <-pingTicker.C: + ws.SetWriteDeadline(time.Now().Add(writeWait)) + log.Info("Ping") + if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + log.Error("Error while writing websocket ping message: %s", err) + return + } + case event := <-api.eventListener: + log.Info("Event") + msg, err := json.Marshal(event) + if err != nil { + log.Error("Error while creating websocket message: %s", err) + continue + } + + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.TextMessage, msg); err != nil { + log.Error("Error while writing websocket message: %s", err) + return + } + log.Info("Sent") + + case <-api.quit: + log.Info("Quit") + return + } + } +} + +func (api *RestAPI) streamReader(ws *websocket.Conn) { + defer ws.Close() + ws.SetReadLimit(512) + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, _, err := ws.ReadMessage() + if err != nil { + log.Info("Closing reader.") + break + } + } +} + +func (api *RestAPI) showEvents(w http.ResponseWriter, r *http.Request) { + var err error + + if api.useWebsocket { + ws, err := api.upgrader.Upgrade(w, r, nil) + if err != nil { + if _, ok := err.(websocket.HandshakeError); !ok { + log.Error("Error while updating api.rest connection to websocket: %s", err) + } + return + } + + log.Info("Websocket streaming started for %s", r.RemoteAddr) + + go api.streamWriter(ws, w, r) + api.streamReader(ws) + } else { + + events := session.I.Events.Sorted() + nmax := len(events) + n := nmax + + q := r.URL.Query() + vals := q["n"] + if len(vals) > 0 { + n, err = strconv.Atoi(q["n"][0]) + if err == nil { + if n > nmax { + n = nmax + } + } else { + n = nmax + } + } + + toJSON(w, events[0:n]) + } } func (api *RestAPI) clearEvents(w http.ResponseWriter, r *http.Request) { @@ -97,7 +204,7 @@ func (api *RestAPI) sessionRoute(w http.ResponseWriter, r *http.Request) { setSecurityHeaders(w) if api.checkAuth(r) == false { - setAuthFailed(w) + setAuthFailed(w, r) } else if r.Method == "GET" { api.showSession(w, r) } else if r.Method == "POST" { @@ -111,7 +218,7 @@ func (api *RestAPI) eventsRoute(w http.ResponseWriter, r *http.Request) { setSecurityHeaders(w) if api.checkAuth(r) == false { - setAuthFailed(w) + setAuthFailed(w, r) } else if r.Method == "GET" { api.showEvents(w, r) } else if r.Method == "DELETE" { diff --git a/modules/events_stream.go b/modules/events_stream.go index 70e93b73..99d009e2 100644 --- a/modules/events_stream.go +++ b/modules/events_stream.go @@ -12,10 +12,11 @@ import ( type EventsStream struct { session.SessionModule - ignoreList *IgnoreList - waitFor string - waitChan chan *session.Event - quit chan bool + ignoreList *IgnoreList + waitFor string + waitChan chan *session.Event + eventListener <-chan session.Event + quit chan bool } func NewEventsStream(s *session.Session) *EventsStream { @@ -124,10 +125,12 @@ func (s *EventsStream) Configure() error { func (s *EventsStream) Start() error { return s.SetRunning(true, func() { + + s.eventListener = s.Session.Events.Listen() for { var e session.Event select { - case e = <-s.Session.Events.NewEvents: + case e = <-s.eventListener: if e.Tag == s.waitFor { s.waitFor = "" s.waitChan <- &e diff --git a/session/events.go b/session/events.go index 26e719f5..cfd4f605 100644 --- a/session/events.go +++ b/session/events.go @@ -39,21 +39,29 @@ func (e Event) Label() string { type EventPool struct { sync.Mutex - NewEvents chan Event debug bool silent bool events []Event + listeners []chan Event } func NewEventPool(debug bool, silent bool) *EventPool { return &EventPool{ - NewEvents: make(chan Event, 0xff), debug: debug, silent: silent, events: make([]Event, 0), + listeners: make([]chan Event, 0), } } +func (p *EventPool) Listen() <-chan Event { + p.Lock() + defer p.Unlock() + l := make(chan Event) + p.listeners = append(p.listeners, l) + return l +} + func (p *EventPool) SetSilent(s bool) { p.Lock() defer p.Unlock() @@ -71,7 +79,11 @@ func (p *EventPool) Add(tag string, data interface{}) { defer p.Unlock() e := NewEvent(tag, data) p.events = append([]Event{e}, p.events...) - p.NewEvents <- e + + // broadcast the event to every listener + for _, l := range p.listeners { + l <- e + } } func (p *EventPool) Log(level int, format string, args ...interface{}) {