diff --git a/modules/http_proxy_script_test.go b/modules/http_proxy_script_benchmark_test.go similarity index 93% rename from modules/http_proxy_script_test.go rename to modules/http_proxy_script_benchmark_test.go index 98e1d4ae..9d47c5c8 100644 --- a/modules/http_proxy_script_test.go +++ b/modules/http_proxy_script_benchmark_test.go @@ -10,7 +10,7 @@ import ( func getScript(src string) *HttpProxyScript { sess := session.Session{} - sess.Env = session.NewEnvironment(&sess, "") + sess.Env, _ = session.NewEnvironment("") err, script := LoadHttpProxyScriptSource("", src, &sess) if err != nil { diff --git a/session/command_handler_test.go b/session/command_handler_test.go index be546800..21a4f9bb 100644 --- a/session/command_handler_test.go +++ b/session/command_handler_test.go @@ -39,15 +39,16 @@ func TestSessionCommandHandler(t *testing.T) { if u.panic { assertPanic(t, "", func() { _ = NewCommandHandler("", u.expr, "", nil) + t.Fatal("panic expected") }) } else { c := NewCommandHandler("", u.expr, "", nil) shouldNotParse := "simple123" shouldParse := "simple 123" - if parsed, parts := c.Parse(shouldNotParse); parsed { + if parsed, _ := c.Parse(shouldNotParse); parsed { t.Fatalf("should not parse '%s'", shouldNotParse) - } else if parsed, parts = c.Parse(shouldParse); !parsed { + } else if parsed, parts := c.Parse(shouldParse); !parsed { t.Fatalf("should parse '%s'", shouldParse) } else if !sameStrings(parts, u.parsed) { t.Fatalf("expected '%v', got '%v'", u.parsed, parts) diff --git a/session/environment.go b/session/environment.go index ea3a7e1d..a47496bd 100644 --- a/session/environment.go +++ b/session/environment.go @@ -11,32 +11,30 @@ import ( "github.com/bettercap/bettercap/core" ) -type SetCallback func(newValue string) +type EnvironmentChangedCallback func(newValue string) type Environment struct { sync.Mutex Data map[string]string `json:"data"` - cbs map[string]SetCallback - sess *Session + cbs map[string]EnvironmentChangedCallback } -func NewEnvironment(s *Session, envFile string) *Environment { +func NewEnvironment(envFile string) (*Environment, error) { env := &Environment{ Data: make(map[string]string), - sess: s, - cbs: make(map[string]SetCallback), + cbs: make(map[string]EnvironmentChangedCallback), } if envFile != "" { envFile, _ := core.ExpandPath(envFile) if core.Exists(envFile) { if err := env.Load(envFile); err != nil { - fmt.Printf("Error while loading %s: %s\n", envFile, err) + return nil, err } } } - return env + return env, nil } func (env *Environment) Load(fileName string) error { @@ -48,7 +46,10 @@ func (env *Environment) Load(fileName string) error { return err } - return json.Unmarshal(raw, &env.Data) + if len(raw) > 0 { + return json.Unmarshal(raw, &env.Data) + } + return nil } func (env *Environment) Save(fileName string) error { @@ -72,15 +73,15 @@ func (env *Environment) Has(name string) bool { return found } -func (env *Environment) SetCallback(name string, cb SetCallback) { +func (env *Environment) addCb(name string, cb EnvironmentChangedCallback) { env.Lock() defer env.Unlock() env.cbs[name] = cb } -func (env *Environment) WithCallback(name, value string, cb SetCallback) string { +func (env *Environment) WithCallback(name, value string, cb EnvironmentChangedCallback) string { + env.addCb(name, cb) ret := env.Set(name, value) - env.SetCallback(name, cb) return ret } @@ -95,8 +96,6 @@ func (env *Environment) Set(name, value string) string { cb(value) } - env.sess.Events.Log(core.DEBUG, "env.change: %s -> '%s'", name, value) - return old } diff --git a/session/environment_test.go b/session/environment_test.go new file mode 100644 index 00000000..b3b3bb85 --- /dev/null +++ b/session/environment_test.go @@ -0,0 +1,282 @@ +package session + +import ( + "encoding/json" + "io/ioutil" + "os" + "reflect" + "testing" +) + +var ( + testEnvFile = "/tmp/test.env" + testEnvData = map[string]string{ + "people": "shit", + "moo": "boo", + "foo": "bar", + } + testEnvSorted = []string{"foo", "moo", "people"} +) + +func setup(t testing.TB, envFile bool, envFileData bool) { + teardown(t) + + if envFile { + if fp, err := os.OpenFile(testEnvFile, os.O_RDONLY|os.O_CREATE, 0666); err == nil { + fp.Close() + } else { + panic(err) + } + } + + if envFileData { + if raw, err := json.Marshal(testEnvData); err != nil { + panic(err) + } else if err = ioutil.WriteFile(testEnvFile, raw, 0755); err != nil { + panic(err) + } + } +} + +func teardown(t testing.TB) { + if err := os.RemoveAll(testEnvFile); err != nil { + panic(err) + } +} + +func TestSessionEnvironmentWithoutFile(t *testing.T) { + if env, err := NewEnvironment(""); env == nil { + t.Fatal("expected valid environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != 0 { + t.Fatalf("expected empty environment, found %d elements", len(env.Data)) + } +} + +func TestSessionEnvironmentWithInvalidFile(t *testing.T) { + if env, err := NewEnvironment("/idontexist"); env == nil { + t.Fatal("expected valid environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != 0 { + t.Fatalf("expected empty environment, found %d elements", len(env.Data)) + } +} + +func TestSessionEnvironmentWithEmptyFile(t *testing.T) { + setup(t, true, false) + defer teardown(t) + + if env, err := NewEnvironment(testEnvFile); env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != 0 { + t.Fatalf("expected empty environment, found %d elements", len(env.Data)) + } +} + +func TestSessionEnvironmentWithDataFile(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + if env, err := NewEnvironment(testEnvFile); env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != len(testEnvData) { + t.Fatalf("expected %d, found %d", len(testEnvData), len(env.Data)) + } else if !reflect.DeepEqual(env.Data, testEnvData) { + t.Fatalf("unexpected contents: %v", env.Data) + } +} + +func TestSessionEnvironmentSaveWithError(t *testing.T) { + setup(t, false, false) + defer teardown(t) + + if env, err := NewEnvironment(testEnvFile); env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if err = env.Save("/lulz/nope"); err == nil { + t.Fatal("expected error") + } +} + +func TestSessionEnvironmentSave(t *testing.T) { + setup(t, false, false) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + env.Data["new"] = "value" + if err = env.Save(testEnvFile); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if env, err := NewEnvironment(testEnvFile); env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if !reflect.DeepEqual(env.Data, map[string]string{"new": "value"}) { + t.Fatalf("unexpected contents: %v", env.Data) + } +} + +func TestSessionEnvironmentHas(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != len(testEnvData) { + t.Fatalf("expected %d, found %d", len(testEnvData), len(env.Data)) + } + + for k := range testEnvData { + if !env.Has(k) { + t.Fatalf("could not find key '%s'", k) + } + } + + for _, k := range []string{"these", "keys", "should", "not", "be", "found"} { + if env.Has(k) { + t.Fatalf("unexpected key '%s'", k) + } + } +} + +func TestSessionEnvironmentSet(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if old := env.Set("people", "ok"); old != "shit" { + t.Fatalf("unexpected old value: %s", old) + } else if env.Data["people"] != "ok" { + t.Fatalf("unexpected new value: %s", env.Data["people"]) + } else if old := env.Set("newkey", "nk"); old != "" { + t.Fatalf("unexpected old value: %s", old) + } else if env.Data["newkey"] != "nk" { + t.Fatalf("unexpected new value: %s", env.Data["newkey"]) + } +} + +func TestSessionEnvironmentSetWithCallback(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cbCalled := false + old := env.WithCallback("people", "ok", func(newValue string) { + cbCalled = true + }) + if old != "shit" { + t.Fatalf("unexpected old value: %s", old) + } + + cbCalled = false + old = env.Set("people", "shitagain") + if old != "ok" { + t.Fatalf("unexpected old value: %s", old) + } else if !cbCalled { + t.Fatal("callback has not been called") + } + + cbCalled = false + env.Set("something", "else") + if cbCalled { + t.Fatal("callback should not have been called") + } +} + +func TestSessionEnvironmentGet(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != len(testEnvData) { + t.Fatalf("expected %d, found %d", len(testEnvData), len(env.Data)) + } + + for k, v := range testEnvData { + if found, vv := env.Get(k); !found { + t.Fatalf("should have found %s", k) + } else if v != vv { + t.Fatalf("unexpected value found: %s", vv) + } + } + + for _, k := range []string{"these", "keys", "should", "not", "be", "found"} { + if found, _ := env.Get(k); found { + t.Fatalf("should not have found %s", k) + } + } +} + +func TestSessionEnvironmentGetInt(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != len(testEnvData) { + t.Fatalf("expected %d, found %d", len(testEnvData), len(env.Data)) + } + + for k := range testEnvData { + if err, _ := env.GetInt(k); err == nil { + t.Fatal("expected error") + } + } + + env.Data["num"] = "1234" + if err, i := env.GetInt("num"); err != nil { + t.Fatalf("unexpected error: %v", err) + } else if i != 1234 { + t.Fatalf("unexpected integer: %d", i) + } +} + +func TestSessionEnvironmentSorted(t *testing.T) { + setup(t, true, true) + defer teardown(t) + + env, err := NewEnvironment(testEnvFile) + if env == nil { + t.Fatal("expected environment") + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } else if len(env.Data) != len(testEnvData) { + t.Fatalf("expected %d, found %d", len(testEnvData), len(env.Data)) + } else if sorted := env.Sorted(); !reflect.DeepEqual(sorted, testEnvSorted) { + t.Fatalf("unexpected sorted keys: %v", sorted) + } +} diff --git a/session/session.go b/session/session.go index 7fc626e0..22ea459d 100644 --- a/session/session.go +++ b/session/session.go @@ -157,7 +157,10 @@ func New() (*Session, error) { } } - s.Env = NewEnvironment(s, *s.Options.EnvFile) + if s.Env, err = NewEnvironment(*s.Options.EnvFile); err != nil { + return nil, err + } + s.Events = NewEventPool(*s.Options.Debug, *s.Options.Silent) s.registerCoreHandlers()