diff --git a/Makefile b/Makefile index e8c1b53a..db6a8eaf 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ TARGET=bettercap-ng BUILD_DATE=`date +%Y-%m-%d\ %H:%M` BUILD_FILE=core/build.go -all: fmt vet lint build +all: fmt vet build @echo "@ Done" @echo -n "\n" diff --git a/session/modules/api_rest.go b/session/modules/api_rest.go index 97c0cd80..82f0166c 100644 --- a/session/modules/api_rest.go +++ b/session/modules/api_rest.go @@ -1,12 +1,10 @@ package session_modules import ( - "encoding/base64" - "encoding/json" + "context" "fmt" "net/http" - "strconv" - "strings" + "time" "github.com/evilsocket/bettercap-ng/core" "github.com/evilsocket/bettercap-ng/session" @@ -71,8 +69,7 @@ func NewRestAPI(s *session.Session) *RestAPI { return api.Stop() })) - http.HandleFunc("/api/session", api.sessRoute) - http.HandleFunc("/api/events", api.eventsRoute) + api.setupRoutes() return api } @@ -85,149 +82,30 @@ type JSSessionResponse struct { Error string `json:"error"` } -func (api *RestAPI) sessRoute(w http.ResponseWriter, r *http.Request) { - if api.checkAuth(w, r) == false { - return - } - - if r.Method == "GET" { - js, err := json.Marshal(api.Session) - if err != nil { - api.Session.Events.Log(session.ERROR, "Error while returning session: %s", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(js) - } else if r.Method == "POST" && r.Body != nil { - var req JSSessionRequest - var res JSSessionResponse - - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - http.Error(w, err.Error(), 400) - return - } - - err = api.Session.Run(req.Command) - if err != nil { - res.Error = err.Error() - } - js, err := json.Marshal(res) - if err != nil { - api.Session.Events.Log(session.ERROR, "Error while returning response: %s", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(js) - } else { - http.Error(w, "Not Found", 404) - } -} - -func (api *RestAPI) eventsRoute(w http.ResponseWriter, r *http.Request) { - if api.checkAuth(w, r) == false { - return - } - - if r.Method == "GET" { - var err error - - events := api.Session.Events.Events() - nmax := len(events) - n := nmax - - keys, ok := r.URL.Query()["n"] - if len(keys) == 1 && ok { - sn := keys[0] - n, err = strconv.Atoi(sn) - if err == nil { - if n > nmax { - n = nmax - } - } else { - n = nmax - } - } - - js, err := json.Marshal(events[0:n]) - if err != nil { - api.Session.Events.Log(session.ERROR, "Error while returning events: %s", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(js) - } else if r.Method == "DELETE" { - api.Session.Events.Clear() - api.Session.Events.Add("sys.log.cleared", nil) - } else { - http.Error(w, "Not Found", 404) - } -} - -func (api RestAPI) checkAuth(w http.ResponseWriter, r *http.Request) bool { - if api.Authenticated(w, r) == false { - api.Session.Events.Log(session.WARNING, "Unauthenticated access!") - http.Error(w, "Not authorized", 401) - return false - } - return true -} - -func (api RestAPI) Authenticated(w http.ResponseWriter, r *http.Request) bool { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - - parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(parts) != 2 { - return false - } - - b, err := base64.StdEncoding.DecodeString(parts[1]) - if err != nil { - return false - } - - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return false - } - - if pair[0] != api.username || pair[1] != api.password { - return false - } - - return true -} - -func (api RestAPI) Name() string { +func (api *RestAPI) Name() string { return "REST API" } -func (api RestAPI) Description() string { +func (api *RestAPI) Description() string { return "Expose a RESTful API." } -func (api RestAPI) Author() string { +func (api *RestAPI) Author() string { return "Simone Margaritelli " } -func (api RestAPI) OnSessionStarted(s *session.Session) { +func (api *RestAPI) OnSessionStarted(s *session.Session) { // refresh the address after session has been created s.Env.Set("api.rest.address", s.Interface.IpAddress) } -func (api RestAPI) OnSessionEnded(s *session.Session) { +func (api *RestAPI) OnSessionEnded(s *session.Session) { if api.Running() { api.Stop() } } -func (api *RestAPI) Start() error { +func (api *RestAPI) configure() error { var address string var port int @@ -243,6 +121,8 @@ func (api *RestAPI) Start() error { port = v.(int) } + api.server.Addr = fmt.Sprintf("%s:%d", address, port) + if err, v := api.Param("api.rest.certificate").Get(api.Session); err != nil { return err } else { @@ -290,9 +170,16 @@ func (api *RestAPI) Start() error { api.Session.Events.Log(session.INFO, "Loading TLS certificate from %s", api.certFile) } + return nil +} + +func (api *RestAPI) Start() error { + if err := api.configure(); err != nil { + return err + } + if api.Running() == false { api.SetRunning(true) - api.server.Addr = fmt.Sprintf("%s:%d", address, port) go func() { api.Session.Events.Log(session.INFO, "API server starting on https://%s", api.server.Addr) err := api.server.ListenAndServeTLS(api.certFile, api.keyFile) @@ -302,15 +189,17 @@ func (api *RestAPI) Start() error { }() return nil - } else { - return fmt.Errorf("REST API server already started.") } + + return fmt.Errorf("REST API server already started.") } func (api *RestAPI) Stop() error { if api.Running() == true { api.SetRunning(false) - return api.server.Shutdown(nil) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + return api.server.Shutdown(ctx) } else { return fmt.Errorf("REST API server already stopped.") } diff --git a/session/modules/api_rest_routes.go b/session/modules/api_rest_routes.go new file mode 100644 index 00000000..22834588 --- /dev/null +++ b/session/modules/api_rest_routes.go @@ -0,0 +1,135 @@ +package session_modules + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/evilsocket/bettercap-ng/session" +) + +func (api *RestAPI) setupRoutes() { + http.HandleFunc("/api/session", api.sessRoute) + http.HandleFunc("/api/events", api.eventsRoute) +} + +func (api RestAPI) checkAuth(w http.ResponseWriter, r *http.Request) bool { + if api.Authenticated(w, r) == false { + api.Session.Events.Log(session.WARNING, "Unauthenticated access!") + http.Error(w, "Not authorized", 401) + return false + } + return true +} + +func (api RestAPI) Authenticated(w http.ResponseWriter, r *http.Request) bool { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + + parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(parts) != 2 { + return false + } + + b, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return false + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return false + } + + if pair[0] != api.username || pair[1] != api.password { + return false + } + + return true +} + +func (api *RestAPI) sessRoute(w http.ResponseWriter, r *http.Request) { + if api.checkAuth(w, r) == false { + return + } + + if r.Method == "GET" { + js, err := json.Marshal(api.Session) + if err != nil { + api.Session.Events.Log(session.ERROR, "Error while returning session: %s", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(js) + } else if r.Method == "POST" && r.Body != nil { + var req JSSessionRequest + var res JSSessionResponse + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, err.Error(), 400) + return + } + + err = api.Session.Run(req.Command) + if err != nil { + res.Error = err.Error() + } + js, err := json.Marshal(res) + if err != nil { + api.Session.Events.Log(session.ERROR, "Error while returning response: %s", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(js) + } else { + http.Error(w, "Not Found", 404) + } +} + +func (api *RestAPI) eventsRoute(w http.ResponseWriter, r *http.Request) { + if api.checkAuth(w, r) == false { + return + } + + if r.Method == "GET" { + var err error + + events := api.Session.Events.Events() + nmax := len(events) + n := nmax + + keys, ok := r.URL.Query()["n"] + if len(keys) == 1 && ok { + sn := keys[0] + n, err = strconv.Atoi(sn) + if err == nil { + if n > nmax { + n = nmax + } + } else { + n = nmax + } + } + + js, err := json.Marshal(events[0:n]) + if err != nil { + api.Session.Events.Log(session.ERROR, "Error while returning events: %s", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(js) + } else if r.Method == "DELETE" { + api.Session.Events.Clear() + api.Session.Events.Add("sys.log.cleared", nil) + } else { + http.Error(w, "Not Found", 404) + } +}