diff --git a/modules/api_rest.go b/modules/api_rest.go index 25fc3398..62f3eabf 100644 --- a/modules/api_rest.go +++ b/modules/api_rest.go @@ -14,15 +14,19 @@ import ( type RestAPI struct { session.SessionModule - server *http.Server - certFile string - keyFile string + server *http.Server + username string + password string + certFile string + keyFile string + useWebsocket bool } func NewRestAPI(s *session.Session) *RestAPI { api := &RestAPI{ SessionModule: session.NewSessionModule("api.rest", s), server: &http.Server{}, + useWebsocket: false, } api.AddParam(session.NewStringParameter("api.rest.address", @@ -54,6 +58,10 @@ func NewRestAPI(s *session.Session) *RestAPI { "", "API TLS key")) + api.AddParam(session.NewBoolParameter("api.rest.websocket", + "false", + "If true the /api/events route will be available as a websocket endpoint instead of HTTPS.")) + api.AddHandler(session.NewModuleHandler("api.rest on", "", "Start REST API server.", func(args []string) error { @@ -106,9 +114,11 @@ func (api *RestAPI) Configure() error { return err } else if api.keyFile, err = core.ExpandPath(api.keyFile); err != nil { return err - } else if err, ApiUsername = api.StringParam("api.rest.username"); err != nil { + } else if err, api.username = api.StringParam("api.rest.username"); err != nil { return err - } else if err, ApiPassword = api.StringParam("api.rest.password"); err != nil { + } else if err, api.password = api.StringParam("api.rest.password"); err != nil { + return err + } else if err, api.useWebsocket = api.BoolParam("api.rest.websocket"); err != nil { return err } else if core.Exists(api.certFile) == false || core.Exists(api.keyFile) == false { log.Info("Generating TLS key to %s", api.keyFile) @@ -125,8 +135,8 @@ func (api *RestAPI) Configure() error { router := http.NewServeMux() - router.HandleFunc("/api/session", SessionRoute) - router.HandleFunc("/api/events", EventsRoute) + router.HandleFunc("/api/session", api.sessionRoute) + router.HandleFunc("/api/events", api.eventsRoute) api.server.Handler = router diff --git a/modules/api_rest_controller.go b/modules/api_rest_controller.go index 05457ae4..717e1f74 100644 --- a/modules/api_rest_controller.go +++ b/modules/api_rest_controller.go @@ -9,11 +9,6 @@ import ( "github.com/bettercap/bettercap/session" ) -var ( - ApiUsername = "" - ApiPassword = "" -) - type CommandRequest struct { Command string `json:"cmd"` } @@ -23,17 +18,6 @@ type APIResponse struct { Message string `json:"msg"` } -func checkAuth(r *http.Request) bool { - user, pass, _ := r.BasicAuth() - // timing attack my ass - if subtle.ConstantTimeCompare([]byte(user), []byte(ApiUsername)) != 1 { - return false - } else if subtle.ConstantTimeCompare([]byte(pass), []byte(ApiPassword)) != 1 { - return false - } - return true -} - func setAuthFailed(w http.ResponseWriter) { w.Header().Set("WWW-Authenticate", `Basic realm="auth"`) w.WriteHeader(401) @@ -52,11 +36,22 @@ func toJSON(w http.ResponseWriter, o interface{}) { json.NewEncoder(w).Encode(o) } -func showSession(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) checkAuth(r *http.Request) bool { + user, pass, _ := r.BasicAuth() + // timing attack my ass + if subtle.ConstantTimeCompare([]byte(user), []byte(api.username)) != 1 { + return false + } else if subtle.ConstantTimeCompare([]byte(pass), []byte(api.password)) != 1 { + return false + } + return true +} + +func (api *RestAPI) showSession(w http.ResponseWriter, r *http.Request) { toJSON(w, session.I) } -func runSessionCommand(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) runSessionCommand(w http.ResponseWriter, r *http.Request) { var err error var cmd CommandRequest @@ -71,7 +66,7 @@ func runSessionCommand(w http.ResponseWriter, r *http.Request) { } } -func showEvents(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) showEvents(w http.ResponseWriter, r *http.Request) { var err error events := session.I.Events.Sorted() @@ -94,33 +89,33 @@ func showEvents(w http.ResponseWriter, r *http.Request) { toJSON(w, events[0:n]) } -func clearEvents(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) clearEvents(w http.ResponseWriter, r *http.Request) { session.I.Events.Clear() } -func SessionRoute(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) sessionRoute(w http.ResponseWriter, r *http.Request) { setSecurityHeaders(w) - if checkAuth(r) == false { + if api.checkAuth(r) == false { setAuthFailed(w) } else if r.Method == "GET" { - showSession(w, r) + api.showSession(w, r) } else if r.Method == "POST" { - runSessionCommand(w, r) + api.runSessionCommand(w, r) } else { http.Error(w, "Bad Request", 400) } } -func EventsRoute(w http.ResponseWriter, r *http.Request) { +func (api *RestAPI) eventsRoute(w http.ResponseWriter, r *http.Request) { setSecurityHeaders(w) - if checkAuth(r) == false { + if api.checkAuth(r) == false { setAuthFailed(w) } else if r.Method == "GET" { - showEvents(w, r) + api.showEvents(w, r) } else if r.Method == "DELETE" { - clearEvents(w, r) + api.clearEvents(w, r) } else { http.Error(w, "Bad Request", 400) }