diff --git a/modules/api_rest.go b/modules/api_rest.go index b99b2173..811fa6b4 100644 --- a/modules/api_rest.go +++ b/modules/api_rest.go @@ -17,15 +17,14 @@ import ( type RestAPI struct { session.SessionModule - server *http.Server - username string - password string - certFile string - keyFile string - useWebsocket bool - upgrader websocket.Upgrader - eventListener <-chan session.Event - quit chan bool + server *http.Server + username string + password string + certFile string + keyFile string + useWebsocket bool + upgrader websocket.Upgrader + quit chan bool } func NewRestAPI(s *session.Session) *RestAPI { @@ -34,7 +33,6 @@ func NewRestAPI(s *session.Session) *RestAPI { server: &http.Server{}, quit: make(chan bool), useWebsocket: false, - eventListener: s.Events.Listen(), upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, diff --git a/modules/api_rest_ws.go b/modules/api_rest_ws.go index b32a8bed..931ea9c3 100644 --- a/modules/api_rest_ws.go +++ b/modules/api_rest_ws.go @@ -68,6 +68,8 @@ func (api *RestAPI) streamWriter(ws *websocket.Conn, w http.ResponseWriter, r *h log.Debug("Listening for events and streaming to ws endpoint ...") pingTicker := time.NewTicker(pingPeriod) + listener := session.I.Events.Listen() + defer session.I.Events.Unlisten(listener) for { select { @@ -75,7 +77,7 @@ func (api *RestAPI) streamWriter(ws *websocket.Conn, w http.ResponseWriter, r *h if err := api.sendPing(ws); err != nil { return } - case event := <-api.eventListener: + case event := <-listener: if err := api.streamEvent(ws, event); err != nil { return } diff --git a/session/events.go b/session/events.go index b919ade3..733d12f9 100644 --- a/session/events.go +++ b/session/events.go @@ -62,6 +62,19 @@ func (p *EventPool) Listen() <-chan Event { return l } +func (p *EventPool) Unlisten(listener <-chan Event) { + p.Lock() + defer p.Unlock() + + for i, l := range p.listeners { + if l == listener { + close(l) + p.listeners = append(p.listeners[:i], p.listeners[i+1:]...) + return + } + } +} + func (p *EventPool) SetSilent(s bool) { p.Lock() defer p.Unlock() @@ -83,16 +96,7 @@ func (p *EventPool) Add(tag string, data interface{}) { // broadcast the event to every listener for _, l := range p.listeners { - select { - case l <- e: - // NOTE: Without this 'default', errors in sending the event - // to the listener would not empty the channel, therefore - // all operations would be stuck at some point (after the first - // event if not buffered or after the first N events if buffered) - // - // See https://github.com/bettercap/bettercap/issues/198 - default: - } + l <- e } }