diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index e236489d..00000000 --- a/.gitattributes +++ /dev/null @@ -1,4 +0,0 @@ -*.js linguist-vendored -/Dockerfile linguist-vendored -/release.py linguist-vendored -/**/*.js linguist-vendored \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index 05551636..00000000 --- a/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,5 +0,0 @@ -blank_issues_enabled: false -contact_links: - - name: Bettercap Documentation - url: https://www.bettercap.org/ - about: Please read the instructions before asking for help. diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index c78a0857..00000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -version: 2 -updates: - # GitHub Actions - - package-ecosystem: github-actions - directory: / - schedule: - interval: daily diff --git a/.github/workflows/build-and-deploy.yml b/.github/workflows/build-and-deploy.yml index a9a770f0..a8f72dbd 100644 --- a/.github/workflows/build-and-deploy.yml +++ b/.github/workflows/build-and-deploy.yml @@ -8,57 +8,56 @@ on: jobs: build: - name: ${{ matrix.os.pretty }} ${{ matrix.arch }} - runs-on: ${{ matrix.os.runs-on }} + runs-on: ${{ matrix.os }} strategy: matrix: - os: - - name: darwin - runs-on: [macos-latest] - pretty: 🍎 macOS - - name: linux - runs-on: [ubuntu-latest] - pretty: 🐧 Linux - - name: windows - runs-on: [windows-latest] - pretty: 🪟 Windows - output: bettercap.exe - arch: [amd64, arm64] - go: [1.24.x] - exclude: - - os: - name: darwin + os: [ubuntu-latest, macos-latest, windows-latest] + go-version: ['1.22.x'] + include: + - os: ubuntu-latest arch: amd64 - # Linux ARM64 images are not yet publicly available (https://github.com/actions/runner-images) - - os: - name: linux + target_os: linux + target_arch: amd64 + - os: ubuntu-latest arch: arm64 - - os: - name: windows + target_os: linux + target_arch: aarch64 + - os: macos-latest arch: arm64 + target_os: darwin + target_arch: arm64 + - os: windows-latest + arch: amd64 + target_os: windows + target_arch: amd64 + output: bettercap.exe env: - OUTPUT: ${{ matrix.os.output || 'bettercap' }} + TARGET_OS: ${{ matrix.target_os }} + TARGET_ARCH: ${{ matrix.target_arch }} + GO_VERSION: ${{ matrix.go-version }} + OUTPUT: ${{ matrix.output || 'bettercap' }} steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v2 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v2 with: - go-version: ${{ matrix.go }} + go-version: ${{ matrix.go-version }} - name: Install Dependencies - if: ${{ matrix.os.name == 'linux' }} + if: ${{ matrix.os == 'ubuntu-latest' }} run: sudo apt-get update && sudo apt-get install -y p7zip-full libpcap-dev libnetfilter-queue-dev libusb-1.0-0-dev - name: Install Dependencies (macOS) - if: ${{ matrix.os.name == 'macos' }} + if: ${{ matrix.os == 'macos-latest' }} run: brew install libpcap libusb p7zip + - name: Install libusb via mingw (Windows) - if: ${{ matrix.os.name == 'windows' }} + if: ${{ matrix.os == 'windows-latest' }} uses: msys2/setup-msys2@v2 with: install: |- @@ -66,7 +65,7 @@ jobs: mingw64/mingw-w64-x86_64-pkg-config - name: Install other Dependencies (Windows) - if: ${{ matrix.os.name == 'windows' }} + if: ${{ matrix.os == 'windows-latest' }} run: | choco install openssl.light -y choco install make -y @@ -82,36 +81,25 @@ jobs: - name: Verify Build run: | file "${{ env.OUTPUT }}" - openssl dgst -sha256 "${{ env.OUTPUT }}" | tee bettercap_${{ matrix.os.name }}_${{ matrix.arch }}.sha256 - 7z a "bettercap_${{ matrix.os.name }}_${{ matrix.arch }}.zip" "${{ env.OUTPUT }}" "bettercap_${{ matrix.os.name }}_${{ matrix.arch }}.sha256" - - - name: Upload Artifacts - uses: actions/upload-artifact@v4 - with: - name: release-artifacts-${{ matrix.os.name }}-${{ matrix.arch }} - path: | - bettercap_*.zip - bettercap_*.sha256 + openssl dgst -sha256 "${{ env.OUTPUT }}" | tee bettercap_${{ matrix.target_os }}_${{ matrix.target_arch }}_${{ env.VERSION }}.sha256 + 7z a "bettercap_${{ matrix.target_os }}_${{ matrix.target_arch }}_${{ env.VERSION }}.zip" "${{ env.OUTPUT }}" "bettercap_${{ matrix.target_os }}_${{ matrix.target_arch }}_${{ env.VERSION }}.sha256" deploy: needs: [build] + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') name: Release runs-on: ubuntu-latest steps: - - name: Download Artifacts - uses: actions/download-artifact@v5 + - name: Checkout Code + uses: actions/checkout@v2 with: - pattern: release-artifacts-* - merge-multiple: true - path: dist/ - - - name: Release Assets - run: ls -l dist + submodules: true - name: Upload Release Assets - uses: softprops/action-gh-release@v2 - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + uses: softprops/action-gh-release@v1 with: - files: dist/bettercap_* + files: | + bettercap_*.zip + bettercap_*.sha256 env: - GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/build-and-push-docker.yml b/.github/workflows/build-and-push-docker.yml index c9ad06f1..c6ef89c2 100644 --- a/.github/workflows/build-and-push-docker.yml +++ b/.github/workflows/build-and-push-docker.yml @@ -23,7 +23,7 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build and push - uses: docker/build-push-action@v6 + uses: docker/build-push-action@v5 with: platforms: linux/amd64,linux/arm64 push: true diff --git a/.github/workflows/test-on-linux.yml b/.github/workflows/test-on-linux.yml index e920f281..665c1bd4 100644 --- a/.github/workflows/test-on-linux.yml +++ b/.github/workflows/test-on-linux.yml @@ -13,14 +13,14 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go-version: ['1.24.x'] + go-version: ['1.22.x'] steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v2 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} diff --git a/.github/workflows/test-on-macos.yml b/.github/workflows/test-on-macos.yml index b48c57cd..278689ef 100644 --- a/.github/workflows/test-on-macos.yml +++ b/.github/workflows/test-on-macos.yml @@ -13,14 +13,14 @@ jobs: strategy: matrix: os: [macos-latest] - go-version: ['1.24.x'] + go-version: ['1.22.x'] steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v2 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} diff --git a/.github/workflows/test-on-windows.yml b/.github/workflows/test-on-windows.yml index b5e6a6e2..08ea79da 100644 --- a/.github/workflows/test-on-windows.yml +++ b/.github/workflows/test-on-windows.yml @@ -13,14 +13,14 @@ jobs: strategy: matrix: os: [windows-latest] - go-version: ['1.24.x'] + go-version: ['1.22.x'] steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v2 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} diff --git a/Dockerfile b/Dockerfile index 362ff471..414cc8c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # build stage -FROM golang:1.24-alpine AS build-env +FROM golang:1.22-alpine3.20 AS build-env RUN apk add --no-cache ca-certificates RUN apk add --no-cache bash gcc g++ binutils-gold iptables wireless-tools build-base libpcap-dev libusb-dev linux-headers libnetfilter_queue-dev git @@ -13,9 +13,9 @@ RUN mkdir -p /usr/local/share/bettercap RUN git clone https://github.com/bettercap/caplets /usr/local/share/bettercap/caplets # final stage -FROM alpine +FROM alpine:3.20 RUN apk add --no-cache ca-certificates -RUN apk add --no-cache bash iproute2 libpcap libusb-dev libnetfilter_queue wireless-tools iw +RUN apk add --no-cache bash iproute2 libpcap libusb-dev libnetfilter_queue wireless-tools COPY --from=build-env /go/src/github.com/bettercap/bettercap/bettercap /app/ COPY --from=build-env /usr/local/share/bettercap/caplets /app/ WORKDIR /app diff --git a/.github/ISSUE_TEMPLATE/default_issue.md b/ISSUE_TEMPLATE.md similarity index 94% rename from .github/ISSUE_TEMPLATE/default_issue.md rename to ISSUE_TEMPLATE.md index 8fc3c85c..5c23a58c 100644 --- a/.github/ISSUE_TEMPLATE/default_issue.md +++ b/ISSUE_TEMPLATE.md @@ -1,8 +1,3 @@ ---- -name: General Issue -about: Write a general issue or bug report. ---- - # Prerequisites Please, before creating this issue make sure that you read the [README](https://github.com/bettercap/bettercap/blob/master/README.md), that you are running the [latest stable version](https://github.com/bettercap/bettercap/releases) and that you already searched [other issues](https://github.com/bettercap/bettercap/issues?q=is%3Aopen+is%3Aissue+label%3Abug) to see if your problem or request was already reported. diff --git a/Makefile b/Makefile index 3ec8e6cc..65a2e917 100644 --- a/Makefile +++ b/Makefile @@ -6,10 +6,10 @@ GO ?= go all: build build: resources - $(GO) build $(GOFLAGS) -o $(TARGET) . + $(GOFLAGS) $(GO) build -o $(TARGET) . build_with_race_detector: resources - $(GO) build $(GOFLAGS) -race -o $(TARGET) . + $(GOFLAGS) $(GO) build -race -o $(TARGET) . resources: network/manuf.go @@ -24,13 +24,13 @@ docker: @docker build -t bettercap:latest . test: - $(GO) test -covermode=atomic -coverprofile=cover.out ./... + $(GOFLAGS) $(GO) test -covermode=atomic -coverprofile=cover.out ./... html_coverage: test - $(GO) tool cover -html=cover.out -o cover.out.html + $(GOFLAGS) $(GO) tool cover -html=cover.out -o cover.out.html benchmark: server_deps - $(GO) test -v -run=doNotRunTests -bench=. -benchmem ./... + $(GOFLAGS) $(GO) test -v -run=doNotRunTests -bench=. -benchmem ./... fmt: $(GO) fmt -s -w $(PACKAGES) diff --git a/README.md b/README.md index 299e1d78..4a27f1cd 100644 --- a/README.md +++ b/README.md @@ -38,15 +38,9 @@ bettercap is a powerful, easily extensible and portable framework written in Go * **A very convenient [web UI](https://www.bettercap.org/usage/#web-ui).** * [More!](https://www.bettercap.org/modules/) -## Contributors - - - bettercap project contributors - - ## License -`bettercap` is made with ♥ and released under the GPL 3 license. +`bettercap` is made with ♥ by [the dev team](https://github.com/orgs/bettercap/people) and it's released under the GPL 3 license. ## Stargazers over time diff --git a/caplets/caplet_test.go b/caplets/caplet_test.go deleted file mode 100644 index dee5d9ff..00000000 --- a/caplets/caplet_test.go +++ /dev/null @@ -1,378 +0,0 @@ -package caplets - -import ( - "errors" - "io/ioutil" - "os" - "strings" - "testing" -) - -func TestNewCaplet(t *testing.T) { - name := "test-caplet" - path := "/path/to/caplet.cap" - size := int64(1024) - - cap := NewCaplet(name, path, size) - - if cap.Name != name { - t.Errorf("expected name %s, got %s", name, cap.Name) - } - if cap.Path != path { - t.Errorf("expected path %s, got %s", path, cap.Path) - } - if cap.Size != size { - t.Errorf("expected size %d, got %d", size, cap.Size) - } - if cap.Code == nil { - t.Error("Code should not be nil") - } - if cap.Scripts == nil { - t.Error("Scripts should not be nil") - } -} - -func TestCapletEval(t *testing.T) { - tests := []struct { - name string - code []string - argv []string - wantLines []string - wantErr bool - }{ - { - name: "empty code", - code: []string{}, - argv: nil, - wantLines: []string{}, - wantErr: false, - }, - { - name: "skip comments and empty lines", - code: []string{ - "# this is a comment", - "", - "set test value", - "# another comment", - "set another value", - }, - argv: nil, - wantLines: []string{ - "set test value", - "set another value", - }, - wantErr: false, - }, - { - name: "variable substitution", - code: []string{ - "set param $0", - "set value $1", - "run $0 $1 $2", - }, - argv: []string{"arg0", "arg1", "arg2"}, - wantLines: []string{ - "set param arg0", - "set value arg1", - "run arg0 arg1 arg2", - }, - wantErr: false, - }, - { - name: "multiple occurrences of same variable", - code: []string{ - "$0 $0 $1 $0", - }, - argv: []string{"foo", "bar"}, - wantLines: []string{ - "foo foo bar foo", - }, - wantErr: false, - }, - { - name: "missing argv values", - code: []string{ - "set $0 $1 $2", - }, - argv: []string{"only_one"}, - wantLines: []string{ - "set only_one $1 $2", - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tempFile, err := ioutil.TempFile("", "test-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - tempFile.Close() - - cap := NewCaplet("test", tempFile.Name(), 100) - cap.Code = tt.code - - var gotLines []string - err = cap.Eval(tt.argv, func(line string) error { - gotLines = append(gotLines, line) - return nil - }) - - if (err != nil) != tt.wantErr { - t.Errorf("Eval() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if len(gotLines) != len(tt.wantLines) { - t.Errorf("got %d lines, want %d", len(gotLines), len(tt.wantLines)) - return - } - - for i, line := range gotLines { - if line != tt.wantLines[i] { - t.Errorf("line %d: got %q, want %q", i, line, tt.wantLines[i]) - } - } - }) - } -} - -func TestCapletEvalError(t *testing.T) { - tempFile, err := ioutil.TempFile("", "test-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - tempFile.Close() - - cap := NewCaplet("test", tempFile.Name(), 100) - cap.Code = []string{ - "first line", - "error line", - "third line", - } - - expectedErr := errors.New("test error") - var executedLines []string - - err = cap.Eval(nil, func(line string) error { - executedLines = append(executedLines, line) - if line == "error line" { - return expectedErr - } - return nil - }) - - if err != expectedErr { - t.Errorf("expected error %v, got %v", expectedErr, err) - } - - // Should have executed first two lines before error - if len(executedLines) != 2 { - t.Errorf("expected 2 executed lines, got %d", len(executedLines)) - } -} - -func TestCapletEvalWithChdirPath(t *testing.T) { - // Create a temporary caplet file to test with - tempFile, err := ioutil.TempFile("", "test-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - tempFile.Close() - - cap := NewCaplet("test", tempFile.Name(), 100) - cap.Code = []string{"test command"} - - executed := false - err = cap.Eval(nil, func(line string) error { - executed = true - return nil - }) - - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !executed { - t.Error("callback was not executed") - } -} - -func TestNewScript(t *testing.T) { - path := "/path/to/script.js" - size := int64(2048) - - script := newScript(path, size) - - if script.Path != path { - t.Errorf("expected path %s, got %s", path, script.Path) - } - if script.Size != size { - t.Errorf("expected size %d, got %d", size, script.Size) - } - if script.Code == nil { - t.Error("Code should not be nil") - } - if len(script.Code) != 0 { - t.Errorf("expected empty Code slice, got %d elements", len(script.Code)) - } -} - -func TestCapletEvalCommentAtStartOfLine(t *testing.T) { - tempFile, err := ioutil.TempFile("", "test-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - tempFile.Close() - - cap := NewCaplet("test", tempFile.Name(), 100) - cap.Code = []string{ - "# comment", - " # not a comment (has space before #)", - " # not a comment (has tab before #)", - "command # inline comment", - } - - var gotLines []string - err = cap.Eval(nil, func(line string) error { - gotLines = append(gotLines, line) - return nil - }) - - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - expectedLines := []string{ - " # not a comment (has space before #)", - " # not a comment (has tab before #)", - "command # inline comment", - } - - if len(gotLines) != len(expectedLines) { - t.Errorf("got %d lines, want %d", len(gotLines), len(expectedLines)) - return - } - - for i, line := range gotLines { - if line != expectedLines[i] { - t.Errorf("line %d: got %q, want %q", i, line, expectedLines[i]) - } - } -} - -func TestCapletEvalArgvSubstitutionEdgeCases(t *testing.T) { - tests := []struct { - name string - code string - argv []string - wantLine string - }{ - { - name: "double digit substitution $10", - code: "$1$0", - argv: []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, - wantLine: "10", - }, - { - name: "no space between variables", - code: "$0$1$2", - argv: []string{"a", "b", "c"}, - wantLine: "abc", - }, - { - name: "variables in quotes", - code: `"$0" '$1'`, - argv: []string{"foo", "bar"}, - wantLine: `"foo" 'bar'`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tempFile, err := ioutil.TempFile("", "test-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - tempFile.Close() - - cap := NewCaplet("test", tempFile.Name(), 100) - cap.Code = []string{tt.code} - - var gotLine string - err = cap.Eval(tt.argv, func(line string) error { - gotLine = line - return nil - }) - - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if gotLine != tt.wantLine { - t.Errorf("got line %q, want %q", gotLine, tt.wantLine) - } - }) - } -} - -func TestCapletStructFields(t *testing.T) { - // Test that Caplet properly embeds Script - tempFile, err := ioutil.TempFile("", "test-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - tempFile.Close() - - cap := NewCaplet("test", tempFile.Name(), 100) - - // These fields should be accessible due to embedding - _ = cap.Path - _ = cap.Size - _ = cap.Code - - // And these are Caplet's own fields - _ = cap.Name - _ = cap.Scripts -} - -func BenchmarkCapletEval(b *testing.B) { - cap := NewCaplet("bench", "/tmp/bench.cap", 100) - cap.Code = []string{ - "set param1 $0", - "set param2 $1", - "# comment line", - "", - "run command $0 $1 $2", - "another command", - } - argv := []string{"arg0", "arg1", "arg2"} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = cap.Eval(argv, func(line string) error { - // Do nothing, just measure evaluation overhead - return nil - }) - } -} - -func BenchmarkVariableSubstitution(b *testing.B) { - line := "command $0 $1 $2 $3 $4 $5 $6 $7 $8 $9" - argv := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - result := line - for j, arg := range argv { - what := "$" + string(rune('0'+j)) - result = strings.Replace(result, what, arg, -1) - } - } -} diff --git a/caplets/env_test.go b/caplets/env_test.go deleted file mode 100644 index c1087216..00000000 --- a/caplets/env_test.go +++ /dev/null @@ -1,308 +0,0 @@ -package caplets - -import ( - "os" - "path/filepath" - "runtime" - "strings" - "testing" -) - -func TestGetDefaultInstallBase(t *testing.T) { - base := getDefaultInstallBase() - - if runtime.GOOS == "windows" { - expected := filepath.Join(os.Getenv("ALLUSERSPROFILE"), "bettercap") - if base != expected { - t.Errorf("on windows, expected %s, got %s", expected, base) - } - } else { - expected := "/usr/local/share/bettercap/" - if base != expected { - t.Errorf("on non-windows, expected %s, got %s", expected, base) - } - } -} - -func TestGetUserHomeDir(t *testing.T) { - home := getUserHomeDir() - - // Should return a non-empty string - if home == "" { - t.Error("getUserHomeDir returned empty string") - } - - // Should be an absolute path - if !filepath.IsAbs(home) { - t.Errorf("expected absolute path, got %s", home) - } -} - -func TestSetup(t *testing.T) { - // Save original values - origInstallBase := InstallBase - origInstallPathArchive := InstallPathArchive - origInstallPath := InstallPath - origArchivePath := ArchivePath - origLoadPaths := LoadPaths - - // Test with custom base - testBase := "/custom/base" - err := Setup(testBase) - - if err != nil { - t.Errorf("Setup returned error: %v", err) - } - - // Check that paths are set correctly - if InstallBase != testBase { - t.Errorf("expected InstallBase %s, got %s", testBase, InstallBase) - } - - expectedArchivePath := filepath.Join(testBase, "caplets-master") - if InstallPathArchive != expectedArchivePath { - t.Errorf("expected InstallPathArchive %s, got %s", expectedArchivePath, InstallPathArchive) - } - - expectedInstallPath := filepath.Join(testBase, "caplets") - if InstallPath != expectedInstallPath { - t.Errorf("expected InstallPath %s, got %s", expectedInstallPath, InstallPath) - } - - expectedTempPath := filepath.Join(os.TempDir(), "caplets.zip") - if ArchivePath != expectedTempPath { - t.Errorf("expected ArchivePath %s, got %s", expectedTempPath, ArchivePath) - } - - // Check LoadPaths contains expected paths - expectedInLoadPaths := []string{ - "./", - "./caplets/", - InstallPath, - filepath.Join(getUserHomeDir(), "caplets"), - } - - for _, expected := range expectedInLoadPaths { - absExpected, _ := filepath.Abs(expected) - found := false - for _, path := range LoadPaths { - if path == absExpected { - found = true - break - } - } - if !found { - t.Errorf("expected path %s not found in LoadPaths", absExpected) - } - } - - // All paths should be absolute - for _, path := range LoadPaths { - if !filepath.IsAbs(path) { - t.Errorf("LoadPath %s is not absolute", path) - } - } - - // Restore original values - InstallBase = origInstallBase - InstallPathArchive = origInstallPathArchive - InstallPath = origInstallPath - ArchivePath = origArchivePath - LoadPaths = origLoadPaths -} - -func TestSetupWithEnvironmentVariable(t *testing.T) { - // Save original values - origEnv := os.Getenv(EnvVarName) - origLoadPaths := LoadPaths - - // Set environment variable with multiple paths - testPaths := []string{"/path1", "/path2", "/path3"} - os.Setenv(EnvVarName, strings.Join(testPaths, string(os.PathListSeparator))) - - // Run setup - err := Setup("/test/base") - if err != nil { - t.Errorf("Setup returned error: %v", err) - } - - // Check that custom paths from env var are in LoadPaths - for _, testPath := range testPaths { - absTestPath, _ := filepath.Abs(testPath) - found := false - for _, path := range LoadPaths { - if path == absTestPath { - found = true - break - } - } - if !found { - t.Errorf("expected env path %s not found in LoadPaths", absTestPath) - } - } - - // Restore original values - if origEnv == "" { - os.Unsetenv(EnvVarName) - } else { - os.Setenv(EnvVarName, origEnv) - } - LoadPaths = origLoadPaths -} - -func TestSetupWithEmptyEnvironmentVariable(t *testing.T) { - // Save original values - origEnv := os.Getenv(EnvVarName) - origLoadPaths := LoadPaths - - // Set empty environment variable - os.Setenv(EnvVarName, "") - - // Count LoadPaths before setup - err := Setup("/test/base") - if err != nil { - t.Errorf("Setup returned error: %v", err) - } - - // Should have only the default paths (4) - if len(LoadPaths) != 4 { - t.Errorf("expected 4 default LoadPaths, got %d", len(LoadPaths)) - } - - // Restore original values - if origEnv == "" { - os.Unsetenv(EnvVarName) - } else { - os.Setenv(EnvVarName, origEnv) - } - LoadPaths = origLoadPaths -} - -func TestSetupWithWhitespaceInEnvironmentVariable(t *testing.T) { - // Save original values - origEnv := os.Getenv(EnvVarName) - origLoadPaths := LoadPaths - - // Set environment variable with whitespace - testPaths := []string{" /path1 ", " ", "/path2 "} - os.Setenv(EnvVarName, strings.Join(testPaths, string(os.PathListSeparator))) - - // Run setup - err := Setup("/test/base") - if err != nil { - t.Errorf("Setup returned error: %v", err) - } - - // Should have added only non-empty paths after trimming - expectedPaths := []string{"/path1", "/path2"} - foundCount := 0 - for _, expectedPath := range expectedPaths { - absExpected, _ := filepath.Abs(expectedPath) - for _, path := range LoadPaths { - if path == absExpected { - foundCount++ - break - } - } - } - - if foundCount != len(expectedPaths) { - t.Errorf("expected to find %d paths from env, found %d", len(expectedPaths), foundCount) - } - - // Restore original values - if origEnv == "" { - os.Unsetenv(EnvVarName) - } else { - os.Setenv(EnvVarName, origEnv) - } - LoadPaths = origLoadPaths -} - -func TestConstants(t *testing.T) { - // Test that constants have expected values - if EnvVarName != "CAPSPATH" { - t.Errorf("expected EnvVarName to be 'CAPSPATH', got %s", EnvVarName) - } - - if Suffix != ".cap" { - t.Errorf("expected Suffix to be '.cap', got %s", Suffix) - } - - if InstallArchive != "https://github.com/bettercap/caplets/archive/master.zip" { - t.Errorf("unexpected InstallArchive value: %s", InstallArchive) - } -} - -func TestInit(t *testing.T) { - // The init function should have been called already - // Check that paths are initialized - if InstallBase == "" { - t.Error("InstallBase not initialized") - } - - if InstallPath == "" { - t.Error("InstallPath not initialized") - } - - if InstallPathArchive == "" { - t.Error("InstallPathArchive not initialized") - } - - if ArchivePath == "" { - t.Error("ArchivePath not initialized") - } - - if LoadPaths == nil || len(LoadPaths) == 0 { - t.Error("LoadPaths not initialized") - } -} - -func TestSetupMultipleTimes(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - - // Setup multiple times with different bases - bases := []string{"/base1", "/base2", "/base3"} - - for _, base := range bases { - err := Setup(base) - if err != nil { - t.Errorf("Setup(%s) returned error: %v", base, err) - } - - // Check that InstallBase is updated - if InstallBase != base { - t.Errorf("expected InstallBase %s, got %s", base, InstallBase) - } - - // LoadPaths should be recreated each time - if len(LoadPaths) < 4 { - t.Errorf("LoadPaths should have at least 4 entries, got %d", len(LoadPaths)) - } - } - - // Restore original values - LoadPaths = origLoadPaths -} - -func BenchmarkSetup(b *testing.B) { - // Save original values - origEnv := os.Getenv(EnvVarName) - - // Set a complex environment - paths := []string{"/p1", "/p2", "/p3", "/p4", "/p5"} - os.Setenv(EnvVarName, strings.Join(paths, string(os.PathListSeparator))) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - Setup("/benchmark/base") - } - - // Restore - if origEnv == "" { - os.Unsetenv(EnvVarName) - } else { - os.Setenv(EnvVarName, origEnv) - } -} diff --git a/caplets/manager_test.go b/caplets/manager_test.go deleted file mode 100644 index 0392a12b..00000000 --- a/caplets/manager_test.go +++ /dev/null @@ -1,511 +0,0 @@ -package caplets - -import ( - "fmt" - "io/ioutil" - "os" - "path/filepath" - "sort" - "strings" - "sync" - "testing" -) - -func createTestCaplet(t testing.TB, dir string, name string, content []string) string { - filename := filepath.Join(dir, name) - data := strings.Join(content, "\n") - err := ioutil.WriteFile(filename, []byte(data), 0644) - if err != nil { - t.Fatalf("failed to create test caplet: %v", err) - } - return filename -} - -func TestList(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directories - tempDir, err := ioutil.TempDir("", "caplets-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - // Create subdirectories - dir1 := filepath.Join(tempDir, "dir1") - dir2 := filepath.Join(tempDir, "dir2") - subdir := filepath.Join(dir1, "subdir") - - os.Mkdir(dir1, 0755) - os.Mkdir(dir2, 0755) - os.Mkdir(subdir, 0755) - - // Create test caplets - createTestCaplet(t, dir1, "test1.cap", []string{"# Test caplet 1", "set test 1"}) - createTestCaplet(t, dir1, "test2.cap", []string{"# Test caplet 2", "set test 2"}) - createTestCaplet(t, dir2, "test3.cap", []string{"# Test caplet 3", "set test 3"}) - createTestCaplet(t, subdir, "nested.cap", []string{"# Nested caplet", "set nested test"}) - - // Also create a non-caplet file - ioutil.WriteFile(filepath.Join(dir1, "notacaplet.txt"), []byte("not a caplet"), 0644) - - // Set LoadPaths - LoadPaths = []string{dir1, dir2} - - // Call List() - caplets := List() - - // Check results - if len(caplets) != 4 { - t.Errorf("expected 4 caplets, got %d", len(caplets)) - } - - // Check names (should be sorted) - expectedNames := []string{filepath.Join("subdir", "nested"), "test1", "test2", "test3"} - sort.Strings(expectedNames) - - gotNames := make([]string, len(caplets)) - for i, cap := range caplets { - gotNames[i] = cap.Name - } - - for i, expected := range expectedNames { - if i >= len(gotNames) || gotNames[i] != expected { - t.Errorf("expected caplet %d to be %s, got %s", i, expected, gotNames[i]) - } - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func TestListEmptyDirectories(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directory - tempDir, err := ioutil.TempDir("", "caplets-empty-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - // Set LoadPaths to empty directory - LoadPaths = []string{tempDir} - - // Call List() - caplets := List() - - // Should return empty list - if len(caplets) != 0 { - t.Errorf("expected 0 caplets, got %d", len(caplets)) - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func TestLoad(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directory - tempDir, err := ioutil.TempDir("", "caplets-load-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - // Create test caplet - capletContent := []string{ - "# Test caplet", - "set param value", - "", - "# Another comment", - "run command", - } - createTestCaplet(t, tempDir, "test.cap", capletContent) - - // Set LoadPaths - LoadPaths = []string{tempDir} - - // Test loading without .cap extension - cap, err := Load("test") - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if cap == nil { - t.Error("caplet is nil") - } else { - if cap.Name != "test" { - t.Errorf("expected name 'test', got %s", cap.Name) - } - if len(cap.Code) != len(capletContent) { - t.Errorf("expected %d lines, got %d", len(capletContent), len(cap.Code)) - } - } - - // Test loading from cache - // Note: The Load function caches with the suffix, so we need to use the same name with suffix - cap2, err := Load("test.cap") - if err != nil { - t.Errorf("unexpected error on cache hit: %v", err) - } - if cap2 == nil { - t.Error("caplet is nil on cache hit") - } - - // Test loading with .cap extension - // Note: Load caches by the name parameter, so "test.cap" is a different cache key - cap3, err := Load("test.cap") - if err != nil { - t.Errorf("unexpected error with .cap extension: %v", err) - } - if cap3 == nil { - t.Error("caplet is nil") - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func TestLoadAbsolutePath(t *testing.T) { - // Save original values - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp file - tempFile, err := ioutil.TempFile("", "test-absolute-*.cap") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) - - // Write content - content := "# Absolute path test\nset test absolute" - tempFile.WriteString(content) - tempFile.Close() - - // Load with absolute path - cap, err := Load(tempFile.Name()) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if cap == nil { - t.Error("caplet is nil") - } else { - if cap.Path != tempFile.Name() { - t.Errorf("expected path %s, got %s", tempFile.Name(), cap.Path) - } - } - - // Restore original values - cache = origCache -} - -func TestLoadNotFound(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Set empty LoadPaths - LoadPaths = []string{} - - // Try to load non-existent caplet - cap, err := Load("nonexistent") - if err == nil { - t.Error("expected error for non-existent caplet") - } - if cap != nil { - t.Error("expected nil caplet for non-existent file") - } - if !strings.Contains(err.Error(), "not found") { - t.Errorf("expected 'not found' error, got: %v", err) - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func TestLoadWithFolder(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directory structure - tempDir, err := ioutil.TempDir("", "caplets-folder-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - // Create a caplet folder - capletDir := filepath.Join(tempDir, "mycaplet") - os.Mkdir(capletDir, 0755) - - // Create main caplet file - mainContent := []string{"# Main caplet", "set main test"} - createTestCaplet(t, capletDir, "mycaplet.cap", mainContent) - - // Create additional files - jsContent := []string{"// JavaScript file", "console.log('test');"} - createTestCaplet(t, capletDir, "script.js", jsContent) - - capContent := []string{"# Sub caplet", "set sub test"} - createTestCaplet(t, capletDir, "sub.cap", capContent) - - // Create a file that should be ignored - ioutil.WriteFile(filepath.Join(capletDir, "readme.txt"), []byte("readme"), 0644) - - // Set LoadPaths - LoadPaths = []string{tempDir} - - // Load the caplet - cap, err := Load("mycaplet/mycaplet") - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if cap == nil { - t.Fatal("caplet is nil") - } - - // Check main caplet - if cap.Name != "mycaplet/mycaplet" { - t.Errorf("expected name 'mycaplet/mycaplet', got %s", cap.Name) - } - if len(cap.Code) != len(mainContent) { - t.Errorf("expected %d lines in main, got %d", len(mainContent), len(cap.Code)) - } - - // Check additional scripts - if len(cap.Scripts) != 2 { - t.Errorf("expected 2 additional scripts, got %d", len(cap.Scripts)) - } - - // Find and check the .js file - foundJS := false - foundCap := false - for _, script := range cap.Scripts { - if strings.HasSuffix(script.Path, "script.js") { - foundJS = true - if len(script.Code) != len(jsContent) { - t.Errorf("expected %d lines in JS, got %d", len(jsContent), len(script.Code)) - } - } - if strings.HasSuffix(script.Path, "sub.cap") { - foundCap = true - if len(script.Code) != len(capContent) { - t.Errorf("expected %d lines in sub.cap, got %d", len(capContent), len(script.Code)) - } - } - } - - if !foundJS { - t.Error("script.js not found in Scripts") - } - if !foundCap { - t.Error("sub.cap not found in Scripts") - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func TestCacheConcurrency(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directory - tempDir, err := ioutil.TempDir("", "caplets-concurrent-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - // Create test caplets - for i := 0; i < 5; i++ { - name := fmt.Sprintf("test%d.cap", i) - content := []string{fmt.Sprintf("# Test %d", i)} - createTestCaplet(t, tempDir, name, content) - } - - // Set LoadPaths - LoadPaths = []string{tempDir} - - // Run concurrent loads - var wg sync.WaitGroup - errors := make(chan error, 50) - - for i := 0; i < 10; i++ { - for j := 0; j < 5; j++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - name := fmt.Sprintf("test%d", idx) - _, err := Load(name) - if err != nil { - errors <- err - } - }(j) - } - } - - wg.Wait() - close(errors) - - // Check for errors - for err := range errors { - t.Errorf("concurrent load error: %v", err) - } - - // Verify cache has all entries - if len(cache) != 5 { - t.Errorf("expected 5 cached entries, got %d", len(cache)) - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func TestLoadPathPriority(t *testing.T) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directories - tempDir1, _ := ioutil.TempDir("", "caplets-priority1-") - tempDir2, _ := ioutil.TempDir("", "caplets-priority2-") - defer os.RemoveAll(tempDir1) - defer os.RemoveAll(tempDir2) - - // Create same-named caplet in both directories - createTestCaplet(t, tempDir1, "test.cap", []string{"# From dir1"}) - createTestCaplet(t, tempDir2, "test.cap", []string{"# From dir2"}) - - // Set LoadPaths with tempDir1 first - LoadPaths = []string{tempDir1, tempDir2} - - // Load caplet - cap, err := Load("test") - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - // Should load from first directory - if cap != nil && len(cap.Code) > 0 { - if cap.Code[0] != "# From dir1" { - t.Error("caplet not loaded from first directory in LoadPaths") - } - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func BenchmarkLoad(b *testing.B) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - - // Create temp directory - tempDir, _ := ioutil.TempDir("", "caplets-bench-") - defer os.RemoveAll(tempDir) - - // Create test caplet - content := make([]string, 100) - for i := range content { - content[i] = fmt.Sprintf("command %d", i) - } - createTestCaplet(b, tempDir, "bench.cap", content) - - // Set LoadPaths - LoadPaths = []string{tempDir} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // Clear cache to measure loading time - cache = make(map[string]*Caplet) - Load("bench") - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func BenchmarkLoadFromCache(b *testing.B) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - cache = make(map[string]*Caplet) - - // Create temp directory - tempDir, _ := ioutil.TempDir("", "caplets-bench-cache-") - defer os.RemoveAll(tempDir) - - // Create test caplet - createTestCaplet(b, tempDir, "bench.cap", []string{"# Benchmark"}) - - // Set LoadPaths - LoadPaths = []string{tempDir} - - // Pre-load into cache - Load("bench") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - Load("bench") - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} - -func BenchmarkList(b *testing.B) { - // Save original values - origLoadPaths := LoadPaths - origCache := cache - - // Create temp directory - tempDir, _ := ioutil.TempDir("", "caplets-bench-list-") - defer os.RemoveAll(tempDir) - - // Create multiple caplets - for i := 0; i < 20; i++ { - name := fmt.Sprintf("test%d.cap", i) - createTestCaplet(b, tempDir, name, []string{fmt.Sprintf("# Test %d", i)}) - } - - // Set LoadPaths - LoadPaths = []string{tempDir} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - cache = make(map[string]*Caplet) - List() - } - - // Restore original values - LoadPaths = origLoadPaths - cache = origCache -} diff --git a/core/banner.go b/core/banner.go index 1a63f0c8..1df1aafa 100644 --- a/core/banner.go +++ b/core/banner.go @@ -2,7 +2,7 @@ package core const ( Name = "bettercap" - Version = "2.41.4" + Version = "2.41.0" Author = "Simone 'evilsocket' Margaritelli" Website = "https://bettercap.org/" ) diff --git a/core/core_test.go b/core/core_test.go index 057e5b21..2dc77c49 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -97,144 +97,3 @@ func TestCoreExists(t *testing.T) { } } } - -func TestHasBinary(t *testing.T) { - tests := []struct { - name string - executable string - expected bool - }{ - { - name: "common shell", - executable: "sh", - expected: true, - }, - { - name: "echo command", - executable: "echo", - expected: true, - }, - { - name: "non-existent binary", - executable: "this-binary-definitely-does-not-exist-12345", - expected: false, - }, - { - name: "empty string", - executable: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := HasBinary(tt.executable) - if got != tt.expected { - t.Errorf("HasBinary(%q) = %v, want %v", tt.executable, got, tt.expected) - } - }) - } -} - -func TestExec(t *testing.T) { - tests := []struct { - name string - executable string - args []string - wantError bool - contains string - }{ - { - name: "echo with args", - executable: "echo", - args: []string{"hello", "world"}, - wantError: false, - contains: "hello world", - }, - { - name: "echo empty", - executable: "echo", - args: []string{}, - wantError: false, - contains: "", - }, - { - name: "non-existent command", - executable: "this-command-does-not-exist-12345", - args: []string{}, - wantError: true, - contains: "", - }, - { - name: "true command", - executable: "true", - args: []string{}, - wantError: false, - contains: "", - }, - { - name: "false command", - executable: "false", - args: []string{}, - wantError: true, - contains: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Skip platform-specific commands if not available - if !HasBinary(tt.executable) && !tt.wantError { - t.Skipf("%s not found in PATH", tt.executable) - } - - output, err := Exec(tt.executable, tt.args) - - if tt.wantError { - if err == nil { - t.Errorf("Exec(%q, %v) expected error but got none", tt.executable, tt.args) - } - } else { - if err != nil { - t.Errorf("Exec(%q, %v) unexpected error: %v", tt.executable, tt.args, err) - } - if tt.contains != "" && output != tt.contains { - t.Errorf("Exec(%q, %v) = %q, want %q", tt.executable, tt.args, output, tt.contains) - } - } - }) - } -} - -func TestExecWithOutput(t *testing.T) { - // Test that Exec properly captures and trims output - if HasBinary("printf") { - output, err := Exec("printf", []string{" hello world \n"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if output != "hello world" { - t.Errorf("expected trimmed output 'hello world', got %q", output) - } - } -} - -func BenchmarkUniqueInts(b *testing.B) { - // Create a slice with duplicates - input := make([]int, 1000) - for i := 0; i < 1000; i++ { - input[i] = i % 100 // This creates 10 duplicates of each number 0-99 - } - - b.Run("unsorted", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = UniqueInts(input, false) - } - }) - - b.Run("sorted", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = UniqueInts(input, true) - } - }) -} diff --git a/firewall/redirection_test.go b/firewall/redirection_test.go deleted file mode 100644 index 050590b2..00000000 --- a/firewall/redirection_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package firewall - -import ( - "testing" -) - -func TestNewRedirection(t *testing.T) { - iface := "eth0" - proto := "tcp" - portFrom := 8080 - addrTo := "192.168.1.100" - portTo := 9090 - - r := NewRedirection(iface, proto, portFrom, addrTo, portTo) - - if r == nil { - t.Fatal("NewRedirection returned nil") - } - - if r.Interface != iface { - t.Errorf("expected Interface %s, got %s", iface, r.Interface) - } - - if r.Protocol != proto { - t.Errorf("expected Protocol %s, got %s", proto, r.Protocol) - } - - if r.SrcAddress != "" { - t.Errorf("expected empty SrcAddress, got %s", r.SrcAddress) - } - - if r.SrcPort != portFrom { - t.Errorf("expected SrcPort %d, got %d", portFrom, r.SrcPort) - } - - if r.DstAddress != addrTo { - t.Errorf("expected DstAddress %s, got %s", addrTo, r.DstAddress) - } - - if r.DstPort != portTo { - t.Errorf("expected DstPort %d, got %d", portTo, r.DstPort) - } -} - -func TestRedirectionString(t *testing.T) { - tests := []struct { - name string - r Redirection - want string - }{ - { - name: "basic redirection", - r: Redirection{ - Interface: "eth0", - Protocol: "tcp", - SrcAddress: "", - SrcPort: 8080, - DstAddress: "192.168.1.100", - DstPort: 9090, - }, - want: "[eth0] (tcp) :8080 -> 192.168.1.100:9090", - }, - { - name: "with source address", - r: Redirection{ - Interface: "wlan0", - Protocol: "udp", - SrcAddress: "192.168.1.50", - SrcPort: 53, - DstAddress: "8.8.8.8", - DstPort: 53, - }, - want: "[wlan0] (udp) 192.168.1.50:53 -> 8.8.8.8:53", - }, - { - name: "localhost redirection", - r: Redirection{ - Interface: "lo", - Protocol: "tcp", - SrcAddress: "127.0.0.1", - SrcPort: 80, - DstAddress: "127.0.0.1", - DstPort: 8080, - }, - want: "[lo] (tcp) 127.0.0.1:80 -> 127.0.0.1:8080", - }, - { - name: "high port numbers", - r: Redirection{ - Interface: "eth1", - Protocol: "tcp", - SrcAddress: "", - SrcPort: 65535, - DstAddress: "10.0.0.1", - DstPort: 65534, - }, - want: "[eth1] (tcp) :65535 -> 10.0.0.1:65534", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.r.String() - if got != tt.want { - t.Errorf("String() = %q, want %q", got, tt.want) - } - }) - } -} - -func TestNewRedirectionVariousProtocols(t *testing.T) { - protocols := []string{"tcp", "udp", "icmp", "any"} - - for _, proto := range protocols { - t.Run(proto, func(t *testing.T) { - r := NewRedirection("eth0", proto, 1234, "10.0.0.1", 5678) - if r.Protocol != proto { - t.Errorf("expected protocol %s, got %s", proto, r.Protocol) - } - }) - } -} - -func TestNewRedirectionVariousInterfaces(t *testing.T) { - interfaces := []string{"eth0", "wlan0", "lo", "docker0", "br0", "tun0"} - - for _, iface := range interfaces { - t.Run(iface, func(t *testing.T) { - r := NewRedirection(iface, "tcp", 80, "192.168.1.1", 8080) - if r.Interface != iface { - t.Errorf("expected interface %s, got %s", iface, r.Interface) - } - }) - } -} - -func TestRedirectionStringEmptyFields(t *testing.T) { - tests := []struct { - name string - r Redirection - want string - }{ - { - name: "empty interface", - r: Redirection{ - Interface: "", - Protocol: "tcp", - SrcAddress: "", - SrcPort: 80, - DstAddress: "192.168.1.1", - DstPort: 8080, - }, - want: "[] (tcp) :80 -> 192.168.1.1:8080", - }, - { - name: "empty protocol", - r: Redirection{ - Interface: "eth0", - Protocol: "", - SrcAddress: "", - SrcPort: 80, - DstAddress: "192.168.1.1", - DstPort: 8080, - }, - want: "[eth0] () :80 -> 192.168.1.1:8080", - }, - { - name: "empty destination", - r: Redirection{ - Interface: "eth0", - Protocol: "tcp", - SrcAddress: "", - SrcPort: 80, - DstAddress: "", - DstPort: 8080, - }, - want: "[eth0] (tcp) :80 -> :8080", - }, - { - name: "all empty strings", - r: Redirection{ - Interface: "", - Protocol: "", - SrcAddress: "", - SrcPort: 0, - DstAddress: "", - DstPort: 0, - }, - want: "[] () :0 -> :0", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.r.String() - if got != tt.want { - t.Errorf("String() = %q, want %q", got, tt.want) - } - }) - } -} - -func TestRedirectionStructCopy(t *testing.T) { - // Test that Redirection can be safely copied - original := NewRedirection("eth0", "tcp", 80, "192.168.1.1", 8080) - original.SrcAddress = "10.0.0.1" - - // Create a copy - copy := *original - - // Modify the copy - copy.Interface = "wlan0" - copy.SrcPort = 443 - - // Verify original is unchanged - if original.Interface != "eth0" { - t.Error("original Interface was modified") - } - if original.SrcPort != 80 { - t.Error("original SrcPort was modified") - } - - // Verify copy has new values - if copy.Interface != "wlan0" { - t.Error("copy Interface was not set correctly") - } - if copy.SrcPort != 443 { - t.Error("copy SrcPort was not set correctly") - } -} - -func BenchmarkNewRedirection(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = NewRedirection("eth0", "tcp", 80, "192.168.1.1", 8080) - } -} - -func BenchmarkRedirectionString(b *testing.B) { - r := Redirection{ - Interface: "eth0", - Protocol: "tcp", - SrcAddress: "192.168.1.50", - SrcPort: 8080, - DstAddress: "192.168.1.100", - DstPort: 9090, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = r.String() - } -} - -func BenchmarkRedirectionStringEmpty(b *testing.B) { - r := Redirection{ - Interface: "eth0", - Protocol: "tcp", - SrcAddress: "", - SrcPort: 8080, - DstAddress: "192.168.1.100", - DstPort: 9090, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = r.String() - } -} diff --git a/go.mod b/go.mod index 0cbddafa..b1b2dfc3 100644 --- a/go.mod +++ b/go.mod @@ -1,20 +1,20 @@ module github.com/bettercap/bettercap/v2 -go 1.23.0 +go 1.21 -toolchain go1.24.4 +toolchain go1.22.6 require ( github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d - github.com/adrianmo/go-nmea v1.10.0 - github.com/antchfx/jsonquery v1.3.6 + github.com/adrianmo/go-nmea v1.9.0 + github.com/antchfx/jsonquery v1.3.5 github.com/bettercap/gatt v0.0.0-20240808115956-ec4935e8c4a0 github.com/bettercap/nrf24 v0.0.0-20190219153547-aa37e6d0e0eb github.com/bettercap/readline v0.0.0-20210228151553-655e48bcb7bf github.com/bettercap/recording v0.0.0-20190408083647-3ce1dcf032e3 github.com/cenkalti/backoff v2.2.1+incompatible github.com/dustin/go-humanize v1.0.1 - github.com/elazarl/goproxy v1.7.2 + github.com/elazarl/goproxy v0.0.0-20240726154733-8b0c20506380 github.com/evilsocket/islazy v1.11.0 github.com/florianl/go-nfqueue/v2 v2.0.0 github.com/gobwas/glob v0.0.0-20181002190808-e7a84e9525fe @@ -23,45 +23,47 @@ require ( github.com/google/gousb v1.1.3 github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.3 + github.com/grandcat/zeroconf v1.0.0 github.com/hashicorp/go-bexpr v0.1.14 github.com/inconshreveable/go-vhost v1.0.0 github.com/jpillora/go-tld v1.2.1 github.com/malfunkt/iprange v0.9.0 github.com/mdlayher/dhcp6 v0.0.0-20190311162359-2a67805d7d0b - github.com/miekg/dns v1.1.67 + github.com/miekg/dns v1.1.61 github.com/mitchellh/go-homedir v1.1.0 github.com/phin1x/go-ipp v1.6.1 - github.com/robertkrimen/otto v0.5.1 + github.com/robertkrimen/otto v0.4.0 github.com/stratoberry/go-gpsd v1.3.0 github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 github.com/thoj/go-ircevent v0.0.0-20210723090443-73e444401d64 - go.einride.tech/can v0.14.0 - golang.org/x/net v0.42.0 + go.einride.tech/can v0.12.0 + golang.org/x/net v0.28.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/antchfx/xpath v1.3.4 // indirect + github.com/antchfx/xpath v1.3.1 // indirect github.com/chzyer/logex v1.2.1 // indirect - github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect - github.com/google/go-cmp v0.7.0 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/kr/binarydist v0.1.0 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mdlayher/netlink v1.7.2 // indirect - github.com/mdlayher/socket v0.5.1 // indirect + github.com/mdlayher/socket v0.4.1 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mgutz/logxi v0.0.0-20161027140823-aebf8a7d67ab // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/mitchellh/mapstructure v1.4.1 // indirect github.com/mitchellh/pointerstructure v1.2.1 // indirect github.com/pkg/errors v0.9.1 // indirect - golang.org/x/mod v0.26.0 // indirect - golang.org/x/sync v0.16.0 // indirect - golang.org/x/sys v0.34.0 // indirect - golang.org/x/text v0.27.0 // indirect - golang.org/x/tools v0.35.0 // indirect + golang.org/x/mod v0.20.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.23.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/tools v0.24.0 // indirect gopkg.in/sourcemap.v1 v1.0.5 // indirect ) diff --git a/go.sum b/go.sum index f9a5d6ad..a2930b76 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,11 @@ github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= -github.com/adrianmo/go-nmea v1.10.0 h1:L1aYaebZ4cXFCoXNSeDeQa0tApvSKvIbqMsK+iaRiCo= -github.com/adrianmo/go-nmea v1.10.0/go.mod h1:u8bPnpKt/D/5rll/5l9f6iDfeq5WZW0+/SXdkwix6Tg= -github.com/antchfx/jsonquery v1.3.6 h1:TaSfeAh7n6T11I74bsZ1FswreIfrbJ0X+OyLflx6mx4= -github.com/antchfx/jsonquery v1.3.6/go.mod h1:fGzSGJn9Y826Qd3pC8Wx45avuUwpkePsACQJYy+58BU= -github.com/antchfx/xpath v1.3.2/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs= -github.com/antchfx/xpath v1.3.4 h1:1ixrW1VnXd4HurCj7qnqnR0jo14g8JMe20Fshg1Vgz4= -github.com/antchfx/xpath v1.3.4/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs= +github.com/adrianmo/go-nmea v1.9.0 h1:kCuerWLDIppltHNZ2HGdCGkqbmupYJYfE6indcGkcp8= +github.com/adrianmo/go-nmea v1.9.0/go.mod h1:u8bPnpKt/D/5rll/5l9f6iDfeq5WZW0+/SXdkwix6Tg= +github.com/antchfx/jsonquery v1.3.5 h1:243OSaQh02EfmASa3w3weKC9UaiD8RRzJhgfvq3q408= +github.com/antchfx/jsonquery v1.3.5/go.mod h1:qH23yX2Jsj1/k378Yu/EOgPCNgJ35P9tiGOeQdt/GWc= +github.com/antchfx/xpath v1.3.1 h1:PNbFuUqHwWl0xRjvUPjJ95Agbmdj2uzzIwmQKgu4oCk= +github.com/antchfx/xpath v1.3.1/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs= github.com/bettercap/gatt v0.0.0-20240808115956-ec4935e8c4a0 h1:HiFUGV/7eGWG/YJAf9HcKOUmxIj+7LVzC8zD57VX1qo= github.com/bettercap/gatt v0.0.0-20240808115956-ec4935e8c4a0/go.mod h1:oafnPgaBI4gqJiYkueCyR4dqygiWGXTGOE0gmmAVeeQ= github.com/bettercap/nrf24 v0.0.0-20190219153547-aa37e6d0e0eb h1:JWAAJk4ny+bT3VrtcX+e7mcmWtWUeUM0xVcocSAUuWc= @@ -27,22 +26,23 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= -github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= +github.com/elazarl/goproxy v0.0.0-20240726154733-8b0c20506380 h1:1NyRx2f4W4WBRyg0Kys0ZbaNmDDzZ2R/C7DTi+bbsJ0= +github.com/elazarl/goproxy v0.0.0-20240726154733-8b0c20506380/go.mod h1:thX175TtLTzLj3p7N/Q9IiKZ7NF+p72cvL91emV0hzo= +github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e h1:CQn2/8fi3kmpT9BTiHEELgdxAOQNVZc9GoPA4qnQzrs= +github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8= github.com/evilsocket/islazy v1.11.0 h1:B5w6uuS6ki6iDG+aH/RFeoMb8ijQh/pGabewqp2UeJ0= github.com/evilsocket/islazy v1.11.0/go.mod h1:muYH4x5MB5YRdkxnrOtrXLIBX6LySj1uFIqys94LKdo= github.com/florianl/go-nfqueue/v2 v2.0.0 h1:NTCxS9b0GSbHkWv1a7oOvZn679fsyDkaSkRvOYpQ9Oo= github.com/florianl/go-nfqueue/v2 v2.0.0/go.mod h1:M2tBLIj62QpwqjwV0qfcjqGOqP3qiTuXr2uSRBXH9Qk= github.com/gobwas/glob v0.0.0-20181002190808-e7a84e9525fe h1:8P+/htb3mwwpeGdJg69yBF/RofK7c6Fjz5Ypa/bTqbY= github.com/gobwas/glob v0.0.0-20181002190808-e7a84e9525fe/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= -github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -55,6 +55,8 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE= +github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs= github.com/hashicorp/go-bexpr v0.1.14 h1:uKDeyuOhWhT1r5CiMTjdVY4Aoxdxs6EtwgTGnlosyp4= github.com/hashicorp/go-bexpr v0.1.14/go.mod h1:gN7hRKB3s7yT+YvTdnhZVLTENejvhlkZ8UE4YVBS+Q8= github.com/inconshreveable/go-vhost v1.0.0 h1:IK4VZTlXL4l9vz2IZoiSFbYaaqUW7dXJAiPriUN5Ur8= @@ -74,28 +76,29 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/malfunkt/iprange v0.9.0 h1:VCs0PKLUPotNVQTpVNszsut4lP7OCGNBwX+lOYBrnVQ= github.com/malfunkt/iprange v0.9.0/go.mod h1:TRGqO/f95gh3LOndUGTL46+W0GXA91WTqyZ0Quwvt4U= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdlayher/dhcp6 v0.0.0-20190311162359-2a67805d7d0b h1:r12blE3QRYlW1WBiBEe007O6NrTb/P54OjR5d4WLEGk= github.com/mdlayher/dhcp6 v0.0.0-20190311162359-2a67805d7d0b/go.mod h1:p4K2+UAoap8Jzsadsxc0KG0OZjmmCthTPUyZqAVkjBY= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= -github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= +github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/logxi v0.0.0-20161027140823-aebf8a7d67ab h1:n8cgpHzJ5+EDyDri2s/GC7a9+qK3/YEGnBsd0uS/8PY= github.com/mgutz/logxi v0.0.0-20161027140823-aebf8a7d67ab/go.mod h1:y1pL58r5z2VvAjeG1VLGc8zOQgSOzbKN7kMHPvFXJ+8= -github.com/miekg/dns v1.1.67 h1:kg0EHj0G4bfT5/oOys6HhZw4vmMlnoZ+gDu8tJ/AlI0= -github.com/miekg/dns v1.1.67/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= +github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= +github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/pointerstructure v1.2.1 h1:ZhBBeX8tSlRpu/FFhXH4RC4OJzFlqsQhoHZAz4x7TIw= github.com/mitchellh/pointerstructure v1.2.1/go.mod h1:BRAsLI5zgXmw97Lf6s25bs8ohIXc3tViBH44KcwB2g4= github.com/phin1x/go-ipp v1.6.1 h1:oxJXi92BO2FZhNcG3twjnxKFH1liTQ46vbbZx+IN/80= @@ -104,8 +107,9 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/robertkrimen/otto v0.5.1 h1:avDI4ToRk8k1hppLdYFTuuzND41n37vPGJU7547dGf0= -github.com/robertkrimen/otto v0.5.1/go.mod h1:bS433I4Q9p+E5pZLu7r17vP6FkE6/wLxBdmKjoqJXF8= +github.com/robertkrimen/otto v0.4.0 h1:/c0GRrK1XDPcgIasAsnlpBT5DelIeB9U/Z/JCQsgr7E= +github.com/robertkrimen/otto v0.4.0/go.mod h1:uW9yN1CYflmUQYvAMS0m+ZiNo3dMzRUDQJX0jWbzgxw= +github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= github.com/stratoberry/go-gpsd v1.3.0 h1:JxJOEC4SgD0QY65AE7B1CtJtweP73nqJghZeLNU9J+c= github.com/stratoberry/go-gpsd v1.3.0/go.mod h1:nVf/vTgfYxOMxiQdy9BtJjojbFRtG8H3wNula++VgkU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -115,16 +119,15 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 h1:UyzmZLoiDWMRywV4DUYb9Fbt8uiOSooupjTq10vpvnU= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/thoj/go-ircevent v0.0.0-20210723090443-73e444401d64 h1:l/T7dYuJEQZOwVOpjIXr1180aM9PZL/d1MnMVIxefX4= github.com/thoj/go-ircevent v0.0.0-20210723090443-73e444401d64/go.mod h1:Q1NAJOuRdQCqN/VIWdnaaEhV8LpeO2rtlBP7/iDJNII= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -go.einride.tech/can v0.14.0 h1:OkQ0jsjCk4ijgTMjD43V1NKQyDztpX7Vo/NrvmnsAXE= -go.einride.tech/can v0.14.0/go.mod h1:615YuRGnWfndMGD+f3Ud1sp1xJLP1oj14dKRtb2CXDQ= +go.einride.tech/can v0.12.0 h1:6MW9TKycSovWqJxcYHpZEiuFCGuAfpqApCzTS15KrPk= +go.einride.tech/can v0.12.0/go.mod h1:5n3+AonCfUso6PfjD9l2d0W2LxTFjjHOnHAm+UMS9Ws= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -132,22 +135,25 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190310074541-c10a0554eabf/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -157,23 +163,25 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/js/crypto.go b/js/crypto.go deleted file mode 100644 index 7128b965..00000000 --- a/js/crypto.go +++ /dev/null @@ -1,29 +0,0 @@ -package js - -import ( - "crypto/sha1" - - "github.com/robertkrimen/otto" -) - -func cryptoSha1(call otto.FunctionCall) otto.Value { - argv := call.ArgumentList - argc := len(argv) - if argc != 1 { - return ReportError("Crypto.sha1: expected 1 argument, %d given instead.", argc) - } - - arg := argv[0] - if (!arg.IsString()) { - return ReportError("Crypto.sha1: single argument must be a string.") - } - - hasher := sha1.New() - hasher.Write([]byte(arg.String())) - v, err := otto.ToValue(string(hasher.Sum(nil))) - if err != nil { - return ReportError("Crypto.sha1: could not convert to string: %s", err) - } - - return v -} diff --git a/js/data.go b/js/data.go index 6fe48f22..e2bfe5b0 100644 --- a/js/data.go +++ b/js/data.go @@ -8,94 +8,25 @@ import ( "github.com/robertkrimen/otto" ) -func textEncode(call otto.FunctionCall) otto.Value { - argv := call.ArgumentList - argc := len(argv) - if argc != 1 { - return ReportError("textEncode: expected 1 argument, %d given instead.", argc) - } - - arg := argv[0] - if (!arg.IsString()) { - return ReportError("textEncode: single argument must be a string.") - } - - encoded := []byte(arg.String()) - vm := otto.New() - v, err := vm.ToValue(encoded) - if err != nil { - return ReportError("textEncode: could not convert to []uint8: %s", err.Error()) - } - - return v -} - -func textDecode(call otto.FunctionCall) otto.Value { - argv := call.ArgumentList - argc := len(argv) - if argc != 1 { - return ReportError("textDecode: expected 1 argument, %d given instead.", argc) - } - - arg, err := argv[0].Export() - if err != nil { - return ReportError("textDecode: could not export argument value: %s", err.Error()) - } - byteArr, ok := arg.([]uint8) - if !ok { - return ReportError("textDecode: single argument must be of type []uint8.") - } - - decoded := string(byteArr) - v, err := otto.ToValue(decoded) - if err != nil { - return ReportError("textDecode: could not convert to string: %s", err.Error()) - } - - return v -} - func btoa(call otto.FunctionCall) otto.Value { - argv := call.ArgumentList - argc := len(argv) - if argc != 1 { - return ReportError("btoa: expected 1 argument, %d given instead.", argc) - } - - arg := argv[0] - if (!arg.IsString()) { - return ReportError("btoa: single argument must be a string.") - } - - encoded := base64.StdEncoding.EncodeToString([]byte(arg.String())) - v, err := otto.ToValue(encoded) + varValue := base64.StdEncoding.EncodeToString([]byte(call.Argument(0).String())) + v, err := otto.ToValue(varValue) if err != nil { - return ReportError("btoa: could not convert to string: %s", err.Error()) + return ReportError("Could not convert to string: %s", varValue) } return v } func atob(call otto.FunctionCall) otto.Value { - argv := call.ArgumentList - argc := len(argv) - if argc != 1 { - return ReportError("atob: expected 1 argument, %d given instead.", argc) - } - - arg := argv[0] - if (!arg.IsString()) { - return ReportError("atob: single argument must be a string.") - } - - decoded, err := base64.StdEncoding.DecodeString(arg.String()) + varValue, err := base64.StdEncoding.DecodeString(call.Argument(0).String()) if err != nil { - return ReportError("atob: could not decode string: %s", err.Error()) + return ReportError("Could not decode string: %s", call.Argument(0).String()) } - v, err := otto.ToValue(string(decoded)) + v, err := otto.ToValue(string(varValue)) if err != nil { - return ReportError("atob: could not convert to string: %s", err.Error()) + return ReportError("Could not convert to string: %s", varValue) } return v @@ -108,12 +39,7 @@ func gzipCompress(call otto.FunctionCall) otto.Value { return ReportError("gzipCompress: expected 1 argument, %d given instead.", argc) } - arg := argv[0] - if (!arg.IsString()) { - return ReportError("gzipCompress: single argument must be a string.") - } - - uncompressedBytes := []byte(arg.String()) + uncompressedBytes := []byte(argv[0].String()) var writerBuffer bytes.Buffer gzipWriter := gzip.NewWriter(&writerBuffer) @@ -127,7 +53,7 @@ func gzipCompress(call otto.FunctionCall) otto.Value { v, err := otto.ToValue(string(compressedBytes)) if err != nil { - return ReportError("gzipCompress: could not convert to string: %s", err.Error()) + return ReportError("Could not convert to string: %s", err.Error()) } return v @@ -157,7 +83,7 @@ func gzipDecompress(call otto.FunctionCall) otto.Value { decompressedBytes := decompressedBuffer.Bytes() v, err := otto.ToValue(string(decompressedBytes)) if err != nil { - return ReportError("gzipDecompress: could not convert to string: %s", err.Error()) + return ReportError("Could not convert to string: %s", err.Error()) } return v diff --git a/js/data_test.go b/js/data_test.go deleted file mode 100644 index 64326418..00000000 --- a/js/data_test.go +++ /dev/null @@ -1,514 +0,0 @@ -package js - -import ( - "encoding/base64" - "strings" - "testing" - - "github.com/robertkrimen/otto" -) - -func TestBtoa(t *testing.T) { - vm := otto.New() - - tests := []struct { - name string - input string - expected string - }{ - { - name: "simple string", - input: "hello world", - expected: base64.StdEncoding.EncodeToString([]byte("hello world")), - }, - { - name: "empty string", - input: "", - expected: base64.StdEncoding.EncodeToString([]byte("")), - }, - { - name: "special characters", - input: "!@#$%^&*()_+-=[]{}|;:,.<>?", - expected: base64.StdEncoding.EncodeToString([]byte("!@#$%^&*()_+-=[]{}|;:,.<>?")), - }, - { - name: "unicode string", - input: "Hello 世界 🌍", - expected: base64.StdEncoding.EncodeToString([]byte("Hello 世界 🌍")), - }, - { - name: "newlines and tabs", - input: "line1\nline2\ttab", - expected: base64.StdEncoding.EncodeToString([]byte("line1\nline2\ttab")), - }, - { - name: "long string", - input: strings.Repeat("a", 1000), - expected: base64.StdEncoding.EncodeToString([]byte(strings.Repeat("a", 1000))), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create call with argument - arg, _ := vm.ToValue(tt.input) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := btoa(call) - - // Check if result is an error - if result.IsUndefined() { - t.Fatal("btoa returned undefined") - } - - // Get string value - resultStr, err := result.ToString() - if err != nil { - t.Fatalf("failed to convert result to string: %v", err) - } - - if resultStr != tt.expected { - t.Errorf("btoa(%q) = %q, want %q", tt.input, resultStr, tt.expected) - } - }) - } -} - -func TestAtob(t *testing.T) { - vm := otto.New() - - tests := []struct { - name string - input string - expected string - wantError bool - }{ - { - name: "simple base64", - input: base64.StdEncoding.EncodeToString([]byte("hello world")), - expected: "hello world", - }, - { - name: "empty base64", - input: base64.StdEncoding.EncodeToString([]byte("")), - expected: "", - }, - { - name: "special characters base64", - input: base64.StdEncoding.EncodeToString([]byte("!@#$%^&*()_+-=[]{}|;:,.<>?")), - expected: "!@#$%^&*()_+-=[]{}|;:,.<>?", - }, - { - name: "unicode base64", - input: base64.StdEncoding.EncodeToString([]byte("Hello 世界 🌍")), - expected: "Hello 世界 🌍", - }, - { - name: "invalid base64", - input: "not valid base64!", - wantError: true, - }, - { - name: "invalid padding", - input: "SGVsbG8gV29ybGQ", // Missing padding - wantError: true, - }, - { - name: "long base64", - input: base64.StdEncoding.EncodeToString([]byte(strings.Repeat("a", 1000))), - expected: strings.Repeat("a", 1000), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create call with argument - arg, _ := vm.ToValue(tt.input) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := atob(call) - - // Get string value - resultStr, err := result.ToString() - if err != nil && !tt.wantError { - t.Fatalf("failed to convert result to string: %v", err) - } - - if tt.wantError { - // Should return undefined (NullValue) on error - if !result.IsUndefined() { - t.Errorf("expected undefined for error case, got %q", resultStr) - } - } else { - if resultStr != tt.expected { - t.Errorf("atob(%q) = %q, want %q", tt.input, resultStr, tt.expected) - } - } - }) - } -} - -func TestGzipCompress(t *testing.T) { - vm := otto.New() - - tests := []struct { - name string - input string - }{ - { - name: "simple string", - input: "hello world", - }, - { - name: "empty string", - input: "", - }, - { - name: "repeated pattern", - input: strings.Repeat("abcd", 100), - }, - { - name: "random text", - input: "The quick brown fox jumps over the lazy dog. " + strings.Repeat("Lorem ipsum dolor sit amet. ", 10), - }, - { - name: "unicode text", - input: "Hello 世界 🌍 " + strings.Repeat("测试数据 ", 50), - }, - { - name: "binary-like data", - input: string([]byte{0, 1, 2, 3, 255, 254, 253, 252}), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create call with argument - arg, _ := vm.ToValue(tt.input) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := gzipCompress(call) - - // Get compressed data - compressed, err := result.ToString() - if err != nil { - t.Fatalf("failed to convert result to string: %v", err) - } - - // Verify it's actually compressed (for non-empty strings, compressed should be different) - if tt.input != "" && compressed == tt.input { - t.Error("compressed data is same as input") - } - - // Verify gzip header (should start with 0x1f, 0x8b) - if len(compressed) >= 2 { - if compressed[0] != 0x1f || compressed[1] != 0x8b { - t.Error("compressed data doesn't have valid gzip header") - } - } - - // Now decompress to verify - argCompressed, _ := vm.ToValue(compressed) - callDecompress := otto.FunctionCall{ - ArgumentList: []otto.Value{argCompressed}, - } - - resultDecompressed := gzipDecompress(callDecompress) - decompressed, err := resultDecompressed.ToString() - if err != nil { - t.Fatalf("failed to decompress: %v", err) - } - - if decompressed != tt.input { - t.Errorf("round-trip failed: got %q, want %q", decompressed, tt.input) - } - }) - } -} - -func TestGzipCompressInvalidArgs(t *testing.T) { - vm := otto.New() - - tests := []struct { - name string - args []otto.Value - }{ - { - name: "no arguments", - args: []otto.Value{}, - }, - { - name: "too many arguments", - args: func() []otto.Value { - arg1, _ := vm.ToValue("test") - arg2, _ := vm.ToValue("extra") - return []otto.Value{arg1, arg2} - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - call := otto.FunctionCall{ - ArgumentList: tt.args, - } - - result := gzipCompress(call) - - // Should return undefined (NullValue) on error - if !result.IsUndefined() { - resultStr, _ := result.ToString() - t.Errorf("expected undefined for error case, got %q", resultStr) - } - }) - } -} - -func TestGzipDecompress(t *testing.T) { - vm := otto.New() - - // First compress some data - originalData := "This is test data for decompression" - arg, _ := vm.ToValue(originalData) - compressCall := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - compressedResult := gzipCompress(compressCall) - compressedData, _ := compressedResult.ToString() - - t.Run("valid decompression", func(t *testing.T) { - argCompressed, _ := vm.ToValue(compressedData) - decompressCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argCompressed}, - } - - result := gzipDecompress(decompressCall) - decompressed, err := result.ToString() - if err != nil { - t.Fatalf("failed to convert result to string: %v", err) - } - - if decompressed != originalData { - t.Errorf("decompressed data doesn't match original: got %q, want %q", decompressed, originalData) - } - }) - - t.Run("invalid gzip data", func(t *testing.T) { - argInvalid, _ := vm.ToValue("not gzip data") - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argInvalid}, - } - - result := gzipDecompress(call) - - // Should return undefined (NullValue) on error - if !result.IsUndefined() { - resultStr, _ := result.ToString() - t.Errorf("expected undefined for error case, got %q", resultStr) - } - }) - - t.Run("corrupted gzip data", func(t *testing.T) { - // Create corrupted gzip by taking valid gzip and modifying it - corruptedData := compressedData[:len(compressedData)/2] + "corrupted" - - argCorrupted, _ := vm.ToValue(corruptedData) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argCorrupted}, - } - - result := gzipDecompress(call) - - // Should return undefined (NullValue) on error - if !result.IsUndefined() { - resultStr, _ := result.ToString() - t.Errorf("expected undefined for error case, got %q", resultStr) - } - }) -} - -func TestGzipDecompressInvalidArgs(t *testing.T) { - vm := otto.New() - - tests := []struct { - name string - args []otto.Value - }{ - { - name: "no arguments", - args: []otto.Value{}, - }, - { - name: "too many arguments", - args: func() []otto.Value { - arg1, _ := vm.ToValue("test") - arg2, _ := vm.ToValue("extra") - return []otto.Value{arg1, arg2} - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - call := otto.FunctionCall{ - ArgumentList: tt.args, - } - - result := gzipDecompress(call) - - // Should return undefined (NullValue) on error - if !result.IsUndefined() { - resultStr, _ := result.ToString() - t.Errorf("expected undefined for error case, got %q", resultStr) - } - }) - } -} - -func TestBtoaAtobRoundTrip(t *testing.T) { - vm := otto.New() - - testStrings := []string{ - "simple", - "", - "with spaces and\nnewlines\ttabs", - "special!@#$%^&*()_+-=[]{}|;:,.<>?", - "unicode 世界 🌍", - strings.Repeat("long string ", 100), - } - - for _, original := range testStrings { - t.Run(original, func(t *testing.T) { - // Encode with btoa - argOriginal, _ := vm.ToValue(original) - encodeCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argOriginal}, - } - - encoded := btoa(encodeCall) - encodedStr, _ := encoded.ToString() - - // Decode with atob - argEncoded, _ := vm.ToValue(encodedStr) - decodeCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argEncoded}, - } - - decoded := atob(decodeCall) - decodedStr, _ := decoded.ToString() - - if decodedStr != original { - t.Errorf("round-trip failed: got %q, want %q", decodedStr, original) - } - }) - } -} - -func TestGzipCompressDecompressRoundTrip(t *testing.T) { - vm := otto.New() - - testData := []string{ - "simple", - "", - strings.Repeat("repetitive data ", 100), - "unicode 世界 🌍 " + strings.Repeat("测试 ", 50), - string([]byte{0, 1, 2, 3, 255, 254, 253, 252}), - } - - for _, original := range testData { - t.Run(original, func(t *testing.T) { - // Compress - argOriginal, _ := vm.ToValue(original) - compressCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argOriginal}, - } - - compressed := gzipCompress(compressCall) - compressedStr, _ := compressed.ToString() - - // Decompress - argCompressed, _ := vm.ToValue(compressedStr) - decompressCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argCompressed}, - } - - decompressed := gzipDecompress(decompressCall) - decompressedStr, _ := decompressed.ToString() - - if decompressedStr != original { - t.Errorf("round-trip failed: got %q, want %q", decompressedStr, original) - } - }) - } -} - -func BenchmarkBtoa(b *testing.B) { - vm := otto.New() - arg, _ := vm.ToValue("The quick brown fox jumps over the lazy dog") - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = btoa(call) - } -} - -func BenchmarkAtob(b *testing.B) { - vm := otto.New() - encoded := base64.StdEncoding.EncodeToString([]byte("The quick brown fox jumps over the lazy dog")) - arg, _ := vm.ToValue(encoded) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = atob(call) - } -} - -func BenchmarkGzipCompress(b *testing.B) { - vm := otto.New() - data := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 10) - arg, _ := vm.ToValue(data) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = gzipCompress(call) - } -} - -func BenchmarkGzipDecompress(b *testing.B) { - vm := otto.New() - - // First compress some data - data := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 10) - argData, _ := vm.ToValue(data) - compressCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argData}, - } - compressed := gzipCompress(compressCall) - compressedStr, _ := compressed.ToString() - - // Benchmark decompression - argCompressed, _ := vm.ToValue(compressedStr) - decompressCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argCompressed}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = gzipDecompress(decompressCall) - } -} diff --git a/js/fs_test.go b/js/fs_test.go deleted file mode 100644 index fd089d28..00000000 --- a/js/fs_test.go +++ /dev/null @@ -1,684 +0,0 @@ -package js - -import ( - "fmt" - "io/ioutil" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - - "github.com/robertkrimen/otto" -) - -func TestReadDir(t *testing.T) { - vm := otto.New() - - // Create a temporary directory for testing - tmpDir, err := ioutil.TempDir("", "js_test_readdir_*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - // Create some test files and subdirectories - testFiles := []string{"file1.txt", "file2.log", ".hidden"} - testDirs := []string{"subdir1", "subdir2"} - - for _, name := range testFiles { - if err := ioutil.WriteFile(filepath.Join(tmpDir, name), []byte("test"), 0644); err != nil { - t.Fatalf("failed to create test file %s: %v", name, err) - } - } - - for _, name := range testDirs { - if err := os.Mkdir(filepath.Join(tmpDir, name), 0755); err != nil { - t.Fatalf("failed to create test dir %s: %v", name, err) - } - } - - t.Run("valid directory", func(t *testing.T) { - arg, _ := vm.ToValue(tmpDir) - call := otto.FunctionCall{ - Otto: vm, - ArgumentList: []otto.Value{arg}, - } - - result := readDir(call) - - // Check if result is not undefined - if result.IsUndefined() { - t.Fatal("readDir returned undefined") - } - - // Convert to Go slice - export, err := result.Export() - if err != nil { - t.Fatalf("failed to export result: %v", err) - } - - entries, ok := export.([]string) - if !ok { - t.Fatalf("expected []string, got %T", export) - } - - // Check all expected entries are present - expectedEntries := append(testFiles, testDirs...) - if len(entries) != len(expectedEntries) { - t.Errorf("expected %d entries, got %d", len(expectedEntries), len(entries)) - } - - // Check each entry exists - for _, expected := range expectedEntries { - found := false - for _, entry := range entries { - if entry == expected { - found = true - break - } - } - if !found { - t.Errorf("expected entry %s not found", expected) - } - } - }) - - t.Run("non-existent directory", func(t *testing.T) { - arg, _ := vm.ToValue("/path/that/does/not/exist") - call := otto.FunctionCall{ - Otto: vm, - ArgumentList: []otto.Value{arg}, - } - - result := readDir(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined for non-existent directory") - } - }) - - t.Run("file instead of directory", func(t *testing.T) { - // Create a file - testFile := filepath.Join(tmpDir, "notadir.txt") - ioutil.WriteFile(testFile, []byte("test"), 0644) - - arg, _ := vm.ToValue(testFile) - call := otto.FunctionCall{ - Otto: vm, - ArgumentList: []otto.Value{arg}, - } - - result := readDir(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined when passing file instead of directory") - } - }) - - t.Run("invalid arguments", func(t *testing.T) { - tests := []struct { - name string - args []otto.Value - }{ - { - name: "no arguments", - args: []otto.Value{}, - }, - { - name: "too many arguments", - args: func() []otto.Value { - arg1, _ := vm.ToValue(tmpDir) - arg2, _ := vm.ToValue("extra") - return []otto.Value{arg1, arg2} - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - call := otto.FunctionCall{ - Otto: vm, - ArgumentList: tt.args, - } - - result := readDir(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined for invalid arguments") - } - }) - } - }) - - t.Run("empty directory", func(t *testing.T) { - emptyDir := filepath.Join(tmpDir, "empty") - os.Mkdir(emptyDir, 0755) - - arg, _ := vm.ToValue(emptyDir) - call := otto.FunctionCall{ - Otto: vm, - ArgumentList: []otto.Value{arg}, - } - - result := readDir(call) - - if result.IsUndefined() { - t.Fatal("readDir returned undefined for empty directory") - } - - export, _ := result.Export() - entries, _ := export.([]string) - - if len(entries) != 0 { - t.Errorf("expected 0 entries for empty directory, got %d", len(entries)) - } - }) -} - -func TestReadFile(t *testing.T) { - vm := otto.New() - - // Create a temporary directory for testing - tmpDir, err := ioutil.TempDir("", "js_test_readfile_*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - t.Run("valid file", func(t *testing.T) { - testContent := "Hello, World!\nThis is a test file.\n特殊字符测试 🌍" - testFile := filepath.Join(tmpDir, "test.txt") - ioutil.WriteFile(testFile, []byte(testContent), 0644) - - arg, _ := vm.ToValue(testFile) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := readFile(call) - - if result.IsUndefined() { - t.Fatal("readFile returned undefined") - } - - content, err := result.ToString() - if err != nil { - t.Fatalf("failed to convert result to string: %v", err) - } - - if content != testContent { - t.Errorf("expected content %q, got %q", testContent, content) - } - }) - - t.Run("non-existent file", func(t *testing.T) { - arg, _ := vm.ToValue("/path/that/does/not/exist.txt") - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := readFile(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined for non-existent file") - } - }) - - t.Run("directory instead of file", func(t *testing.T) { - arg, _ := vm.ToValue(tmpDir) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := readFile(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined when passing directory instead of file") - } - }) - - t.Run("empty file", func(t *testing.T) { - emptyFile := filepath.Join(tmpDir, "empty.txt") - ioutil.WriteFile(emptyFile, []byte(""), 0644) - - arg, _ := vm.ToValue(emptyFile) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := readFile(call) - - if result.IsUndefined() { - t.Fatal("readFile returned undefined for empty file") - } - - content, _ := result.ToString() - if content != "" { - t.Errorf("expected empty string, got %q", content) - } - }) - - t.Run("binary file", func(t *testing.T) { - binaryContent := []byte{0, 1, 2, 3, 255, 254, 253, 252} - binaryFile := filepath.Join(tmpDir, "binary.bin") - ioutil.WriteFile(binaryFile, binaryContent, 0644) - - arg, _ := vm.ToValue(binaryFile) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := readFile(call) - - if result.IsUndefined() { - t.Fatal("readFile returned undefined for binary file") - } - - content, _ := result.ToString() - if content != string(binaryContent) { - t.Error("binary content mismatch") - } - }) - - t.Run("invalid arguments", func(t *testing.T) { - tests := []struct { - name string - args []otto.Value - }{ - { - name: "no arguments", - args: []otto.Value{}, - }, - { - name: "too many arguments", - args: func() []otto.Value { - arg1, _ := vm.ToValue("file.txt") - arg2, _ := vm.ToValue("extra") - return []otto.Value{arg1, arg2} - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - call := otto.FunctionCall{ - ArgumentList: tt.args, - } - - result := readFile(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined for invalid arguments") - } - }) - } - }) - - t.Run("large file", func(t *testing.T) { - // Create a 1MB file - largeContent := strings.Repeat("A", 1024*1024) - largeFile := filepath.Join(tmpDir, "large.txt") - ioutil.WriteFile(largeFile, []byte(largeContent), 0644) - - arg, _ := vm.ToValue(largeFile) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - result := readFile(call) - - if result.IsUndefined() { - t.Fatal("readFile returned undefined for large file") - } - - content, _ := result.ToString() - if len(content) != len(largeContent) { - t.Errorf("expected content length %d, got %d", len(largeContent), len(content)) - } - }) -} - -func TestWriteFile(t *testing.T) { - vm := otto.New() - - // Create a temporary directory for testing - tmpDir, err := ioutil.TempDir("", "js_test_writefile_*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - t.Run("write new file", func(t *testing.T) { - testFile := filepath.Join(tmpDir, "new_file.txt") - testContent := "Hello, World!\nThis is a new file.\n特殊字符测试 🌍" - - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue(testContent) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - - result := writeFile(call) - - // writeFile returns null on success - if !result.IsNull() { - t.Error("expected null return value for successful write") - } - - // Verify file was created with correct content - content, err := ioutil.ReadFile(testFile) - if err != nil { - t.Fatalf("failed to read written file: %v", err) - } - - if string(content) != testContent { - t.Errorf("expected content %q, got %q", testContent, string(content)) - } - - // Check file permissions - info, _ := os.Stat(testFile) - if runtime.GOOS == "windows" { - // On Windows, permissions are different - just check that file exists and is readable - if info.Mode()&0400 == 0 { - t.Error("expected file to be readable on Windows") - } - } else { - // On Unix-like systems, check exact permissions - if info.Mode().Perm() != 0644 { - t.Errorf("expected permissions 0644, got %v", info.Mode().Perm()) - } - } - }) - - t.Run("overwrite existing file", func(t *testing.T) { - testFile := filepath.Join(tmpDir, "existing.txt") - oldContent := "Old content" - newContent := "New content that is longer than the old content" - - // Create initial file - ioutil.WriteFile(testFile, []byte(oldContent), 0644) - - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue(newContent) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - - result := writeFile(call) - - if !result.IsNull() { - t.Error("expected null return value for successful write") - } - - // Verify file was overwritten - content, _ := ioutil.ReadFile(testFile) - if string(content) != newContent { - t.Errorf("expected content %q, got %q", newContent, string(content)) - } - }) - - t.Run("write to non-existent directory", func(t *testing.T) { - testFile := filepath.Join(tmpDir, "nonexistent", "subdir", "file.txt") - testContent := "test" - - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue(testContent) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - - result := writeFile(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined when writing to non-existent directory") - } - }) - - t.Run("write empty content", func(t *testing.T) { - testFile := filepath.Join(tmpDir, "empty.txt") - - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue("") - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - - result := writeFile(call) - - if !result.IsNull() { - t.Error("expected null return value for successful write") - } - - // Verify empty file was created - content, _ := ioutil.ReadFile(testFile) - if len(content) != 0 { - t.Errorf("expected empty file, got %d bytes", len(content)) - } - }) - - t.Run("invalid arguments", func(t *testing.T) { - tests := []struct { - name string - args []otto.Value - }{ - { - name: "no arguments", - args: []otto.Value{}, - }, - { - name: "one argument", - args: func() []otto.Value { - arg, _ := vm.ToValue("file.txt") - return []otto.Value{arg} - }(), - }, - { - name: "too many arguments", - args: func() []otto.Value { - arg1, _ := vm.ToValue("file.txt") - arg2, _ := vm.ToValue("content") - arg3, _ := vm.ToValue("extra") - return []otto.Value{arg1, arg2, arg3} - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - call := otto.FunctionCall{ - ArgumentList: tt.args, - } - - result := writeFile(call) - - // Should return undefined (error) - if !result.IsUndefined() { - t.Error("expected undefined for invalid arguments") - } - }) - } - }) - - t.Run("write binary content", func(t *testing.T) { - testFile := filepath.Join(tmpDir, "binary.bin") - binaryContent := string([]byte{0, 1, 2, 3, 255, 254, 253, 252}) - - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue(binaryContent) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - - result := writeFile(call) - - if !result.IsNull() { - t.Error("expected null return value for successful write") - } - - // Verify binary content - content, _ := ioutil.ReadFile(testFile) - if string(content) != binaryContent { - t.Error("binary content mismatch") - } - }) -} - -func TestFileSystemIntegration(t *testing.T) { - vm := otto.New() - - // Create a temporary directory for testing - tmpDir, err := ioutil.TempDir("", "js_test_integration_*") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - t.Run("write then read file", func(t *testing.T) { - testFile := filepath.Join(tmpDir, "roundtrip.txt") - testContent := "Round-trip test content\nLine 2\nLine 3" - - // Write file - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue(testContent) - writeCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - - writeResult := writeFile(writeCall) - if !writeResult.IsNull() { - t.Fatal("write failed") - } - - // Read file back - readCall := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile}, - } - - readResult := readFile(readCall) - if readResult.IsUndefined() { - t.Fatal("read failed") - } - - readContent, _ := readResult.ToString() - if readContent != testContent { - t.Errorf("round-trip failed: expected %q, got %q", testContent, readContent) - } - }) - - t.Run("create files then list directory", func(t *testing.T) { - // Create multiple files - files := []string{"file1.txt", "file2.txt", "file3.txt"} - for _, name := range files { - path := filepath.Join(tmpDir, name) - argFile, _ := vm.ToValue(path) - argContent, _ := vm.ToValue("content of " + name) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - writeFile(call) - } - - // List directory - argDir, _ := vm.ToValue(tmpDir) - listCall := otto.FunctionCall{ - Otto: vm, - ArgumentList: []otto.Value{argDir}, - } - - listResult := readDir(listCall) - if listResult.IsUndefined() { - t.Fatal("readDir failed") - } - - export, _ := listResult.Export() - entries, _ := export.([]string) - - // Check all files are listed - for _, expected := range files { - found := false - for _, entry := range entries { - if entry == expected { - found = true - break - } - } - if !found { - t.Errorf("expected file %s not found in directory listing", expected) - } - } - }) -} - -func BenchmarkReadFile(b *testing.B) { - vm := otto.New() - - // Create test file - tmpFile, _ := ioutil.TempFile("", "bench_readfile_*") - defer os.Remove(tmpFile.Name()) - - content := strings.Repeat("Benchmark test content line\n", 100) - ioutil.WriteFile(tmpFile.Name(), []byte(content), 0644) - - arg, _ := vm.ToValue(tmpFile.Name()) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{arg}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = readFile(call) - } -} - -func BenchmarkWriteFile(b *testing.B) { - vm := otto.New() - - tmpDir, _ := ioutil.TempDir("", "bench_writefile_*") - defer os.RemoveAll(tmpDir) - - content := strings.Repeat("Benchmark test content line\n", 100) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - testFile := filepath.Join(tmpDir, fmt.Sprintf("bench_%d.txt", i)) - argFile, _ := vm.ToValue(testFile) - argContent, _ := vm.ToValue(content) - call := otto.FunctionCall{ - ArgumentList: []otto.Value{argFile, argContent}, - } - _ = writeFile(call) - } -} - -func BenchmarkReadDir(b *testing.B) { - vm := otto.New() - - // Create test directory with files - tmpDir, _ := ioutil.TempDir("", "bench_readdir_*") - defer os.RemoveAll(tmpDir) - - // Create 100 files - for i := 0; i < 100; i++ { - name := filepath.Join(tmpDir, fmt.Sprintf("file_%d.txt", i)) - ioutil.WriteFile(name, []byte("test"), 0644) - } - - arg, _ := vm.ToValue(tmpDir) - call := otto.FunctionCall{ - Otto: vm, - ArgumentList: []otto.Value{arg}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = readDir(call) - } -} diff --git a/js/http.go b/js/http.go index 685f8ec0..615928cb 100644 --- a/js/http.go +++ b/js/http.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "net/http" "net/url" "strings" @@ -63,7 +64,7 @@ func (c httpPackage) Request(method string, uri string, } defer resp.Body.Close() - raw, err := io.ReadAll(resp.Body) + raw, err := ioutil.ReadAll(resp.Body) if err != nil { return httpResponse{Error: err} } @@ -132,7 +133,7 @@ func httpRequest(call otto.FunctionCall) otto.Value { } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) if err != nil { return ReportError("Could not read response: %s", err) } diff --git a/js/init.go b/js/init.go index 1aaa52cd..6415dd88 100644 --- a/js/init.go +++ b/js/init.go @@ -27,16 +27,10 @@ func init() { plugin.Defines["log_error"] = log_error plugin.Defines["log_fatal"] = log_fatal - plugin.Defines["Crypto"] = map[string]interface{}{ - "sha1": cryptoSha1, - } - plugin.Defines["btoa"] = btoa plugin.Defines["atob"] = atob plugin.Defines["gzipCompress"] = gzipCompress plugin.Defines["gzipDecompress"] = gzipDecompress - plugin.Defines["textEncode"] = textEncode - plugin.Defines["textDecode"] = textDecode plugin.Defines["httpRequest"] = httpRequest plugin.Defines["http"] = httpPackage{} diff --git a/js/random_test.go b/js/random_test.go deleted file mode 100644 index 594a16ad..00000000 --- a/js/random_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package js - -import ( - "net" - "regexp" - "strings" - "testing" -) - -func TestRandomString(t *testing.T) { - r := randomPackage{} - - tests := []struct { - name string - size int - charset string - }{ - { - name: "alphanumeric", - size: 10, - charset: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", - }, - { - name: "numbers only", - size: 20, - charset: "0123456789", - }, - { - name: "lowercase letters", - size: 15, - charset: "abcdefghijklmnopqrstuvwxyz", - }, - { - name: "uppercase letters", - size: 8, - charset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", - }, - { - name: "special characters", - size: 12, - charset: "!@#$%^&*()_+-=[]{}|;:,.<>?", - }, - { - name: "unicode characters", - size: 5, - charset: "αβγδεζηθικλμνξοπρστυφχψω", - }, - { - name: "mixed unicode and ascii", - size: 10, - charset: "abc123αβγ", - }, - { - name: "single character", - size: 100, - charset: "a", - }, - { - name: "empty size", - size: 0, - charset: "abcdef", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := r.String(tt.size, tt.charset) - - // Check length - if len([]rune(result)) != tt.size { - t.Errorf("expected length %d, got %d", tt.size, len([]rune(result))) - } - - // Check that all characters are from the charset - for _, char := range result { - if !strings.ContainsRune(tt.charset, char) { - t.Errorf("character %c not in charset %s", char, tt.charset) - } - } - }) - } -} - -func TestRandomStringDistribution(t *testing.T) { - r := randomPackage{} - charset := "ab" - size := 1000 - - // Generate many single-character strings - counts := make(map[rune]int) - for i := 0; i < size; i++ { - result := r.String(1, charset) - if len(result) == 1 { - counts[rune(result[0])]++ - } - } - - // Check that both characters appear (very high probability) - if len(counts) != 2 { - t.Errorf("expected both characters to appear, got %d unique characters", len(counts)) - } - - // Check distribution is reasonable (not perfect due to randomness) - for char, count := range counts { - ratio := float64(count) / float64(size) - if ratio < 0.3 || ratio > 0.7 { - t.Errorf("character %c appeared %d times (%.2f%%), expected around 50%%", - char, count, ratio*100) - } - } -} - -func TestRandomMac(t *testing.T) { - r := randomPackage{} - macRegex := regexp.MustCompile(`^([0-9a-f]{2}:){5}[0-9a-f]{2}$`) - - // Generate multiple MAC addresses - macs := make(map[string]bool) - for i := 0; i < 100; i++ { - mac := r.Mac() - - // Check format - if !macRegex.MatchString(mac) { - t.Errorf("invalid MAC format: %s", mac) - } - - // Check it's a valid MAC - _, err := net.ParseMAC(mac) - if err != nil { - t.Errorf("invalid MAC address: %s, error: %v", mac, err) - } - - // Store for uniqueness check - macs[mac] = true - } - - // Check that we get different MACs (very high probability) - if len(macs) < 95 { - t.Errorf("expected at least 95 unique MACs out of 100, got %d", len(macs)) - } -} - -func TestRandomMacNormalization(t *testing.T) { - r := randomPackage{} - - // Generate several MACs and check they're normalized - for i := 0; i < 10; i++ { - mac := r.Mac() - - // Check lowercase - if mac != strings.ToLower(mac) { - t.Errorf("MAC not normalized to lowercase: %s", mac) - } - - // Check separator is colon - if strings.Contains(mac, "-") { - t.Errorf("MAC contains hyphen instead of colon: %s", mac) - } - - // Check length - if len(mac) != 17 { // 6 bytes * 2 chars + 5 colons - t.Errorf("MAC has wrong length: %s (len=%d)", mac, len(mac)) - } - } -} - -func TestRandomStringEdgeCases(t *testing.T) { - r := randomPackage{} - - // Test with various edge cases - tests := []struct { - name string - size int - charset string - }{ - { - name: "zero size", - size: 0, - charset: "abc", - }, - { - name: "very large size", - size: 10000, - charset: "abc", - }, - { - name: "size larger than charset", - size: 10, - charset: "ab", - }, - { - name: "single char charset with large size", - size: 1000, - charset: "x", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := r.String(tt.size, tt.charset) - - if len([]rune(result)) != tt.size { - t.Errorf("expected length %d, got %d", tt.size, len([]rune(result))) - } - - // Check all characters are from charset - for _, c := range result { - if !strings.ContainsRune(tt.charset, c) { - t.Errorf("character %c not in charset %s", c, tt.charset) - } - } - }) - } -} - -func TestRandomStringNegativeSize(t *testing.T) { - r := randomPackage{} - - // Test that negative size causes panic - defer func() { - if r := recover(); r == nil { - t.Error("expected panic for negative size but didn't get one") - } - }() - - // This should panic - _ = r.String(-1, "abc") -} - -func TestRandomPackageInstance(t *testing.T) { - // Test that we can create multiple instances - r1 := randomPackage{} - r2 := randomPackage{} - - // Both should work independently - s1 := r1.String(5, "abc") - s2 := r2.String(5, "xyz") - - if len(s1) != 5 { - t.Errorf("r1.String returned wrong length: %d", len(s1)) - } - if len(s2) != 5 { - t.Errorf("r2.String returned wrong length: %d", len(s2)) - } - - // Check correct charset usage - for _, c := range s1 { - if !strings.ContainsRune("abc", c) { - t.Errorf("r1 produced character outside charset: %c", c) - } - } - for _, c := range s2 { - if !strings.ContainsRune("xyz", c) { - t.Errorf("r2 produced character outside charset: %c", c) - } - } -} - -func BenchmarkRandomString(b *testing.B) { - r := randomPackage{} - charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - - b.Run("size-10", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = r.String(10, charset) - } - }) - - b.Run("size-100", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = r.String(100, charset) - } - }) - - b.Run("size-1000", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = r.String(1000, charset) - } - }) -} - -func BenchmarkRandomMac(b *testing.B) { - r := randomPackage{} - - for i := 0; i < b.N; i++ { - _ = r.Mac() - } -} - -func BenchmarkRandomStringCharsets(b *testing.B) { - r := randomPackage{} - - charsets := map[string]string{ - "small": "abc", - "medium": "abcdefghijklmnopqrstuvwxyz", - "large": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:,.<>?", - "unicode": "αβγδεζηθικλμνξοπρστυφχψωABCDEFGHIJKLMNOPQRSTUVWXYZ", - } - - for name, charset := range charsets { - b.Run(name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = r.String(20, charset) - } - }) - } -} diff --git a/log/log_test.go b/log/log_test.go deleted file mode 100644 index af696d19..00000000 --- a/log/log_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package log - -import ( - "testing" - - "github.com/evilsocket/islazy/log" -) - -var called bool -var calledLevel log.Verbosity -var calledFormat string -var calledArgs []interface{} - -func mockLogger(level log.Verbosity, format string, args ...interface{}) { - called = true - calledLevel = level - calledFormat = format - calledArgs = args -} - -func reset() { - called = false - calledLevel = log.DEBUG - calledFormat = "" - calledArgs = nil -} - -func TestLoggerNil(t *testing.T) { - reset() - Logger = nil - - Debug("test") - if called { - t.Error("Debug should not call if Logger is nil") - } - - Info("test") - if called { - t.Error("Info should not call if Logger is nil") - } - - Warning("test") - if called { - t.Error("Warning should not call if Logger is nil") - } - - Error("test") - if called { - t.Error("Error should not call if Logger is nil") - } - - Fatal("test") - if called { - t.Error("Fatal should not call if Logger is nil") - } -} - -func TestDebug(t *testing.T) { - reset() - Logger = mockLogger - - Debug("test %d", 42) - if !called || calledLevel != log.DEBUG || calledFormat != "test %d" || len(calledArgs) != 1 || calledArgs[0] != 42 { - t.Errorf("Debug not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs) - } -} - -func TestInfo(t *testing.T) { - reset() - Logger = mockLogger - - Info("test %s", "info") - if !called || calledLevel != log.INFO || calledFormat != "test %s" || len(calledArgs) != 1 || calledArgs[0] != "info" { - t.Errorf("Info not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs) - } -} - -func TestWarning(t *testing.T) { - reset() - Logger = mockLogger - - Warning("test %f", 3.14) - if !called || calledLevel != log.WARNING || calledFormat != "test %f" || len(calledArgs) != 1 || calledArgs[0] != 3.14 { - t.Errorf("Warning not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs) - } -} - -func TestError(t *testing.T) { - reset() - Logger = mockLogger - - Error("test error") - if !called || calledLevel != log.ERROR || calledFormat != "test error" || len(calledArgs) != 0 { - t.Errorf("Error not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs) - } -} - -func TestFatal(t *testing.T) { - reset() - Logger = mockLogger - - Fatal("test fatal") - if !called || calledLevel != log.FATAL || calledFormat != "test fatal" || len(calledArgs) != 0 { - t.Errorf("Fatal not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs) - } -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 102788ae..00000000 --- a/main_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package main - -import ( - "bytes" - "strings" - "testing" -) - -func TestExitPrompt(t *testing.T) { - tests := []struct { - name string - input string - expected bool - }{ - { - name: "yes lowercase", - input: "y\n", - expected: true, - }, - { - name: "yes uppercase", - input: "Y\n", - expected: true, - }, - { - name: "no lowercase", - input: "n\n", - expected: false, - }, - { - name: "no uppercase", - input: "N\n", - expected: false, - }, - { - name: "invalid input", - input: "maybe\n", - expected: false, - }, - { - name: "empty input", - input: "\n", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Redirect stdin - oldStdin := strings.NewReader(tt.input) - r := bytes.NewReader([]byte(tt.input)) - - // Mock stdin by reading from our buffer - // This is a simplified test - in production you'd want to properly mock stdin - _ = oldStdin - _ = r - - // For now, we'll test the string comparison logic directly - input := strings.TrimSpace(strings.TrimSuffix(tt.input, "\n")) - result := strings.ToLower(input) == "y" - - if result != tt.expected { - t.Errorf("exitPrompt() with input %q = %v, want %v", tt.input, result, tt.expected) - } - }) - } -} - -// Test some utility functions that would be refactored from main -func TestVersionString(t *testing.T) { - // This tests the version string formatting logic - version := "2.32.0" - os := "darwin" - arch := "amd64" - goVersion := "go1.19" - - expected := "bettercap v2.32.0 (built for darwin amd64 with go1.19)" - result := formatVersion("bettercap", version, os, arch, goVersion) - - if result != expected { - t.Errorf("formatVersion() = %v, want %v", result, expected) - } -} - -// Helper function that would be refactored from main -func formatVersion(name, version, os, arch, goVersion string) string { - return name + " v" + version + " (built for " + os + " " + arch + " with " + goVersion + ")" -} diff --git a/modules/any_proxy/any_proxy_test.go b/modules/any_proxy/any_proxy_test.go deleted file mode 100644 index e5d28276..00000000 --- a/modules/any_proxy/any_proxy_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package any_proxy - -import ( - "fmt" - "strconv" - "strings" - "sync" - "testing" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewAnyProxy(t *testing.T) { - s := createMockSession(t) - mod := NewAnyProxy(s) - - if mod == nil { - t.Fatal("NewAnyProxy returned nil") - } - - if mod.Name() != "any.proxy" { - t.Errorf("Expected name 'any.proxy', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check handlers - handlers := mod.Handlers() - if len(handlers) != 2 { - t.Errorf("Expected 2 handlers, got %d", len(handlers)) - } - - handlerNames := make(map[string]bool) - for _, h := range handlers { - handlerNames[h.Name] = true - } - - if !handlerNames["any.proxy on"] { - t.Error("Handler 'any.proxy on' not found") - } - if !handlerNames["any.proxy off"] { - t.Error("Handler 'any.proxy off' not found") - } - - // Check that parameters were added (but don't try to get values as that requires session interface) - expectedParams := 6 // iface, protocol, src_port, src_address, dst_address, dst_port - // This is a simplified check - in a real test we'd mock the interface - _ = expectedParams -} - -// Test port parsing logic directly -func TestPortParsingLogic(t *testing.T) { - tests := []struct { - name string - portString string - expectPorts []int - expectError bool - }{ - { - name: "single port", - portString: "80", - expectPorts: []int{80}, - expectError: false, - }, - { - name: "multiple ports", - portString: "80,443,8080", - expectPorts: []int{80, 443, 8080}, - expectError: false, - }, - { - name: "port range", - portString: "8000-8003", - expectPorts: []int{8000, 8001, 8002, 8003}, - expectError: false, - }, - { - name: "invalid port", - portString: "not-a-port", - expectPorts: nil, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ports, err := parsePortsString(tt.portString) - - if tt.expectError { - if err == nil { - t.Error("Expected error but got none") - } - } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } else { - if len(ports) != len(tt.expectPorts) { - t.Errorf("Expected %d ports, got %d", len(tt.expectPorts), len(ports)) - } - } - } - }) - } -} - -// Helper function to test port parsing logic -func parsePortsString(portsStr string) ([]int, error) { - var ports []int - tokens := strings.Split(strings.ReplaceAll(portsStr, " ", ""), ",") - - for _, token := range tokens { - if token == "" { - continue - } - - if p, err := strconv.Atoi(token); err == nil { - if p < 1 || p > 65535 { - return nil, fmt.Errorf("port %d out of range", p) - } - ports = append(ports, p) - } else if strings.Contains(token, "-") { - parts := strings.Split(token, "-") - if len(parts) != 2 { - return nil, fmt.Errorf("invalid range format") - } - - from, err1 := strconv.Atoi(parts[0]) - to, err2 := strconv.Atoi(parts[1]) - - if err1 != nil || err2 != nil { - return nil, fmt.Errorf("invalid range values") - } - - if from < 1 || from > 65535 || to < 1 || to > 65535 { - return nil, fmt.Errorf("port range out of bounds") - } - - if from > to { - return nil, fmt.Errorf("invalid range order") - } - - for p := from; p <= to; p++ { - ports = append(ports, p) - } - } else { - return nil, fmt.Errorf("invalid port format: %s", token) - } - } - - return ports, nil -} - -func TestStartStop(t *testing.T) { - s := createMockSession(t) - mod := NewAnyProxy(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } - - // Note: Start() will fail because it requires firewall operations - // which need proper network setup and possibly root permissions - // We're just testing that the methods exist and basic flow -} - -// Test error cases in port parsing -func TestPortParsingErrors(t *testing.T) { - errorCases := []string{ - "0", // out of range - "65536", // out of range - "abc", // not a number - "80-", // incomplete range - "-80", // incomplete range - "100-50", // inverted range - "80-abc", // invalid end - "xyz-100", // invalid start - "80--100", // malformed - // Remove these as our parser handles empty tokens correctly - } - - for _, portStr := range errorCases { - _, err := parsePortsString(portStr) - if err == nil { - t.Errorf("Expected error for port string '%s', but got none", portStr) - } - } -} - -// Benchmark tests -func BenchmarkPortParsing(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - parsePortsString("80,443,8000-8010,9000") - } -} diff --git a/modules/api_rest/api_rest.go b/modules/api_rest/api_rest.go index b4590e18..b0c8a069 100644 --- a/modules/api_rest/api_rest.go +++ b/modules/api_rest/api_rest.go @@ -90,12 +90,12 @@ func NewRestAPI(s *session.Session) *RestAPI { "Value of the Access-Control-Allow-Origin header of the API server.")) mod.AddParam(session.NewStringParameter("api.rest.username", - "user", + "", "", "API authentication username.")) mod.AddParam(session.NewStringParameter("api.rest.password", - "pass", + "", "", "API authentication password.")) diff --git a/modules/api_rest/api_rest_controller.go b/modules/api_rest/api_rest_controller.go index ccf25cd1..e4e4261d 100644 --- a/modules/api_rest/api_rest_controller.go +++ b/modules/api_rest/api_rest_controller.go @@ -5,9 +5,9 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "net/http" "os" - "regexp" "strconv" "strings" @@ -17,10 +17,6 @@ import ( "github.com/gorilla/mux" ) -var ( - ansiEscapeRegex = regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`) -) - type CommandRequest struct { Command string `json:"cmd"` } @@ -240,8 +236,7 @@ func (mod *RestAPI) runSessionCommand(w http.ResponseWriter, r *http.Request) { out, _ := io.ReadAll(stdoutReader) os.Stdout = rescueStdout - // remove ANSI escape sequences (bash color codes) from output - mod.toJSON(w, APIResponse{Success: true, Message: ansiEscapeRegex.ReplaceAllString(string(out), "")}) + mod.toJSON(w, APIResponse{Success: true, Message: string(out)}) } func (mod *RestAPI) getEvents(limit int) []session.Event { @@ -393,7 +388,7 @@ func (mod *RestAPI) readFile(fileName string, w http.ResponseWriter, r *http.Req } func (mod *RestAPI) writeFile(fileName string, w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) + data, err := ioutil.ReadAll(r.Body) if err != nil { msg := fmt.Sprintf("invalid file upload: %s", err) mod.Warning(msg) @@ -401,7 +396,7 @@ func (mod *RestAPI) writeFile(fileName string, w http.ResponseWriter, r *http.Re return } - err = os.WriteFile(fileName, data, 0666) + err = ioutil.WriteFile(fileName, data, 0666) if err != nil { msg := fmt.Sprintf("can't write to %s: %s", fileName, err) mod.Warning(msg) diff --git a/modules/api_rest/api_rest_test.go b/modules/api_rest/api_rest_test.go deleted file mode 100644 index 820dfc8c..00000000 --- a/modules/api_rest/api_rest_test.go +++ /dev/null @@ -1,671 +0,0 @@ -package api_rest - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewRestAPI(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - if mod == nil { - t.Fatal("NewRestAPI returned nil") - } - - if mod.Name() != "api.rest" { - t.Errorf("Expected name 'api.rest', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{ - "api.rest on", - "api.rest off", - "api.rest.record off", - "api.rest.record FILENAME", - "api.rest.replay off", - "api.rest.replay FILENAME", - } - - if len(handlers) != len(expectedHandlers) { - t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers)) - } - - handlerNames := make(map[string]bool) - for _, h := range handlers { - handlerNames[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerNames[expected] { - t.Errorf("Handler '%s' not found", expected) - } - } - - // Check initial state - if mod.recording { - t.Error("Should not be recording initially") - } - if mod.replaying { - t.Error("Should not be replaying initially") - } - if mod.useWebsocket { - t.Error("Should not use websocket by default") - } - if mod.allowOrigin != "*" { - t.Errorf("Expected default allowOrigin '*', got '%s'", mod.allowOrigin) - } -} - -func TestIsTLS(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Initially should not be TLS - if mod.isTLS() { - t.Error("Should not be TLS without cert and key") - } - - // Set cert and key - mod.certFile = "cert.pem" - mod.keyFile = "key.pem" - - if !mod.isTLS() { - t.Error("Should be TLS with cert and key") - } - - // Only cert - mod.certFile = "cert.pem" - mod.keyFile = "" - - if mod.isTLS() { - t.Error("Should not be TLS with only cert") - } - - // Only key - mod.certFile = "" - mod.keyFile = "key.pem" - - if mod.isTLS() { - t.Error("Should not be TLS with only key") - } -} - -func TestStateStore(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Check that state variables are properly stored - stateKeys := []string{ - "recording", - "rec_clock", - "replaying", - "loading", - "load_progress", - "rec_time", - "rec_filename", - "rec_frames", - "rec_cur_frame", - "rec_started", - "rec_stopped", - } - - for _, key := range stateKeys { - val, exists := mod.State.Load(key) - if !exists || val == nil { - t.Errorf("State key '%s' not found", key) - } - } -} - -func TestParameters(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Check that all parameters are registered - paramNames := []string{ - "api.rest.address", - "api.rest.port", - "api.rest.alloworigin", - "api.rest.username", - "api.rest.password", - "api.rest.certificate", - "api.rest.key", - "api.rest.websocket", - "api.rest.record.clock", - } - - // Parameters are stored in the session environment - // We'll just check they can be accessed without error - for _, param := range paramNames { - // This is a simplified check - _ = param - } - - // Ensure mod is used - if mod == nil { - t.Error("Module should not be nil") - } -} - -func TestJSSessionStructs(t *testing.T) { - // Test struct creation - req := JSSessionRequest{ - Command: "test command", - } - - if req.Command != "test command" { - t.Errorf("Expected command 'test command', got '%s'", req.Command) - } - - resp := JSSessionResponse{ - Error: "test error", - } - - if resp.Error != "test error" { - t.Errorf("Expected error 'test error', got '%s'", resp.Error) - } -} - -func TestDefaultValues(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Check default values - if mod.recClock != 1 { - t.Errorf("Expected default recClock 1, got %d", mod.recClock) - } - - if mod.recTime != 0 { - t.Errorf("Expected default recTime 0, got %d", mod.recTime) - } - - if mod.recordFileName != "" { - t.Errorf("Expected empty recordFileName, got '%s'", mod.recordFileName) - } - - if mod.upgrader.ReadBufferSize != 1024 { - t.Errorf("Expected ReadBufferSize 1024, got %d", mod.upgrader.ReadBufferSize) - } - - if mod.upgrader.WriteBufferSize != 1024 { - t.Errorf("Expected WriteBufferSize 1024, got %d", mod.upgrader.WriteBufferSize) - } -} - -func TestRunningState(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } - - // Note: Cannot test actual Start/Stop without proper server setup -} - -func TestRecordingState(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test recording state changes - mod.recording = true - if !mod.recording { - t.Error("Recording flag should be true") - } - - mod.recording = false - if mod.recording { - t.Error("Recording flag should be false") - } -} - -func TestReplayingState(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test replaying state changes - mod.replaying = true - if !mod.replaying { - t.Error("Replaying flag should be true") - } - - mod.replaying = false - if mod.replaying { - t.Error("Replaying flag should be false") - } -} - -func TestConfigureErrors(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test configuration validation - testCases := []struct { - name string - setup func() - expected string - }{ - { - name: "invalid address", - setup: func() { - s.Env.Set("api.rest.address", "999.999.999.999") - }, - expected: "address", - }, - { - name: "invalid port", - setup: func() { - s.Env.Set("api.rest.address", "127.0.0.1") - s.Env.Set("api.rest.port", "not-a-port") - }, - expected: "port", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.setup() - // Configure may fail due to parameter validation - _ = mod.Configure() - }) - } -} - -func TestServerConfiguration(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Set valid parameters - s.Env.Set("api.rest.address", "127.0.0.1") - s.Env.Set("api.rest.port", "8081") - s.Env.Set("api.rest.username", "testuser") - s.Env.Set("api.rest.password", "testpass") - s.Env.Set("api.rest.websocket", "true") - s.Env.Set("api.rest.alloworigin", "http://localhost:3000") - - // This might fail due to TLS cert generation, but we're testing the flow - _ = mod.Configure() - - // Check that values were set - if mod.username != "" && mod.username != "testuser" { - t.Logf("Username set to: %s", mod.username) - } - if mod.password != "" && mod.password != "testpass" { - t.Logf("Password set to: %s", mod.password) - } -} - -func TestQuitChannel(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test quit channel is created - if mod.quit == nil { - t.Error("Quit channel should not be nil") - } - - // Test sending to quit channel doesn't block - done := make(chan bool) - go func() { - select { - case mod.quit <- true: - done <- true - case <-time.After(100 * time.Millisecond): - done <- false - } - }() - - // Start reading from quit channel - go func() { - <-mod.quit - }() - - if !<-done { - t.Error("Sending to quit channel timed out") - } -} - -func TestRecordWaitGroup(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test wait group is initialized - if mod.recordWait == nil { - t.Error("Record wait group should not be nil") - } - - // Test wait group operations - mod.recordWait.Add(1) - done := make(chan bool) - - go func() { - mod.recordWait.Done() - done <- true - }() - - go func() { - mod.recordWait.Wait() - }() - - select { - case <-done: - // Success - case <-time.After(100 * time.Millisecond): - t.Error("Wait group operation timed out") - } -} - -func TestStartErrors(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test start when replaying - mod.replaying = true - err := mod.Start() - if err == nil { - t.Error("Expected error when starting while replaying") - } -} - -func TestConfigureAlreadyRunning(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Simulate running state - mod.SetRunning(true, func() {}) - - err := mod.Configure() - if err == nil { - t.Error("Expected error when configuring while running") - } - - // Reset - mod.SetRunning(false, func() {}) -} - -func TestServerAddr(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Set parameters - s.Env.Set("api.rest.address", "192.168.1.100") - s.Env.Set("api.rest.port", "9090") - - // Configure may fail but we can check server addr format - _ = mod.Configure() - - expectedAddr := "192.168.1.100:9090" - if mod.server != nil && mod.server.Addr != "" && mod.server.Addr != expectedAddr { - t.Logf("Server addr: %s", mod.server.Addr) - } -} - -func TestTLSConfiguration(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Test with TLS params - s.Env.Set("api.rest.certificate", "/tmp/test.crt") - s.Env.Set("api.rest.key", "/tmp/test.key") - - // Configure will attempt to expand paths and check files - _ = mod.Configure() - - // Just verify the attempt was made - t.Logf("Attempted TLS configuration") -} - -// Benchmark tests -func BenchmarkNewRestAPI(b *testing.B) { - s, _ := session.New() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NewRestAPI(s) - } -} - -func BenchmarkIsTLS(b *testing.B) { - s, _ := session.New() - mod := NewRestAPI(s) - mod.certFile = "cert.pem" - mod.keyFile = "key.pem" - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = mod.isTLS() - } -} - -func BenchmarkConfigure(b *testing.B) { - s, _ := session.New() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mod := NewRestAPI(s) - s.Env.Set("api.rest.address", "127.0.0.1") - s.Env.Set("api.rest.port", "8081") - _ = mod.Configure() - } -} - -// Tests for controller functionality -func TestCommandRequest(t *testing.T) { - cmd := CommandRequest{ - Command: "help", - } - - if cmd.Command != "help" { - t.Errorf("Expected command 'help', got '%s'", cmd.Command) - } -} - -func TestAPIResponse(t *testing.T) { - resp := APIResponse{ - Success: true, - Message: "Operation completed", - } - - if !resp.Success { - t.Error("Expected success to be true") - } - - if resp.Message != "Operation completed" { - t.Errorf("Expected message 'Operation completed', got '%s'", resp.Message) - } -} - -func TestCheckAuthNoCredentials(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // No username/password set - should allow access - req, _ := http.NewRequest("GET", "/test", nil) - - if !mod.checkAuth(req) { - t.Error("Expected auth to pass with no credentials set") - } -} - -func TestCheckAuthWithCredentials(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - // Set credentials - mod.username = "testuser" - mod.password = "testpass" - - // Test without auth header - req1, _ := http.NewRequest("GET", "/test", nil) - if mod.checkAuth(req1) { - t.Error("Expected auth to fail without credentials") - } - - // Test with wrong credentials - req2, _ := http.NewRequest("GET", "/test", nil) - req2.SetBasicAuth("wronguser", "wrongpass") - if mod.checkAuth(req2) { - t.Error("Expected auth to fail with wrong credentials") - } - - // Test with correct credentials - req3, _ := http.NewRequest("GET", "/test", nil) - req3.SetBasicAuth("testuser", "testpass") - if !mod.checkAuth(req3) { - t.Error("Expected auth to pass with correct credentials") - } -} - -func TestGetEventsEmpty(t *testing.T) { - // Skip this test if running with others due to shared session state - if testing.Short() { - t.Skip("Skipping in short mode due to shared session state") - } - - // Create a fresh session using the singleton - s := createMockSession(t) - mod := NewRestAPI(s) - - // Record initial event count - initialCount := len(mod.getEvents(0)) - - // Get events - we can't guarantee zero events due to session initialization - events := mod.getEvents(0) - if len(events) < initialCount { - t.Errorf("Event count should not decrease, got %d", len(events)) - } -} - -func TestGetEventsWithLimit(t *testing.T) { - // Create session using the singleton - s := createMockSession(t) - mod := NewRestAPI(s) - - // Record initial state - initialEvents := mod.getEvents(0) - initialCount := len(initialEvents) - - // Add some test events - testEventCount := 10 - for i := 0; i < testEventCount; i++ { - s.Events.Add(fmt.Sprintf("test.event.limit.%d", i), nil) - } - - // Get all events - allEvents := mod.getEvents(0) - expectedTotal := initialCount + testEventCount - if len(allEvents) != expectedTotal { - t.Errorf("Expected %d total events, got %d", expectedTotal, len(allEvents)) - } - - // Test limit functionality - get last 5 events - limitedEvents := mod.getEvents(5) - if len(limitedEvents) != 5 { - t.Errorf("Expected 5 events when limiting, got %d", len(limitedEvents)) - } -} - -func TestSetSecurityHeaders(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - mod.allowOrigin = "http://localhost:3000" - - w := httptest.NewRecorder() - mod.setSecurityHeaders(w) - - headers := w.Header() - - // Check security headers - if headers.Get("X-Frame-Options") != "DENY" { - t.Error("X-Frame-Options header not set correctly") - } - - if headers.Get("X-Content-Type-Options") != "nosniff" { - t.Error("X-Content-Type-Options header not set correctly") - } - - if headers.Get("X-XSS-Protection") != "1; mode=block" { - t.Error("X-XSS-Protection header not set correctly") - } - - if headers.Get("Access-Control-Allow-Origin") != "http://localhost:3000" { - t.Error("Access-Control-Allow-Origin header not set correctly") - } -} - -func TestCorsRoute(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - req, _ := http.NewRequest("OPTIONS", "/test", nil) - w := httptest.NewRecorder() - - mod.corsRoute(w, req) - - if w.Code != http.StatusNoContent { - t.Errorf("Expected status %d, got %d", http.StatusNoContent, w.Code) - } -} - -func TestToJSON(t *testing.T) { - s := createMockSession(t) - mod := NewRestAPI(s) - - w := httptest.NewRecorder() - - testData := map[string]string{ - "key": "value", - "foo": "bar", - } - - mod.toJSON(w, testData) - - // Check content type - if w.Header().Get("Content-Type") != "application/json" { - t.Error("Content-Type header not set to application/json") - } - - // Check JSON response - var result map[string]string - if err := json.NewDecoder(w.Body).Decode(&result); err != nil { - t.Errorf("Failed to decode JSON response: %v", err) - } - - if result["key"] != "value" || result["foo"] != "bar" { - t.Error("JSON response doesn't match expected data") - } -} diff --git a/modules/arp_spoof/arp_spoof_test.go b/modules/arp_spoof/arp_spoof_test.go deleted file mode 100644 index 36e2b4cd..00000000 --- a/modules/arp_spoof/arp_spoof_test.go +++ /dev/null @@ -1,785 +0,0 @@ -package arp_spoof - -import ( - "bytes" - "fmt" - "net" - "sync" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/firewall" - "github.com/bettercap/bettercap/v2/network" - "github.com/bettercap/bettercap/v2/packets" - "github.com/bettercap/bettercap/v2/session" - "github.com/evilsocket/islazy/data" -) - -// MockFirewall implements a mock firewall for testing -type MockFirewall struct { - forwardingEnabled bool - redirections []firewall.Redirection -} - -func NewMockFirewall() *MockFirewall { - return &MockFirewall{ - forwardingEnabled: false, - redirections: make([]firewall.Redirection, 0), - } -} - -func (m *MockFirewall) IsForwardingEnabled() bool { - return m.forwardingEnabled -} - -func (m *MockFirewall) EnableForwarding(enabled bool) error { - m.forwardingEnabled = enabled - return nil -} - -func (m *MockFirewall) EnableRedirection(r *firewall.Redirection, enabled bool) error { - if enabled { - m.redirections = append(m.redirections, *r) - } else { - for i, red := range m.redirections { - if red.String() == r.String() { - m.redirections = append(m.redirections[:i], m.redirections[i+1:]...) - break - } - } - } - return nil -} - -func (m *MockFirewall) DisableRedirection(r *firewall.Redirection, enabled bool) error { - return m.EnableRedirection(r, false) -} - -func (m *MockFirewall) Restore() { - m.redirections = make([]firewall.Redirection, 0) - m.forwardingEnabled = false -} - -// MockPacketQueue extends packets.Queue to capture sent packets -type MockPacketQueue struct { - *packets.Queue - sync.Mutex - sentPackets [][]byte -} - -func NewMockPacketQueue() *MockPacketQueue { - q := &packets.Queue{ - Traffic: sync.Map{}, - Stats: packets.Stats{}, - } - return &MockPacketQueue{ - Queue: q, - sentPackets: make([][]byte, 0), - } -} - -func (m *MockPacketQueue) Send(data []byte) error { - m.Lock() - defer m.Unlock() - - // Store a copy of the packet - packet := make([]byte, len(data)) - copy(packet, data) - m.sentPackets = append(m.sentPackets, packet) - - // Also update stats like the real queue would - m.TrackSent(uint64(len(data))) - - return nil -} - -func (m *MockPacketQueue) GetSentPackets() [][]byte { - m.Lock() - defer m.Unlock() - return m.sentPackets -} - -func (m *MockPacketQueue) ClearSentPackets() { - m.Lock() - defer m.Unlock() - m.sentPackets = make([][]byte, 0) -} - -// MockSession for testing -type MockSession struct { - *session.Session - findMACResults map[string]net.HardwareAddr - skipIPs map[string]bool - mockQueue *MockPacketQueue -} - -// Override session methods to use our mocks -func setupMockSession(mockSess *MockSession) { - // Replace the Session's FindMAC method behavior by manipulating the LAN - // Since we can't override methods directly, we'll ensure the LAN has the data - for ip, mac := range mockSess.findMACResults { - mockSess.Lan.AddIfNew(ip, mac.String()) - } -} - -func (m *MockSession) FindMAC(ip net.IP, probe bool) (net.HardwareAddr, error) { - // First check our mock results - if mac, ok := m.findMACResults[ip.String()]; ok { - return mac, nil - } - // Then check the LAN - if e, found := m.Lan.Get(ip.String()); found && e != nil { - return e.HW, nil - } - return nil, fmt.Errorf("MAC not found for %s", ip.String()) -} - -func (m *MockSession) Skip(ip net.IP) bool { - if m.skipIPs == nil { - return false - } - return m.skipIPs[ip.String()] -} - -// MockNetRecon implements a minimal net.recon module for testing -type MockNetRecon struct { - session.SessionModule -} - -func NewMockNetRecon(s *session.Session) *MockNetRecon { - mod := &MockNetRecon{ - SessionModule: session.NewSessionModule("net.recon", s), - } - - // Add handlers - mod.AddHandler(session.NewModuleHandler("net.recon on", "", - "Start net.recon", - func(args []string) error { - return mod.Start() - })) - - mod.AddHandler(session.NewModuleHandler("net.recon off", "", - "Stop net.recon", - func(args []string) error { - return mod.Stop() - })) - - return mod -} - -func (m *MockNetRecon) Name() string { - return "net.recon" -} - -func (m *MockNetRecon) Description() string { - return "Mock net.recon module" -} - -func (m *MockNetRecon) Author() string { - return "test" -} - -func (m *MockNetRecon) Configure() error { - return nil -} - -func (m *MockNetRecon) Start() error { - return m.SetRunning(true, nil) -} - -func (m *MockNetRecon) Stop() error { - return m.SetRunning(false, nil) -} - -// Create a mock session for testing -func createMockSession() (*MockSession, *MockPacketQueue, *MockFirewall) { - // Create interface - iface := &network.Endpoint{ - IpAddress: "192.168.1.100", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "eth0", - } - iface.SetIP("192.168.1.100") - iface.SetBits(24) - - // Parse interface addresses - ifaceIP := net.ParseIP("192.168.1.100") - ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - iface.IP = ifaceIP - iface.HW = ifaceHW - - // Create gateway - gateway := &network.Endpoint{ - IpAddress: "192.168.1.1", - HwAddress: "11:22:33:44:55:66", - } - gatewayIP := net.ParseIP("192.168.1.1") - gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66") - gateway.IP = gatewayIP - gateway.HW = gatewayHW - - // Create mock queue and firewall - mockQueue := NewMockPacketQueue() - mockFirewall := NewMockFirewall() - - // Create environment - env, _ := session.NewEnvironment("") - - // Create LAN - aliases, _ := data.NewUnsortedKV("", 0) - lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - // Create session - sess := &session.Session{ - Interface: iface, - Gateway: gateway, - Lan: lan, - StartedAt: time.Now(), - Active: true, - Env: env, - Queue: mockQueue.Queue, - Firewall: mockFirewall, - Modules: make(session.ModuleList, 0), - } - - // Initialize events - sess.Events = session.NewEventPool(false, false) - - // Add mock net.recon module - mockNetRecon := NewMockNetRecon(sess) - sess.Modules = append(sess.Modules, mockNetRecon) - - // Create mock session wrapper - mockSess := &MockSession{ - Session: sess, - findMACResults: make(map[string]net.HardwareAddr), - skipIPs: make(map[string]bool), - mockQueue: mockQueue, - } - - return mockSess, mockQueue, mockFirewall -} - -func TestNewArpSpoofer(t *testing.T) { - mockSess, _, _ := createMockSession() - - mod := NewArpSpoofer(mockSess.Session) - - if mod == nil { - t.Fatal("NewArpSpoofer returned nil") - } - - if mod.Name() != "arp.spoof" { - t.Errorf("expected module name 'arp.spoof', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("unexpected author: %s", mod.Author()) - } - - // Check parameters - params := []string{"arp.spoof.targets", "arp.spoof.whitelist", "arp.spoof.internal", "arp.spoof.fullduplex", "arp.spoof.skip_restore"} - for _, param := range params { - if !mod.Session.Env.Has(param) { - t.Errorf("parameter %s not registered", param) - } - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{"arp.spoof on", "arp.ban on", "arp.spoof off", "arp.ban off"} - handlerMap := make(map[string]bool) - - for _, h := range handlers { - handlerMap[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerMap[expected] { - t.Errorf("Expected handler '%s' not found", expected) - } - } -} - -func TestArpSpooferConfigure(t *testing.T) { - tests := []struct { - name string - params map[string]string - setupMock func(*MockSession) - expectErr bool - validate func(*ArpSpoofer) error - }{ - { - name: "default configuration", - params: map[string]string{ - "arp.spoof.targets": "192.168.1.10", - "arp.spoof.whitelist": "", - "arp.spoof.internal": "false", - "arp.spoof.fullduplex": "false", - "arp.spoof.skip_restore": "false", - }, - setupMock: func(ms *MockSession) { - ms.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa") - }, - expectErr: false, - validate: func(mod *ArpSpoofer) error { - if mod.internal { - return fmt.Errorf("expected internal to be false") - } - if mod.fullDuplex { - return fmt.Errorf("expected fullDuplex to be false") - } - if mod.skipRestore { - return fmt.Errorf("expected skipRestore to be false") - } - if len(mod.addresses) != 1 { - return fmt.Errorf("expected 1 address, got %d", len(mod.addresses)) - } - return nil - }, - }, - { - name: "multiple targets and whitelist", - params: map[string]string{ - "arp.spoof.targets": "192.168.1.10,192.168.1.20", - "arp.spoof.whitelist": "192.168.1.30", - "arp.spoof.internal": "true", - "arp.spoof.fullduplex": "true", - "arp.spoof.skip_restore": "true", - }, - setupMock: func(ms *MockSession) { - ms.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa") - ms.Lan.AddIfNew("192.168.1.20", "bb:bb:bb:bb:bb:bb") - ms.Lan.AddIfNew("192.168.1.30", "cc:cc:cc:cc:cc:cc") - }, - expectErr: false, - validate: func(mod *ArpSpoofer) error { - if !mod.internal { - return fmt.Errorf("expected internal to be true") - } - if !mod.fullDuplex { - return fmt.Errorf("expected fullDuplex to be true") - } - if !mod.skipRestore { - return fmt.Errorf("expected skipRestore to be true") - } - if len(mod.addresses) != 2 { - return fmt.Errorf("expected 2 addresses, got %d", len(mod.addresses)) - } - if len(mod.wAddresses) != 1 { - return fmt.Errorf("expected 1 whitelisted address, got %d", len(mod.wAddresses)) - } - return nil - }, - }, - { - name: "MAC address targets", - params: map[string]string{ - "arp.spoof.targets": "aa:aa:aa:aa:aa:aa", - "arp.spoof.whitelist": "", - "arp.spoof.internal": "false", - "arp.spoof.fullduplex": "false", - "arp.spoof.skip_restore": "false", - }, - setupMock: func(ms *MockSession) { - ms.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa") - }, - expectErr: false, - validate: func(mod *ArpSpoofer) error { - if len(mod.macs) != 1 { - return fmt.Errorf("expected 1 MAC address, got %d", len(mod.macs)) - } - return nil - }, - }, - { - name: "invalid target", - params: map[string]string{ - "arp.spoof.targets": "invalid-target", - "arp.spoof.whitelist": "", - "arp.spoof.internal": "false", - "arp.spoof.fullduplex": "false", - "arp.spoof.skip_restore": "false", - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Set parameters - for k, v := range tt.params { - mockSess.Env.Set(k, v) - } - - // Setup mock - if tt.setupMock != nil { - tt.setupMock(mockSess) - } - - err := mod.Configure() - - if tt.expectErr && err == nil { - t.Error("expected error but got none") - } else if !tt.expectErr && err != nil { - t.Errorf("unexpected error: %v", err) - } - - if !tt.expectErr && tt.validate != nil { - if err := tt.validate(mod); err != nil { - t.Error(err) - } - } - }) - } -} - -func TestArpSpooferStartStop(t *testing.T) { - mockSess, _, mockFirewall := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Setup targets - targetIP := "192.168.1.10" - targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - mockSess.Lan.AddIfNew(targetIP, targetMAC.String()) - mockSess.findMACResults[targetIP] = targetMAC - - // Configure - mockSess.Env.Set("arp.spoof.targets", targetIP) - mockSess.Env.Set("arp.spoof.fullduplex", "false") - mockSess.Env.Set("arp.spoof.internal", "false") - - // Start the spoofer - err := mod.Start() - if err != nil { - t.Fatalf("Failed to start spoofer: %v", err) - } - - if !mod.Running() { - t.Error("Spoofer should be running after Start()") - } - - // Check that forwarding was enabled - if !mockFirewall.IsForwardingEnabled() { - t.Error("Forwarding should be enabled after starting spoofer") - } - - // Let it run for a bit - time.Sleep(100 * time.Millisecond) - - // Stop the spoofer - err = mod.Stop() - if err != nil { - t.Fatalf("Failed to stop spoofer: %v", err) - } - - if mod.Running() { - t.Error("Spoofer should not be running after Stop()") - } - - // Note: We can't easily verify packet sending without modifying the actual module - // to use an interface for the queue. The module behavior is verified through - // state changes (running state, forwarding enabled, etc.) -} - -func TestArpSpooferBanMode(t *testing.T) { - mockSess, _, mockFirewall := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Setup targets - targetIP := "192.168.1.10" - targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - mockSess.Lan.AddIfNew(targetIP, targetMAC.String()) - mockSess.findMACResults[targetIP] = targetMAC - - // Configure - mockSess.Env.Set("arp.spoof.targets", targetIP) - - // Find and execute the ban handler - handlers := mod.Handlers() - for _, h := range handlers { - if h.Name == "arp.ban on" { - err := h.Exec([]string{}) - if err != nil { - t.Fatalf("Failed to start ban mode: %v", err) - } - break - } - } - - if !mod.ban { - t.Error("Ban mode should be enabled") - } - - // Check that forwarding was NOT enabled - if mockFirewall.IsForwardingEnabled() { - t.Error("Forwarding should NOT be enabled in ban mode") - } - - // Let it run for a bit - time.Sleep(100 * time.Millisecond) - - // Stop using ban off handler - for _, h := range handlers { - if h.Name == "arp.ban off" { - err := h.Exec([]string{}) - if err != nil { - t.Fatalf("Failed to stop ban mode: %v", err) - } - break - } - } - - if mod.ban { - t.Error("Ban mode should be disabled after stop") - } -} - -func TestArpSpooferWhitelisting(t *testing.T) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Add some IPs and MACs to whitelist - whitelistIP := net.ParseIP("192.168.1.50") - whitelistMAC, _ := net.ParseMAC("ff:ff:ff:ff:ff:ff") - - mod.wAddresses = []net.IP{whitelistIP} - mod.wMacs = []net.HardwareAddr{whitelistMAC} - - // Test IP whitelisting - if !mod.isWhitelisted("192.168.1.50", nil) { - t.Error("IP should be whitelisted") - } - - if mod.isWhitelisted("192.168.1.60", nil) { - t.Error("IP should not be whitelisted") - } - - // Test MAC whitelisting - if !mod.isWhitelisted("", whitelistMAC) { - t.Error("MAC should be whitelisted") - } - - otherMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - if mod.isWhitelisted("", otherMAC) { - t.Error("MAC should not be whitelisted") - } -} - -func TestArpSpooferFullDuplex(t *testing.T) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Setup targets - targetIP := "192.168.1.10" - targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - mockSess.Lan.AddIfNew(targetIP, targetMAC.String()) - mockSess.findMACResults[targetIP] = targetMAC - - // Configure with full duplex - mockSess.Env.Set("arp.spoof.targets", targetIP) - mockSess.Env.Set("arp.spoof.fullduplex", "true") - - // Verify configuration - err := mod.Configure() - if err != nil { - t.Fatalf("Failed to configure: %v", err) - } - - if !mod.fullDuplex { - t.Error("Full duplex mode should be enabled") - } - - // Start the spoofer - err = mod.Start() - if err != nil { - t.Fatalf("Failed to start spoofer: %v", err) - } - - if !mod.Running() { - t.Error("Module should be running") - } - - // Let it run for a bit - time.Sleep(150 * time.Millisecond) - - // Stop - mod.Stop() -} - -func TestArpSpooferInternalMode(t *testing.T) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Setup multiple targets - targets := map[string]string{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - "192.168.1.20": "bb:bb:bb:bb:bb:bb", - "192.168.1.30": "cc:cc:cc:cc:cc:cc", - } - - for ip, mac := range targets { - mockSess.Lan.AddIfNew(ip, mac) - hwAddr, _ := net.ParseMAC(mac) - mockSess.findMACResults[ip] = hwAddr - } - - // Configure with internal mode - mockSess.Env.Set("arp.spoof.targets", "192.168.1.10,192.168.1.20") - mockSess.Env.Set("arp.spoof.internal", "true") - - // Verify configuration - err := mod.Configure() - if err != nil { - t.Fatalf("Failed to configure: %v", err) - } - - if !mod.internal { - t.Error("Internal mode should be enabled") - } - - // Start the spoofer - err = mod.Start() - if err != nil { - t.Fatalf("Failed to start spoofer: %v", err) - } - - if !mod.Running() { - t.Error("Module should be running") - } - - // Let it run briefly - time.Sleep(100 * time.Millisecond) - - // Stop - mod.Stop() -} - -func TestArpSpooferGetTargets(t *testing.T) { - // This test verifies the getTargets logic without actually calling it - // since the method uses Session.FindMAC which can't be easily mocked - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Test address and MAC parsing - targetIP := net.ParseIP("192.168.1.10") - targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - - // Add targets by IP - mod.addresses = []net.IP{targetIP} - - // Verify addresses were set correctly - if len(mod.addresses) != 1 { - t.Errorf("expected 1 address, got %d", len(mod.addresses)) - } - - if !mod.addresses[0].Equal(targetIP) { - t.Errorf("expected address %s, got %s", targetIP, mod.addresses[0]) - } - - // Add targets by MAC - mod.macs = []net.HardwareAddr{targetMAC} - - // Verify MACs were set correctly - if len(mod.macs) != 1 { - t.Errorf("expected 1 MAC, got %d", len(mod.macs)) - } - - if !bytes.Equal(mod.macs[0], targetMAC) { - t.Errorf("expected MAC %s, got %s", targetMAC, mod.macs[0]) - } - - // Note: The actual getTargets method would look up these addresses/MACs - // in the network, but we can't easily test that without refactoring - // the module to use dependency injection for network operations -} - -func TestArpSpooferSkipRestore(t *testing.T) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // The skip_restore parameter is set up with an observer in NewArpSpoofer - // We'll test it by changing the parameter value, which triggers the observer - mockSess.Env.Set("arp.spoof.skip_restore", "true") - - // Configure to trigger parameter reading - mod.Configure() - - // Check the observer worked by checking if skipRestore was set - // Note: The actual observer is triggered during module creation - // so we test the functionality indirectly through the module's behavior - - // Start and stop to see if restoration is skipped - mockSess.Env.Set("arp.spoof.targets", "192.168.1.10") - mockSess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa") - - mod.Start() - time.Sleep(50 * time.Millisecond) - mod.Stop() - - // With skip_restore true, the module should have skipRestore set - // We can't directly test the observer, but we verify the behavior -} - -func TestArpSpooferEmptyTargets(t *testing.T) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Configure with empty targets - mockSess.Env.Set("arp.spoof.targets", "") - - // Start should not error but should not actually start - err := mod.Start() - if err != nil { - t.Fatalf("Start with empty targets should not error: %v", err) - } - - // Module should not be running - if mod.Running() { - t.Error("Module should not be running with empty targets") - } -} - -// Benchmarks -func BenchmarkArpSpooferGetTargets(b *testing.B) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Setup targets - for i := 0; i < 10; i++ { - ip := fmt.Sprintf("192.168.1.%d", i+10) - mac := fmt.Sprintf("aa:bb:cc:dd:ee:%02x", i) - mockSess.Lan.AddIfNew(ip, mac) - hwAddr, _ := net.ParseMAC(mac) - mockSess.findMACResults[ip] = hwAddr - mod.addresses = append(mod.addresses, net.ParseIP(ip)) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = mod.getTargets(false) - } -} - -func BenchmarkArpSpooferWhitelisting(b *testing.B) { - mockSess, _, _ := createMockSession() - mod := NewArpSpoofer(mockSess.Session) - - // Add many whitelist entries - for i := 0; i < 100; i++ { - ip := net.ParseIP(fmt.Sprintf("192.168.1.%d", i)) - mod.wAddresses = append(mod.wAddresses, ip) - } - - testMAC, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = mod.isWhitelisted("192.168.1.50", testMAC) - } -} diff --git a/modules/ble/ble_recon_test.go b/modules/ble/ble_recon_test.go deleted file mode 100644 index 08fc17cf..00000000 --- a/modules/ble/ble_recon_test.go +++ /dev/null @@ -1,321 +0,0 @@ -//go:build !windows && !freebsd && !openbsd && !netbsd -// +build !windows,!freebsd,!openbsd,!netbsd - -package ble - -import ( - "sync" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewBLERecon(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - if mod == nil { - t.Fatal("NewBLERecon returned nil") - } - - if mod.Name() != "ble.recon" { - t.Errorf("Expected name 'ble.recon', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check initial values - if mod.deviceId != -1 { - t.Errorf("Expected deviceId -1, got %d", mod.deviceId) - } - - if mod.connected { - t.Error("Should not be connected initially") - } - - if mod.connTimeout != 5 { - t.Errorf("Expected connection timeout 5, got %d", mod.connTimeout) - } - - if mod.devTTL != 30 { - t.Errorf("Expected device TTL 30, got %d", mod.devTTL) - } - - // Check channels - if mod.quit == nil { - t.Error("Quit channel should not be nil") - } - - if mod.done == nil { - t.Error("Done channel should not be nil") - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{ - "ble.recon on", - "ble.recon off", - "ble.clear", - "ble.show", - "ble.enum MAC", - "ble.write MAC UUID HEX_DATA", - } - - if len(handlers) != len(expectedHandlers) { - t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers)) - } - - handlerNames := make(map[string]bool) - for _, h := range handlers { - handlerNames[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerNames[expected] { - t.Errorf("Handler '%s' not found", expected) - } - } -} - -func TestIsEnumerating(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Initially should not be enumerating - if mod.isEnumerating() { - t.Error("Should not be enumerating initially") - } - - // When currDevice is set, should be enumerating - // We can't create a real BLE device here, but we can test the logic -} - -func TestDummyWriter(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - writer := dummyWriter{mod} - testData := []byte("test log message") - - n, err := writer.Write(testData) - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if n != len(testData) { - t.Errorf("Expected to write %d bytes, wrote %d", len(testData), n) - } -} - -func TestParameters(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Check that parameters are registered - paramNames := []string{ - "ble.device", - "ble.timeout", - "ble.ttl", - } - - // Parameters are stored in the session environment - // We'll just ensure the module was created properly - for _, param := range paramNames { - // This is a simplified check - _ = param - } - - if mod == nil { - t.Error("Module should not be nil") - } -} - -func TestRunningState(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } - - // Note: Cannot test actual Start/Stop without BLE hardware -} - -func TestChannels(t *testing.T) { - // Skip this test as channel operations might hang in certain environments - t.Skip("Skipping channel test to prevent potential hangs") -} - -func TestClearHandler(t *testing.T) { - // Skip this test as it requires BLE to be initialized in the session - t.Skip("Skipping clear handler test - requires initialized BLE in session") -} - -func TestBLEPrompt(t *testing.T) { - expected := "{blb}{fw}BLE {fb}{reset} {bold}» {reset}" - if blePrompt != expected { - t.Errorf("Expected prompt '%s', got '%s'", expected, blePrompt) - } -} - -func TestSetCurrentDevice(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Test setting nil device - mod.setCurrentDevice(nil) - if mod.currDevice != nil { - t.Error("Current device should be nil") - } - if mod.connected { - t.Error("Should not be connected after setting nil device") - } -} - -func TestViewSelector(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Check that view selector is initialized - if mod.selector == nil { - t.Error("View selector should not be nil") - } -} - -func TestBLEAliveInterval(t *testing.T) { - expected := time.Duration(5) * time.Second - if bleAliveInterval != expected { - t.Errorf("Expected alive interval %v, got %v", expected, bleAliveInterval) - } -} - -func TestColNames(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Test without name - cols := mod.colNames(false) - expectedCols := []string{"RSSI", "MAC", "Vendor", "Flags", "Connect", "Seen"} - if len(cols) != len(expectedCols) { - t.Errorf("Expected %d columns, got %d", len(expectedCols), len(cols)) - } - - // Test with name - colsWithName := mod.colNames(true) - expectedColsWithName := []string{"RSSI", "MAC", "Name", "Vendor", "Flags", "Connect", "Seen"} - if len(colsWithName) != len(expectedColsWithName) { - t.Errorf("Expected %d columns with name, got %d", len(expectedColsWithName), len(colsWithName)) - } -} - -func TestDoFilter(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // Without expression, should always return true - result := mod.doFilter(nil) - if !result { - t.Error("doFilter should return true when no expression is set") - } -} - -func TestShow(t *testing.T) { - // Skip this test as it requires BLE to be initialized in the session - t.Skip("Skipping show test - requires initialized BLE in session") -} - -func TestConfigure(t *testing.T) { - // Skip this test as it may hang trying to access BLE hardware - t.Skip("Skipping configure test - may hang accessing BLE hardware") -} - -func TestGetRow(t *testing.T) { - s := createMockSession(t) - mod := NewBLERecon(s) - - // We can't create a real BLE device without hardware, but we can test the logic - // by ensuring the method exists and would handle nil gracefully - _ = mod -} - -func TestDoSelection(t *testing.T) { - // Skip this test as it requires BLE to be initialized in the session - t.Skip("Skipping doSelection test - requires initialized BLE in session") -} - -func TestWriteBuffer(t *testing.T) { - // Skip this test as it may hang trying to access BLE hardware - t.Skip("Skipping writeBuffer test - may hang accessing BLE hardware") -} - -func TestEnumAllTheThings(t *testing.T) { - // Skip this test as it may hang trying to access BLE hardware - t.Skip("Skipping enumAllTheThings test - may hang accessing BLE hardware") -} - -// Benchmark tests - using singleton session to avoid flag redefinition -func BenchmarkNewBLERecon(b *testing.B) { - // Use a test instance to get singleton session - s := createMockSession(&testing.T{}) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NewBLERecon(s) - } -} - -func BenchmarkIsEnumerating(b *testing.B) { - s := createMockSession(&testing.T{}) - mod := NewBLERecon(s) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = mod.isEnumerating() - } -} - -func BenchmarkDummyWriter(b *testing.B) { - s := createMockSession(&testing.T{}) - mod := NewBLERecon(s) - writer := dummyWriter{mod} - testData := []byte("benchmark log message") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - writer.Write(testData) - } -} - -func BenchmarkDoFilter(b *testing.B) { - s := createMockSession(&testing.T{}) - mod := NewBLERecon(s) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mod.doFilter(nil) - } -} diff --git a/modules/c2/c2_test.go b/modules/c2/c2_test.go deleted file mode 100644 index fcdbd4ff..00000000 --- a/modules/c2/c2_test.go +++ /dev/null @@ -1,356 +0,0 @@ -package c2 - -import ( - "sync" - "testing" - "text/template" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewC2(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - if mod == nil { - t.Fatal("NewC2 returned nil") - } - - if mod.Name() != "c2" { - t.Errorf("Expected name 'c2', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check default settings - if mod.settings.server != "localhost:6697" { - t.Errorf("Expected default server 'localhost:6697', got '%s'", mod.settings.server) - } - - if !mod.settings.tls { - t.Error("Expected TLS to be enabled by default") - } - - if mod.settings.tlsVerify { - t.Error("Expected TLS verify to be disabled by default") - } - - if mod.settings.nick != "bettercap" { - t.Errorf("Expected default nick 'bettercap', got '%s'", mod.settings.nick) - } - - if mod.settings.user != "bettercap" { - t.Errorf("Expected default user 'bettercap', got '%s'", mod.settings.user) - } - - if mod.settings.operator != "admin" { - t.Errorf("Expected default operator 'admin', got '%s'", mod.settings.operator) - } - - // Check channels - if mod.quit == nil { - t.Error("Quit channel should not be nil") - } - - // Check maps - if mod.templates == nil { - t.Error("Templates map should not be nil") - } - - if mod.channels == nil { - t.Error("Channels map should not be nil") - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{ - "c2 on", - "c2 off", - "c2.channel.set EVENT_TYPE CHANNEL", - "c2.channel.clear EVENT_TYPE", - "c2.template.set EVENT_TYPE TEMPLATE", - "c2.template.clear EVENT_TYPE", - } - - if len(handlers) != len(expectedHandlers) { - t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers)) - } - - handlerNames := make(map[string]bool) - for _, h := range handlers { - handlerNames[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerNames[expected] { - t.Errorf("Handler '%s' not found", expected) - } - } -} - -func TestDefaultSettings(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - // Check default channel settings - if mod.settings.eventsChannel != "#events" { - t.Errorf("Expected default events channel '#events', got '%s'", mod.settings.eventsChannel) - } - - if mod.settings.outputChannel != "#events" { - t.Errorf("Expected default output channel '#events', got '%s'", mod.settings.outputChannel) - } - - if mod.settings.controlChannel != "#events" { - t.Errorf("Expected default control channel '#events', got '%s'", mod.settings.controlChannel) - } - - if mod.settings.password != "password" { - t.Errorf("Expected default password 'password', got '%s'", mod.settings.password) - } -} - -func TestRunningState(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } - - // Note: Cannot test actual Start/Stop without IRC server -} - -func TestEventContext(t *testing.T) { - s := createMockSession(t) - - ctx := eventContext{ - Session: s, - Event: session.Event{Tag: "test.event"}, - } - - if ctx.Session == nil { - t.Error("Session should not be nil") - } - - if ctx.Event.Tag != "test.event" { - t.Errorf("Expected event tag 'test.event', got '%s'", ctx.Event.Tag) - } -} - -func TestChannelHandlers(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - // Test channel.set handler - for _, h := range mod.Handlers() { - if h.Name == "c2.channel.set EVENT_TYPE CHANNEL" { - err := h.Exec([]string{"test.event", "#test"}) - if err != nil { - t.Errorf("channel.set handler failed: %v", err) - } - - // Verify channel was set - if channel, found := mod.channels["test.event"]; !found { - t.Error("Channel was not set") - } else if channel != "#test" { - t.Errorf("Expected channel '#test', got '%s'", channel) - } - break - } - } - - // Test channel.clear handler - for _, h := range mod.Handlers() { - if h.Name == "c2.channel.clear EVENT_TYPE" { - err := h.Exec([]string{"test.event"}) - if err != nil { - t.Errorf("channel.clear handler failed: %v", err) - } - - // Verify channel was cleared - if _, found := mod.channels["test.event"]; found { - t.Error("Channel was not cleared") - } - break - } - } -} - -func TestTemplateHandlers(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - // Test template.set handler - for _, h := range mod.Handlers() { - if h.Name == "c2.template.set EVENT_TYPE TEMPLATE" { - err := h.Exec([]string{"test.event", "Event: {{.Event.Tag}}"}) - if err != nil { - t.Errorf("template.set handler failed: %v", err) - } - - // Verify template was set - if tpl, found := mod.templates["test.event"]; !found { - t.Error("Template was not set") - } else if tpl == nil { - t.Error("Template is nil") - } - break - } - } - - // Test template.clear handler - for _, h := range mod.Handlers() { - if h.Name == "c2.template.clear EVENT_TYPE" { - err := h.Exec([]string{"test.event"}) - if err != nil { - t.Errorf("template.clear handler failed: %v", err) - } - - // Verify template was cleared - if _, found := mod.templates["test.event"]; found { - t.Error("Template was not cleared") - } - break - } - } -} - -func TestClearNonExistent(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - // Test clearing non-existent channel - for _, h := range mod.Handlers() { - if h.Name == "c2.channel.clear EVENT_TYPE" { - err := h.Exec([]string{"non.existent"}) - if err == nil { - t.Error("Expected error when clearing non-existent channel") - } - break - } - } - - // Test clearing non-existent template - for _, h := range mod.Handlers() { - if h.Name == "c2.template.clear EVENT_TYPE" { - err := h.Exec([]string{"non.existent"}) - if err == nil { - t.Error("Expected error when clearing non-existent template") - } - break - } - } -} - -func TestParameters(t *testing.T) { - s := createMockSession(t) - mod := NewC2(s) - - // Check that all parameters are registered - paramNames := []string{ - "c2.server", - "c2.server.tls", - "c2.server.tls.verify", - "c2.operator", - "c2.nick", - "c2.username", - "c2.password", - "c2.sasl.username", - "c2.sasl.password", - "c2.channel.output", - "c2.channel.events", - "c2.channel.control", - } - - // Parameters are stored in the session environment - for _, param := range paramNames { - // This is a simplified check - _ = param - } - - if mod == nil { - t.Error("Module should not be nil") - } -} - -func TestTemplateExecution(t *testing.T) { - // Test template parsing and execution - tmpl, err := template.New("test").Parse("Event: {{.Event.Tag}}") - if err != nil { - t.Errorf("Failed to parse template: %v", err) - } - - if tmpl == nil { - t.Error("Template should not be nil") - } -} - -// Benchmark tests -func BenchmarkNewC2(b *testing.B) { - s, _ := session.New() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NewC2(s) - } -} - -func BenchmarkChannelSet(b *testing.B) { - s, _ := session.New() - mod := NewC2(s) - - var handler *session.ModuleHandler - for _, h := range mod.Handlers() { - if h.Name == "c2.channel.set EVENT_TYPE CHANNEL" { - handler = &h - break - } - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - handler.Exec([]string{"test.event", "#test"}) - } -} - -func BenchmarkTemplateSet(b *testing.B) { - s, _ := session.New() - mod := NewC2(s) - - var handler *session.ModuleHandler - for _, h := range mod.Handlers() { - if h.Name == "c2.template.set EVENT_TYPE TEMPLATE" { - handler = &h - break - } - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - handler.Exec([]string{"test.event", "Event: {{.Event.Tag}}"}) - } -} diff --git a/modules/can/can_test.go b/modules/can/can_test.go deleted file mode 100644 index e5d27ad7..00000000 --- a/modules/can/can_test.go +++ /dev/null @@ -1,407 +0,0 @@ -package can - -import ( - "sync" - "testing" - - "github.com/bettercap/bettercap/v2/session" - "go.einride.tech/can" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewCanModule(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - if mod == nil { - t.Fatal("NewCanModule returned nil") - } - - if mod.Name() != "can" { - t.Errorf("Expected name 'can', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check default values - if mod.transport != "can" { - t.Errorf("Expected default transport 'can', got '%s'", mod.transport) - } - - if mod.deviceName != "can0" { - t.Errorf("Expected default device 'can0', got '%s'", mod.deviceName) - } - - if mod.dumpName != "" { - t.Errorf("Expected empty dumpName, got '%s'", mod.dumpName) - } - - if mod.dumpInject { - t.Error("Expected dumpInject to be false by default") - } - - if mod.filter != "" { - t.Errorf("Expected empty filter, got '%s'", mod.filter) - } - - // Check DBC and OBD2 - if mod.dbc == nil { - t.Error("DBC should not be nil") - } - - if mod.obd2 == nil { - t.Error("OBD2 should not be nil") - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{ - "can.recon on", - "can.recon off", - "can.clear", - "can.show", - "can.dbc.load NAME", - "can.inject FRAME_EXPRESSION", - "can.fuzz ID_OR_NODE_NAME OPTIONAL_SIZE", - } - - if len(handlers) != len(expectedHandlers) { - t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers)) - } - - handlerNames := make(map[string]bool) - for _, h := range handlers { - handlerNames[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerNames[expected] { - t.Errorf("Handler '%s' not found", expected) - } - } -} - -func TestRunningState(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } - - // Note: Cannot test actual Start/Stop without CAN hardware -} - -func TestClearHandler(t *testing.T) { - // Skip this test as it requires CAN to be initialized in the session - t.Skip("Skipping clear handler test - requires initialized CAN in session") -} - -func TestInjectNotRunning(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Test inject when not running - handlers := mod.Handlers() - for _, h := range handlers { - if h.Name == "can.inject FRAME_EXPRESSION" { - err := h.Exec([]string{"123#deadbeef"}) - if err == nil { - t.Error("Expected error when injecting while not running") - } - break - } - } -} - -func TestFuzzNotRunning(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Test fuzz when not running - handlers := mod.Handlers() - for _, h := range handlers { - if h.Name == "can.fuzz ID_OR_NODE_NAME OPTIONAL_SIZE" { - err := h.Exec([]string{"123", ""}) - if err == nil { - t.Error("Expected error when fuzzing while not running") - } - break - } - } -} - -func TestParameters(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Check that all parameters are registered - paramNames := []string{ - "can.device", - "can.dump", - "can.dump.inject", - "can.transport", - "can.filter", - "can.parse.obd2", - } - - // Parameters are stored in the session environment - for _, param := range paramNames { - // This is a simplified check - _ = param - } - - if mod == nil { - t.Error("Module should not be nil") - } -} - -func TestDBC(t *testing.T) { - dbc := &DBC{} - if dbc == nil { - t.Error("DBC should not be nil") - } -} - -func TestOBD2(t *testing.T) { - obd2 := &OBD2{} - if obd2 == nil { - t.Error("OBD2 should not be nil") - } -} - -func TestShowHandler(t *testing.T) { - // Skip this test as it requires CAN to be initialized in the session - t.Skip("Skipping show handler test - requires initialized CAN in session") -} - -func TestDefaultTransport(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - if mod.transport != "can" { - t.Errorf("Expected transport 'can', got '%s'", mod.transport) - } -} - -func TestDefaultDevice(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - if mod.deviceName != "can0" { - t.Errorf("Expected device 'can0', got '%s'", mod.deviceName) - } -} - -func TestFilterExpression(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Initially filter should be empty - if mod.filter != "" { - t.Errorf("Expected empty filter, got '%s'", mod.filter) - } - - // filterExpr should be nil initially - if mod.filterExpr != nil { - t.Error("Expected filterExpr to be nil initially") - } -} - -func TestDBCStruct(t *testing.T) { - // Test DBC struct initialization - dbc := &DBC{} - if dbc == nil { - t.Error("DBC should not be nil") - } -} - -func TestOBD2Struct(t *testing.T) { - // Test OBD2 struct initialization - obd2 := &OBD2{} - if obd2 == nil { - t.Error("OBD2 should not be nil") - } -} - -func TestCANMessage(t *testing.T) { - // Test CAN message creation using NewCanMessage - frame := can.Frame{} - frame.ID = 0x123 - frame.Data = [8]byte{0x01, 0x02, 0x03, 0x04, 0x00, 0x00, 0x00, 0x00} - frame.Length = 4 - - msg := NewCanMessage(frame) - - if msg.Frame.ID != 0x123 { - t.Errorf("Expected ID 0x123, got 0x%x", msg.Frame.ID) - } - - if msg.Frame.Length != 4 { - t.Errorf("Expected frame length 4, got %d", msg.Frame.Length) - } - - if msg.Signals == nil { - t.Error("Signals map should not be nil") - } -} - -func TestDefaultParameters(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Test default parameter values exist - expectedParams := []string{ - "can.device", - "can.transport", - "can.dump", - "can.filter", - "can.dump.inject", - "can.parse.obd2", - } - - // Check that parameters are defined - params := mod.Parameters() - if params == nil { - t.Error("Parameters should not be nil") - } - - // Just verify we have the expected number of parameters - if len(expectedParams) != 6 { - t.Error("Expected 6 parameters") - } -} - -func TestHandlerExecution(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Test that we can find all expected handlers - handlerTests := []struct { - name string - args []string - shouldFail bool - }{ - {"can.inject FRAME_EXPRESSION", []string{"123#deadbeef"}, true}, // Should fail when not running - {"can.fuzz ID_OR_NODE_NAME OPTIONAL_SIZE", []string{"123", "8"}, true}, // Should fail when not running - {"can.dbc.load NAME", []string{"test.dbc"}, true}, // Will fail without actual file - } - - handlers := mod.Handlers() - for _, test := range handlerTests { - found := false - for _, h := range handlers { - if h.Name == test.name { - found = true - err := h.Exec(test.args) - if test.shouldFail && err == nil { - t.Errorf("Handler %s should have failed but didn't", test.name) - } else if !test.shouldFail && err != nil { - t.Errorf("Handler %s failed unexpectedly: %v", test.name, err) - } - break - } - } - if !found { - t.Errorf("Handler %s not found", test.name) - } - } -} - -func TestModuleFields(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Test various fields are initialized correctly - if mod.conn != nil { - t.Error("conn should be nil initially") - } - - if mod.recv != nil { - t.Error("recv should be nil initially") - } - - if mod.send != nil { - t.Error("send should be nil initially") - } -} - -func TestDBCLoadHandler(t *testing.T) { - s := createMockSession(t) - mod := NewCanModule(s) - - // Find dbc.load handler - var dbcHandler *session.ModuleHandler - for _, h := range mod.Handlers() { - if h.Name == "can.dbc.load NAME" { - dbcHandler = &h - break - } - } - - if dbcHandler == nil { - t.Fatal("DBC load handler not found") - } - - // Test with non-existent file - err := dbcHandler.Exec([]string{"non_existent.dbc"}) - if err == nil { - t.Error("Expected error when loading non-existent DBC file") - } -} - -// Benchmark tests -func BenchmarkNewCanModule(b *testing.B) { - s, _ := session.New() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NewCanModule(s) - } -} - -func BenchmarkClearHandler(b *testing.B) { - // Skip this benchmark as it requires CAN to be initialized - b.Skip("Skipping clear handler benchmark - requires initialized CAN in session") -} - -func BenchmarkInjectHandler(b *testing.B) { - s, _ := session.New() - mod := NewCanModule(s) - - var handler *session.ModuleHandler - for _, h := range mod.Handlers() { - if h.Name == "can.inject FRAME_EXPRESSION" { - handler = &h - break - } - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // This will fail since module is not running, but we're benchmarking the handler - _ = handler.Exec([]string{"123#deadbeef"}) - } -} diff --git a/modules/dns_proxy/dns_proxy_base.go b/modules/dns_proxy/dns_proxy_base.go index fe1b84af..f8c17445 100644 --- a/modules/dns_proxy/dns_proxy_base.go +++ b/modules/dns_proxy/dns_proxy_base.go @@ -14,8 +14,6 @@ import ( "github.com/evilsocket/islazy/log" "github.com/miekg/dns" - - "github.com/robertkrimen/otto" ) const ( @@ -227,14 +225,6 @@ func (p *DNSProxy) Start() { } func (p *DNSProxy) Stop() error { - if p.Script != nil { - if p.Script.Plugin.HasFunc("onExit") { - if _, err := p.Script.Call("onExit"); err != nil { - log.Error("Error while executing onExit callback: %s", "\nTraceback:\n "+err.(*otto.Error).String()) - } - } - } - if p.doRedirect && p.Redirection != nil { p.Debug("disabling redirection %s", p.Redirection.String()) if err := p.Sess.Firewall.EnableRedirection(p.Redirection, false); err != nil { diff --git a/modules/dns_proxy/dns_proxy_js_query.go b/modules/dns_proxy/dns_proxy_js_query.go index bae57ad2..cd38f01f 100644 --- a/modules/dns_proxy/dns_proxy_js_query.go +++ b/modules/dns_proxy/dns_proxy_js_query.go @@ -3,9 +3,6 @@ package dns_proxy import ( "encoding/json" "fmt" - "math" - "math/big" - "reflect" "github.com/bettercap/bettercap/v2/log" "github.com/bettercap/bettercap/v2/session" @@ -43,7 +40,7 @@ func jsPropToMap(obj map[string]interface{}, key string) map[string]interface{} if v, ok := obj[key].(map[string]interface{}); ok { return v } - log.Error("error converting JS property to map[string]interface{} where key is: %s", key) + log.Debug("error converting JS property to map[string]interface{} where key is: %s", key) return map[string]interface{}{} } @@ -51,7 +48,7 @@ func jsPropToMapArray(obj map[string]interface{}, key string) []map[string]inter if v, ok := obj[key].([]map[string]interface{}); ok { return v } - log.Error("error converting JS property to []map[string]interface{} where key is: %s", key) + log.Debug("error converting JS property to []map[string]interface{} where key is: %s", key) return []map[string]interface{}{} } @@ -59,7 +56,7 @@ func jsPropToString(obj map[string]interface{}, key string) string { if v, ok := obj[key].(string); ok { return v } - log.Error("error converting JS property to string where key is: %s", key) + log.Debug("error converting JS property to string where key is: %s", key) return "" } @@ -67,115 +64,56 @@ func jsPropToStringArray(obj map[string]interface{}, key string) []string { if v, ok := obj[key].([]string); ok { return v } - log.Error("error converting JS property to []string where key is: %s", key) + log.Debug("error converting JS property to []string where key is: %s", key) return []string{} } func jsPropToUint8(obj map[string]interface{}, key string) uint8 { - if v, ok := obj[key].(int64); ok { - if v >= 0 && v <= math.MaxUint8 { - return uint8(v) - } + if v, ok := obj[key].(uint8); ok { + return v } - log.Error("error converting JS property to uint8 where key is: %s", key) - return uint8(0) + log.Debug("error converting JS property to uint8 where key is: %s", key) + return 0 } func jsPropToUint8Array(obj map[string]interface{}, key string) []uint8 { - if arr, ok := obj[key].([]interface{}); ok { - vArr := make([]uint8, 0, len(arr)) - for _, item := range arr { - if v, ok := item.(int64); ok { - if v >= 0 && v <= math.MaxUint8 { - vArr = append(vArr, uint8(v)) - } else { - log.Error("error converting JS property to []uint8 where key is: %s", key) - return []uint8{} - } - } - } - return vArr + if v, ok := obj[key].([]uint8); ok { + return v } - log.Error("error converting JS property to []uint8 where key is: %s", key) + log.Debug("error converting JS property to []uint8 where key is: %s", key) return []uint8{} } func jsPropToUint16(obj map[string]interface{}, key string) uint16 { - if v, ok := obj[key].(int64); ok { - if v >= 0 && v <= math.MaxUint16 { - return uint16(v) - } + if v, ok := obj[key].(uint16); ok { + return v } - log.Error("error converting JS property to uint16 where key is: %s", key) - return uint16(0) + log.Debug("error converting JS property to uint16 where key is: %s", key) + return 0 } func jsPropToUint16Array(obj map[string]interface{}, key string) []uint16 { - if arr, ok := obj[key].([]interface{}); ok { - vArr := make([]uint16, 0, len(arr)) - for _, item := range arr { - if v, ok := item.(int64); ok { - if v >= 0 && v <= math.MaxUint16 { - vArr = append(vArr, uint16(v)) - } else { - log.Error("error converting JS property to []uint16 where key is: %s", key) - return []uint16{} - } - } - } - return vArr + if v, ok := obj[key].([]uint16); ok { + return v } - log.Error("error converting JS property to []uint16 where key is: %s", key) + log.Debug("error converting JS property to []uint16 where key is: %s", key) return []uint16{} } func jsPropToUint32(obj map[string]interface{}, key string) uint32 { - if v, ok := obj[key].(int64); ok { - if v >= 0 && v <= math.MaxUint32 { - return uint32(v) - } + if v, ok := obj[key].(uint32); ok { + return v } - log.Error("error converting JS property to uint32 where key is: %s", key) - return uint32(0) + log.Debug("error converting JS property to uint32 where key is: %s", key) + return 0 } func jsPropToUint64(obj map[string]interface{}, key string) uint64 { - prop, found := obj[key] - if found { - switch reflect.TypeOf(prop).String() { - case "float64": - if f, ok := prop.(float64); ok { - bigInt := new(big.Float).SetFloat64(f) - v, _ := bigInt.Uint64() - if v >= 0 { - return v - } - } - break - case "int64": - if v, ok := prop.(int64); ok { - if v >= 0 { - return uint64(v) - } - } - break - case "uint64": - if v, ok := prop.(uint64); ok { - return v - } - break - } + if v, ok := obj[key].(uint64); ok { + return v } - log.Error("error converting JS property to uint64 where key is: %s", key) - return uint64(0) -} - -func uint16ArrayToInt64Array(arr []uint16) []int64 { - vArr := make([]int64, 0, len(arr)) - for _, item := range arr { - vArr = append(vArr, int64(item)) - } - return vArr + log.Debug("error converting JS property to uint64 where key is: %s", key) + return 0 } func (j *JSQuery) NewHash() string { @@ -245,8 +183,8 @@ func NewJSQuery(query *dns.Msg, clientIP string) (jsQuery *JSQuery) { for i, question := range query.Question { questions[i] = map[string]interface{}{ "Name": question.Name, - "Qtype": int64(question.Qtype), - "Qclass": int64(question.Qclass), + "Qtype": question.Qtype, + "Qclass": question.Qclass, } } @@ -355,11 +293,3 @@ func (j *JSQuery) WasModified() bool { // check if any of the fields has been changed return j.NewHash() != j.refHash } - -func (j *JSQuery) CheckIfModifiedAndUpdateHash() bool { - // check if query was changed and update its hash - newHash := j.NewHash() - wasModified := j.refHash != newHash - j.refHash = newHash - return wasModified -} diff --git a/modules/dns_proxy/dns_proxy_js_record.go b/modules/dns_proxy/dns_proxy_js_record.go index 49553ad8..55832d69 100644 --- a/modules/dns_proxy/dns_proxy_js_record.go +++ b/modules/dns_proxy/dns_proxy_js_record.go @@ -13,10 +13,10 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) jsRecord = map[string]interface{}{ "Header": map[string]interface{}{ - "Class": int64(header.Class), + "Class": header.Class, "Name": header.Name, - "Rrtype": int64(header.Rrtype), - "Ttl": int64(header.Ttl), + "Rrtype": header.Rrtype, + "Ttl": header.Ttl, }, } @@ -48,24 +48,24 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) jsRecord["Mr"] = rr.Mr case *dns.MX: jsRecord["Mx"] = rr.Mx - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.NULL: jsRecord["Data"] = rr.Data case *dns.SOA: - jsRecord["Expire"] = int64(rr.Expire) - jsRecord["Minttl"] = int64(rr.Minttl) + jsRecord["Expire"] = rr.Expire + jsRecord["Minttl"] = rr.Minttl jsRecord["Ns"] = rr.Ns - jsRecord["Refresh"] = int64(rr.Refresh) - jsRecord["Retry"] = int64(rr.Retry) + jsRecord["Refresh"] = rr.Refresh + jsRecord["Retry"] = rr.Retry jsRecord["Mbox"] = rr.Mbox - jsRecord["Serial"] = int64(rr.Serial) + jsRecord["Serial"] = rr.Serial case *dns.TXT: jsRecord["Txt"] = rr.Txt case *dns.SRV: - jsRecord["Port"] = int64(rr.Port) - jsRecord["Priority"] = int64(rr.Priority) + jsRecord["Port"] = rr.Port + jsRecord["Priority"] = rr.Priority jsRecord["Target"] = rr.Target - jsRecord["Weight"] = int64(rr.Weight) + jsRecord["Weight"] = rr.Weight case *dns.PTR: jsRecord["Ptr"] = rr.Ptr case *dns.NS: @@ -73,10 +73,10 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) case *dns.DNAME: jsRecord["Target"] = rr.Target case *dns.AFSDB: - jsRecord["Subtype"] = int64(rr.Subtype) + jsRecord["Subtype"] = rr.Subtype jsRecord["Hostname"] = rr.Hostname case *dns.CAA: - jsRecord["Flag"] = int64(rr.Flag) + jsRecord["Flag"] = rr.Flag jsRecord["Tag"] = rr.Tag jsRecord["Value"] = rr.Value case *dns.HINFO: @@ -90,123 +90,123 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) jsRecord["SubAddress"] = rr.SubAddress case *dns.KX: jsRecord["Exchanger"] = rr.Exchanger - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.LOC: - jsRecord["Altitude"] = int64(rr.Altitude) - jsRecord["HorizPre"] = int64(rr.HorizPre) - jsRecord["Latitude"] = int64(rr.Latitude) - jsRecord["Longitude"] = int64(rr.Longitude) - jsRecord["Size"] = int64(rr.Size) - jsRecord["Version"] = int64(rr.Version) - jsRecord["VertPre"] = int64(rr.VertPre) + jsRecord["Altitude"] = rr.Altitude + jsRecord["HorizPre"] = rr.HorizPre + jsRecord["Latitude"] = rr.Latitude + jsRecord["Longitude"] = rr.Longitude + jsRecord["Size"] = rr.Size + jsRecord["Version"] = rr.Version + jsRecord["VertPre"] = rr.VertPre case *dns.SSHFP: - jsRecord["Algorithm"] = int64(rr.Algorithm) + jsRecord["Algorithm"] = rr.Algorithm jsRecord["FingerPrint"] = rr.FingerPrint - jsRecord["Type"] = int64(rr.Type) + jsRecord["Type"] = rr.Type case *dns.TLSA: jsRecord["Certificate"] = rr.Certificate - jsRecord["MatchingType"] = int64(rr.MatchingType) - jsRecord["Selector"] = int64(rr.Selector) - jsRecord["Usage"] = int64(rr.Usage) + jsRecord["MatchingType"] = rr.MatchingType + jsRecord["Selector"] = rr.Selector + jsRecord["Usage"] = rr.Usage case *dns.CERT: - jsRecord["Algorithm"] = int64(rr.Algorithm) + jsRecord["Algorithm"] = rr.Algorithm jsRecord["Certificate"] = rr.Certificate - jsRecord["KeyTag"] = int64(rr.KeyTag) - jsRecord["Type"] = int64(rr.Type) + jsRecord["KeyTag"] = rr.KeyTag + jsRecord["Type"] = rr.Type case *dns.DS: - jsRecord["Algorithm"] = int64(rr.Algorithm) + jsRecord["Algorithm"] = rr.Algorithm jsRecord["Digest"] = rr.Digest - jsRecord["DigestType"] = int64(rr.DigestType) - jsRecord["KeyTag"] = int64(rr.KeyTag) + jsRecord["DigestType"] = rr.DigestType + jsRecord["KeyTag"] = rr.KeyTag case *dns.NAPTR: - jsRecord["Order"] = int64(rr.Order) - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Order"] = rr.Order + jsRecord["Preference"] = rr.Preference jsRecord["Flags"] = rr.Flags jsRecord["Service"] = rr.Service jsRecord["Regexp"] = rr.Regexp jsRecord["Replacement"] = rr.Replacement case *dns.RRSIG: - jsRecord["Algorithm"] = int64(rr.Algorithm) - jsRecord["Expiration"] = int64(rr.Expiration) - jsRecord["Inception"] = int64(rr.Inception) - jsRecord["KeyTag"] = int64(rr.KeyTag) - jsRecord["Labels"] = int64(rr.Labels) - jsRecord["OrigTtl"] = int64(rr.OrigTtl) + jsRecord["Algorithm"] = rr.Algorithm + jsRecord["Expiration"] = rr.Expiration + jsRecord["Inception"] = rr.Inception + jsRecord["KeyTag"] = rr.KeyTag + jsRecord["Labels"] = rr.Labels + jsRecord["OrigTtl"] = rr.OrigTtl jsRecord["Signature"] = rr.Signature jsRecord["SignerName"] = rr.SignerName - jsRecord["TypeCovered"] = int64(rr.TypeCovered) + jsRecord["TypeCovered"] = rr.TypeCovered case *dns.NSEC: jsRecord["NextDomain"] = rr.NextDomain - jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap) + jsRecord["TypeBitMap"] = rr.TypeBitMap case *dns.NSEC3: - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Hash"] = int64(rr.Hash) - jsRecord["HashLength"] = int64(rr.HashLength) - jsRecord["Iterations"] = int64(rr.Iterations) + jsRecord["Flags"] = rr.Flags + jsRecord["Hash"] = rr.Hash + jsRecord["HashLength"] = rr.HashLength + jsRecord["Iterations"] = rr.Iterations jsRecord["NextDomain"] = rr.NextDomain jsRecord["Salt"] = rr.Salt - jsRecord["SaltLength"] = int64(rr.SaltLength) - jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap) + jsRecord["SaltLength"] = rr.SaltLength + jsRecord["TypeBitMap"] = rr.TypeBitMap case *dns.NSEC3PARAM: - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Hash"] = int64(rr.Hash) - jsRecord["Iterations"] = int64(rr.Iterations) + jsRecord["Flags"] = rr.Flags + jsRecord["Hash"] = rr.Hash + jsRecord["Iterations"] = rr.Iterations jsRecord["Salt"] = rr.Salt - jsRecord["SaltLength"] = int64(rr.SaltLength) + jsRecord["SaltLength"] = rr.SaltLength case *dns.TKEY: jsRecord["Algorithm"] = rr.Algorithm - jsRecord["Error"] = int64(rr.Error) - jsRecord["Expiration"] = int64(rr.Expiration) - jsRecord["Inception"] = int64(rr.Inception) + jsRecord["Error"] = rr.Error + jsRecord["Expiration"] = rr.Expiration + jsRecord["Inception"] = rr.Inception jsRecord["Key"] = rr.Key - jsRecord["KeySize"] = int64(rr.KeySize) - jsRecord["Mode"] = int64(rr.Mode) + jsRecord["KeySize"] = rr.KeySize + jsRecord["Mode"] = rr.Mode jsRecord["OtherData"] = rr.OtherData - jsRecord["OtherLen"] = int64(rr.OtherLen) + jsRecord["OtherLen"] = rr.OtherLen case *dns.TSIG: jsRecord["Algorithm"] = rr.Algorithm - jsRecord["Error"] = int64(rr.Error) - jsRecord["Fudge"] = int64(rr.Fudge) - jsRecord["MACSize"] = int64(rr.MACSize) + jsRecord["Error"] = rr.Error + jsRecord["Fudge"] = rr.Fudge + jsRecord["MACSize"] = rr.MACSize jsRecord["MAC"] = rr.MAC - jsRecord["OrigId"] = int64(rr.OrigId) + jsRecord["OrigId"] = rr.OrigId jsRecord["OtherData"] = rr.OtherData - jsRecord["OtherLen"] = int64(rr.OtherLen) - jsRecord["TimeSigned"] = int64(rr.TimeSigned) + jsRecord["OtherLen"] = rr.OtherLen + jsRecord["TimeSigned"] = rr.TimeSigned case *dns.IPSECKEY: - jsRecord["Algorithm"] = int64(rr.Algorithm) + jsRecord["Algorithm"] = rr.Algorithm jsRecord["GatewayAddr"] = rr.GatewayAddr.String() jsRecord["GatewayHost"] = rr.GatewayHost - jsRecord["GatewayType"] = int64(rr.GatewayType) - jsRecord["Precedence"] = int64(rr.Precedence) + jsRecord["GatewayType"] = rr.GatewayType + jsRecord["Precedence"] = rr.Precedence jsRecord["PublicKey"] = rr.PublicKey case *dns.KEY: - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Protocol"] = int64(rr.Protocol) - jsRecord["Algorithm"] = int64(rr.Algorithm) + jsRecord["Flags"] = rr.Flags + jsRecord["Protocol"] = rr.Protocol + jsRecord["Algorithm"] = rr.Algorithm jsRecord["PublicKey"] = rr.PublicKey case *dns.CDS: - jsRecord["KeyTag"] = int64(rr.KeyTag) - jsRecord["Algorithm"] = int64(rr.Algorithm) - jsRecord["DigestType"] = int64(rr.DigestType) + jsRecord["KeyTag"] = rr.KeyTag + jsRecord["Algorithm"] = rr.Algorithm + jsRecord["DigestType"] = rr.DigestType jsRecord["Digest"] = rr.Digest case *dns.CDNSKEY: - jsRecord["Algorithm"] = int64(rr.Algorithm) - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Protocol"] = int64(rr.Protocol) + jsRecord["Algorithm"] = rr.Algorithm + jsRecord["Flags"] = rr.Flags + jsRecord["Protocol"] = rr.Protocol jsRecord["PublicKey"] = rr.PublicKey case *dns.NID: jsRecord["NodeID"] = rr.NodeID - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.L32: jsRecord["Locator32"] = rr.Locator32.String() - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.L64: jsRecord["Locator64"] = rr.Locator64 - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.LP: jsRecord["Fqdn"] = rr.Fqdn - jsRecord["Preference"] = int16(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.GPOS: jsRecord["Altitude"] = rr.Altitude jsRecord["Latitude"] = rr.Latitude @@ -215,40 +215,40 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) jsRecord["Mbox"] = rr.Mbox jsRecord["Txt"] = rr.Txt case *dns.RKEY: - jsRecord["Algorithm"] = int64(rr.Algorithm) - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Protocol"] = int64(rr.Protocol) + jsRecord["Algorithm"] = rr.Algorithm + jsRecord["Flags"] = rr.Flags + jsRecord["Protocol"] = rr.Protocol jsRecord["PublicKey"] = rr.PublicKey case *dns.SMIMEA: jsRecord["Certificate"] = rr.Certificate - jsRecord["MatchingType"] = int64(rr.MatchingType) - jsRecord["Selector"] = int64(rr.Selector) - jsRecord["Usage"] = int64(rr.Usage) + jsRecord["MatchingType"] = rr.MatchingType + jsRecord["Selector"] = rr.Selector + jsRecord["Usage"] = rr.Usage case *dns.AMTRELAY: jsRecord["GatewayAddr"] = rr.GatewayAddr.String() jsRecord["GatewayHost"] = rr.GatewayHost - jsRecord["GatewayType"] = int64(rr.GatewayType) - jsRecord["Precedence"] = int64(rr.Precedence) + jsRecord["GatewayType"] = rr.GatewayType + jsRecord["Precedence"] = rr.Precedence case *dns.AVC: jsRecord["Txt"] = rr.Txt case *dns.URI: - jsRecord["Priority"] = int64(rr.Priority) - jsRecord["Weight"] = int64(rr.Weight) + jsRecord["Priority"] = rr.Priority + jsRecord["Weight"] = rr.Weight jsRecord["Target"] = rr.Target case *dns.EUI48: jsRecord["Address"] = rr.Address case *dns.EUI64: jsRecord["Address"] = rr.Address case *dns.GID: - jsRecord["Gid"] = int64(rr.Gid) + jsRecord["Gid"] = rr.Gid case *dns.UID: - jsRecord["Uid"] = int64(rr.Uid) + jsRecord["Uid"] = rr.Uid case *dns.UINFO: jsRecord["Uinfo"] = rr.Uinfo case *dns.SPF: jsRecord["Txt"] = rr.Txt case *dns.HTTPS: - jsRecord["Priority"] = int64(rr.Priority) + jsRecord["Priority"] = rr.Priority jsRecord["Target"] = rr.Target kvs := rr.Value var jsKvs []map[string]interface{} @@ -262,7 +262,7 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) } jsRecord["Value"] = jsKvs case *dns.SVCB: - jsRecord["Priority"] = int64(rr.Priority) + jsRecord["Priority"] = rr.Priority jsRecord["Target"] = rr.Target kvs := rr.Value jsKvs := make([]map[string]interface{}, len(kvs)) @@ -277,13 +277,13 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) jsRecord["Value"] = jsKvs case *dns.ZONEMD: jsRecord["Digest"] = rr.Digest - jsRecord["Hash"] = int64(rr.Hash) - jsRecord["Scheme"] = int64(rr.Scheme) - jsRecord["Serial"] = int64(rr.Serial) + jsRecord["Hash"] = rr.Hash + jsRecord["Scheme"] = rr.Scheme + jsRecord["Serial"] = rr.Serial case *dns.CSYNC: - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Serial"] = int64(rr.Serial) - jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap) + jsRecord["Flags"] = rr.Flags + jsRecord["Serial"] = rr.Serial + jsRecord["TypeBitMap"] = rr.TypeBitMap case *dns.OPENPGPKEY: jsRecord["PublicKey"] = rr.PublicKey case *dns.TALINK: @@ -294,53 +294,43 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error) case *dns.DHCID: jsRecord["Digest"] = rr.Digest case *dns.DNSKEY: - jsRecord["Flags"] = int64(rr.Flags) - jsRecord["Protocol"] = int64(rr.Protocol) - jsRecord["Algorithm"] = int64(rr.Algorithm) + jsRecord["Flags"] = rr.Flags + jsRecord["Protocol"] = rr.Protocol + jsRecord["Algorithm"] = rr.Algorithm jsRecord["PublicKey"] = rr.PublicKey case *dns.HIP: jsRecord["Hit"] = rr.Hit - jsRecord["HitLength"] = int64(rr.HitLength) + jsRecord["HitLength"] = rr.HitLength jsRecord["PublicKey"] = rr.PublicKey - jsRecord["PublicKeyAlgorithm"] = int64(rr.PublicKeyAlgorithm) - jsRecord["PublicKeyLength"] = int64(rr.PublicKeyLength) + jsRecord["PublicKeyAlgorithm"] = rr.PublicKeyAlgorithm + jsRecord["PublicKeyLength"] = rr.PublicKeyLength jsRecord["RendezvousServers"] = rr.RendezvousServers case *dns.OPT: - options := rr.Option - jsOptions := make([]map[string]interface{}, len(options)) - for i, option := range options { - jsOption, err := NewJSEDNS0(option) - if err != nil { - log.Error(err.Error()) - continue - } - jsOptions[i] = jsOption - } - jsRecord["Option"] = jsOptions + jsRecord["Option"] = rr.Option case *dns.NIMLOC: jsRecord["Locator"] = rr.Locator case *dns.EID: jsRecord["Endpoint"] = rr.Endpoint case *dns.NXT: jsRecord["NextDomain"] = rr.NextDomain - jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap) + jsRecord["TypeBitMap"] = rr.TypeBitMap case *dns.PX: jsRecord["Mapx400"] = rr.Mapx400 jsRecord["Map822"] = rr.Map822 - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.SIG: - jsRecord["Algorithm"] = int64(rr.Algorithm) - jsRecord["Expiration"] = int64(rr.Expiration) - jsRecord["Inception"] = int64(rr.Inception) - jsRecord["KeyTag"] = int64(rr.KeyTag) - jsRecord["Labels"] = int64(rr.Labels) - jsRecord["OrigTtl"] = int64(rr.OrigTtl) + jsRecord["Algorithm"] = rr.Algorithm + jsRecord["Expiration"] = rr.Expiration + jsRecord["Inception"] = rr.Inception + jsRecord["KeyTag"] = rr.KeyTag + jsRecord["Labels"] = rr.Labels + jsRecord["OrigTtl"] = rr.OrigTtl jsRecord["Signature"] = rr.Signature jsRecord["SignerName"] = rr.SignerName - jsRecord["TypeCovered"] = int64(rr.TypeCovered) + jsRecord["TypeCovered"] = rr.TypeCovered case *dns.RT: jsRecord["Host"] = rr.Host - jsRecord["Preference"] = int64(rr.Preference) + jsRecord["Preference"] = rr.Preference case *dns.NSAPPTR: jsRecord["Ptr"] = rr.Ptr case *dns.X25: diff --git a/modules/dns_proxy/dns_proxy_script.go b/modules/dns_proxy/dns_proxy_script.go index 83dd6777..4a608168 100644 --- a/modules/dns_proxy/dns_proxy_script.go +++ b/modules/dns_proxy/dns_proxy_script.go @@ -84,9 +84,11 @@ func (s *DnsProxyScript) OnRequest(req *dns.Msg, clientIP string) (jsreq, jsres if _, err := s.Call("onRequest", jsreq, jsres); err != nil { log.Error("%s", err) return nil, nil - } else if jsreq.CheckIfModifiedAndUpdateHash() { + } else if jsreq.WasModified() { + jsreq.UpdateHash() return jsreq, nil - } else if jsres.CheckIfModifiedAndUpdateHash() { + } else if jsres.WasModified() { + jsres.UpdateHash() return nil, jsres } } @@ -102,7 +104,8 @@ func (s *DnsProxyScript) OnResponse(req, res *dns.Msg, clientIP string) (jsreq, if _, err := s.Call("onResponse", jsreq, jsres); err != nil { log.Error("%s", err) return nil, nil - } else if jsres.CheckIfModifiedAndUpdateHash() { + } else if jsres.WasModified() { + jsres.UpdateHash() return nil, jsres } } diff --git a/modules/events_stream/events_view.go b/modules/events_stream/events_view.go index f06d8dae..56d0e10d 100644 --- a/modules/events_stream/events_view.go +++ b/modules/events_stream/events_view.go @@ -137,7 +137,7 @@ func (mod *EventsStream) Render(output io.Writer, e session.Event) { } else if strings.HasPrefix(e.Tag, "zeroconf.") { mod.viewZeroConfEvent(output, e) } else if !strings.HasPrefix(e.Tag, "tick") && e.Tag != "session.started" && e.Tag != "session.stopped" { - fmt.Fprintf(output, "[%s] [%s] %v\n", e.Time.Format(mod.timeFormat), tui.Green(e.Tag), e.Data) + fmt.Fprintf(output, "[%s] [%s] %v\n", e.Time.Format(mod.timeFormat), tui.Green(e.Tag), e) } } diff --git a/modules/http_proxy/http_proxy_base.go b/modules/http_proxy/http_proxy_base.go index 7ace2122..5d4eebef 100644 --- a/modules/http_proxy/http_proxy_base.go +++ b/modules/http_proxy/http_proxy_base.go @@ -27,8 +27,6 @@ import ( "github.com/evilsocket/islazy/log" "github.com/evilsocket/islazy/str" "github.com/evilsocket/islazy/tui" - - "github.com/robertkrimen/otto" ) const ( @@ -434,14 +432,6 @@ func (p *HTTPProxy) Start() { } func (p *HTTPProxy) Stop() error { - if p.Script != nil { - if p.Script.Plugin.HasFunc("onExit") { - if _, err := p.Script.Call("onExit"); err != nil { - log.Error("Error while executing onExit callback: %s", "\nTraceback:\n "+err.(*otto.Error).String()) - } - } - } - if p.doRedirect && p.Redirection != nil { p.Debug("disabling redirection %s", p.Redirection.String()) if err := p.Sess.Firewall.EnableRedirection(p.Redirection, false); err != nil { diff --git a/modules/http_proxy/http_proxy_base_filters.go b/modules/http_proxy/http_proxy_base_filters.go index 988807f2..017fc0c3 100644 --- a/modules/http_proxy/http_proxy_base_filters.go +++ b/modules/http_proxy/http_proxy_base_filters.go @@ -1,10 +1,10 @@ package http_proxy import ( - "io" + "io/ioutil" "net/http" - "strconv" "strings" + "strconv" "github.com/elazarl/goproxy" @@ -74,10 +74,10 @@ func (p *HTTPProxy) isScriptInjectable(res *http.Response) (bool, string) { return false, "" } -func (p *HTTPProxy) doScriptInjection(res *http.Response, cType string) error { +func (p *HTTPProxy) doScriptInjection(res *http.Response, cType string) (error) { defer res.Body.Close() - raw, err := io.ReadAll(res.Body) + raw, err := ioutil.ReadAll(res.Body) if err != nil { return err } else if html := string(raw); strings.Contains(html, "") { @@ -91,7 +91,7 @@ func (p *HTTPProxy) doScriptInjection(res *http.Response, cType string) error { res.Header.Set("Content-Length", strconv.Itoa(len(html))) // reset the response body to the original unread state - res.Body = io.NopCloser(strings.NewReader(html)) + res.Body = ioutil.NopCloser(strings.NewReader(html)) return nil } diff --git a/modules/http_proxy/http_proxy_base_sslstriper.go b/modules/http_proxy/http_proxy_base_sslstriper.go index e3331b18..d2fd0f4f 100644 --- a/modules/http_proxy/http_proxy_base_sslstriper.go +++ b/modules/http_proxy/http_proxy_base_sslstriper.go @@ -1,7 +1,7 @@ package http_proxy import ( - "io" + "io/ioutil" "net/http" "net/url" "regexp" @@ -253,7 +253,7 @@ func (s *SSLStripper) Process(res *http.Response, ctx *goproxy.ProxyCtx) { // if we have a text or html content type, fetch the body // and perform sslstripping if s.isContentStrippable(res) { - raw, err := io.ReadAll(res.Body) + raw, err := ioutil.ReadAll(res.Body) if err != nil { log.Error("Could not read response body: %s", err) return @@ -297,9 +297,9 @@ func (s *SSLStripper) Process(res *http.Response, ctx *goproxy.ProxyCtx) { // reset the response body to the original unread state // but with just a string reader, this way further calls - // to ui.ReadAll(res.Body) will just return the content + // to ioutil.ReadAll(res.Body) will just return the content // we stripped without downloading anything again. - res.Body = io.NopCloser(strings.NewReader(body)) + res.Body = ioutil.NopCloser(strings.NewReader(body)) } // fix cookies domain + strip "secure" + "httponly" flags diff --git a/modules/http_proxy/http_proxy_js_request.go b/modules/http_proxy/http_proxy_js_request.go index 859526e4..a3c6a1da 100644 --- a/modules/http_proxy/http_proxy_js_request.go +++ b/modules/http_proxy/http_proxy_js_request.go @@ -3,7 +3,7 @@ package http_proxy import ( "bytes" "fmt" - "io" + "io/ioutil" "net/http" "net/url" "regexp" @@ -103,21 +103,7 @@ func (j *JSRequest) WasModified() bool { return j.NewHash() != j.refHash } -func (j *JSRequest) CheckIfModifiedAndUpdateHash() bool { - newHash := j.NewHash() - // body was read - if j.bodyRead { - j.refHash = newHash - return true - } - // check if req was changed and update its hash - wasModified := j.refHash != newHash - j.refHash = newHash - return wasModified -} - func (j *JSRequest) GetHeader(name, deflt string) string { - name = strings.ToLower(name) headers := strings.Split(j.Headers, "\r\n") for i := 0; i < len(headers); i++ { if headers[i] != "" { @@ -125,7 +111,8 @@ func (j *JSRequest) GetHeader(name, deflt string) string { if len(header_parts) != 0 && len(header_parts[0]) == 3 { header_name := string(header_parts[0][1]) header_value := string(header_parts[0][2]) - if name == strings.ToLower(header_name) { + + if strings.ToLower(name) == strings.ToLower(header_name) { return header_value } } @@ -134,25 +121,6 @@ func (j *JSRequest) GetHeader(name, deflt string) string { return deflt } -func (j *JSRequest) GetHeaders(name string) []string { - name = strings.ToLower(name) - headers := strings.Split(j.Headers, "\r\n") - header_values := make([]string, 0, len(headers)) - for i := 0; i < len(headers); i++ { - if headers[i] != "" { - header_parts := header_regexp.FindAllSubmatch([]byte(headers[i]), 1) - if len(header_parts) != 0 && len(header_parts[0]) == 3 { - header_name := string(header_parts[0][1]) - header_value := string(header_parts[0][2]) - if name == strings.ToLower(header_name) { - header_values = append(header_values, header_value) - } - } - } - } - return header_values -} - func (j *JSRequest) SetHeader(name, value string) { name = strings.TrimSpace(name) value = strings.TrimSpace(value) @@ -201,7 +169,7 @@ func (j *JSRequest) RemoveHeader(name string) { } func (j *JSRequest) ReadBody() string { - raw, err := io.ReadAll(j.req.Body) + raw, err := ioutil.ReadAll(j.req.Body) if err != nil { return "" } @@ -209,7 +177,7 @@ func (j *JSRequest) ReadBody() string { j.Body = string(raw) j.bodyRead = true // reset the request body to the original unread state - j.req.Body = io.NopCloser(bytes.NewBuffer(raw)) + j.req.Body = ioutil.NopCloser(bytes.NewBuffer(raw)) return j.Body } diff --git a/modules/http_proxy/http_proxy_js_response.go b/modules/http_proxy/http_proxy_js_response.go index c1bb98bf..051812ef 100644 --- a/modules/http_proxy/http_proxy_js_response.go +++ b/modules/http_proxy/http_proxy_js_response.go @@ -3,7 +3,7 @@ package http_proxy import ( "bytes" "fmt" - "io" + "io/ioutil" "net/http" "strings" @@ -76,29 +76,7 @@ func (j *JSResponse) WasModified() bool { return j.NewHash() != j.refHash } -func (j *JSResponse) CheckIfModifiedAndUpdateHash() bool { - newHash := j.NewHash() - if j.bodyRead { - // body was read - j.refHash = newHash - return true - } else if j.bodyClear { - // body was cleared manually - j.refHash = newHash - return true - } else if j.Body != "" { - // body was not read but just set - j.refHash = newHash - return true - } - // check if res was changed and update its hash - wasModified := j.refHash != newHash - j.refHash = newHash - return wasModified -} - func (j *JSResponse) GetHeader(name, deflt string) string { - name = strings.ToLower(name) headers := strings.Split(j.Headers, "\r\n") for i := 0; i < len(headers); i++ { if headers[i] != "" { @@ -106,7 +84,8 @@ func (j *JSResponse) GetHeader(name, deflt string) string { if len(header_parts) != 0 && len(header_parts[0]) == 3 { header_name := string(header_parts[0][1]) header_value := string(header_parts[0][2]) - if name == strings.ToLower(header_name) { + + if strings.ToLower(name) == strings.ToLower(header_name) { return header_value } } @@ -115,25 +94,6 @@ func (j *JSResponse) GetHeader(name, deflt string) string { return deflt } -func (j *JSResponse) GetHeaders(name string) []string { - name = strings.ToLower(name) - headers := strings.Split(j.Headers, "\r\n") - header_values := make([]string, 0, len(headers)) - for i := 0; i < len(headers); i++ { - if headers[i] != "" { - header_parts := header_regexp.FindAllSubmatch([]byte(headers[i]), 1) - if len(header_parts) != 0 && len(header_parts[0]) == 3 { - header_name := string(header_parts[0][1]) - header_value := string(header_parts[0][2]) - if name == strings.ToLower(header_name) { - header_values = append(header_values, header_value) - } - } - } - } - return header_values -} - func (j *JSResponse) SetHeader(name, value string) { name = strings.TrimSpace(name) value = strings.TrimSpace(value) @@ -208,7 +168,7 @@ func (j *JSResponse) ToResponse(req *http.Request) (resp *http.Response) { func (j *JSResponse) ReadBody() string { defer j.resp.Body.Close() - raw, err := io.ReadAll(j.resp.Body) + raw, err := ioutil.ReadAll(j.resp.Body) if err != nil { return "" } @@ -217,7 +177,7 @@ func (j *JSResponse) ReadBody() string { j.bodyRead = true j.bodyClear = false // reset the response body to the original unread state - j.resp.Body = io.NopCloser(bytes.NewBuffer(raw)) + j.resp.Body = ioutil.NopCloser(bytes.NewBuffer(raw)) return j.Body } diff --git a/modules/http_proxy/http_proxy_script.go b/modules/http_proxy/http_proxy_script.go index 446f61da..070f7e24 100644 --- a/modules/http_proxy/http_proxy_script.go +++ b/modules/http_proxy/http_proxy_script.go @@ -84,9 +84,11 @@ func (s *HttpProxyScript) OnRequest(original *http.Request) (jsreq *JSRequest, j if _, err := s.Call("onRequest", jsreq, jsres); err != nil { log.Error("%s", err) return nil, nil - } else if jsreq.CheckIfModifiedAndUpdateHash() { + } else if jsreq.WasModified() { + jsreq.UpdateHash() return jsreq, nil - } else if jsres.CheckIfModifiedAndUpdateHash() { + } else if jsres.WasModified() { + jsres.UpdateHash() return nil, jsres } } @@ -102,7 +104,8 @@ func (s *HttpProxyScript) OnResponse(res *http.Response) (jsreq *JSRequest, jsre if _, err := s.Call("onResponse", jsreq, jsres); err != nil { log.Error("%s", err) return nil, nil - } else if jsres.CheckIfModifiedAndUpdateHash() { + } else if jsres.WasModified() { + jsres.UpdateHash() return nil, jsres } } diff --git a/modules/http_proxy/http_proxy_test.go b/modules/http_proxy/http_proxy_test.go deleted file mode 100644 index d05d046e..00000000 --- a/modules/http_proxy/http_proxy_test.go +++ /dev/null @@ -1,706 +0,0 @@ -package http_proxy - -import ( - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/httptest" - "os" - "runtime" - "strings" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/firewall" - "github.com/bettercap/bettercap/v2/network" - "github.com/bettercap/bettercap/v2/packets" - "github.com/bettercap/bettercap/v2/session" - "github.com/evilsocket/islazy/data" -) - -// MockFirewall implements a mock firewall for testing -type MockFirewall struct { - forwardingEnabled bool - redirections []firewall.Redirection -} - -func NewMockFirewall() *MockFirewall { - return &MockFirewall{ - forwardingEnabled: false, - redirections: make([]firewall.Redirection, 0), - } -} - -func (m *MockFirewall) IsForwardingEnabled() bool { - return m.forwardingEnabled -} - -func (m *MockFirewall) EnableForwarding(enabled bool) error { - m.forwardingEnabled = enabled - return nil -} - -func (m *MockFirewall) EnableRedirection(r *firewall.Redirection, enabled bool) error { - if enabled { - m.redirections = append(m.redirections, *r) - } else { - for i, red := range m.redirections { - if red.String() == r.String() { - m.redirections = append(m.redirections[:i], m.redirections[i+1:]...) - break - } - } - } - return nil -} - -func (m *MockFirewall) DisableRedirection(r *firewall.Redirection, enabled bool) error { - return m.EnableRedirection(r, false) -} - -func (m *MockFirewall) Restore() { - m.redirections = make([]firewall.Redirection, 0) - m.forwardingEnabled = false -} - -// Create a mock session for testing -func createMockSession() (*session.Session, *MockFirewall) { - // Create interface - iface := &network.Endpoint{ - IpAddress: "192.168.1.100", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "eth0", - } - iface.SetIP("192.168.1.100") - iface.SetBits(24) - - // Parse interface addresses - ifaceIP := net.ParseIP("192.168.1.100") - ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - iface.IP = ifaceIP - iface.HW = ifaceHW - - // Create gateway - gateway := &network.Endpoint{ - IpAddress: "192.168.1.1", - HwAddress: "11:22:33:44:55:66", - } - gatewayIP := net.ParseIP("192.168.1.1") - gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66") - gateway.IP = gatewayIP - gateway.HW = gatewayHW - - // Create mock firewall - mockFirewall := NewMockFirewall() - - // Create environment - env, _ := session.NewEnvironment("") - - // Create LAN - aliases, _ := data.NewUnsortedKV("", 0) - lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - // Create session - sess := &session.Session{ - Interface: iface, - Gateway: gateway, - Lan: lan, - StartedAt: time.Now(), - Active: true, - Env: env, - Queue: &packets.Queue{}, - Firewall: mockFirewall, - Modules: make(session.ModuleList, 0), - } - - // Initialize events - sess.Events = session.NewEventPool(false, false) - - return sess, mockFirewall -} - -func TestNewHttpProxy(t *testing.T) { - sess, _ := createMockSession() - - mod := NewHttpProxy(sess) - - if mod == nil { - t.Fatal("NewHttpProxy returned nil") - } - - if mod.Name() != "http.proxy" { - t.Errorf("expected module name 'http.proxy', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("unexpected author: %s", mod.Author()) - } - - // Check parameters - params := []string{ - "http.port", - "http.proxy.address", - "http.proxy.port", - "http.proxy.redirect", - "http.proxy.script", - "http.proxy.injectjs", - "http.proxy.blacklist", - "http.proxy.whitelist", - "http.proxy.sslstrip", - } - for _, param := range params { - if !mod.Session.Env.Has(param) { - t.Errorf("parameter %s not registered", param) - } - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{"http.proxy on", "http.proxy off"} - handlerMap := make(map[string]bool) - - for _, h := range handlers { - handlerMap[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerMap[expected] { - t.Errorf("Expected handler '%s' not found", expected) - } - } -} - -func TestHttpProxyConfigure(t *testing.T) { - tests := []struct { - name string - params map[string]string - expectErr bool - validate func(*HttpProxy) error - }{ - { - name: "default configuration", - params: map[string]string{ - "http.port": "80", - "http.proxy.address": "192.168.1.100", - "http.proxy.port": "8080", - "http.proxy.redirect": "true", - "http.proxy.script": "", - "http.proxy.injectjs": "", - "http.proxy.blacklist": "", - "http.proxy.whitelist": "", - "http.proxy.sslstrip": "false", - }, - expectErr: false, - validate: func(mod *HttpProxy) error { - if mod.proxy == nil { - return fmt.Errorf("proxy not initialized") - } - if mod.proxy.Address != "192.168.1.100" { - return fmt.Errorf("expected address 192.168.1.100, got %s", mod.proxy.Address) - } - if !mod.proxy.doRedirect { - return fmt.Errorf("expected redirect to be true") - } - if mod.proxy.Stripper == nil { - return fmt.Errorf("SSL stripper not initialized") - } - if mod.proxy.Stripper.Enabled() { - return fmt.Errorf("SSL stripper should be disabled") - } - return nil - }, - }, - // Note: SSL stripping test removed as it requires elevated permissions - // to create network capture handles - { - name: "with blacklist and whitelist", - params: map[string]string{ - "http.port": "80", - "http.proxy.address": "192.168.1.100", - "http.proxy.port": "8080", - "http.proxy.redirect": "false", - "http.proxy.script": "", - "http.proxy.injectjs": "", - "http.proxy.blacklist": "*.evil.com,bad.site.org", - "http.proxy.whitelist": "*.good.com,safe.site.org", - "http.proxy.sslstrip": "false", - }, - expectErr: false, - validate: func(mod *HttpProxy) error { - if len(mod.proxy.Blacklist) != 2 { - return fmt.Errorf("expected 2 blacklist entries, got %d", len(mod.proxy.Blacklist)) - } - if len(mod.proxy.Whitelist) != 2 { - return fmt.Errorf("expected 2 whitelist entries, got %d", len(mod.proxy.Whitelist)) - } - if mod.proxy.doRedirect { - return fmt.Errorf("expected redirect to be false") - } - return nil - }, - }, - { - name: "JavaScript injection with inline code", - params: map[string]string{ - "http.port": "80", - "http.proxy.address": "192.168.1.100", - "http.proxy.port": "8080", - "http.proxy.redirect": "true", - "http.proxy.script": "", - "http.proxy.injectjs": "alert('injected');", - "http.proxy.blacklist": "", - "http.proxy.whitelist": "", - "http.proxy.sslstrip": "false", - }, - expectErr: false, - validate: func(mod *HttpProxy) error { - if mod.proxy.jsHook == "" { - return fmt.Errorf("jsHook should be set") - } - if !strings.Contains(mod.proxy.jsHook, "alert('injected');") { - return fmt.Errorf("jsHook should contain injected code") - } - return nil - }, - }, - { - name: "JavaScript injection with URL", - params: map[string]string{ - "http.port": "80", - "http.proxy.address": "192.168.1.100", - "http.proxy.port": "8080", - "http.proxy.redirect": "true", - "http.proxy.script": "", - "http.proxy.injectjs": "http://evil.com/hook.js", - "http.proxy.blacklist": "", - "http.proxy.whitelist": "", - "http.proxy.sslstrip": "false", - }, - expectErr: false, - validate: func(mod *HttpProxy) error { - if mod.proxy.jsHook == "" { - return fmt.Errorf("jsHook should be set") - } - if !strings.Contains(mod.proxy.jsHook, "http://evil.com/hook.js") { - return fmt.Errorf("jsHook should contain script URL") - } - return nil - }, - }, - { - name: "invalid address", - params: map[string]string{ - "http.port": "80", - "http.proxy.address": "invalid-address", - "http.proxy.port": "8080", - "http.proxy.redirect": "true", - "http.proxy.script": "", - "http.proxy.injectjs": "", - "http.proxy.blacklist": "", - "http.proxy.whitelist": "", - "http.proxy.sslstrip": "false", - }, - expectErr: true, - }, - { - name: "invalid port", - params: map[string]string{ - "http.port": "80", - "http.proxy.address": "192.168.1.100", - "http.proxy.port": "invalid-port", - "http.proxy.redirect": "true", - "http.proxy.script": "", - "http.proxy.injectjs": "", - "http.proxy.blacklist": "", - "http.proxy.whitelist": "", - "http.proxy.sslstrip": "false", - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sess, _ := createMockSession() - mod := NewHttpProxy(sess) - - // Set parameters - for k, v := range tt.params { - sess.Env.Set(k, v) - } - - err := mod.Configure() - - if tt.expectErr && err == nil { - t.Error("expected error but got none") - } else if !tt.expectErr && err != nil { - t.Errorf("unexpected error: %v", err) - } - - if !tt.expectErr && tt.validate != nil { - if err := tt.validate(mod); err != nil { - t.Error(err) - } - } - }) - } -} - -func TestHttpProxyStartStop(t *testing.T) { - sess, mockFirewall := createMockSession() - mod := NewHttpProxy(sess) - - // Configure with test parameters - sess.Env.Set("http.port", "80") - sess.Env.Set("http.proxy.address", "127.0.0.1") - sess.Env.Set("http.proxy.port", "0") // Use port 0 to get a random available port - sess.Env.Set("http.proxy.redirect", "true") - sess.Env.Set("http.proxy.sslstrip", "false") - - // Start the proxy - err := mod.Start() - if err != nil { - t.Fatalf("Failed to start proxy: %v", err) - } - - if !mod.Running() { - t.Error("Proxy should be running after Start()") - } - - // Check that forwarding was enabled - if !mockFirewall.IsForwardingEnabled() { - t.Error("Forwarding should be enabled after starting proxy") - } - - // Check that redirection was added - if len(mockFirewall.redirections) != 1 { - t.Errorf("Expected 1 redirection, got %d", len(mockFirewall.redirections)) - } - - // Give the server time to start - time.Sleep(100 * time.Millisecond) - - // Stop the proxy - err = mod.Stop() - if err != nil { - t.Fatalf("Failed to stop proxy: %v", err) - } - - if mod.Running() { - t.Error("Proxy should not be running after Stop()") - } - - // Check that redirection was removed - if len(mockFirewall.redirections) != 0 { - t.Errorf("Expected 0 redirections after stop, got %d", len(mockFirewall.redirections)) - } -} - -func TestHttpProxyAlreadyStarted(t *testing.T) { - sess, _ := createMockSession() - mod := NewHttpProxy(sess) - - // Configure - sess.Env.Set("http.port", "80") - sess.Env.Set("http.proxy.address", "127.0.0.1") - sess.Env.Set("http.proxy.port", "0") - sess.Env.Set("http.proxy.redirect", "false") - - // Start the proxy - err := mod.Start() - if err != nil { - t.Fatalf("Failed to start proxy: %v", err) - } - - // Try to configure while running - err = mod.Configure() - if err == nil { - t.Error("Configure should fail when proxy is already running") - } - - // Stop the proxy - mod.Stop() -} - -func TestHTTPProxyDoProxy(t *testing.T) { - sess, _ := createMockSession() - proxy := NewHTTPProxy(sess, "test") - - tests := []struct { - name string - request *http.Request - expected bool - }{ - { - name: "valid request", - request: &http.Request{ - Host: "example.com", - }, - expected: true, - }, - { - name: "empty host", - request: &http.Request{ - Host: "", - }, - expected: false, - }, - { - name: "localhost request", - request: &http.Request{ - Host: "localhost:8080", - }, - expected: false, - }, - { - name: "127.0.0.1 request", - request: &http.Request{ - Host: "127.0.0.1:8080", - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := proxy.doProxy(tt.request) - if result != tt.expected { - t.Errorf("doProxy(%v) = %v, expected %v", tt.request.Host, result, tt.expected) - } - }) - } -} - -func TestHTTPProxyShouldProxy(t *testing.T) { - sess, _ := createMockSession() - proxy := NewHTTPProxy(sess, "test") - - tests := []struct { - name string - blacklist []string - whitelist []string - host string - expected bool - }{ - { - name: "no filters", - blacklist: []string{}, - whitelist: []string{}, - host: "example.com", - expected: true, - }, - { - name: "blacklisted exact match", - blacklist: []string{"evil.com"}, - whitelist: []string{}, - host: "evil.com", - expected: false, - }, - { - name: "blacklisted wildcard match", - blacklist: []string{"*.evil.com"}, - whitelist: []string{}, - host: "sub.evil.com", - expected: false, - }, - { - name: "whitelisted exact match", - blacklist: []string{"*"}, - whitelist: []string{"good.com"}, - host: "good.com", - expected: true, - }, - { - name: "not blacklisted", - blacklist: []string{"evil.com"}, - whitelist: []string{}, - host: "good.com", - expected: true, - }, - { - name: "whitelist takes precedence", - blacklist: []string{"*"}, - whitelist: []string{"good.com"}, - host: "good.com", - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - proxy.Blacklist = tt.blacklist - proxy.Whitelist = tt.whitelist - - req := &http.Request{ - Host: tt.host, - } - - result := proxy.shouldProxy(req) - if result != tt.expected { - t.Errorf("shouldProxy(%v) = %v, expected %v", tt.host, result, tt.expected) - } - }) - } -} - -func TestHTTPProxyStripPort(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {"example.com:8080", "example.com"}, - {"example.com", "example.com"}, - {"192.168.1.1:443", "192.168.1.1"}, - {"[::1]:8080", "["}, // stripPort splits on first colon, so IPv6 addresses don't work correctly - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - result := stripPort(tt.input) - if result != tt.expected { - t.Errorf("stripPort(%s) = %s, expected %s", tt.input, result, tt.expected) - } - }) - } -} - -func TestHTTPProxyJavaScriptInjection(t *testing.T) { - sess, _ := createMockSession() - proxy := NewHTTPProxy(sess, "test") - - tests := []struct { - name string - jsToInject string - expectedHook string - }{ - { - name: "inline JavaScript", - jsToInject: "console.log('test');", - expectedHook: ``, - }, - { - name: "script tag", - jsToInject: ``, - expectedHook: ``, // script tags get wrapped - }, - { - name: "external URL", - jsToInject: "http://example.com/script.js", - expectedHook: ``, - }, - { - name: "HTTPS URL", - jsToInject: "https://example.com/script.js", - expectedHook: ``, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Skip test with invalid filename characters on Windows - if runtime.GOOS == "windows" && strings.ContainsAny(tt.jsToInject, "<>:\"|?*") { - t.Skip("Skipping test with invalid filename characters on Windows") - } - - err := proxy.Configure("127.0.0.1", 8080, 80, false, "", tt.jsToInject, false) - if err != nil { - t.Fatalf("Configure failed: %v", err) - } - - if proxy.jsHook != tt.expectedHook { - t.Errorf("jsHook = %q, expected %q", proxy.jsHook, tt.expectedHook) - } - }) - } -} - -func TestHTTPProxyWithTestServer(t *testing.T) { - // Create a test HTTP server - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("Test Page")) - })) - defer testServer.Close() - - sess, _ := createMockSession() - proxy := NewHTTPProxy(sess, "test") - - // Configure proxy with JS injection - err := proxy.Configure("127.0.0.1", 0, 80, false, "", "console.log('injected');", false) - if err != nil { - t.Fatalf("Configure failed: %v", err) - } - - // Create a simple test to verify proxy is initialized - if proxy.Proxy == nil { - t.Error("Proxy not initialized") - } - - if proxy.jsHook == "" { - t.Error("JavaScript hook not set") - } - - // Note: Testing actual proxy behavior would require setting up the proxy server - // and making HTTP requests through it, which is complex in a unit test environment -} - -func TestHTTPProxyScriptLoading(t *testing.T) { - sess, _ := createMockSession() - proxy := NewHTTPProxy(sess, "test") - - // Create a temporary script file - scriptContent := ` -function onRequest(req, res) { - console.log("Request intercepted"); -} -` - tmpFile, err := ioutil.TempFile("", "proxy_script_*.js") - if err != nil { - t.Fatalf("Failed to create temp file: %v", err) - } - defer os.Remove(tmpFile.Name()) - - if _, err := tmpFile.Write([]byte(scriptContent)); err != nil { - t.Fatalf("Failed to write script: %v", err) - } - tmpFile.Close() - - // Try to configure with non-existent script - err = proxy.Configure("127.0.0.1", 8080, 80, false, "non_existent_script.js", "", false) - if err == nil { - t.Error("Configure should fail with non-existent script") - } - - // Note: Actual script loading would require proper JS engine setup - // which is complex to mock. This test verifies the error handling. -} - -// Benchmarks -func BenchmarkHTTPProxyShouldProxy(b *testing.B) { - sess, _ := createMockSession() - proxy := NewHTTPProxy(sess, "test") - - proxy.Blacklist = []string{"*.evil.com", "bad.site.org", "*.malicious.net"} - proxy.Whitelist = []string{"*.good.com", "safe.site.org"} - - req := &http.Request{ - Host: "example.com", - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = proxy.shouldProxy(req) - } -} - -func BenchmarkHTTPProxyStripPort(b *testing.B) { - testHost := "example.com:8080" - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = stripPort(testHost) - } -} diff --git a/modules/http_server/http_server.go b/modules/http_server/http_server.go index da309d3d..25cd7802 100644 --- a/modules/http_server/http_server.go +++ b/modules/http_server/http_server.go @@ -31,20 +31,20 @@ func NewHttpServer(s *session.Session) *HttpServer { mod.AddParam(session.NewStringParameter("http.server.address", session.ParamIfaceAddress, session.IPv4Validator, - "Address to bind the HTTP server to.")) + "Address to bind the http server to.")) mod.AddParam(session.NewIntParameter("http.server.port", "80", - "Port to bind the HTTP server to.")) + "Port to bind the http server to.")) mod.AddHandler(session.NewModuleHandler("http.server on", "", - "Start HTTP server.", + "Start httpd server.", func(args []string) error { return mod.Start() })) mod.AddHandler(session.NewModuleHandler("http.server off", "", - "Stop HTTP server.", + "Stop httpd server.", func(args []string) error { return mod.Stop() })) diff --git a/modules/https_server/https_server.go b/modules/https_server/https_server.go index 2f3fd0a6..8e547fa7 100644 --- a/modules/https_server/https_server.go +++ b/modules/https_server/https_server.go @@ -35,11 +35,11 @@ func NewHttpsServer(s *session.Session) *HttpsServer { mod.AddParam(session.NewStringParameter("https.server.address", session.ParamIfaceAddress, session.IPv4Validator, - "Address to bind the HTTPS server to.")) + "Address to bind the http server to.")) mod.AddParam(session.NewIntParameter("https.server.port", "443", - "Port to bind the HTTPS server to.")) + "Port to bind the http server to.")) mod.AddParam(session.NewStringParameter("https.server.certificate", "~/.bettercap-httpd.cert.pem", @@ -54,13 +54,13 @@ func NewHttpsServer(s *session.Session) *HttpsServer { tls.CertConfigToModule("https.server", &mod.SessionModule, tls.DefaultLegitConfig) mod.AddHandler(session.NewModuleHandler("https.server on", "", - "Start HTTPS server.", + "Start https server.", func(args []string) error { return mod.Start() })) mod.AddHandler(session.NewModuleHandler("https.server off", "", - "Stop HTTPS server.", + "Stop https server.", func(args []string) error { return mod.Stop() })) diff --git a/modules/modules_test.go b/modules/modules_test.go deleted file mode 100644 index 3cde11cd..00000000 --- a/modules/modules_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package modules - -import ( - "testing" -) - -func TestLoadModulesWithNilSession(t *testing.T) { - // This test verifies that LoadModules handles nil session gracefully - // In the actual implementation, this would panic, which is expected behavior - defer func() { - if r := recover(); r == nil { - t.Error("expected panic when loading modules with nil session, but didn't get one") - } - }() - - LoadModules(nil) -} - -// Since LoadModules requires a fully initialized session with command-line flags, -// which conflicts with the test runner, we can't easily test the actual module loading. -// The main functionality is tested through integration tests and the actual application. -// This test file at least provides some coverage for the package and demonstrates -// the expected behavior with invalid input. diff --git a/modules/net_probe/net_probe_test.go b/modules/net_probe/net_probe_test.go deleted file mode 100644 index 7013dd23..00000000 --- a/modules/net_probe/net_probe_test.go +++ /dev/null @@ -1,610 +0,0 @@ -package net_probe - -import ( - "fmt" - "net" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/network" - "github.com/bettercap/bettercap/v2/packets" - "github.com/bettercap/bettercap/v2/session" - "github.com/malfunkt/iprange" -) - -// MockQueue implements a mock packet queue for testing -type MockQueue struct { - sync.Mutex - sentPackets [][]byte - sendError error - active bool -} - -func NewMockQueue() *MockQueue { - return &MockQueue{ - sentPackets: make([][]byte, 0), - active: true, - } -} - -func (m *MockQueue) Send(data []byte) error { - m.Lock() - defer m.Unlock() - - if m.sendError != nil { - return m.sendError - } - - // Store a copy of the packet - packet := make([]byte, len(data)) - copy(packet, data) - m.sentPackets = append(m.sentPackets, packet) - return nil -} - -func (m *MockQueue) GetSentPackets() [][]byte { - m.Lock() - defer m.Unlock() - return m.sentPackets -} - -func (m *MockQueue) ClearSentPackets() { - m.Lock() - defer m.Unlock() - m.sentPackets = make([][]byte, 0) -} - -func (m *MockQueue) Stop() { - m.Lock() - defer m.Unlock() - m.active = false -} - -// MockSession for testing -type MockSession struct { - *session.Session - runCommands []string - skipIPs map[string]bool -} - -func (m *MockSession) Run(cmd string) error { - m.runCommands = append(m.runCommands, cmd) - - // Handle module commands - if cmd == "net.recon on" { - // Find and start the net.recon module - for _, mod := range m.Modules { - if mod.Name() == "net.recon" { - if !mod.Running() { - return mod.Start() - } - return nil - } - } - } else if cmd == "net.recon off" { - // Find and stop the net.recon module - for _, mod := range m.Modules { - if mod.Name() == "net.recon" { - if mod.Running() { - return mod.Stop() - } - return nil - } - } - } else if cmd == "zerogod.discovery on" || cmd == "zerogod.discovery off" { - // Mock zerogod.discovery commands - return nil - } - - return nil -} - -func (m *MockSession) Skip(ip net.IP) bool { - if m.skipIPs == nil { - return false - } - return m.skipIPs[ip.String()] -} - -// MockNetRecon implements a minimal net.recon module for testing -type MockNetRecon struct { - session.SessionModule -} - -func NewMockNetRecon(s *session.Session) *MockNetRecon { - mod := &MockNetRecon{ - SessionModule: session.NewSessionModule("net.recon", s), - } - - // Add handlers so the module can be started/stopped via commands - mod.AddHandler(session.NewModuleHandler("net.recon on", "", - "Start net.recon", - func(args []string) error { - return mod.Start() - })) - - mod.AddHandler(session.NewModuleHandler("net.recon off", "", - "Stop net.recon", - func(args []string) error { - return mod.Stop() - })) - - return mod -} - -func (m *MockNetRecon) Name() string { - return "net.recon" -} - -func (m *MockNetRecon) Description() string { - return "Mock net.recon module" -} - -func (m *MockNetRecon) Author() string { - return "test" -} - -func (m *MockNetRecon) Configure() error { - return nil -} - -func (m *MockNetRecon) Start() error { - return m.SetRunning(true, nil) -} - -func (m *MockNetRecon) Stop() error { - return m.SetRunning(false, nil) -} - -// Create a mock session for testing -func createMockSession() (*MockSession, *MockQueue) { - // Create interface - iface := &network.Endpoint{ - IpAddress: "192.168.1.100", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "eth0", - } - iface.SetIP("192.168.1.100") - iface.SetBits(24) - - // Parse interface addresses - ifaceIP := net.ParseIP("192.168.1.100") - ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - iface.IP = ifaceIP - iface.HW = ifaceHW - - // Create gateway - gateway := &network.Endpoint{ - IpAddress: "192.168.1.1", - HwAddress: "11:22:33:44:55:66", - } - - // Create mock queue - mockQueue := NewMockQueue() - - // Create environment - env, _ := session.NewEnvironment("") - - // Create session - sess := &session.Session{ - Interface: iface, - Gateway: gateway, - StartedAt: time.Now(), - Active: true, - Env: env, - Queue: &packets.Queue{ - Traffic: sync.Map{}, - Stats: packets.Stats{}, - }, - Modules: make(session.ModuleList, 0), - } - - // Initialize events - sess.Events = session.NewEventPool(false, false) - - // Add mock net.recon module - mockNetRecon := NewMockNetRecon(sess) - sess.Modules = append(sess.Modules, mockNetRecon) - - // Create mock session wrapper - mockSess := &MockSession{ - Session: sess, - runCommands: make([]string, 0), - skipIPs: make(map[string]bool), - } - - return mockSess, mockQueue -} - -func TestNewProber(t *testing.T) { - mockSess, _ := createMockSession() - - mod := NewProber(mockSess.Session) - - if mod == nil { - t.Fatal("NewProber returned nil") - } - - if mod.Name() != "net.probe" { - t.Errorf("expected module name 'net.probe', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("unexpected author: %s", mod.Author()) - } - - // Check parameters - params := []string{"net.probe.nbns", "net.probe.mdns", "net.probe.upnp", "net.probe.wsd", "net.probe.throttle"} - for _, param := range params { - if !mod.Session.Env.Has(param) { - t.Errorf("parameter %s not registered", param) - } - } -} - -func TestProberConfigure(t *testing.T) { - tests := []struct { - name string - params map[string]string - expectErr bool - expected struct { - throttle int - nbns bool - mdns bool - upnp bool - wsd bool - } - }{ - { - name: "default configuration", - params: map[string]string{ - "net.probe.throttle": "10", - "net.probe.nbns": "true", - "net.probe.mdns": "true", - "net.probe.upnp": "true", - "net.probe.wsd": "true", - }, - expectErr: false, - expected: struct { - throttle int - nbns bool - mdns bool - upnp bool - wsd bool - }{10, true, true, true, true}, - }, - { - name: "disabled probes", - params: map[string]string{ - "net.probe.throttle": "5", - "net.probe.nbns": "false", - "net.probe.mdns": "false", - "net.probe.upnp": "false", - "net.probe.wsd": "false", - }, - expectErr: false, - expected: struct { - throttle int - nbns bool - mdns bool - upnp bool - wsd bool - }{5, false, false, false, false}, - }, - { - name: "invalid throttle", - params: map[string]string{ - "net.probe.throttle": "invalid", - "net.probe.nbns": "true", - "net.probe.mdns": "true", - "net.probe.upnp": "true", - "net.probe.wsd": "true", - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSess, _ := createMockSession() - mod := NewProber(mockSess.Session) - - // Set parameters - for k, v := range tt.params { - mockSess.Env.Set(k, v) - } - - err := mod.Configure() - - if tt.expectErr && err == nil { - t.Error("expected error but got none") - } else if !tt.expectErr && err != nil { - t.Errorf("unexpected error: %v", err) - } - - if !tt.expectErr { - if mod.throttle != tt.expected.throttle { - t.Errorf("expected throttle %d, got %d", tt.expected.throttle, mod.throttle) - } - if mod.probes.NBNS != tt.expected.nbns { - t.Errorf("expected NBNS %v, got %v", tt.expected.nbns, mod.probes.NBNS) - } - if mod.probes.MDNS != tt.expected.mdns { - t.Errorf("expected MDNS %v, got %v", tt.expected.mdns, mod.probes.MDNS) - } - if mod.probes.UPNP != tt.expected.upnp { - t.Errorf("expected UPNP %v, got %v", tt.expected.upnp, mod.probes.UPNP) - } - if mod.probes.WSD != tt.expected.wsd { - t.Errorf("expected WSD %v, got %v", tt.expected.wsd, mod.probes.WSD) - } - } - }) - } -} - -// MockProber wraps Prober to allow mocking probe methods -type MockProber struct { - *Prober - nbnsCount *int32 - upnpCount *int32 - wsdCount *int32 - mockQueue *MockQueue -} - -func (m *MockProber) sendProbeNBNS(from net.IP, from_hw net.HardwareAddr, to net.IP) { - atomic.AddInt32(m.nbnsCount, 1) - m.mockQueue.Send([]byte(fmt.Sprintf("NBNS probe to %s", to))) -} - -func (m *MockProber) sendProbeUPNP(from net.IP, from_hw net.HardwareAddr) { - atomic.AddInt32(m.upnpCount, 1) - m.mockQueue.Send([]byte("UPNP probe")) -} - -func (m *MockProber) sendProbeWSD(from net.IP, from_hw net.HardwareAddr) { - atomic.AddInt32(m.wsdCount, 1) - m.mockQueue.Send([]byte("WSD probe")) -} - -func TestProberStartStop(t *testing.T) { - mockSess, _ := createMockSession() - mod := NewProber(mockSess.Session) - - // Configure with fast throttle for testing - mockSess.Env.Set("net.probe.throttle", "1") - mockSess.Env.Set("net.probe.nbns", "true") - mockSess.Env.Set("net.probe.mdns", "true") - mockSess.Env.Set("net.probe.upnp", "true") - mockSess.Env.Set("net.probe.wsd", "true") - - // Start the prober - err := mod.Start() - if err != nil { - t.Fatalf("Failed to start prober: %v", err) - } - - if !mod.Running() { - t.Error("Prober should be running after Start()") - } - - // Give it a moment to initialize - time.Sleep(50 * time.Millisecond) - - // Stop the prober - err = mod.Stop() - if err != nil { - t.Fatalf("Failed to stop prober: %v", err) - } - - if mod.Running() { - t.Error("Prober should not be running after Stop()") - } - - // Since we can't easily mock the probe methods, we'll verify the module's state - // and trust that the actual probe sending is tested in integration tests -} - -func TestProberMonitorMode(t *testing.T) { - mockSess, _ := createMockSession() - mod := NewProber(mockSess.Session) - - // Set interface to monitor mode - mockSess.Interface.IpAddress = network.MonitorModeAddress - - // Start the prober - err := mod.Start() - if err != nil { - t.Fatalf("Failed to start prober: %v", err) - } - - // Give it time to potentially start probing - time.Sleep(50 * time.Millisecond) - - // Stop the prober - mod.Stop() - - // In monitor mode, the prober should exit early without doing any work - // We can't easily verify no probes were sent without mocking network calls, - // but we can verify the module starts and stops correctly -} - -func TestProberHandlers(t *testing.T) { - mockSess, _ := createMockSession() - mod := NewProber(mockSess.Session) - - // Test handlers - handlers := mod.Handlers() - - expectedHandlers := []string{"net.probe on", "net.probe off"} - handlerMap := make(map[string]bool) - - for _, h := range handlers { - handlerMap[h.Name] = true - } - - for _, expected := range expectedHandlers { - if !handlerMap[expected] { - t.Errorf("Expected handler '%s' not found", expected) - } - } - - // Test handler execution - for _, h := range handlers { - if h.Name == "net.probe on" { - // Should start the module - err := h.Exec([]string{}) - if err != nil { - t.Errorf("Handler 'net.probe on' failed: %v", err) - } - if !mod.Running() { - t.Error("Module should be running after 'net.probe on'") - } - mod.Stop() - } else if h.Name == "net.probe off" { - // Start first, then stop - mod.Start() - err := h.Exec([]string{}) - if err != nil { - t.Errorf("Handler 'net.probe off' failed: %v", err) - } - if mod.Running() { - t.Error("Module should not be running after 'net.probe off'") - } - } - } -} - -func TestProberSelectiveProbes(t *testing.T) { - tests := []struct { - name string - enabledProbes map[string]bool - }{ - { - name: "only NBNS", - enabledProbes: map[string]bool{ - "nbns": true, - "mdns": false, - "upnp": false, - "wsd": false, - }, - }, - { - name: "only UPNP and WSD", - enabledProbes: map[string]bool{ - "nbns": false, - "mdns": false, - "upnp": true, - "wsd": true, - }, - }, - { - name: "all probes enabled", - enabledProbes: map[string]bool{ - "nbns": true, - "mdns": true, - "upnp": true, - "wsd": true, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSess, _ := createMockSession() - mod := NewProber(mockSess.Session) - - // Configure probes - mockSess.Env.Set("net.probe.throttle", "10") - mockSess.Env.Set("net.probe.nbns", fmt.Sprintf("%v", tt.enabledProbes["nbns"])) - mockSess.Env.Set("net.probe.mdns", fmt.Sprintf("%v", tt.enabledProbes["mdns"])) - mockSess.Env.Set("net.probe.upnp", fmt.Sprintf("%v", tt.enabledProbes["upnp"])) - mockSess.Env.Set("net.probe.wsd", fmt.Sprintf("%v", tt.enabledProbes["wsd"])) - - // Configure and verify the settings - err := mod.Configure() - if err != nil { - t.Fatalf("Failed to configure: %v", err) - } - - // Verify configuration - if mod.probes.NBNS != tt.enabledProbes["nbns"] { - t.Errorf("NBNS probe setting mismatch: expected %v, got %v", - tt.enabledProbes["nbns"], mod.probes.NBNS) - } - if mod.probes.MDNS != tt.enabledProbes["mdns"] { - t.Errorf("MDNS probe setting mismatch: expected %v, got %v", - tt.enabledProbes["mdns"], mod.probes.MDNS) - } - if mod.probes.UPNP != tt.enabledProbes["upnp"] { - t.Errorf("UPNP probe setting mismatch: expected %v, got %v", - tt.enabledProbes["upnp"], mod.probes.UPNP) - } - if mod.probes.WSD != tt.enabledProbes["wsd"] { - t.Errorf("WSD probe setting mismatch: expected %v, got %v", - tt.enabledProbes["wsd"], mod.probes.WSD) - } - }) - } -} - -func TestIPRangeExpansion(t *testing.T) { - // Test that we correctly iterate through the subnet - cidr := "192.168.1.0/30" // Small subnet for testing - list, err := iprange.Parse(cidr) - if err != nil { - t.Fatalf("Failed to parse CIDR: %v", err) - } - - addresses := list.Expand() - - // For /30, we should get 4 addresses - expectedAddresses := []string{ - "192.168.1.0", - "192.168.1.1", - "192.168.1.2", - "192.168.1.3", - } - - if len(addresses) != len(expectedAddresses) { - t.Errorf("Expected %d addresses, got %d", len(expectedAddresses), len(addresses)) - } - - for i, addr := range addresses { - if addr.String() != expectedAddresses[i] { - t.Errorf("Expected address %s at position %d, got %s", expectedAddresses[i], i, addr.String()) - } - } -} - -// Benchmarks -func BenchmarkProberConfiguration(b *testing.B) { - mockSess, _ := createMockSession() - - // Set up parameters - mockSess.Env.Set("net.probe.throttle", "10") - mockSess.Env.Set("net.probe.nbns", "true") - mockSess.Env.Set("net.probe.mdns", "true") - mockSess.Env.Set("net.probe.upnp", "true") - mockSess.Env.Set("net.probe.wsd", "true") - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mod := NewProber(mockSess.Session) - mod.Configure() - } -} - -func BenchmarkIPRangeExpansion(b *testing.B) { - cidr := "192.168.1.0/24" - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - list, _ := iprange.Parse(cidr) - _ = list.Expand() - } -} diff --git a/modules/net_recon/net_recon_test.go b/modules/net_recon/net_recon_test.go deleted file mode 100644 index 93459666..00000000 --- a/modules/net_recon/net_recon_test.go +++ /dev/null @@ -1,644 +0,0 @@ -package net_recon - -import ( - "fmt" - "sync" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/modules/utils" - "github.com/bettercap/bettercap/v2/network" - "github.com/bettercap/bettercap/v2/packets" - "github.com/bettercap/bettercap/v2/session" - "github.com/evilsocket/islazy/data" -) - -// Mock ArpUpdate function -var mockArpUpdateFunc func(string) (network.ArpTable, error) - -// Override the network.ArpUpdate function for testing -func mockArpUpdate(iface string) (network.ArpTable, error) { - if mockArpUpdateFunc != nil { - return mockArpUpdateFunc(iface) - } - return make(network.ArpTable), nil -} - -// MockLAN implements a mock version of the LAN interface -type MockLAN struct { - sync.RWMutex - hosts map[string]*network.Endpoint - wasMissed map[string]bool - addedHosts []string - removedHosts []string -} - -func NewMockLAN() *MockLAN { - return &MockLAN{ - hosts: make(map[string]*network.Endpoint), - wasMissed: make(map[string]bool), - addedHosts: []string{}, - removedHosts: []string{}, - } -} - -func (m *MockLAN) AddIfNew(ip, mac string) { - m.Lock() - defer m.Unlock() - - if _, exists := m.hosts[mac]; !exists { - m.hosts[mac] = &network.Endpoint{ - IpAddress: ip, - HwAddress: mac, - FirstSeen: time.Now(), - LastSeen: time.Now(), - } - m.addedHosts = append(m.addedHosts, mac) - } -} - -func (m *MockLAN) Remove(ip, mac string) { - m.Lock() - defer m.Unlock() - - if _, exists := m.hosts[mac]; exists { - delete(m.hosts, mac) - m.removedHosts = append(m.removedHosts, mac) - } -} - -func (m *MockLAN) Clear() { - m.Lock() - defer m.Unlock() - - m.hosts = make(map[string]*network.Endpoint) - m.wasMissed = make(map[string]bool) - m.addedHosts = []string{} - m.removedHosts = []string{} -} - -func (m *MockLAN) EachHost(cb func(mac string, e *network.Endpoint)) { - m.RLock() - defer m.RUnlock() - - for mac, host := range m.hosts { - cb(mac, host) - } -} - -func (m *MockLAN) List() []*network.Endpoint { - m.RLock() - defer m.RUnlock() - - list := make([]*network.Endpoint, 0, len(m.hosts)) - for _, host := range m.hosts { - list = append(list, host) - } - return list -} - -func (m *MockLAN) WasMissed(mac string) bool { - m.RLock() - defer m.RUnlock() - - return m.wasMissed[mac] -} - -func (m *MockLAN) Get(mac string) *network.Endpoint { - m.RLock() - defer m.RUnlock() - - return m.hosts[mac] -} - -// Create a mock session for testing -func createMockSession() *session.Session { - iface := &network.Endpoint{ - IpAddress: "192.168.1.100", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "eth0", - } - iface.SetIP("192.168.1.100") - iface.SetBits(24) - - gateway := &network.Endpoint{ - IpAddress: "192.168.1.1", - HwAddress: "11:22:33:44:55:66", - } - - // Create environment - env, _ := session.NewEnvironment("") - - sess := &session.Session{ - Interface: iface, - Gateway: gateway, - StartedAt: time.Now(), - Active: true, - Env: env, - Queue: &packets.Queue{ - Traffic: sync.Map{}, - Stats: packets.Stats{}, - }, - Modules: make(session.ModuleList, 0), - } - - // Initialize the Events field with a mock EventPool - sess.Events = session.NewEventPool(false, false) - - return sess -} - -func TestNewDiscovery(t *testing.T) { - sess := createMockSession() - mod := NewDiscovery(sess) - - if mod == nil { - t.Fatal("NewDiscovery returned nil") - } - - if mod.Name() != "net.recon" { - t.Errorf("expected module name 'net.recon', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("unexpected author: %s", mod.Author()) - } - - if mod.selector == nil { - t.Error("selector should be initialized") - } -} - -func TestRunDiff(t *testing.T) { - // Test the basic diff functionality with a simpler approach - tests := []struct { - name string - initialHosts map[string]string // IP -> MAC - arpTable network.ArpTable - expectedAdded []string - expectedRemoved []string - }{ - { - name: "no changes", - initialHosts: map[string]string{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - "192.168.1.20": "bb:bb:bb:bb:bb:bb", - }, - arpTable: network.ArpTable{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - "192.168.1.20": "bb:bb:bb:bb:bb:bb", - }, - expectedAdded: []string{}, - expectedRemoved: []string{}, - }, - { - name: "new host discovered", - initialHosts: map[string]string{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - }, - arpTable: network.ArpTable{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - "192.168.1.20": "bb:bb:bb:bb:bb:bb", - }, - expectedAdded: []string{"bb:bb:bb:bb:bb:bb"}, - expectedRemoved: []string{}, - }, - { - name: "host disappeared", - initialHosts: map[string]string{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - "192.168.1.20": "bb:bb:bb:bb:bb:bb", - }, - arpTable: network.ArpTable{ - "192.168.1.10": "aa:aa:aa:aa:aa:aa", - }, - expectedAdded: []string{}, - expectedRemoved: []string{"bb:bb:bb:bb:bb:bb"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sess := createMockSession() - - // Track callbacks - addedHosts := []string{} - removedHosts := []string{} - - newCb := func(e *network.Endpoint) { - addedHosts = append(addedHosts, e.HwAddress) - } - - lostCb := func(e *network.Endpoint) { - removedHosts = append(removedHosts, e.HwAddress) - } - - aliases, _ := data.NewUnsortedKV("", 0) - sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, newCb, lostCb) - - mod := &Discovery{ - SessionModule: session.NewSessionModule("net.recon", sess), - } - - // Add initial hosts - for ip, mac := range tt.initialHosts { - sess.Lan.AddIfNew(ip, mac) - } - - // Reset tracking - addedHosts = []string{} - removedHosts = []string{} - - // Add interface and gateway to ARP table to avoid them being removed - finalArpTable := make(network.ArpTable) - for k, v := range tt.arpTable { - finalArpTable[k] = v - } - finalArpTable[sess.Interface.IpAddress] = sess.Interface.HwAddress - finalArpTable[sess.Gateway.IpAddress] = sess.Gateway.HwAddress - - // Run the diff multiple times to trigger actual removal (TTL countdown) - for i := 0; i < network.LANDefaultttl+1; i++ { - mod.runDiff(finalArpTable) - } - - // Check results - if len(addedHosts) != len(tt.expectedAdded) { - t.Errorf("expected %d added hosts, got %d. Added: %v", len(tt.expectedAdded), len(addedHosts), addedHosts) - } - - if len(removedHosts) != len(tt.expectedRemoved) { - t.Errorf("expected %d removed hosts, got %d. Removed: %v", len(tt.expectedRemoved), len(removedHosts), removedHosts) - } - }) - } -} - -func TestConfigure(t *testing.T) { - sess := createMockSession() - mod := NewDiscovery(sess) - - err := mod.Configure() - if err != nil { - t.Errorf("Configure() returned error: %v", err) - } -} - -func TestStartStop(t *testing.T) { - sess := createMockSession() - aliases, _ := data.NewUnsortedKV("", 0) - sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - mod := NewDiscovery(sess) - - // Test starting the module - err := mod.Start() - if err != nil { - t.Errorf("Start() returned error: %v", err) - } - - if !mod.Running() { - t.Error("module should be running after Start()") - } - - // Let it run briefly - time.Sleep(100 * time.Millisecond) - - // Test stopping the module - err = mod.Stop() - if err != nil { - t.Errorf("Stop() returned error: %v", err) - } - - if mod.Running() { - t.Error("module should not be running after Stop()") - } -} - -func TestShowMethods(t *testing.T) { - // Skip this test as it requires a full session with readline - t.Skip("Skipping TestShowMethods as it requires readline initialization") -} - -func TestDoSelection(t *testing.T) { - sess := createMockSession() - aliases, _ := data.NewUnsortedKV("", 0) - sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - // Add test endpoints - sess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa") - sess.Lan.AddIfNew("192.168.1.20", "bb:bb:bb:bb:bb:bb") - sess.Lan.AddIfNew("192.168.1.30", "cc:cc:cc:cc:cc:cc") - - // Get endpoints and set additional properties - if e, found := sess.Lan.Get("aa:aa:aa:aa:aa:aa"); found { - e.Hostname = "host1" - e.Vendor = "Vendor1" - } - - if e, found := sess.Lan.Get("bb:bb:bb:bb:bb:bb"); found { - e.Alias = "mydevice" - e.Vendor = "Vendor2" - } - - mod := NewDiscovery(sess) - mod.selector = utils.ViewSelectorFor(&mod.SessionModule, "net.show", - []string{"ip", "mac", "seen", "sent", "rcvd"}, "ip asc") - - tests := []struct { - name string - arg string - expectedCount int - expectedIPs []string - }{ - { - name: "select all", - arg: "", - expectedCount: 3, - }, - { - name: "select by IP", - arg: "192.168.1.10", - expectedCount: 1, - expectedIPs: []string{"192.168.1.10"}, - }, - { - name: "select by MAC", - arg: "aa:aa:aa:aa:aa:aa", - expectedCount: 1, - expectedIPs: []string{"192.168.1.10"}, - }, - { - name: "select multiple by comma", - arg: "192.168.1.10,192.168.1.20", - expectedCount: 2, - expectedIPs: []string{"192.168.1.10", "192.168.1.20"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err, targets := mod.doSelection(tt.arg) - if err != nil { - t.Errorf("doSelection returned error: %v", err) - } - - if len(targets) != tt.expectedCount { - t.Errorf("expected %d targets, got %d", tt.expectedCount, len(targets)) - } - - if tt.expectedIPs != nil { - for _, expectedIP := range tt.expectedIPs { - found := false - for _, target := range targets { - if target.IpAddress == expectedIP { - found = true - break - } - } - if !found { - t.Errorf("expected to find IP %s in targets", expectedIP) - } - } - } - }) - } -} - -func TestHandlers(t *testing.T) { - sess := createMockSession() - aliases, _ := data.NewUnsortedKV("", 0) - sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - mod := NewDiscovery(sess) - - handlers := []struct { - name string - handler string - args []string - setup func() - validate func() error - }{ - { - name: "net.clear", - handler: "net.clear", - args: []string{}, - setup: func() { - sess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa") - }, - validate: func() error { - // Check if hosts were cleared - hosts := sess.Lan.List() - if len(hosts) != 0 { - return fmt.Errorf("expected empty hosts after clear, got %d", len(hosts)) - } - return nil - }, - }, - } - - for _, tt := range handlers { - t.Run(tt.name, func(t *testing.T) { - if tt.setup != nil { - tt.setup() - } - - // Find and execute the handler - found := false - for _, h := range mod.Handlers() { - if h.Name == tt.handler { - found = true - err := h.Exec(tt.args) - if err != nil { - t.Errorf("handler %s returned error: %v", tt.handler, err) - } - break - } - } - - if !found { - t.Errorf("handler %s not found", tt.handler) - } - - if tt.validate != nil { - if err := tt.validate(); err != nil { - t.Error(err) - } - } - }) - } -} - -func TestGetRow(t *testing.T) { - sess := createMockSession() - aliases, _ := data.NewUnsortedKV("", 0) - sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - mod := NewDiscovery(sess) - - // Test endpoint with metadata - endpoint := &network.Endpoint{ - IpAddress: "192.168.1.10", - HwAddress: "aa:aa:aa:aa:aa:aa", - Hostname: "testhost", - Vendor: "Test Vendor", - FirstSeen: time.Now().Add(-time.Hour), - LastSeen: time.Now(), - Meta: network.NewMeta(), - } - endpoint.Meta.Set("key1", "value1") - endpoint.Meta.Set("key2", "value2") - - // Test without meta - rows := mod.getRow(endpoint, false) - if len(rows) != 1 { - t.Errorf("expected 1 row without meta, got %d", len(rows)) - } - if len(rows[0]) != 7 { - t.Errorf("expected 7 columns, got %d", len(rows[0])) - } - - // Test with meta - rows = mod.getRow(endpoint, true) - if len(rows) != 2 { // One main row + one meta row per metadata entry - t.Errorf("expected 2 rows with meta, got %d", len(rows)) - } - - // Test interface endpoint - ifaceEndpoint := sess.Interface - rows = mod.getRow(ifaceEndpoint, false) - if len(rows) != 1 { - t.Errorf("expected 1 row for interface, got %d", len(rows)) - } - - // Test gateway endpoint - gatewayEndpoint := sess.Gateway - rows = mod.getRow(gatewayEndpoint, false) - if len(rows) != 1 { - t.Errorf("expected 1 row for gateway, got %d", len(rows)) - } -} - -func TestDoFilter(t *testing.T) { - sess := createMockSession() - mod := NewDiscovery(sess) - mod.selector = utils.ViewSelectorFor(&mod.SessionModule, "net.show", - []string{"ip", "mac", "seen", "sent", "rcvd"}, "ip asc") - - // Test that doFilter behavior matches the actual implementation - // When Expression is nil, it returns true (no filtering) - // When Expression is set, it matches against any of the fields - - tests := []struct { - name string - filter string - endpoint *network.Endpoint - shouldMatch bool - }{ - { - name: "no filter", - filter: "", - endpoint: &network.Endpoint{ - IpAddress: "192.168.1.10", - Meta: network.NewMeta(), - }, - shouldMatch: true, - }, - { - name: "ip filter match", - filter: "192.168", - endpoint: &network.Endpoint{ - IpAddress: "192.168.1.10", - Meta: network.NewMeta(), - }, - shouldMatch: true, - }, - { - name: "mac filter match", - filter: "aa:bb", - endpoint: &network.Endpoint{ - IpAddress: "192.168.1.10", - HwAddress: "aa:bb:cc:dd:ee:ff", - Meta: network.NewMeta(), - }, - shouldMatch: true, - }, - { - name: "hostname filter match", - filter: "myhost", - endpoint: &network.Endpoint{ - IpAddress: "192.168.1.10", - Hostname: "myhost.local", - Meta: network.NewMeta(), - }, - shouldMatch: true, - }, - { - name: "no match - testing unique string", - filter: "xyz123nomatch", - endpoint: &network.Endpoint{ - IpAddress: "192.168.1.10", - Ip6Address: "", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "host.local", - Alias: "", - Vendor: "", - Meta: network.NewMeta(), - }, - shouldMatch: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Reset selector for each test - // Set the parameter value that Update() will read - sess.Env.Set("net.show.filter", tt.filter) - mod.selector.Expression = nil - - // Update will read from the parameter - err := mod.selector.Update() - if err != nil { - t.Fatalf("selector.Update() failed: %v", err) - } - - result := mod.doFilter(tt.endpoint) - if result != tt.shouldMatch { - if mod.selector.Expression != nil { - t.Errorf("expected doFilter to return %v, got %v. Regex: %s", tt.shouldMatch, result, mod.selector.Expression.String()) - } else { - t.Errorf("expected doFilter to return %v, got %v. Expression is nil", tt.shouldMatch, result) - } - } - }) - } -} - -// Benchmark the runDiff method -func BenchmarkRunDiff(b *testing.B) { - sess := createMockSession() - aliases, _ := data.NewUnsortedKV("", 0) - sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - mod := &Discovery{ - SessionModule: session.NewSessionModule("net.recon", sess), - } - - // Create a large ARP table - arpTable := make(network.ArpTable) - for i := 0; i < 100; i++ { - ip := fmt.Sprintf("192.168.1.%d", i) - mac := fmt.Sprintf("aa:bb:cc:dd:%02x:%02x", i/256, i%256) - arpTable[ip] = mac - - // Add half to the existing LAN - if i < 50 { - sess.Lan.AddIfNew(ip, mac) - } - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mod.runDiff(arpTable) - } -} diff --git a/modules/net_sniff/net_sniff.go b/modules/net_sniff/net_sniff.go index cb2c1b48..4daa9859 100644 --- a/modules/net_sniff/net_sniff.go +++ b/modules/net_sniff/net_sniff.go @@ -59,11 +59,6 @@ func NewSniffer(s *session.Session) *Sniffer { "", "If set, the sniffer will read from this pcap file instead of the current interface.")) - mod.AddParam(session.NewStringParameter("net.sniff.interface", - "", - "", - "Interface to sniff on.")) - mod.AddHandler(session.NewModuleHandler("net.sniff stats", "", "Print sniffer session configuration and statistics.", func(args []string) error { diff --git a/modules/net_sniff/net_sniff_context.go b/modules/net_sniff/net_sniff_context.go index 633238f1..e275ebf8 100644 --- a/modules/net_sniff/net_sniff_context.go +++ b/modules/net_sniff/net_sniff_context.go @@ -17,7 +17,6 @@ import ( type SnifferContext struct { Handle *pcap.Handle - Interface string Source string DumpLocal bool Verbose bool @@ -38,22 +37,13 @@ func (mod *Sniffer) GetContext() (error, *SnifferContext) { return err, ctx } - if err, ctx.Interface = mod.StringParam("net.sniff.interface"); err != nil { - return err, ctx - } - - if ctx.Interface == "" { - ctx.Interface = mod.Session.Interface.Name() - } - if ctx.Source == "" { /* * We don't want to pcap.BlockForever otherwise pcap_close(handle) * could hang waiting for a timeout to expire ... */ - readTimeout := 500 * time.Millisecond - if ctx.Handle, err = network.CaptureWithTimeout(ctx.Interface, readTimeout); err != nil { + if ctx.Handle, err = network.CaptureWithTimeout(mod.Session.Interface.Name(), readTimeout); err != nil { return err, ctx } } else { @@ -104,8 +94,6 @@ func (mod *Sniffer) GetContext() (error, *SnifferContext) { func NewSnifferContext() *SnifferContext { return &SnifferContext{ Handle: nil, - Interface: "", - Source: "", DumpLocal: false, Verbose: false, Filter: "", @@ -127,8 +115,7 @@ var ( ) func (c *SnifferContext) Log(sess *session.Session) { - log.Info("Interface : %s", tui.Bold(c.Interface)) - log.Info("Skip local packets : %s", yn[!c.DumpLocal]) + log.Info("Skip local packets : %s", yn[c.DumpLocal]) log.Info("Verbose : %s", yn[c.Verbose]) log.Info("BPF Filter : '%s'", tui.Yellow(c.Filter)) log.Info("Regular expression : '%s'", tui.Yellow(c.Expression)) diff --git a/modules/net_sniff/net_sniff_http.go b/modules/net_sniff/net_sniff_http.go index 23e0375c..a111c08b 100644 --- a/modules/net_sniff/net_sniff_http.go +++ b/modules/net_sniff/net_sniff_http.go @@ -4,7 +4,7 @@ import ( "bufio" "bytes" "compress/gzip" - "io" + "io/ioutil" "net" "net/http" "strings" @@ -50,7 +50,7 @@ func toSerializableRequest(req *http.Request) HTTPRequest { body := []byte(nil) ctype := "?" if req.Body != nil { - body, _ = io.ReadAll(req.Body) + body, _ = ioutil.ReadAll(req.Body) } for name, values := range req.Header { @@ -90,7 +90,7 @@ func toSerializableResponse(res *http.Response) HTTPResponse { } if res.Body != nil { - body, _ = io.ReadAll(res.Body) + body, _ = ioutil.ReadAll(res.Body) } // attempt decompression, but since this has been parsed by just diff --git a/modules/packet_proxy/packet_proxy_linux.go b/modules/packet_proxy/packet_proxy_linux.go index 9a40fcff..e124976c 100644 --- a/modules/packet_proxy/packet_proxy_linux.go +++ b/modules/packet_proxy/packet_proxy_linux.go @@ -22,7 +22,7 @@ type PacketProxy struct { rule string queue *nfqueue.Nfqueue queueNum int - queueCb func(q *nfqueue.Nfqueue, a nfqueue.Attribute) int + queueCb nfqueue.HookFunc pluginPath string plugin *plugin.Plugin } @@ -149,7 +149,7 @@ func (mod *PacketProxy) Configure() (err error) { return } else if sym, err = mod.plugin.Lookup("OnPacket"); err != nil { return - } else if mod.queueCb, ok = sym.(func(q *nfqueue.Nfqueue, a nfqueue.Attribute) int); !ok { + } else if mod.queueCb, ok = sym.(func(nfqueue.Attribute) int); !ok { return fmt.Errorf("Symbol OnPacket is not a valid callback function.") } @@ -198,7 +198,7 @@ func (mod *PacketProxy) Configure() (err error) { // CGO callback ... ¯\_(ツ)_/¯ func dummyCallback(attribute nfqueue.Attribute) int { if mod.queueCb != nil { - return mod.queueCb(mod.queue, attribute) + return mod.queueCb(attribute) } else { id := *attribute.PacketID diff --git a/modules/tcp_proxy/tcp_proxy_script.go b/modules/tcp_proxy/tcp_proxy_script.go index 50956ea0..fa801be5 100644 --- a/modules/tcp_proxy/tcp_proxy_script.go +++ b/modules/tcp_proxy/tcp_proxy_script.go @@ -1,7 +1,6 @@ package tcp_proxy import ( - "encoding/json" "net" "strings" @@ -56,36 +55,12 @@ func (s *TcpProxyScript) OnData(from, to net.Addr, data []byte, callback func(ca log.Error("error while executing onData callback: %s", err) return nil } else if ret != nil { - return toByteArray(ret) - } - } - return nil -} - -func toByteArray(ret interface{}) []byte { - // this approach is a bit hacky but it handles all cases - - // serialize ret to JSON - if jsonData, err := json.Marshal(ret); err == nil { - // attempt to deserialize as []float64 - var back2Array []float64 - if err := json.Unmarshal(jsonData, &back2Array); err == nil { - result := make([]byte, len(back2Array)) - for i, num := range back2Array { - if num >= 0 && num <= 255 { - result[i] = byte(num) - } else { - log.Error("array element at index %d is not a valid byte value %d", i, num) - return nil - } + array, ok := ret.([]byte) + if !ok { + log.Error("error while casting exported value to array of byte: value = %+v", ret) } - return result - } else { - log.Error("failed to deserialize %+v to []float64: %v", ret, err) + return array } - } else { - log.Error("failed to serialize %+v to JSON: %v", ret, err) } - return nil } diff --git a/modules/tcp_proxy/tcp_proxy_script_test.go b/modules/tcp_proxy/tcp_proxy_script_test.go deleted file mode 100644 index 27bdc099..00000000 --- a/modules/tcp_proxy/tcp_proxy_script_test.go +++ /dev/null @@ -1,169 +0,0 @@ -package tcp_proxy - -import ( - "net" - "testing" - - "github.com/evilsocket/islazy/plugin" -) - -func TestOnData_NoReturn(t *testing.T) { - jsCode := ` - function onData(from, to, data, callback) { - // don't return anything - } - ` - - plug, err := plugin.Parse(jsCode) - if err != nil { - t.Fatalf("Failed to parse plugin: %v", err) - } - - script := &TcpProxyScript{ - Plugin: plug, - doOnData: plug.HasFunc("onData"), - } - - from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - to := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 5678} - data := []byte("test data") - - result := script.OnData(from, to, data, nil) - if result != nil { - t.Errorf("Expected nil result when callback returns nothing, got %v", result) - } -} - -func TestOnData_ReturnsArrayOfIntegers(t *testing.T) { - jsCode := ` - function onData(from, to, data, callback) { - // Return modified data as array of integers - return [72, 101, 108, 108, 111]; // "Hello" in ASCII - } - ` - - plug, err := plugin.Parse(jsCode) - if err != nil { - t.Fatalf("Failed to parse plugin: %v", err) - } - - script := &TcpProxyScript{ - Plugin: plug, - doOnData: plug.HasFunc("onData"), - } - - from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - to := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 5678} - data := []byte("test data") - - result := script.OnData(from, to, data, nil) - expected := []byte("Hello") - - if result == nil { - t.Fatal("Expected non-nil result when callback returns array of integers") - } - - if len(result) != len(expected) { - t.Fatalf("Expected result length %d, got %d", len(expected), len(result)) - } - - for i, b := range result { - if b != expected[i] { - t.Errorf("Expected byte at index %d to be %d, got %d", i, expected[i], b) - } - } -} - -func TestOnData_ReturnsDynamicArray(t *testing.T) { - jsCode := ` - function onData(from, to, data, callback) { - var result = []; - for (var i = 0; i < data.length; i++) { - result.push((data[i] + 1) % 256); - } - return result; - } - ` - - plug, err := plugin.Parse(jsCode) - if err != nil { - t.Fatalf("Failed to parse plugin: %v", err) - } - - script := &TcpProxyScript{ - Plugin: plug, - doOnData: plug.HasFunc("onData"), - } - - from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - to := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 5678} - data := []byte{10, 20, 30, 40, 255} - - result := script.OnData(from, to, data, nil) - expected := []byte{11, 21, 31, 41, 0} // 255 + 1 = 256 % 256 = 0 - - if result == nil { - t.Fatal("Expected non-nil result when callback returns array of integers") - } - - if len(result) != len(expected) { - t.Fatalf("Expected result length %d, got %d", len(expected), len(result)) - } - - for i, b := range result { - if b != expected[i] { - t.Errorf("Expected byte at index %d to be %d, got %d", i, expected[i], b) - } - } -} - -func TestOnData_ReturnsMixedArray(t *testing.T) { - jsCode := ` - function charToInt(value) { - return value.charCodeAt() - } - - function onData(from, to, data) { - st_data = String.fromCharCode.apply(null, data) - if( st_data.indexOf("mysearch") != -1 ) { - payload = "mypayload"; - st_data = st_data.replace("mysearch", payload); - res_int_arr = st_data.split("").map(charToInt) // []uint16 - res_int_arr[0] = payload.length + 1; // first index is float64 and rest []uint16 - return res_int_arr; - } - return data; - } - ` - - plug, err := plugin.Parse(jsCode) - if err != nil { - t.Fatalf("Failed to parse plugin: %v", err) - } - - script := &TcpProxyScript{ - Plugin: plug, - doOnData: plug.HasFunc("onData"), - } - - from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - to := &net.TCPAddr{IP: net.ParseIP("192.168.1.6"), Port: 5678} - data := []byte("Hello mysearch world") - - result := script.OnData(from, to, data, nil) - expected := []byte("\x0aello mypayload world") - - if result == nil { - t.Fatal("Expected non-nil result when callback returns array of integers") - } - - if len(result) != len(expected) { - t.Fatalf("Expected result length %d, got %d", len(expected), len(result)) - } - - for i, b := range result { - if b != expected[i] { - t.Errorf("Expected byte at index %d to be %d, got %d", i, expected[i], b) - } - } -} diff --git a/modules/ticker/ticker.go b/modules/ticker/ticker.go index 34c4c02b..e629d2f0 100644 --- a/modules/ticker/ticker.go +++ b/modules/ticker/ticker.go @@ -43,7 +43,7 @@ func NewTicker(s *session.Session) *Ticker { })) mod.AddHandler(session.NewModuleHandler("ticker off", "", - "Stop the main ticker.", + "Stop the maint icker.", func(args []string) error { return mod.Stop() })) diff --git a/modules/ticker/ticker_test.go b/modules/ticker/ticker_test.go deleted file mode 100644 index 9b1b97a5..00000000 --- a/modules/ticker/ticker_test.go +++ /dev/null @@ -1,413 +0,0 @@ -package ticker - -import ( - "sync" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewTicker(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - if mod == nil { - t.Fatal("NewTicker returned nil") - } - - if mod.Name() != "ticker" { - t.Errorf("Expected name 'ticker', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check parameters exist - if err, _ := mod.StringParam("ticker.commands"); err != nil { - t.Error("ticker.commands parameter not found") - } - - if err, _ := mod.IntParam("ticker.period"); err != nil { - t.Error("ticker.period parameter not found") - } - - // Check handlers - only check the main ones since create/destroy have regex patterns - handlers := []string{"ticker on", "ticker off"} - for _, handler := range handlers { - found := false - for _, h := range mod.Handlers() { - if h.Name == handler { - found = true - break - } - } - if !found { - t.Errorf("Handler '%s' not found", handler) - } - } - - // Check that we have handlers for create and destroy (they have regex patterns) - hasCreate := false - hasDestroy := false - for _, h := range mod.Handlers() { - if h.Name == "ticker.create " { - hasCreate = true - } else if h.Name == "ticker.destroy " { - hasDestroy = true - } - } - if !hasCreate { - t.Error("ticker.create handler not found") - } - if !hasDestroy { - t.Error("ticker.destroy handler not found") - } -} - -func TestTickerConfigure(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Test configure before start - if err := mod.Configure(); err != nil { - t.Errorf("Configure failed: %v", err) - } - - // Check main params were set - if mod.main.Period == 0 { - t.Error("Period not set") - } - - if len(mod.main.Commands) == 0 { - t.Error("Commands not set") - } - - if !mod.main.Running { - t.Error("Running flag not set") - } -} - -func TestTickerStartStop(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Set a short period for testing using session environment - mod.Session.Env.Set("ticker.period", "1") - mod.Session.Env.Set("ticker.commands", "help") - - // Start ticker - if err := mod.Start(); err != nil { - t.Fatalf("Failed to start ticker: %v", err) - } - - if !mod.Running() { - t.Error("Ticker should be running") - } - - // Let it run briefly - time.Sleep(100 * time.Millisecond) - - // Stop ticker - if err := mod.Stop(); err != nil { - t.Fatalf("Failed to stop ticker: %v", err) - } - - if mod.Running() { - t.Error("Ticker should not be running") - } - - if mod.main.Running { - t.Error("Main ticker should not be running") - } -} - -func TestTickerAlreadyStarted(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Start ticker - if err := mod.Start(); err != nil { - t.Fatalf("Failed to start ticker: %v", err) - } - - // Try to configure while running - if err := mod.Configure(); err == nil { - t.Error("Configure should fail when already running") - } - - // Stop ticker - mod.Stop() -} - -func TestTickerNamedOperations(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Create named ticker - name := "test_ticker" - if err := mod.createNamed(name, 1, "help"); err != nil { - t.Fatalf("Failed to create named ticker: %v", err) - } - - // Check it was created - if _, found := mod.named[name]; !found { - t.Error("Named ticker not found in map") - } - - // Try to create duplicate - if err := mod.createNamed(name, 1, "help"); err == nil { - t.Error("Should not allow duplicate named ticker") - } - - // Let it run briefly - time.Sleep(100 * time.Millisecond) - - // Destroy named ticker - if err := mod.destroyNamed(name); err != nil { - t.Fatalf("Failed to destroy named ticker: %v", err) - } - - // Check it was removed - if _, found := mod.named[name]; found { - t.Error("Named ticker still in map after destroy") - } - - // Try to destroy non-existent - if err := mod.destroyNamed("nonexistent"); err == nil { - t.Error("Should fail when destroying non-existent ticker") - } -} - -func TestTickerHandlers(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - tests := []struct { - name string - handler string - regex string - args []string - wantErr bool - }{ - { - name: "ticker on", - handler: "ticker on", - args: []string{}, - wantErr: false, - }, - { - name: "ticker off", - handler: "ticker off", - args: []string{}, - wantErr: true, // ticker off will fail if not running - }, - { - name: "ticker.create valid", - handler: "ticker.create ", - args: []string{"myticker", "2", "help; events.show"}, - wantErr: false, - }, - { - name: "ticker.create invalid period", - handler: "ticker.create ", - args: []string{"myticker", "notanumber", "help"}, - wantErr: true, - }, - { - name: "ticker.destroy", - handler: "ticker.destroy ", - args: []string{"myticker"}, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Find the handler - var handler *session.ModuleHandler - for _, h := range mod.Handlers() { - if h.Name == tt.handler { - handler = &h - break - } - } - - if handler == nil { - t.Fatalf("Handler '%s' not found", tt.handler) - } - - // Create ticker if needed for destroy test - if tt.handler == "ticker.destroy " && len(tt.args) > 0 && tt.args[0] == "myticker" { - mod.createNamed("myticker", 1, "help") - } - - // Execute handler - err := handler.Exec(tt.args) - if (err != nil) != tt.wantErr { - t.Errorf("Handler execution error = %v, wantErr %v", err, tt.wantErr) - } - - // Cleanup - if tt.handler == "ticker on" || tt.handler == "ticker.create " { - mod.Stop() - } - }) - } -} - -func TestTickerWorker(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Create params for testing - params := &Params{ - Commands: []string{"help"}, - Period: 100 * time.Millisecond, - Running: true, - } - - // Start worker in goroutine - done := make(chan bool) - go func() { - mod.worker("test", params) - done <- true - }() - - // Let it tick at least once - time.Sleep(150 * time.Millisecond) - - // Stop the worker - params.Running = false - - // Wait for worker to finish - select { - case <-done: - // Worker finished successfully - case <-time.After(1 * time.Second): - t.Error("Worker did not stop in time") - } -} - -func TestTickerParams(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Test setting invalid period - mod.Session.Env.Set("ticker.period", "invalid") - if err := mod.Configure(); err == nil { - t.Error("Configure should fail with invalid period") - } - - // Test empty commands - mod.Session.Env.Set("ticker.period", "1") - mod.Session.Env.Set("ticker.commands", "") - if err := mod.Configure(); err != nil { - t.Errorf("Configure should work with empty commands: %v", err) - } -} - -func TestTickerMultipleNamed(t *testing.T) { - s := createMockSession(t) - mod := NewTicker(s) - - // Start the ticker first - if err := mod.Start(); err != nil { - t.Fatalf("Failed to start ticker: %v", err) - } - - // Create multiple named tickers - names := []string{"ticker1", "ticker2", "ticker3"} - for _, name := range names { - if err := mod.createNamed(name, 1, "help"); err != nil { - t.Errorf("Failed to create ticker '%s': %v", name, err) - } - } - - // Check all were created - if len(mod.named) != len(names) { - t.Errorf("Expected %d named tickers, got %d", len(names), len(mod.named)) - } - - // Stop all via Stop() - if err := mod.Stop(); err != nil { - t.Fatalf("Failed to stop: %v", err) - } - - // Check all were stopped - for name, params := range mod.named { - if params.Running { - t.Errorf("Ticker '%s' still running after Stop()", name) - } - } -} - -func TestTickEvent(t *testing.T) { - // Simple test for TickEvent struct - event := TickEvent{} - // TickEvent is empty, just ensure it can be created - _ = event -} - -// Benchmark tests -func BenchmarkTickerCreate(b *testing.B) { - // Use existing session to avoid flag redefinition - s := testSession - if s == nil { - var err error - s, err = session.New() - if err != nil { - b.Fatal(err) - } - testSession = s - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mod := NewTicker(s) - _ = mod - } -} - -func BenchmarkTickerStartStop(b *testing.B) { - // Use existing session to avoid flag redefinition - s := testSession - if s == nil { - var err error - s, err = session.New() - if err != nil { - b.Fatal(err) - } - testSession = s - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mod := NewTicker(s) - // Set period parameter - mod.Session.Env.Set("ticker.period", "1") - mod.Start() - mod.Stop() - } -} diff --git a/modules/update/update_test.go b/modules/update/update_test.go deleted file mode 100644 index f112fc14..00000000 --- a/modules/update/update_test.go +++ /dev/null @@ -1,348 +0,0 @@ -package update - -import ( - "sync" - "testing" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -func TestNewUpdateModule(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - if mod == nil { - t.Fatal("NewUpdateModule returned nil") - } - - if mod.Name() != "update" { - t.Errorf("Expected name 'update', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check handler - handlers := mod.Handlers() - if len(handlers) != 1 { - t.Errorf("Expected 1 handler, got %d", len(handlers)) - } - - if len(handlers) > 0 && handlers[0].Name != "update.check on" { - t.Errorf("Expected handler 'update.check on', got '%s'", handlers[0].Name) - } -} - -func TestVersionToNum(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - tests := []struct { - name string - version string - want float64 - }{ - { - name: "simple version", - version: "1.2.3", - want: 123, // 3*1 + 2*10 + 1*100 - }, - { - name: "version with v prefix", - version: "v1.2.3", - want: 123, - }, - { - name: "major version only", - version: "2", - want: 2, - }, - { - name: "major.minor version", - version: "2.1", - want: 21, // 1*1 + 2*10 - }, - { - name: "zero version", - version: "0.0.0", - want: 0, - }, - { - name: "large patch version", - version: "1.0.10", - want: 110, // 10*1 + 0*10 + 1*100 - }, - { - name: "very large version", - version: "10.20.30", - want: 1230, // 30*1 + 20*10 + 10*100 - }, - { - name: "version with leading v", - version: "v2.2.0", - want: 220, // 0*1 + 2*10 + 2*100 - }, - { - name: "single digit versions", - version: "1.1.1", - want: 111, // 1*1 + 1*10 + 1*100 - }, - { - name: "asymmetric version", - version: "1.10.100", - want: 300, // 100*1 + 10*10 + 1*100 - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := mod.versionToNum(tt.version) - if got != tt.want { - t.Errorf("versionToNum(%q) = %v, want %v", tt.version, got, tt.want) - } - }) - } -} - -func TestVersionComparison(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - tests := []struct { - name string - current string - latest string - isNewer bool - }{ - { - name: "newer patch version", - current: "1.2.3", - latest: "1.2.4", - isNewer: true, - }, - { - name: "newer minor version", - current: "1.2.3", - latest: "1.3.0", - isNewer: true, - }, - { - name: "newer major version", - current: "1.2.3", - latest: "2.0.0", - isNewer: true, - }, - { - name: "same version", - current: "1.2.3", - latest: "1.2.3", - isNewer: false, - }, - { - name: "older version", - current: "2.0.0", - latest: "1.9.9", - isNewer: false, - }, - { - name: "v prefix handling", - current: "v1.2.3", - latest: "v1.2.4", - isNewer: true, - }, - { - name: "mixed v prefix", - current: "1.2.3", - latest: "v1.2.4", - isNewer: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - currentNum := mod.versionToNum(tt.current) - latestNum := mod.versionToNum(tt.latest) - - isNewer := currentNum < latestNum - if isNewer != tt.isNewer { - t.Errorf("Expected %s < %s to be %v, but got %v (%.2f vs %.2f)", - tt.current, tt.latest, tt.isNewer, isNewer, currentNum, latestNum) - } - }) - } -} - -func TestConfigure(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - if err := mod.Configure(); err != nil { - t.Errorf("Configure() error = %v", err) - } -} - -func TestStop(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - if err := mod.Stop(); err != nil { - t.Errorf("Stop() error = %v", err) - } -} - -func TestModuleRunning(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } -} - -func TestVersionEdgeCases(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - tests := []struct { - name string - version string - want float64 - wantErr bool - }{ - { - name: "empty version", - version: "", - want: 0, - wantErr: true, // Will panic on ver[0] access - }, - { - name: "only v", - version: "v", - want: 0, - wantErr: true, // Will panic after stripping v - }, - { - name: "non-numeric version", - version: "va.b.c", - want: 0, // strconv.Atoi will return 0 for non-numeric - }, - { - name: "partial numeric", - version: "1.a.3", - want: 103, // 3*1 + 0*10 + 1*100 (a converts to 0) - }, - { - name: "extra dots", - version: "1.2.3.4", - want: 1234, // 4*1 + 3*10 + 2*100 + 1*1000 - }, - { - name: "trailing dot", - version: "1.2.", - want: 120, // splits to ["1","2",""], reverses to ["","2","1"], = 0*1 + 2*10 + 1*100 - }, - { - name: "leading dot", - version: ".1.2", - want: 12, // splits to ["","1","2"], reverses to ["2","1",""], = 2*1 + 1*10 + 0*100 - }, - { - name: "single part", - version: "42", - want: 42, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Skip tests that would panic due to empty version - if tt.wantErr { - // These would panic, so skip them - t.Skip("Skipping test that would panic") - return - } - - got := mod.versionToNum(tt.version) - if got != tt.want { - t.Errorf("versionToNum(%q) = %v, want %v", tt.version, got, tt.want) - } - }) - } -} - -func TestHandlerExecution(t *testing.T) { - s := createMockSession(t) - mod := NewUpdateModule(s) - - // Find the handler - var handler *session.ModuleHandler - for _, h := range mod.Handlers() { - if h.Name == "update.check on" { - handler = &h - break - } - } - - if handler == nil { - t.Fatal("Handler 'update.check on' not found") - } - - // Note: This will make a real API call to GitHub - // In a production test suite, you'd want to mock the GitHub client - // For now, we'll just check that the handler can be executed - // The actual Start() method will be tested separately -} - -// Benchmark tests -func BenchmarkVersionToNum(b *testing.B) { - s, _ := session.New() - mod := NewUpdateModule(s) - - versions := []string{ - "1.2.3", - "v2.4.6", - "10.20.30", - "v100.200.300", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - for _, v := range versions { - mod.versionToNum(v) - } - } -} - -func BenchmarkVersionComparison(b *testing.B) { - s, _ := session.New() - mod := NewUpdateModule(s) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - current := mod.versionToNum("1.2.3") - latest := mod.versionToNum("1.2.4") - _ = current < latest - } -} diff --git a/modules/utils/view_selector_test.go b/modules/utils/view_selector_test.go deleted file mode 100644 index e2a9c609..00000000 --- a/modules/utils/view_selector_test.go +++ /dev/null @@ -1,455 +0,0 @@ -package utils - -import ( - "regexp" - "sync" - "testing" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - }) - return testSession -} - -type mockModule struct { - session.SessionModule -} - -func newMockModule(s *session.Session) *mockModule { - return &mockModule{ - SessionModule: session.NewSessionModule("test", s), - } -} - -func TestViewSelectorFor(t *testing.T) { - s := createMockSession(t) - m := newMockModule(s) - - sortFields := []string{"name", "mac", "seen"} - defExpression := "seen desc" - prefix := "test" - - vs := ViewSelectorFor(&m.SessionModule, prefix, sortFields, defExpression) - - if vs == nil { - t.Fatal("ViewSelectorFor returned nil") - } - - if vs.owner != &m.SessionModule { - t.Error("ViewSelector owner not set correctly") - } - - if vs.filterName != "test.filter" { - t.Errorf("filterName = %s, want test.filter", vs.filterName) - } - - if vs.sortName != "test.sort" { - t.Errorf("sortName = %s, want test.sort", vs.sortName) - } - - if vs.limitName != "test.limit" { - t.Errorf("limitName = %s, want test.limit", vs.limitName) - } - - // Check that parameters were added by trying to retrieve them - if err, _ := m.SessionModule.StringParam("test.filter"); err != nil { - t.Error("filter parameter not accessible") - } - if err, _ := m.SessionModule.StringParam("test.sort"); err != nil { - t.Error("sort parameter not accessible") - } - if err, _ := m.SessionModule.IntParam("test.limit"); err != nil { - t.Error("limit parameter not accessible") - } - - // Check default sorting - if vs.SortField != "seen" { - t.Errorf("Default SortField = %s, want seen", vs.SortField) - } - if vs.Sort != "desc" { - t.Errorf("Default Sort = %s, want desc", vs.Sort) - } -} - -func TestParseFilter(t *testing.T) { - s := createMockSession(t) - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name"}, "name asc") - - tests := []struct { - name string - filter string - wantErr bool - wantExpr bool - }{ - { - name: "empty filter", - filter: "", - wantErr: false, - wantExpr: false, - }, - { - name: "valid regex", - filter: "^test.*", - wantErr: false, - wantExpr: true, - }, - { - name: "invalid regex", - filter: "[invalid", - wantErr: true, - wantExpr: false, - }, - { - name: "simple string", - filter: "test", - wantErr: false, - wantExpr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Set the filter parameter - m.Session.Env.Set("test.filter", tt.filter) - - err := vs.parseFilter() - if (err != nil) != tt.wantErr { - t.Errorf("parseFilter() error = %v, wantErr %v", err, tt.wantErr) - } - - if tt.wantExpr && vs.Expression == nil { - t.Error("Expected Expression to be set, but it's nil") - } - if !tt.wantExpr && vs.Expression != nil { - t.Error("Expected Expression to be nil, but it's set") - } - - if tt.filter != "" && !tt.wantErr { - if vs.Filter != tt.filter { - t.Errorf("Filter = %s, want %s", vs.Filter, tt.filter) - } - } - }) - } -} - -func TestParseSorting(t *testing.T) { - s := createMockSession(t) - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac", "seen"}, "name asc") - - tests := []struct { - name string - sortExpr string - wantErr bool - wantField string - wantDirection string - wantSymbol string - }{ - { - name: "name ascending", - sortExpr: "name asc", - wantErr: false, - wantField: "name", - wantDirection: "asc", - wantSymbol: "▴", // Will be colored blue - }, - { - name: "mac descending", - sortExpr: "mac desc", - wantErr: false, - wantField: "mac", - wantDirection: "desc", - wantSymbol: "▾", // Will be colored blue - }, - { - name: "seen descending", - sortExpr: "seen desc", - wantErr: false, - wantField: "seen", - wantDirection: "desc", - wantSymbol: "▾", - }, - { - name: "invalid field", - sortExpr: "invalid desc", - wantErr: true, - wantField: "", - wantDirection: "", - }, - { - name: "invalid direction", - sortExpr: "name invalid", - wantErr: true, - wantField: "", - wantDirection: "", - }, - { - name: "malformed expression", - sortExpr: "nameDesc", - wantErr: true, - wantField: "", - wantDirection: "", - }, - { - name: "empty expression", - sortExpr: "", - wantErr: true, - wantField: "", - wantDirection: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Set the sort parameter - m.Session.Env.Set("test.sort", tt.sortExpr) - - err := vs.parseSorting() - if (err != nil) != tt.wantErr { - t.Errorf("parseSorting() error = %v, wantErr %v", err, tt.wantErr) - } - - if !tt.wantErr { - if vs.SortField != tt.wantField { - t.Errorf("SortField = %s, want %s", vs.SortField, tt.wantField) - } - if vs.Sort != tt.wantDirection { - t.Errorf("Sort = %s, want %s", vs.Sort, tt.wantDirection) - } - // Check symbol contains expected character (stripping color codes) - if !containsSymbol(vs.SortSymbol, tt.wantSymbol) { - t.Errorf("SortSymbol doesn't contain %s", tt.wantSymbol) - } - } - }) - } -} - -func TestUpdate(t *testing.T) { - s := createMockSession(t) - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac"}, "name asc") - - tests := []struct { - name string - filter string - sort string - limit string - wantErr bool - wantLimit int - }{ - { - name: "all valid", - filter: "test.*", - sort: "mac desc", - limit: "10", - wantErr: false, - wantLimit: 10, - }, - { - name: "invalid filter", - filter: "[invalid", - sort: "name asc", - limit: "5", - wantErr: true, - wantLimit: 0, - }, - { - name: "invalid sort", - filter: "valid", - sort: "invalid field", - limit: "5", - wantErr: true, - wantLimit: 0, - }, - { - name: "invalid limit", - filter: "valid", - sort: "name asc", - limit: "not a number", - wantErr: true, - wantLimit: 0, - }, - { - name: "zero limit", - filter: "", - sort: "name asc", - limit: "0", - wantErr: false, - wantLimit: 0, - }, - { - name: "negative limit", - filter: "", - sort: "name asc", - limit: "-1", - wantErr: false, - wantLimit: -1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Set parameters - m.Session.Env.Set("test.filter", tt.filter) - m.Session.Env.Set("test.sort", tt.sort) - m.Session.Env.Set("test.limit", tt.limit) - - err := vs.Update() - if (err != nil) != tt.wantErr { - t.Errorf("Update() error = %v, wantErr %v", err, tt.wantErr) - } - - if !tt.wantErr { - if vs.Limit != tt.wantLimit { - t.Errorf("Limit = %d, want %d", vs.Limit, tt.wantLimit) - } - } - }) - } -} - -func TestFilterCaching(t *testing.T) { - s := createMockSession(t) - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name"}, "name asc") - - // Set initial filter - m.Session.Env.Set("test.filter", "test1") - if err := vs.parseFilter(); err != nil { - t.Fatalf("Failed to parse initial filter: %v", err) - } - - firstExpr := vs.Expression - if firstExpr == nil { - t.Fatal("Expression should not be nil") - } - - // Parse again with same filter - should use cached expression - if err := vs.parseFilter(); err != nil { - t.Fatalf("Failed to parse filter second time: %v", err) - } - - // The filterPrev mechanism should prevent recompilation - if vs.filterPrev != "test1" { - t.Errorf("filterPrev = %s, want test1", vs.filterPrev) - } - - // Change filter - m.Session.Env.Set("test.filter", "test2") - if err := vs.parseFilter(); err != nil { - t.Fatalf("Failed to parse new filter: %v", err) - } - - if vs.Filter != "test2" { - t.Errorf("Filter = %s, want test2", vs.Filter) - } - if vs.filterPrev != "test2" { - t.Errorf("filterPrev = %s, want test2", vs.filterPrev) - } -} - -func TestSortParserRegex(t *testing.T) { - s := createMockSession(t) - m := newMockModule(s) - - sortFields := []string{"field1", "field2", "complex_field"} - vs := ViewSelectorFor(&m.SessionModule, "test", sortFields, "field1 asc") - - // Test the generated regex pattern - expectedPattern := "(field1|field2|complex_field) (desc|asc)" - if vs.sortParser != expectedPattern { - t.Errorf("sortParser = %s, want %s", vs.sortParser, expectedPattern) - } - - // Test regex compilation - if vs.sortParse == nil { - t.Fatal("sortParse regex is nil") - } - - // Test regex matching - testCases := []struct { - expr string - matches bool - }{ - {"field1 asc", true}, - {"field2 desc", true}, - {"complex_field asc", true}, - {"invalid_field asc", false}, - {"field1 invalid", false}, - {"field1asc", false}, - {"", false}, - } - - for _, tc := range testCases { - matches := vs.sortParse.MatchString(tc.expr) - if matches != tc.matches { - t.Errorf("sortParse.MatchString(%q) = %v, want %v", tc.expr, matches, tc.matches) - } - } -} - -// Helper function to check if a string contains a symbol (ignoring ANSI color codes) -func containsSymbol(s, symbol string) bool { - // Remove ANSI color codes - re := regexp.MustCompile(`\x1b\[[0-9;]*m`) - cleaned := re.ReplaceAllString(s, "") - return cleaned == symbol -} - -// Benchmark tests -func BenchmarkParseFilter(b *testing.B) { - s, _ := session.New() - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name"}, "name asc") - - m.Session.Env.Set("test.filter", "test.*") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - vs.parseFilter() - } -} - -func BenchmarkParseSorting(b *testing.B) { - s, _ := session.New() - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac", "seen"}, "name asc") - - m.Session.Env.Set("test.sort", "mac desc") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - vs.parseSorting() - } -} - -func BenchmarkUpdate(b *testing.B) { - s, _ := session.New() - m := newMockModule(s) - vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac"}, "name asc") - - m.Session.Env.Set("test.filter", "test") - m.Session.Env.Set("test.sort", "mac desc") - m.Session.Env.Set("test.limit", "10") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - vs.Update() - } -} diff --git a/modules/wifi/wifi.go b/modules/wifi/wifi.go index 2a000f4b..dea727b1 100644 --- a/modules/wifi/wifi.go +++ b/modules/wifi/wifi.go @@ -104,10 +104,7 @@ func NewWiFiModule(s *session.Session) *WiFiModule { } mod.InitState("channels") - mod.InitState("channel") - mod.State.Store("channels", []int{}) - mod.State.Store("channel", 0) mod.AddParam(session.NewStringParameter("wifi.interface", "", @@ -265,8 +262,8 @@ func NewWiFiModule(s *session.Session) *WiFiModule { mod.AddHandler(probe) - channelSwitchAnnounce := session.NewModuleHandler("wifi.channel_switch_announce BSSID CHANNEL ", `wifi\.channel_switch_announce ((?:[a-fA-F0-9:]{11,}))\s+((?:[0-9]+))`, - "Start a 802.11 channel hop attack, all client will be forced to change the channel lead to connection down.", + channelSwitchAnnounce := session.NewModuleHandler("wifi.channel_switch_announce bssid channel ", `wifi\.channel_switch_announce ((?:[a-fA-F0-9:]{11,}))\s+((?:[0-9]+))`, + "Start a 802.11 channel hop attack, all client will be force to change the channel lead to connection down.", func(args []string) error { bssid, err := net.ParseMAC(args[0]) if err != nil { @@ -651,22 +648,19 @@ func (mod *WiFiModule) Configure() error { mod.hopPeriod = time.Duration(hopPeriod) * time.Millisecond if mod.source == "" { - if len(mod.frequencies) == 0 { - if freqs, err := network.GetSupportedFrequencies(ifName); err != nil { - return fmt.Errorf("error while getting supported frequencies of %s: %s", ifName, err) - } else { - mod.setFrequencies(freqs) - } - - mod.Debug("wifi supported frequencies: %v", mod.frequencies) + if freqs, err := network.GetSupportedFrequencies(ifName); err != nil { + return fmt.Errorf("error while getting supported frequencies of %s: %s", ifName, err) + } else { + mod.setFrequencies(freqs) } + mod.Debug("wifi supported frequencies: %v", mod.frequencies) + // we need to start somewhere, this is just to check if // this OS supports switching channel programmatically. if err = network.SetInterfaceChannel(ifName, 1); err != nil { return fmt.Errorf("error while initializing %s to channel 1: %s", ifName, err) } - mod.State.Store("channel", 1) mod.Info("started (min rssi: %d dBm)", mod.minRSSI) } diff --git a/modules/wifi/wifi_hopping.go b/modules/wifi/wifi_hopping.go index 03797908..43b5fe7d 100644 --- a/modules/wifi/wifi_hopping.go +++ b/modules/wifi/wifi_hopping.go @@ -36,8 +36,6 @@ func (mod *WiFiModule) hopUnlocked(channel int) (mustStop bool) { } } - mod.State.Store("channel", channel) - return } diff --git a/modules/wifi/wifi_test.go b/modules/wifi/wifi_test.go deleted file mode 100644 index afd5322c..00000000 --- a/modules/wifi/wifi_test.go +++ /dev/null @@ -1,629 +0,0 @@ -package wifi - -import ( - "bytes" - "net" - "regexp" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/network" - "github.com/bettercap/bettercap/v2/packets" - "github.com/bettercap/bettercap/v2/session" - "github.com/evilsocket/islazy/data" -) - -// Create a mock session for testing -func createMockSession() *session.Session { - // Create interface - iface := &network.Endpoint{ - IpAddress: "192.168.1.100", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "wlan0", - } - iface.SetIP("192.168.1.100") - iface.SetBits(24) - - // Parse interface addresses - ifaceIP := net.ParseIP("192.168.1.100") - ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - iface.IP = ifaceIP - iface.HW = ifaceHW - - // Create gateway - gateway := &network.Endpoint{ - IpAddress: "192.168.1.1", - HwAddress: "11:22:33:44:55:66", - } - gatewayIP := net.ParseIP("192.168.1.1") - gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66") - gateway.IP = gatewayIP - gateway.HW = gatewayHW - - // Create environment - env, _ := session.NewEnvironment("") - - // Create LAN - aliases, _ := data.NewUnsortedKV("", 0) - lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - // Create session - sess := &session.Session{ - Interface: iface, - Gateway: gateway, - Lan: lan, - StartedAt: time.Now(), - Active: true, - Env: env, - Queue: &packets.Queue{}, - Modules: make(session.ModuleList, 0), - } - - // Initialize events - sess.Events = session.NewEventPool(false, false) - - // Initialize WiFi state - sess.WiFi = network.NewWiFi(iface, aliases, func(ap *network.AccessPoint) {}, func(ap *network.AccessPoint) {}) - - return sess -} - -func TestNewWiFiModule(t *testing.T) { - sess := createMockSession() - - mod := NewWiFiModule(sess) - - if mod == nil { - t.Fatal("NewWiFiModule returned nil") - } - - if mod.Name() != "wifi" { - t.Errorf("expected module name 'wifi', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli && Gianluca Braga " { - t.Errorf("unexpected author: %s", mod.Author()) - } - - // Check parameters - params := []string{ - "wifi.interface", - "wifi.rssi.min", - "wifi.deauth.skip", - "wifi.deauth.silent", - "wifi.deauth.open", - "wifi.deauth.acquired", - "wifi.assoc.skip", - "wifi.assoc.silent", - "wifi.assoc.open", - "wifi.assoc.acquired", - "wifi.ap.ttl", - "wifi.sta.ttl", - "wifi.region", - "wifi.txpower", - "wifi.handshakes.file", - "wifi.handshakes.aggregate", - "wifi.ap.ssid", - "wifi.ap.bssid", - "wifi.ap.channel", - "wifi.ap.encryption", - "wifi.show.manufacturer", - "wifi.source.file", - "wifi.hop.period", - "wifi.skip-broken", - "wifi.channel_switch_announce.silent", - "wifi.fake_auth.silent", - "wifi.bruteforce.target", - "wifi.bruteforce.wordlist", - "wifi.bruteforce.workers", - "wifi.bruteforce.wide", - "wifi.bruteforce.stop_at_first", - "wifi.bruteforce.timeout", - } - for _, param := range params { - if !mod.Session.Env.Has(param) { - t.Errorf("parameter %s not registered", param) - } - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{ - "wifi.recon on", - "wifi.recon off", - "wifi.clear", - "wifi.recon MAC", - "wifi.recon clear", - "wifi.deauth BSSID", - "wifi.probe BSSID ESSID", - "wifi.assoc BSSID", - "wifi.ap", - "wifi.show.wps BSSID", - "wifi.show", - "wifi.recon.channel CHANNEL", - "wifi.client.probe.sta.filter FILTER", - "wifi.client.probe.ap.filter FILTER", - "wifi.channel_switch_announce bssid channel ", - "wifi.fake_auth bssid client", - "wifi.bruteforce on", - "wifi.bruteforce off", - } - - if len(handlers) != len(expectedHandlers) { - t.Errorf("expected %d handlers, got %d", len(expectedHandlers), len(handlers)) - } -} - -func TestWiFiModuleConfigure(t *testing.T) { - tests := []struct { - name string - params map[string]string - expectErr bool - }{ - { - name: "default configuration", - params: map[string]string{ - "wifi.interface": "", - "wifi.ap.ttl": "300", - "wifi.sta.ttl": "300", - "wifi.region": "", - "wifi.txpower": "30", - "wifi.source.file": "", - "wifi.rssi.min": "-200", - "wifi.handshakes.file": "~/bettercap-wifi-handshakes.pcap", - "wifi.handshakes.aggregate": "true", - "wifi.hop.period": "250", - "wifi.skip-broken": "true", - }, - expectErr: true, // Will fail without actual interface - }, - { - name: "invalid rssi", - params: map[string]string{ - "wifi.rssi.min": "not-a-number", - }, - expectErr: true, - }, - { - name: "invalid hop period", - params: map[string]string{ - "wifi.hop.period": "invalid", - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Set parameters - for k, v := range tt.params { - sess.Env.Set(k, v) - } - - err := mod.Configure() - - if tt.expectErr && err == nil { - t.Error("expected error but got none") - } else if !tt.expectErr && err != nil { - t.Errorf("unexpected error: %v", err) - } - }) - } -} - -func TestWiFiModuleFrequencies(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test setting frequencies - freqs := []int{2412, 2437, 2462, 5180, 5200} // Channels 1, 6, 11, 36, 40 - mod.setFrequencies(freqs) - - if len(mod.frequencies) != len(freqs) { - t.Errorf("expected %d frequencies, got %d", len(freqs), len(mod.frequencies)) - } - - // Check if channels were properly converted - channels, _ := mod.State.Load("channels") - channelList := channels.([]int) - expectedChannels := []int{1, 6, 11, 36, 40} - - if len(channelList) != len(expectedChannels) { - t.Errorf("expected %d channels, got %d", len(expectedChannels), len(channelList)) - } -} - -func TestWiFiModuleFilters(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test STA filter - handlers := mod.Handlers() - var staFilterHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "wifi.client.probe.sta.filter FILTER" { - staFilterHandler = h - break - } - } - - if staFilterHandler.Name == "" { - t.Fatal("STA filter handler not found") - } - - // Set a filter - err := staFilterHandler.Exec([]string{"^aa:bb:.*"}) - if err != nil { - t.Errorf("Failed to set STA filter: %v", err) - } - - if mod.filterProbeSTA == nil { - t.Error("STA filter was not set") - } - - // Clear filter - err = staFilterHandler.Exec([]string{"clear"}) - if err != nil { - t.Errorf("Failed to clear STA filter: %v", err) - } - - if mod.filterProbeSTA != nil { - t.Error("STA filter was not cleared") - } - - // Test AP filter - var apFilterHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "wifi.client.probe.ap.filter FILTER" { - apFilterHandler = h - break - } - } - - if apFilterHandler.Name == "" { - t.Fatal("AP filter handler not found") - } - - // Set a filter - err = apFilterHandler.Exec([]string{"^TestAP.*"}) - if err != nil { - t.Errorf("Failed to set AP filter: %v", err) - } - - if mod.filterProbeAP == nil { - t.Error("AP filter was not set") - } -} - -func TestWiFiModuleDeauth(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test deauth handler - handlers := mod.Handlers() - var deauthHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "wifi.deauth BSSID" { - deauthHandler = h - break - } - } - - if deauthHandler.Name == "" { - t.Fatal("Deauth handler not found") - } - - // Test with "all" - err := deauthHandler.Exec([]string{"all"}) - if err == nil { - t.Error("Expected error when starting deauth without running module") - } - - // Test with invalid MAC - err = deauthHandler.Exec([]string{"invalid-mac"}) - if err == nil { - t.Error("Expected error with invalid MAC address") - } -} - -func TestWiFiModuleChannelHandler(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test channel handler - handlers := mod.Handlers() - var channelHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "wifi.recon.channel CHANNEL" { - channelHandler = h - break - } - } - - if channelHandler.Name == "" { - t.Fatal("Channel handler not found") - } - - // Test with valid channels - err := channelHandler.Exec([]string{"1,6,11"}) - if err != nil { - t.Errorf("Failed to set channels: %v", err) - } - - // Test with invalid channel - err = channelHandler.Exec([]string{"999"}) - if err == nil { - t.Error("Expected error with invalid channel") - } - - // Test clear - err = channelHandler.Exec([]string{"clear"}) - if err == nil { - // Will fail without actual interface but should parse correctly - t.Log("Clear channels parsed correctly") - } -} - -func TestWiFiModuleShow(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test show handler exists - handlers := mod.Handlers() - found := false - for _, h := range handlers { - if h.Name == "wifi.show" { - found = true - break - } - } - - if !found { - t.Fatal("Show handler not found") - } - - // Skip actual execution as it requires UI components - t.Log("Show handler found, skipping execution due to UI dependencies") -} - -func TestWiFiModuleShowWPS(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test show WPS handler exists - handlers := mod.Handlers() - found := false - for _, h := range handlers { - if h.Name == "wifi.show.wps BSSID" { - found = true - break - } - } - - if !found { - t.Fatal("Show WPS handler not found") - } - - // Skip actual execution as it requires UI components - t.Log("Show WPS handler found, skipping execution due to UI dependencies") -} - -func TestWiFiModuleBruteforce(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Check bruteforce config - if mod.bruteforce == nil { - t.Fatal("Bruteforce config not initialized") - } - - // Test bruteforce parameters - params := map[string]string{ - "wifi.bruteforce.target": "TestAP", - "wifi.bruteforce.wordlist": "/tmp/wordlist.txt", - "wifi.bruteforce.workers": "4", - "wifi.bruteforce.wide": "true", - "wifi.bruteforce.stop_at_first": "true", - "wifi.bruteforce.timeout": "30", - } - - for k, v := range params { - sess.Env.Set(k, v) - } - - // Verify parameters were set - if err, target := mod.StringParam("wifi.bruteforce.target"); err != nil { - t.Errorf("Failed to get bruteforce target: %v", err) - } else if target != "TestAP" { - t.Errorf("Expected target 'TestAP', got '%s'", target) - } -} - -func TestWiFiModuleAPConfig(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Set AP parameters - params := map[string]string{ - "wifi.ap.ssid": "TestAP", - "wifi.ap.bssid": "aa:bb:cc:dd:ee:ff", - "wifi.ap.channel": "6", - "wifi.ap.encryption": "true", - } - - for k, v := range params { - sess.Env.Set(k, v) - } - - // Parse AP config - err := mod.parseApConfig() - if err != nil { - t.Errorf("Failed to parse AP config: %v", err) - } - - // Verify config - if mod.apConfig.SSID != "TestAP" { - t.Errorf("Expected SSID 'TestAP', got '%s'", mod.apConfig.SSID) - } - - if !bytes.Equal(mod.apConfig.BSSID, []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}) { - t.Errorf("BSSID mismatch") - } - - if mod.apConfig.Channel != 6 { - t.Errorf("Expected channel 6, got %d", mod.apConfig.Channel) - } - - if !mod.apConfig.Encryption { - t.Error("Expected encryption to be enabled") - } -} - -func TestWiFiModuleSkipMACs(t *testing.T) { - // Skip this test as updateDeauthSkipList and updateAssocSkipList are private methods - t.Skip("Skipping test for private skip list methods") -} - -func TestWiFiModuleProbe(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test probe handler - handlers := mod.Handlers() - var probeHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "wifi.probe BSSID ESSID" { - probeHandler = h - break - } - } - - if probeHandler.Name == "" { - t.Fatal("Probe handler not found") - } - - // Test with valid parameters - err := probeHandler.Exec([]string{"aa:bb:cc:dd:ee:ff", "TestNetwork"}) - if err == nil { - t.Error("Expected error when probing without running module") - } - - // Test with invalid MAC - err = probeHandler.Exec([]string{"invalid-mac", "TestNetwork"}) - if err == nil { - t.Error("Expected error with invalid MAC address") - } -} - -func TestWiFiModuleFakeAuth(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Test fake auth handler - handlers := mod.Handlers() - var fakeAuthHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "wifi.fake_auth bssid client" { - fakeAuthHandler = h - break - } - } - - if fakeAuthHandler.Name == "" { - t.Fatal("Fake auth handler not found") - } - - // Test with valid parameters - err := fakeAuthHandler.Exec([]string{"aa:bb:cc:dd:ee:ff", "11:22:33:44:55:66"}) - if err == nil { - t.Error("Expected error when running fake auth without running module") - } - - // Test with invalid MACs - err = fakeAuthHandler.Exec([]string{"invalid-mac", "11:22:33:44:55:66"}) - if err == nil { - t.Error("Expected error with invalid BSSID") - } -} - -func TestWiFiModuleViewSelector(t *testing.T) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - // Check if view selector is initialized - if mod.selector == nil { - t.Fatal("View selector not initialized") - } -} - -// Helper function -func contains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false -} - -// Test bruteforce config -func TestBruteforceConfig(t *testing.T) { - config := NewBruteForceConfig() - - if config == nil { - t.Fatal("NewBruteForceConfig returned nil") - } - - // Check defaults - if config.target != "" { - t.Errorf("Expected empty target, got '%s'", config.target) - } - - if config.wordlist != "/usr/share/dict/words" { - t.Errorf("Expected wordlist '/usr/share/dict/words', got '%s'", config.wordlist) - } - - if config.workers != 1 { - t.Errorf("Expected 1 worker, got %d", config.workers) - } - - if config.wide { - t.Error("Expected wide to be false by default") - } - - if !config.stop_at_first { - t.Error("Expected stop_at_first to be true by default") - } - - if config.timeout != 15 { - t.Errorf("Expected timeout 15, got %d", config.timeout) - } -} - -// Benchmarks -func BenchmarkWiFiModuleSetFrequencies(b *testing.B) { - sess := createMockSession() - mod := NewWiFiModule(sess) - - freqs := []int{2412, 2437, 2462, 5180, 5200, 5220, 5240, 5745, 5765, 5785, 5805, 5825} - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mod.setFrequencies(freqs) - } -} - -func BenchmarkWiFiModuleFilterCheck(b *testing.B) { - filter, _ := regexp.Compile("^aa:bb:.*") - testMAC := "aa:bb:cc:dd:ee:ff" - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = filter.MatchString(testMAC) - } -} diff --git a/modules/wol/wol_test.go b/modules/wol/wol_test.go deleted file mode 100644 index 115f4f32..00000000 --- a/modules/wol/wol_test.go +++ /dev/null @@ -1,364 +0,0 @@ -package wol - -import ( - "bytes" - "net" - "sync" - "testing" - - "github.com/bettercap/bettercap/v2/session" -) - -var ( - testSession *session.Session - sessionOnce sync.Once -) - -func createMockSession(t *testing.T) *session.Session { - sessionOnce.Do(func() { - var err error - testSession, err = session.New() - if err != nil { - t.Fatalf("Failed to create session: %v", err) - } - // Initialize interface with mock data to avoid nil pointer - // For now, we'll skip initializing these as they require more complex setup - // The tests will handle the nil cases appropriately - }) - return testSession -} - -func TestNewWOL(t *testing.T) { - s := createMockSession(t) - mod := NewWOL(s) - - if mod == nil { - t.Fatal("NewWOL returned nil") - } - - if mod.Name() != "wol" { - t.Errorf("Expected name 'wol', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("Unexpected author: %s", mod.Author()) - } - - if mod.Description() == "" { - t.Error("Empty description") - } - - // Check handlers - handlers := []string{"wol.eth MAC", "wol.udp MAC"} - for _, handlerName := range handlers { - found := false - for _, h := range mod.Handlers() { - if h.Name == handlerName { - found = true - break - } - } - if !found { - t.Errorf("Handler '%s' not found", handlerName) - } - } -} - -func TestParseMAC(t *testing.T) { - tests := []struct { - name string - args []string - want string - wantErr bool - }{ - { - name: "empty args", - args: []string{}, - want: "ff:ff:ff:ff:ff:ff", - wantErr: false, - }, - { - name: "empty string arg", - args: []string{""}, - want: "ff:ff:ff:ff:ff:ff", - wantErr: false, - }, - { - name: "valid MAC with colons", - args: []string{"aa:bb:cc:dd:ee:ff"}, - want: "aa:bb:cc:dd:ee:ff", - wantErr: false, - }, - { - name: "valid MAC with dashes", - args: []string{"aa-bb-cc-dd-ee-ff"}, - want: "aa-bb-cc-dd-ee-ff", - wantErr: false, - }, - { - name: "valid MAC uppercase", - args: []string{"AA:BB:CC:DD:EE:FF"}, - want: "AA:BB:CC:DD:EE:FF", - wantErr: false, - }, - { - name: "valid MAC mixed case", - args: []string{"aA:bB:cC:dD:eE:fF"}, - want: "aA:bB:cC:dD:eE:fF", - wantErr: false, - }, - { - name: "invalid MAC - too short", - args: []string{"aa:bb:cc:dd:ee"}, - want: "", - wantErr: true, - }, - { - name: "invalid MAC - too long", - args: []string{"aa:bb:cc:dd:ee:ff:gg"}, - want: "", - wantErr: true, - }, - { - name: "invalid MAC - bad characters", - args: []string{"aa:bb:cc:dd:ee:gg"}, - want: "", - wantErr: true, - }, - { - name: "invalid MAC - no separators", - args: []string{"aabbccddeeff"}, - want: "", - wantErr: true, - }, - { - name: "MAC with spaces", - args: []string{" aa:bb:cc:dd:ee:ff "}, - want: "aa:bb:cc:dd:ee:ff", - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseMAC(tt.args) - if (err != nil) != tt.wantErr { - t.Errorf("parseMAC() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("parseMAC() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestBuildPayload(t *testing.T) { - tests := []struct { - name string - mac string - }{ - { - name: "broadcast MAC", - mac: "ff:ff:ff:ff:ff:ff", - }, - { - name: "specific MAC", - mac: "aa:bb:cc:dd:ee:ff", - }, - { - name: "zeros MAC", - mac: "00:00:00:00:00:00", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - payload := buildPayload(tt.mac) - - // Payload should be 102 bytes: 6 bytes sync + 16 * 6 bytes MAC - if len(payload) != 102 { - t.Errorf("buildPayload() length = %d, want 102", len(payload)) - } - - // First 6 bytes should be 0xff - for i := 0; i < 6; i++ { - if payload[i] != 0xff { - t.Errorf("payload[%d] = %x, want 0xff", i, payload[i]) - } - } - - // Parse the MAC for comparison - parsedMAC, _ := net.ParseMAC(tt.mac) - - // Next 16 copies of the MAC - for i := 0; i < 16; i++ { - start := 6 + i*6 - end := start + 6 - if !bytes.Equal(payload[start:end], parsedMAC) { - t.Errorf("MAC copy %d = %x, want %x", i, payload[start:end], parsedMAC) - } - } - }) - } -} - -func TestWOLConfigure(t *testing.T) { - s := createMockSession(t) - mod := NewWOL(s) - - if err := mod.Configure(); err != nil { - t.Errorf("Configure() error = %v", err) - } -} - -func TestWOLStartStop(t *testing.T) { - s := createMockSession(t) - mod := NewWOL(s) - - if err := mod.Start(); err != nil { - t.Errorf("Start() error = %v", err) - } - - if err := mod.Stop(); err != nil { - t.Errorf("Stop() error = %v", err) - } -} - -func TestWOLHandlers(t *testing.T) { - // Only test parseMAC validation since the actual handlers require a fully initialized session - testCases := []struct { - name string - args []string - wantMAC string - wantErr bool - }{ - { - name: "empty args", - args: []string{}, - wantMAC: "ff:ff:ff:ff:ff:ff", - wantErr: false, - }, - { - name: "valid MAC", - args: []string{"aa:bb:cc:dd:ee:ff"}, - wantMAC: "aa:bb:cc:dd:ee:ff", - wantErr: false, - }, - { - name: "invalid MAC", - args: []string{"invalid:mac"}, - wantMAC: "", - wantErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - mac, err := parseMAC(tc.args) - if (err != nil) != tc.wantErr { - t.Errorf("parseMAC() error = %v, wantErr %v", err, tc.wantErr) - } - if mac != tc.wantMAC { - t.Errorf("parseMAC() = %v, want %v", mac, tc.wantMAC) - } - }) - } -} - -func TestWOLMethods(t *testing.T) { - s := createMockSession(t) - mod := NewWOL(s) - - // Test that the methods exist and can be called without panic - // The actual execution will fail due to nil session interface/queue - // but we're testing the module structure - - // Check that handlers were properly registered - expectedHandlers := 2 // wol.eth and wol.udp - if len(mod.Handlers()) != expectedHandlers { - t.Errorf("Expected %d handlers, got %d", expectedHandlers, len(mod.Handlers())) - } - - // Verify handler names - handlerNames := make(map[string]bool) - for _, h := range mod.Handlers() { - handlerNames[h.Name] = true - } - - if !handlerNames["wol.eth MAC"] { - t.Error("wol.eth handler not found") - } - if !handlerNames["wol.udp MAC"] { - t.Error("wol.udp handler not found") - } -} - -func TestReMAC(t *testing.T) { - tests := []struct { - mac string - valid bool - }{ - {"aa:bb:cc:dd:ee:ff", true}, - {"AA:BB:CC:DD:EE:FF", true}, - {"aa-bb-cc-dd-ee-ff", true}, - {"AA-BB-CC-DD-EE-FF", true}, - {"aA:bB:cC:dD:eE:fF", true}, - {"00:00:00:00:00:00", true}, - {"ff:ff:ff:ff:ff:ff", true}, - {"aabbccddeeff", false}, - {"aa:bb:cc:dd:ee", false}, - {"aa:bb:cc:dd:ee:ff:gg", false}, - {"aa:bb:cc:dd:ee:gg", false}, - {"zz:zz:zz:zz:zz:zz", false}, - {"", false}, - {"not a mac", false}, - } - - for _, tt := range tests { - t.Run(tt.mac, func(t *testing.T) { - if got := reMAC.MatchString(tt.mac); got != tt.valid { - t.Errorf("reMAC.MatchString(%q) = %v, want %v", tt.mac, got, tt.valid) - } - }) - } -} - -// Test that the module sets running state correctly -func TestWOLRunningState(t *testing.T) { - s := createMockSession(t) - mod := NewWOL(s) - - // Initially should not be running - if mod.Running() { - t.Error("Module should not be running initially") - } - - // Note: wolETH and wolUDP will fail due to nil session.Queue, - // but they should still set the running state before failing -} - -// Benchmark tests -func BenchmarkBuildPayload(b *testing.B) { - mac := "aa:bb:cc:dd:ee:ff" - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = buildPayload(mac) - } -} - -func BenchmarkParseMAC(b *testing.B) { - args := []string{"aa:bb:cc:dd:ee:ff"} - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = parseMAC(args) - } -} - -func BenchmarkReMAC(b *testing.B) { - mac := "aa:bb:cc:dd:ee:ff" - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = reMAC.MatchString(mac) - } -} diff --git a/modules/zerogod/zerogod_discovery.go b/modules/zerogod/zerogod_discovery.go index f6223e54..97d0f486 100644 --- a/modules/zerogod/zerogod_discovery.go +++ b/modules/zerogod/zerogod_discovery.go @@ -201,14 +201,6 @@ func (mod *ZeroGod) logDNS(src net.IP, dns layers.DNS, isLocal bool) { func (mod *ZeroGod) onPacket(pkt gopacket.Packet) { mod.Debug("%++v", pkt) - // sadly the latest available version of gopacket has an unpatched bug :/ - // https://github.com/bettercap/bettercap/issues/1184 - defer func() { - if err := recover(); err != nil { - mod.Error("unexpected error while parsing network packet: %v\n\n%++v", err, pkt) - } - }() - netLayer := pkt.NetworkLayer() if netLayer == nil { mod.Warning("not network layer in packet %+v", pkt) diff --git a/modules/zerogod/zerogod_show.go b/modules/zerogod/zerogod_show.go index 4c465d0d..03abebbf 100644 --- a/modules/zerogod/zerogod_show.go +++ b/modules/zerogod/zerogod_show.go @@ -61,24 +61,15 @@ func (mod *ZeroGod) show(filter string, withData bool) error { for _, field := range svc.Text { if field = str.Trim(field); len(field) > 0 { keyval := strings.SplitN(field, "=", 2) - key := str.Trim(keyval[0]) - val := str.Trim(keyval[1]) - - if key != "" || val != "" { - rows = append(rows, []string{ - key, - val, - }) - } + rows = append(rows, []string{ + keyval[0], + keyval[1], + }) } } - if len(rows) == 0 { - fmt.Fprintf(mod.Session.Events.Stdout, " %s\n", tui.Dim("no data")) - } else { - tui.Table(mod.Session.Events.Stdout, columns, rows) - fmt.Fprintf(mod.Session.Events.Stdout, "\n") - } + tui.Table(mod.Session.Events.Stdout, columns, rows) + fmt.Fprintf(mod.Session.Events.Stdout, "\n") } else { fmt.Fprintf(mod.Session.Events.Stdout, " %s\n", tui.Dim("no data")) diff --git a/modules/zerogod/zerogod_test.go b/modules/zerogod/zerogod_test.go deleted file mode 100644 index b64bbab0..00000000 --- a/modules/zerogod/zerogod_test.go +++ /dev/null @@ -1,480 +0,0 @@ -package zerogod - -import ( - "fmt" - "io/ioutil" - "net" - "os" - "testing" - "time" - - "github.com/bettercap/bettercap/v2/network" - "github.com/bettercap/bettercap/v2/packets" - "github.com/bettercap/bettercap/v2/session" - "github.com/evilsocket/islazy/data" -) - -// MockNetRecon implements a minimal net.recon module for testing -type MockNetRecon struct { - session.SessionModule -} - -func NewMockNetRecon(s *session.Session) *MockNetRecon { - mod := &MockNetRecon{ - SessionModule: session.NewSessionModule("net.recon", s), - } - - // Add handlers - mod.AddHandler(session.NewModuleHandler("net.recon on", "", - "Start net.recon", - func(args []string) error { - return mod.Start() - })) - - mod.AddHandler(session.NewModuleHandler("net.recon off", "", - "Stop net.recon", - func(args []string) error { - return mod.Stop() - })) - - return mod -} - -func (m *MockNetRecon) Name() string { - return "net.recon" -} - -func (m *MockNetRecon) Description() string { - return "Mock net.recon module" -} - -func (m *MockNetRecon) Author() string { - return "test" -} - -func (m *MockNetRecon) Configure() error { - return nil -} - -func (m *MockNetRecon) Start() error { - return m.SetRunning(true, nil) -} - -func (m *MockNetRecon) Stop() error { - return m.SetRunning(false, nil) -} - -// MockBrowser for testing -type MockBrowser struct { - started bool - stopped bool - waitCh chan bool -} - -func (m *MockBrowser) Start() error { - m.started = true - m.waitCh = make(chan bool, 1) - return nil -} - -func (m *MockBrowser) Stop() error { - m.stopped = true - if m.waitCh != nil { - m.waitCh <- true - close(m.waitCh) - } - return nil -} - -func (m *MockBrowser) Wait() { - if m.waitCh != nil { - <-m.waitCh - } -} - -// MockAdvertiser for testing -type MockAdvertiser struct { - started bool - stopped bool - services []*ServiceData - config string -} - -func (m *MockAdvertiser) Start(services []*ServiceData) error { - m.started = true - m.services = services - return nil -} - -func (m *MockAdvertiser) Stop() error { - m.stopped = true - return nil -} - -// Create a mock session for testing -func createMockSession() *session.Session { - // Create interface - iface := &network.Endpoint{ - IpAddress: "192.168.1.100", - HwAddress: "aa:bb:cc:dd:ee:ff", - Hostname: "eth0", - } - iface.SetIP("192.168.1.100") - iface.SetBits(24) - - // Parse interface addresses - ifaceIP := net.ParseIP("192.168.1.100") - ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - iface.IP = ifaceIP - iface.HW = ifaceHW - - // Create gateway - gateway := &network.Endpoint{ - IpAddress: "192.168.1.1", - HwAddress: "11:22:33:44:55:66", - } - gatewayIP := net.ParseIP("192.168.1.1") - gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66") - gateway.IP = gatewayIP - gateway.HW = gatewayHW - - // Create environment - env, _ := session.NewEnvironment("") - - // Create LAN with some test endpoints - aliases, _ := data.NewUnsortedKV("", 0) - lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {}) - - // Add test endpoints - testEndpoint := &network.Endpoint{ - IpAddress: "192.168.1.10", - HwAddress: "11:11:11:11:11:11", - Hostname: "test-device", - } - testEndpoint.IP = net.ParseIP("192.168.1.10") - // Add endpoint to LAN using AddIfNew - lan.AddIfNew(testEndpoint.IpAddress, testEndpoint.HwAddress) - - // Create session - sess := &session.Session{ - Interface: iface, - Gateway: gateway, - Lan: lan, - StartedAt: time.Now(), - Active: true, - Env: env, - Queue: &packets.Queue{}, - Modules: make(session.ModuleList, 0), - } - - // Initialize events - sess.Events = session.NewEventPool(false, false) - - // Add mock net.recon module - mockNetRecon := NewMockNetRecon(sess) - sess.Modules = append(sess.Modules, mockNetRecon) - - return sess -} - -func TestNewZeroGod(t *testing.T) { - sess := createMockSession() - - mod := NewZeroGod(sess) - - if mod == nil { - t.Fatal("NewZeroGod returned nil") - } - - if mod.Name() != "zerogod" { - t.Errorf("expected module name 'zerogod', got '%s'", mod.Name()) - } - - if mod.Author() != "Simone Margaritelli " { - t.Errorf("unexpected author: %s", mod.Author()) - } - - // Check parameters - only check the ones that are directly registered - params := []string{ - "zerogod.advertise.certificate", - "zerogod.advertise.key", - "zerogod.ipp.save_path", - "zerogod.verbose", - } - for _, param := range params { - if !mod.Session.Env.Has(param) { - t.Errorf("parameter %s not registered", param) - } - } - - // Check handlers - handlers := mod.Handlers() - expectedHandlers := []string{ - "zerogod.discovery on", - "zerogod.discovery off", - "zerogod.show-full ADDRESS", - "zerogod.show ADDRESS", - "zerogod.save ADDRESS FILENAME", - "zerogod.advertise FILENAME", - "zerogod.impersonate ADDRESS", - } - - if len(handlers) != len(expectedHandlers) { - t.Errorf("expected %d handlers, got %d", len(expectedHandlers), len(handlers)) - } -} - -func TestZeroGodConfigure(t *testing.T) { - sess := createMockSession() - mod := NewZeroGod(sess) - - // Configure should succeed when not running - err := mod.Configure() - if err != nil { - t.Errorf("Configure failed: %v", err) - } - - // Force module to running state by starting it - mod.SetRunning(true, nil) - - // Configure should fail when already running - err = mod.Configure() - if err == nil { - t.Error("Configure should fail when module is already running") - } - - // Clean up - mod.SetRunning(false, nil) -} - -func TestZeroGodStartStop(t *testing.T) { - sess := createMockSession() - _ = NewZeroGod(sess) - - // Skip this test as it requires mocking private methods - t.Skip("Skipping test that requires mocking private methods") -} - -func TestZeroGodShow(t *testing.T) { - sess := createMockSession() - mod := NewZeroGod(sess) - - // Start discovery first (mock it) - mod.browser = &Browser{} - - // Test show handler - handlers := mod.Handlers() - var showHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "zerogod.show ADDRESS" { - showHandler = h - break - } - } - - if showHandler.Name == "" { - t.Fatal("Show handler not found") - } - - // Test with IP address - err := showHandler.Exec([]string{"192.168.1.10"}) - if err != nil { - t.Errorf("Show handler failed: %v", err) - } - - // Test with empty address (show all) - err = showHandler.Exec([]string{}) - if err != nil { - t.Errorf("Show handler failed with empty address: %v", err) - } -} - -func TestZeroGodShowFull(t *testing.T) { - sess := createMockSession() - mod := NewZeroGod(sess) - - // Start discovery first (mock it) - mod.browser = &Browser{} - - // Test show-full handler - handlers := mod.Handlers() - var showFullHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "zerogod.show-full ADDRESS" { - showFullHandler = h - break - } - } - - if showFullHandler.Name == "" { - t.Fatal("Show-full handler not found") - } - - // Test with IP address - err := showFullHandler.Exec([]string{"192.168.1.10"}) - if err != nil { - t.Errorf("Show-full handler failed: %v", err) - } -} - -func TestZeroGodSave(t *testing.T) { - // Skip this test as it requires actual mDNS discovery data - t.Skip("Skipping test that requires actual mDNS discovery data") -} - -func TestZeroGodAdvertise(t *testing.T) { - sess := createMockSession() - mod := NewZeroGod(sess) - - // Mock advertiser - skip test as we can't properly mock the advertiser structure - t.Skip("Skipping test that requires complex advertiser mocking") - - // Create a test YAML file with services - tmpFile, err := ioutil.TempFile("", "zerogod_advertise_*.yml") - if err != nil { - t.Fatalf("Failed to create temp file: %v", err) - } - defer os.Remove(tmpFile.Name()) - - yamlContent := `services: - - name: Test Service - type: _http._tcp - port: 8080 - txt: - - model=TestDevice - - version=1.0 -` - if _, err := tmpFile.Write([]byte(yamlContent)); err != nil { - t.Fatalf("Failed to write YAML content: %v", err) - } - tmpFile.Close() - - // Test advertise handler - handlers := mod.Handlers() - var advertiseHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "zerogod.advertise FILENAME" { - advertiseHandler = h - break - } - } - - if advertiseHandler.Name == "" { - t.Fatal("Advertise handler not found") - } - - // Note: Cannot mock methods in Go, would need interface refactoring -} - -func TestZeroGodImpersonate(t *testing.T) { - sess := createMockSession() - mod := NewZeroGod(sess) - - // Skip test as we can't properly mock the advertiser - t.Skip("Skipping test that requires complex advertiser mocking") - - // Test impersonate handler - handlers := mod.Handlers() - var impersonateHandler session.ModuleHandler - for _, h := range handlers { - if h.Name == "zerogod.impersonate ADDRESS" { - impersonateHandler = h - break - } - } - - if impersonateHandler.Name == "" { - t.Fatal("Impersonate handler not found") - } - - // Note: Cannot mock methods in Go, would need interface refactoring -} - -func TestZeroGodParameters(t *testing.T) { - // Skip parameter validation tests as Environment.Set behavior is not straightforward - t.Skip("Skipping parameter validation tests") -} - -// Test service data structure -func TestServiceData(t *testing.T) { - svc := ServiceData{ - Name: "Test Service", - Service: "_http._tcp", - Domain: "local", - Port: 8080, - Records: []string{"model=TestDevice", "version=1.0"}, - IPP: map[string]string{"attr1": "value1"}, - HTTP: map[string]string{"/": "index.html"}, - } - - // Test basic properties - if svc.Name != "Test Service" { - t.Errorf("Expected service name 'Test Service', got '%s'", svc.Name) - } - - if svc.Port != 8080 { - t.Errorf("Expected port 8080, got %d", svc.Port) - } - - if len(svc.Records) != 2 { - t.Errorf("Expected 2 records, got %d", len(svc.Records)) - } - - // Test FullName method - fullName := svc.FullName() - expected := "Test Service._http._tcp.local" - if fullName != expected { - t.Errorf("Expected full name '%s', got '%s'", expected, fullName) - } -} - -// Test endpoint handling -func TestEndpointHandling(t *testing.T) { - endpoint := &network.Endpoint{ - IpAddress: "192.168.1.10", - HwAddress: "11:11:11:11:11:11", - Hostname: "test-device", - } - - // Verify basic endpoint properties - if endpoint.IpAddress != "192.168.1.10" { - t.Errorf("Expected IP address '192.168.1.10', got '%s'", endpoint.IpAddress) - } - - if endpoint.Hostname != "test-device" { - t.Errorf("Expected hostname 'test-device', got '%s'", endpoint.Hostname) - } -} - -// Test known services lookup -func TestKnownServices(t *testing.T) { - // Skip this test as knownServices might not be available in test context - t.Skip("Skipping known services test - requires module initialization") -} - -// Benchmarks -func BenchmarkServiceDataCreation(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ServiceData{ - Name: fmt.Sprintf("Service %d", i), - Service: "_http._tcp", - Port: 8080 + i, - Domain: "local", - Records: []string{"model=Test", fmt.Sprintf("id=%d", i)}, - } - } -} - -func BenchmarkServiceDataFullName(b *testing.B) { - svc := ServiceData{ - Name: "Test Service", - Service: "_http._tcp", - Domain: "local", - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _ = svc.FullName() - } -} diff --git a/network/lan.go b/network/lan.go index 6342968d..082b4c74 100644 --- a/network/lan.go +++ b/network/lan.go @@ -62,7 +62,7 @@ func (lan *LAN) Get(mac string) (*Endpoint, bool) { if mac == lan.iface.HwAddress { return lan.iface, true - } else if lan.gateway != nil && mac == lan.gateway.HwAddress { + } else if mac == lan.gateway.HwAddress { return lan.gateway, true } @@ -78,7 +78,7 @@ func (lan *LAN) GetByIp(ip string) *Endpoint { if ip == lan.iface.IpAddress || ip == lan.iface.Ip6Address { return lan.iface - } else if lan.gateway != nil && (ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address) { + } else if ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address { return lan.gateway } @@ -107,7 +107,7 @@ func (lan *LAN) Aliases() *data.UnsortedKV { } func (lan *LAN) WasMissed(mac string) bool { - if mac == lan.iface.HwAddress || (lan.gateway != nil && mac == lan.gateway.HwAddress) { + if mac == lan.iface.HwAddress || mac == lan.gateway.HwAddress { return false } @@ -141,7 +141,7 @@ func (lan *LAN) shouldIgnore(ip, mac string) bool { return true } // skip the gateway - if lan.gateway != nil && (ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address || mac == lan.gateway.HwAddress) { + if ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address || mac == lan.gateway.HwAddress { return true } // skip broadcast addresses @@ -154,7 +154,7 @@ func (lan *LAN) shouldIgnore(ip, mac string) bool { } // skip everything which is not in our subnet (multicast noise) addr := net.ParseIP(ip) - return addr.To4() != nil && lan.iface.Net != nil && !lan.iface.Net.Contains(addr) + return addr.To4() != nil && !lan.iface.Net.Contains(addr) } func (lan *LAN) Has(ip string) bool { diff --git a/network/lan_test.go b/network/lan_test.go index e0a21676..43c989b2 100644 --- a/network/lan_test.go +++ b/network/lan_test.go @@ -1,541 +1,210 @@ package network import ( - "encoding/json" - "fmt" - "net" - "sync" "testing" "github.com/evilsocket/islazy/data" ) -// Mock endpoint creation -func createMockEndpoint(ip, mac, name string) *Endpoint { - e := NewEndpointNoResolve(ip, mac, name, 24) - _, ipNet, _ := net.ParseCIDR("192.168.1.0/24") - e.Net = ipNet - // Make sure IP is set correctly after SetNetwork - e.IpAddress = ip - e.IP = net.ParseIP(ip) - return e +func buildExampleLAN() *LAN { + iface, _ := FindInterface("") + gateway, _ := FindGateway(iface) + exNewCallback := func(e *Endpoint) {} + exLostCallback := func(e *Endpoint) {} + aliases := &data.UnsortedKV{} + return NewLAN(iface, gateway, aliases, exNewCallback, exLostCallback) } -// Mock LAN creation with controlled endpoints -func createMockLAN() (*LAN, *Endpoint, *Endpoint) { - iface := createMockEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff", "eth0") - gateway := createMockEndpoint("192.168.1.1", "11:22:33:44:55:66", "gateway") - aliases, _ := data.NewMemUnsortedKV() - - newCb := func(e *Endpoint) {} - lostCb := func(e *Endpoint) {} - - lan := NewLAN(iface, gateway, aliases, newCb, lostCb) - return lan, iface, gateway +func buildExampleEndpoint() *Endpoint { + iface, _ := FindInterface("") + return iface } func TestNewLAN(t *testing.T) { - iface := createMockEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff", "eth0") - gateway := createMockEndpoint("192.168.1.1", "11:22:33:44:55:66", "gateway") - aliases, _ := data.NewMemUnsortedKV() - - newCb := func(e *Endpoint) {} - lostCb := func(e *Endpoint) {} - - lan := NewLAN(iface, gateway, aliases, newCb, lostCb) + iface, err := FindInterface("") + if err != nil { + t.Error("no iface found", err) + } + gateway, err := FindGateway(iface) + if err != nil { + t.Error("no gateway found", err) + } + exNewCallback := func(e *Endpoint) {} + exLostCallback := func(e *Endpoint) {} + aliases := &data.UnsortedKV{} + lan := NewLAN(iface, gateway, aliases, exNewCallback, exLostCallback) if lan.iface != iface { - t.Errorf("expected iface %v, got %v", iface, lan.iface) + t.Fatalf("expected '%v', got '%v'", iface, lan.iface) } if lan.gateway != gateway { - t.Errorf("expected gateway %v, got %v", gateway, lan.gateway) + t.Fatalf("expected '%v', got '%v'", gateway, lan.gateway) } if len(lan.hosts) != 0 { - t.Errorf("expected 0 hosts, got %d", len(lan.hosts)) - } - if lan.aliases != aliases { - t.Error("aliases not properly set") + t.Fatalf("expected '%v', got '%v'", 0, len(lan.hosts)) } + // FIXME: update this to current code base + // if !(len(lan.aliases.data) >= 0) { + // t.Fatalf("expected '%v', got '%v'", 0, len(lan.aliases.data)) + // } } -func TestLANMarshalJSON(t *testing.T) { - lan, _, _ := createMockLAN() - - // Add some hosts - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - - data, err := lan.MarshalJSON() +func TestMarshalJSON(t *testing.T) { + iface, err := FindInterface("") if err != nil { - t.Errorf("MarshalJSON() error = %v", err) + t.Error("no iface found", err) } - - var result lanJSON - if err := json.Unmarshal(data, &result); err != nil { - t.Errorf("Failed to unmarshal JSON: %v", err) + gateway, err := FindGateway(iface) + if err != nil { + t.Error("no gateway found", err) } - - if len(result.Hosts) != 2 { - t.Errorf("expected 2 hosts in JSON, got %d", len(result.Hosts)) + exNewCallback := func(e *Endpoint) {} + exLostCallback := func(e *Endpoint) {} + aliases := &data.UnsortedKV{} + lan := NewLAN(iface, gateway, aliases, exNewCallback, exLostCallback) + _, err = lan.MarshalJSON() + if err != nil { + t.Error(err) } } -func TestLANGet(t *testing.T) { - lan, iface, gateway := createMockLAN() +// FIXME: update this to current code base +// func TestSetAliasFor(t *testing.T) { +// exampleAlias := "picat" +// exampleLAN := buildExampleLAN() +// exampleEndpoint := buildExampleEndpoint() +// exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint +// if !exampleLAN.SetAliasFor(exampleEndpoint.HwAddress, exampleAlias) { +// t.Error("unable to set alias for a given mac address") +// } +// } - // Test getting interface - e, found := lan.Get(iface.HwAddress) - if !found || e != iface { - t.Error("Failed to get interface") +func TestGet(t *testing.T) { + exampleLAN := buildExampleLAN() + exampleEndpoint := buildExampleEndpoint() + exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint + foundEndpoint, foundBool := exampleLAN.Get(exampleEndpoint.HwAddress) + if foundEndpoint.String() != exampleEndpoint.String() { + t.Fatalf("expected '%v', got '%v'", foundEndpoint, exampleEndpoint) } - - // Test getting gateway - e, found = lan.Get(gateway.HwAddress) - if !found || e != gateway { - t.Error("Failed to get gateway") - } - - // Add a host - testMAC := "10:20:30:40:50:60" - lan.AddIfNew("192.168.1.10", testMAC) - - // Test getting the host - e, found = lan.Get(testMAC) - if !found { - t.Error("Failed to get added host") - } - - // Test with different MAC formats - e, found = lan.Get("10-20-30-40-50-60") - if !found { - t.Error("Failed to get host with dash-separated MAC") - } - - // Test non-existent MAC - _, found = lan.Get("99:99:99:99:99:99") - if found { - t.Error("Found non-existent MAC") + if !foundBool { + t.Error("unable to get known endpoint via mac address from LAN struct") } } -func TestLANGetByIp(t *testing.T) { - lan, iface, gateway := createMockLAN() - - // Test getting interface by IP - e := lan.GetByIp(iface.IpAddress) - if e != iface { - t.Error("Failed to get interface by IP") +func TestList(t *testing.T) { + exampleLAN := buildExampleLAN() + exampleEndpoint := buildExampleEndpoint() + exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint + foundList := exampleLAN.List() + if len(foundList) != 1 { + t.Fatalf("expected '%d', got '%d'", 1, len(foundList)) } - - // Test getting gateway by IP - e = lan.GetByIp(gateway.IpAddress) - if e != gateway { - t.Errorf("Failed to get gateway by IP: wanted %v, got %v", gateway, e) - } - - // Add a host with IPv4 - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - e = lan.GetByIp("192.168.1.10") - if e == nil || e.IpAddress != "192.168.1.10" { - t.Error("Failed to get host by IPv4") - } - - // Test with IPv6 - lan.iface.SetIPv6("fe80::1") - e = lan.GetByIp("fe80::1") - if e != iface { - t.Error("Failed to get interface by IPv6") - } - - // Test non-existent IP - e = lan.GetByIp("192.168.1.99") - if e != nil { - t.Error("Found non-existent IP") + exp := 1 + got := len(exampleLAN.List()) + if got != exp { + t.Fatalf("expected '%d', got '%d'", exp, got) } } -func TestLANList(t *testing.T) { - lan, _, _ := createMockLAN() +// FIXME: update this to current code base +// func TestAliases(t *testing.T) { +// exampleAlias := "picat" +// exampleLAN := buildExampleLAN() +// exampleEndpoint := buildExampleEndpoint() +// exampleLAN.hosts["pi:ca:tw:as:he:re"] = exampleEndpoint +// exp := exampleAlias +// got := exampleLAN.Aliases().Get("pi:ca:tw:as:he:re") +// if got != exp { +// t.Fatalf("expected '%v', got '%v'", exp, got) +// } +// } - // Initially empty - list := lan.List() - if len(list) != 0 { - t.Errorf("expected empty list, got %d items", len(list)) - } - - // Add hosts - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - - list = lan.List() - if len(list) != 2 { - t.Errorf("expected 2 items, got %d", len(list)) +func TestWasMissed(t *testing.T) { + exampleLAN := buildExampleLAN() + exampleEndpoint := buildExampleEndpoint() + exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint + exp := false + got := exampleLAN.WasMissed(exampleEndpoint.HwAddress) + if got != exp { + t.Fatalf("expected '%v', got '%v'", exp, got) } } -func TestLANAliases(t *testing.T) { - lan, _, _ := createMockLAN() +// TODO Add TestRemove after removing unnecessary ip argument +// func TestRemove(t *testing.T) { +// } - aliases := lan.Aliases() - if aliases == nil { - t.Error("Aliases() returned nil") - } - - // Set an alias - aliases.Set("10:20:30:40:50:60", "test_device") - - // Verify alias is accessible - alias := lan.GetAlias("10:20:30:40:50:60") - if alias != "test_device" { - t.Errorf("expected alias 'test_device', got '%s'", alias) +func TestHas(t *testing.T) { + exampleLAN := buildExampleLAN() + exampleEndpoint := buildExampleEndpoint() + exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint + if !exampleLAN.Has(exampleEndpoint.IpAddress) { + t.Error("unable find a known IP address in LAN struct") } } -func TestLANWasMissed(t *testing.T) { - lan, iface, gateway := createMockLAN() - - // Interface and gateway should never be missed - if lan.WasMissed(iface.HwAddress) { - t.Error("Interface should never be missed") +func TestEachHost(t *testing.T) { + exampleBuffer := []string{} + exampleLAN := buildExampleLAN() + exampleEndpoint := buildExampleEndpoint() + exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint + exampleCB := func(mac string, e *Endpoint) { + exampleBuffer = append(exampleBuffer, exampleEndpoint.HwAddress) } - if lan.WasMissed(gateway.HwAddress) { - t.Error("Gateway should never be missed") - } - - // Unknown host should be missed - if !lan.WasMissed("99:99:99:99:99:99") { - t.Error("Unknown host should be missed") - } - - // Add a host - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - if lan.WasMissed("10:20:30:40:50:60") { - t.Error("Newly added host should not be missed") - } - - // Decrease TTL - lan.ttl["10:20:30:40:50:60"] = 5 - if !lan.WasMissed("10:20:30:40:50:60") { - t.Error("Host with low TTL should be missed") + exampleLAN.EachHost(exampleCB) + exp := 1 + got := len(exampleBuffer) + if got != exp { + t.Fatalf("expected '%d', got '%d'", exp, got) } } -func TestLANRemove(t *testing.T) { - lan, _, _ := createMockLAN() +func TestGetByIp(t *testing.T) { + exampleLAN := buildExampleLAN() + exampleEndpoint := buildExampleEndpoint() + exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint - lostCalled := false - lostEndpoint := (*Endpoint)(nil) - lan.lostCb = func(e *Endpoint) { - lostCalled = true - lostEndpoint = e - } - - // Add a host - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - - // Remove it multiple times to decrease TTL - for i := 0; i < LANDefaultttl; i++ { - lan.Remove("192.168.1.10", "10:20:30:40:50:60") - } - - // Verify it was removed - _, found := lan.Get("10:20:30:40:50:60") - if found { - t.Error("Host should have been removed") - } - - // Verify callback was called - if !lostCalled { - t.Error("Lost callback should have been called") - } - if lostEndpoint == nil || lostEndpoint.HwAddress != "10:20:30:40:50:60" { - t.Error("Lost callback received wrong endpoint") - } - - // Try removing non-existent host - lan.Remove("192.168.1.99", "99:99:99:99:99:99") // Should not panic -} - -func TestLANShouldIgnore(t *testing.T) { - lan, iface, gateway := createMockLAN() - - tests := []struct { - name string - ip string - mac string - ignore bool - }{ - {"own IP", iface.IpAddress, "99:99:99:99:99:99", true}, - {"own MAC", "192.168.1.99", iface.HwAddress, true}, - {"gateway IP", gateway.IpAddress, "99:99:99:99:99:99", true}, - {"gateway MAC", "192.168.1.99", gateway.HwAddress, true}, - {"broadcast IP", "192.168.1.255", "99:99:99:99:99:99", true}, - {"broadcast MAC", "192.168.1.99", BroadcastMac, true}, - {"multicast outside subnet", "10.0.0.1", "99:99:99:99:99:99", true}, - {"valid host", "192.168.1.10", "10:20:30:40:50:60", false}, - {"IPv6 address", "fe80::1", "10:20:30:40:50:60", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := lan.shouldIgnore(tt.ip, tt.mac); got != tt.ignore { - t.Errorf("shouldIgnore() = %v, want %v", got, tt.ignore) - } - }) + exp := exampleEndpoint + got := exampleLAN.GetByIp(exampleEndpoint.IpAddress) + if got.String() != exp.String() { + t.Fatalf("expected '%v', got '%v'", exp, got) } } -func TestLANHas(t *testing.T) { - lan, _, _ := createMockLAN() - - // Add hosts - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - - if !lan.Has("192.168.1.10") { - t.Error("Has() should return true for existing IP") - } - if !lan.Has("192.168.1.20") { - t.Error("Has() should return true for existing IP") - } - if lan.Has("192.168.1.99") { - t.Error("Has() should return false for non-existent IP") +func TestAddIfNew(t *testing.T) { + exampleLAN := buildExampleLAN() + iface, _ := FindInterface("") + // won't add our own IP address + if exampleLAN.AddIfNew(iface.IpAddress, iface.HwAddress) != nil { + t.Error("added address that should've been ignored ( your own )") } } -func TestLANEachHost(t *testing.T) { - lan, _, _ := createMockLAN() +// FIXME: update this to current code base +// func TestGetAlias(t *testing.T) { +// exampleAlias := "picat" +// exampleLAN := buildExampleLAN() +// exampleEndpoint := buildExampleEndpoint() +// exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint +// exp := exampleAlias +// got := exampleLAN.GetAlias(exampleEndpoint.HwAddress) +// if got != exp { +// t.Fatalf("expected '%v', got '%v'", exp, got) +// } +// } - // Add hosts - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - - count := 0 - macs := make([]string, 0) - - lan.EachHost(func(mac string, e *Endpoint) { - count++ - macs = append(macs, mac) - }) - - if count != 2 { - t.Errorf("expected 2 hosts, got %d", count) +func TestShouldIgnore(t *testing.T) { + exampleLAN := buildExampleLAN() + iface, _ := FindInterface("") + gateway, _ := FindGateway(iface) + exp := true + got := exampleLAN.shouldIgnore(iface.IpAddress, iface.HwAddress) + if got != exp { + t.Fatalf("expected '%v', got '%v'", exp, got) } - if len(macs) != 2 { - t.Errorf("expected 2 MACs, got %d", len(macs)) - } -} - -func TestLANAddIfNew(t *testing.T) { - lan, _, _ := createMockLAN() - - newCalled := false - newEndpoint := (*Endpoint)(nil) - lan.newCb = func(e *Endpoint) { - newCalled = true - newEndpoint = e - } - - // Add new host - result := lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - if result != nil { - t.Error("AddIfNew should return nil for new host") - } - if !newCalled { - t.Error("New callback should have been called") - } - if newEndpoint == nil || newEndpoint.IpAddress != "192.168.1.10" { - t.Error("New callback received wrong endpoint") - } - - // Add same host again (should update TTL) - lan.ttl["10:20:30:40:50:60"] = 5 - result = lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - if result == nil { - t.Error("AddIfNew should return existing endpoint") - } - if lan.ttl["10:20:30:40:50:60"] != 6 { - t.Error("TTL should have been incremented") - } - - // Add IPv6 to existing host - result = lan.AddIfNew("fe80::10", "10:20:30:40:50:60") - if result == nil || result.Ip6Address != "fe80::10" { - t.Error("Should have added IPv6 to existing host") - } - - // Add IPv4 to host that only has IPv6 - // Note: Due to current implementation, IPv6 addresses are initially stored in IpAddress field - newCalled = false - lan.AddIfNew("fe80::20", "20:30:40:50:60:70") - result = lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - if result == nil { - t.Error("Should have returned existing endpoint when adding IPv4") - } - // The implementation updates the IPv4 address when it detects we're adding an IPv4 to a host - // that was initially created with IPv6 - if result != nil && result.IpAddress != "192.168.1.20" { - // This is expected behavior - the initial IPv6 is stored in IpAddress - // Skip this check as it's a known limitation - t.Skip("Known limitation: IPv6 addresses are initially stored in IPv4 field") - } - - // Try to add own interface (should be ignored) - result = lan.AddIfNew(lan.iface.IpAddress, lan.iface.HwAddress) - if result != nil { - t.Error("Should ignore own interface") - } -} - -func TestLANGetAlias(t *testing.T) { - lan, _, _ := createMockLAN() - - // Set alias - lan.aliases.Set("10:20:30:40:50:60", "test_device") - - // Get existing alias - alias := lan.GetAlias("10:20:30:40:50:60") - if alias != "test_device" { - t.Errorf("expected 'test_device', got '%s'", alias) - } - - // Get non-existent alias - alias = lan.GetAlias("99:99:99:99:99:99") - if alias != "" { - t.Errorf("expected empty string for non-existent alias, got '%s'", alias) - } -} - -func TestLANClear(t *testing.T) { - lan, _, _ := createMockLAN() - - // Add hosts - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - - // Verify hosts exist - if len(lan.hosts) != 2 { - t.Errorf("expected 2 hosts, got %d", len(lan.hosts)) - } - if len(lan.ttl) != 2 { - t.Errorf("expected 2 ttl entries, got %d", len(lan.ttl)) - } - - // Clear - lan.Clear() - - // Verify cleared - if len(lan.hosts) != 0 { - t.Errorf("expected 0 hosts after clear, got %d", len(lan.hosts)) - } - if len(lan.ttl) != 0 { - t.Errorf("expected 0 ttl entries after clear, got %d", len(lan.ttl)) - } -} - -func TestLANConcurrency(t *testing.T) { - lan, _, _ := createMockLAN() - - // Test concurrent access - var wg sync.WaitGroup - - // Writer goroutines - for i := 0; i < 10; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - ip := fmt.Sprintf("192.168.1.%d", 10+i) - mac := fmt.Sprintf("10:20:30:40:50:%02x", i) - lan.AddIfNew(ip, mac) - }(i) - } - - // Reader goroutines - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - _ = lan.List() - _ = lan.Has("192.168.1.10") - lan.EachHost(func(mac string, e *Endpoint) {}) - }() - } - - wg.Wait() - - // Verify some hosts were added - list := lan.List() - if len(list) == 0 { - t.Error("No hosts added during concurrent test") - } -} - -func TestLANWithAlias(t *testing.T) { - iface := createMockEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff", "eth0") - gateway := createMockEndpoint("192.168.1.1", "11:22:33:44:55:66", "gateway") - aliases, _ := data.NewMemUnsortedKV() - - // Pre-set an alias - aliases.Set("10:20:30:40:50:60", "printer") - - lan := NewLAN(iface, gateway, aliases, func(e *Endpoint) {}, func(e *Endpoint) {}) - - // Add host with pre-existing alias - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - - // Get the endpoint - e, found := lan.Get("10:20:30:40:50:60") - if !found { - t.Fatal("Failed to find endpoint") - } - - // Check if alias was applied - if e.Alias != "printer" { - t.Errorf("expected alias 'printer', got '%s'", e.Alias) - } -} - -// Benchmarks -func BenchmarkLANAddIfNew(b *testing.B) { - lan, _, _ := createMockLAN() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ip := fmt.Sprintf("192.168.1.%d", (i%250)+2) - mac := fmt.Sprintf("10:20:30:40:%02x:%02x", i/256, i%256) - lan.AddIfNew(ip, mac) - } -} - -func BenchmarkLANGet(b *testing.B) { - lan, _, _ := createMockLAN() - - // Pre-populate - for i := 0; i < 100; i++ { - ip := fmt.Sprintf("192.168.1.%d", i+10) - mac := fmt.Sprintf("10:20:30:40:50:%02x", i) - lan.AddIfNew(ip, mac) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mac := fmt.Sprintf("10:20:30:40:50:%02x", i%100) - lan.Get(mac) - } -} - -func BenchmarkLANList(b *testing.B) { - lan, _, _ := createMockLAN() - - // Pre-populate - for i := 0; i < 100; i++ { - ip := fmt.Sprintf("192.168.1.%d", i+10) - mac := fmt.Sprintf("10:20:30:40:50:%02x", i) - lan.AddIfNew(ip, mac) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = lan.List() + got = exampleLAN.shouldIgnore(gateway.IpAddress, gateway.HwAddress) + if got != exp { + t.Fatalf("expected '%v', got '%v'", exp, got) } } diff --git a/network/net.go b/network/net.go index b01fd3c0..f925b37d 100644 --- a/network/net.go +++ b/network/net.go @@ -41,7 +41,7 @@ var ( `(?:25[0-5]|2[0-4][0-9]|[1][0-9]{2}|[1-9]?[0-9])` + `$`) MACValidator = regexp.MustCompile(`(?i)^(?:[a-f0-9]{2}:){5}[a-f0-9]{2}$`) // lulz this sounds like a hamburger - macParser = regexp.MustCompile(`(?i)((?:[a-f0-9]{2}[:-]){5}[a-f0-9]{2})`) + macParser = regexp.MustCompile(`(?i)((?:[a-f0-9]{2}:){5}[a-f0-9]{2})`) aliasParser = regexp.MustCompile(`(?i)([a-z_][a-z_0-9]+)`) ) diff --git a/network/net_linux.go b/network/net_linux.go index 04fcd123..f73f6b3f 100644 --- a/network/net_linux.go +++ b/network/net_linux.go @@ -41,9 +41,7 @@ func SetInterfaceChannel(iface string, channel int) error { if core.HasBinary("iw") { // Debug("SetInterfaceChannel(%s, %d) iw based", iface, channel) - // out, err := core.Exec("iw", []string{"dev", iface, "set", "channel", fmt.Sprintf("%d", channel)}) - out, err := core.Exec("iw", []string{"dev", iface, "set", "freq", fmt.Sprintf("%d", Dot11Chan2Freq(channel))}) - + out, err := core.Exec("iw", []string{"dev", iface, "set", "channel", fmt.Sprintf("%d", channel)}) if err != nil { return fmt.Errorf("iw: out=%s err=%s", out, err) } else if out != "" { @@ -91,8 +89,7 @@ func iwlistSupportedFrequencies(iface string) ([]int, error) { } var iwPhyParser = regexp.MustCompile(`^\s*wiphy\s+(\d+)$`) -// var iwFreqParser = regexp.MustCompile(`^\s+\*\s+(\d+)\s+MHz.+dBm.+$`) -var iwFreqParser = regexp.MustCompile(`^\s+\*\s+(\d+)\.\d+\s+MHz.+dBm.+$`) +var iwFreqParser = regexp.MustCompile(`^\s+\*\s+(\d+)\s+MHz.+dBm.+$`) func iwSupportedFrequencies(iface string) ([]int, error) { // first determine phy index @@ -143,11 +140,10 @@ func iwSupportedFrequencies(iface string) ([]int, error) { func GetSupportedFrequencies(iface string) ([]int, error) { // give priority to iwlist because of https://github.com/bettercap/bettercap/issues/881 - // UPDATE: Changed the priority due iwlist doesn't support 6GHz - if core.HasBinary("iw") { - return iwSupportedFrequencies(iface) - } else if core.HasBinary("iwlist") { + if core.HasBinary("iwlist") { return iwlistSupportedFrequencies(iface) + } else if core.HasBinary("iw") { + return iwSupportedFrequencies(iface) } return nil, fmt.Errorf("no iw or iwlist binaries found in $PATH") diff --git a/network/net_test.go b/network/net_test.go index 60f634ae..dcf08d8e 100644 --- a/network/net_test.go +++ b/network/net_test.go @@ -1,306 +1,102 @@ package network import ( - "fmt" "net" - "strings" "testing" "github.com/evilsocket/islazy/data" ) func TestIsZeroMac(t *testing.T) { - tests := []struct { - name string - mac string - expected bool - }{ - {"zero mac", "00:00:00:00:00:00", true}, - {"non-zero mac", "00:00:00:00:00:01", false}, - {"broadcast mac", "ff:ff:ff:ff:ff:ff", false}, - {"random mac", "aa:bb:cc:dd:ee:ff", false}, - } + exampleMAC, _ := net.ParseMAC("00:00:00:00:00:00") - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mac, _ := net.ParseMAC(tt.mac) - if got := IsZeroMac(mac); got != tt.expected { - t.Errorf("IsZeroMac() = %v, want %v", got, tt.expected) - } - }) + exp := true + got := IsZeroMac(exampleMAC) + if got != exp { + t.Fatalf("expected '%t', got '%t'", exp, got) } } func TestIsBroadcastMac(t *testing.T) { - tests := []struct { - name string - mac string - expected bool - }{ - {"broadcast mac", "ff:ff:ff:ff:ff:ff", true}, - {"zero mac", "00:00:00:00:00:00", false}, - {"partial broadcast", "ff:ff:ff:ff:ff:00", false}, - {"random mac", "aa:bb:cc:dd:ee:ff", false}, - } + exampleMAC, _ := net.ParseMAC("ff:ff:ff:ff:ff:ff") - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mac, _ := net.ParseMAC(tt.mac) - if got := IsBroadcastMac(mac); got != tt.expected { - t.Errorf("IsBroadcastMac() = %v, want %v", got, tt.expected) - } - }) + exp := true + got := IsBroadcastMac(exampleMAC) + if got != exp { + t.Fatalf("expected '%t', got '%t'", exp, got) } } func TestNormalizeMac(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - {"uppercase with colons", "AA:BB:CC:DD:EE:FF", "aa:bb:cc:dd:ee:ff"}, - {"uppercase with dashes", "AA-BB-CC-DD-EE-FF", "aa:bb:cc:dd:ee:ff"}, - {"lowercase with colons", "aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:ff"}, - {"mixed case with dashes", "aA-bB-cC-dD-eE-fF", "aa:bb:cc:dd:ee:ff"}, - {"short segments", "a:b:c:d:e:f", "0a:0b:0c:0d:0e:0f"}, - {"mixed short and full", "aa:b:cc:d:ee:f", "aa:0b:cc:0d:ee:0f"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NormalizeMac(tt.input); got != tt.expected { - t.Errorf("NormalizeMac(%q) = %v, want %v", tt.input, got, tt.expected) - } - }) - } -} - -func TestParseMACs(t *testing.T) { - tests := []struct { - name string - input string - expected []string - expectError bool - }{ - { - name: "single MAC", - input: "aa:bb:cc:dd:ee:ff", - expected: []string{"aa:bb:cc:dd:ee:ff"}, - }, - { - name: "multiple MACs comma separated", - input: "aa:bb:cc:dd:ee:ff, 11:22:33:44:55:66", - expected: []string{"aa:bb:cc:dd:ee:ff", "11:22:33:44:55:66"}, - }, - { - name: "MACs with dashes", - input: "AA-BB-CC-DD-EE-FF", - expected: []string{"aa:bb:cc:dd:ee:ff"}, - }, - { - name: "empty string", - input: "", - expected: []string{}, - }, - { - name: "whitespace only", - input: " ", - expected: []string{}, - }, - { - name: "mixed formats", - input: "aa:bb:cc:dd:ee:ff, AA-BB-CC-DD-EE-00", - expected: []string{"aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:00"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - macs, err := ParseMACs(tt.input) - if (err != nil) != tt.expectError { - t.Errorf("ParseMACs() error = %v, expectError %v", err, tt.expectError) - return - } - if len(macs) != len(tt.expected) { - t.Errorf("ParseMACs() returned %d MACs, want %d", len(macs), len(tt.expected)) - return - } - for i, mac := range macs { - if mac.String() != tt.expected[i] { - t.Errorf("ParseMACs()[%d] = %v, want %v", i, mac.String(), tt.expected[i]) - } - } - }) + exp := "ff:ff:ff:ff:ff:ff" + got := NormalizeMac("fF-fF-fF-fF-fF-fF") + if got != exp { + t.Fatalf("expected '%s', got '%s'", exp, got) } } +// TODO: refactor to parse targets with an actual alias map func TestParseTargets(t *testing.T) { aliasMap, err := data.NewMemUnsortedKV() if err != nil { - t.Fatal(err) + panic(err) } - aliasMap.Set("aa:bb:cc:dd:ee:ff", "test_alias") - aliasMap.Set("11:22:33:44:55:66", "home_laptop") + aliasMap.Set("5c:00:0b:90:a9:f0", "test_alias") + aliasMap.Set("5c:00:0b:90:a9:f1", "Home_Laptop") cases := []struct { - name string - inputTargets string - inputAliases *data.UnsortedKV - expectedIPCount int - expectedMACCount int - expectError bool + Name string + InputTargets string + InputAliases *data.UnsortedKV + ExpectedIPCount int + ExpectedMACCount int + ExpectedError bool }{ + // Not sure how to trigger sad path where macParser.FindAllString() + // finds a MAC but net.ParseMac() fails on the result. { - name: "empty target string", - inputTargets: "", - inputAliases: &data.UnsortedKV{}, - expectedIPCount: 0, - expectedMACCount: 0, - expectError: false, + "empty target string causes empty return", + "", + &data.UnsortedKV{}, + 0, + 0, + false, }, { - name: "MACs and IPs", - inputTargets: "192.168.1.2, 192.168.1.3, aa:bb:cc:dd:ee:ff, 11:22:33:44:55:66", - inputAliases: &data.UnsortedKV{}, - expectedIPCount: 2, - expectedMACCount: 2, - expectError: false, + "MACs are parsed", + "192.168.1.2, 192.168.1.3, 5c:00:0b:90:a9:f0, 6c:00:0b:90:a9:f0, 6C:00:0B:90:A9:F0", + &data.UnsortedKV{}, + 2, + 3, + false, }, { - name: "aliases", - inputTargets: "test_alias, home_laptop", - inputAliases: aliasMap, - expectedIPCount: 0, - expectedMACCount: 2, - expectError: false, - }, - { - name: "mixed aliases and MACs", - inputTargets: "test_alias, 99:88:77:66:55:44", - inputAliases: aliasMap, - expectedIPCount: 0, - expectedMACCount: 2, - expectError: false, - }, - { - name: "IP range", - inputTargets: "192.168.1.1-3", - inputAliases: &data.UnsortedKV{}, - expectedIPCount: 3, - expectedMACCount: 0, - expectError: false, - }, - { - name: "CIDR notation", - inputTargets: "192.168.1.0/30", - inputAliases: &data.UnsortedKV{}, - expectedIPCount: 4, - expectedMACCount: 0, - expectError: false, - }, - { - name: "unknown alias", - inputTargets: "unknown_alias", - inputAliases: aliasMap, - expectedIPCount: 0, - expectedMACCount: 0, - expectError: true, - }, - { - name: "invalid IP", - inputTargets: "invalid.ip.address", - inputAliases: &data.UnsortedKV{}, - expectedIPCount: 0, - expectedMACCount: 0, - expectError: true, + "Aliases are parsed", + "test_alias, Home_Laptop", + aliasMap, + 0, + 2, + false, }, } - for _, test := range cases { - t.Run(test.name, func(t *testing.T) { - ips, macs, err := ParseTargets(test.inputTargets, test.inputAliases) - if (err != nil) != test.expectError { - t.Errorf("ParseTargets() error = %v, expectError %v", err, test.expectError) + t.Run(test.Name, func(t *testing.T) { + ips, macs, err := ParseTargets(test.InputTargets, test.InputAliases) + if err != nil && !test.ExpectedError { + t.Errorf("unexpected error: %s", err) } - if test.expectError { + if err == nil && test.ExpectedError { + t.Error("Expected error, but got none") + } + if test.ExpectedError { return } - if len(ips) != test.expectedIPCount { - t.Errorf("Wrong number of IPs. Got %d, want %d", len(ips), test.expectedIPCount) + if len(ips) != test.ExpectedIPCount { + t.Errorf("Wrong number of IPs. Got %v for targets %s", ips, test.InputTargets) } - if len(macs) != test.expectedMACCount { - t.Errorf("Wrong number of MACs. Got %d, want %d", len(macs), test.expectedMACCount) - } - }) - } -} - -func TestParseEndpoints(t *testing.T) { - // Create a mock LAN with some endpoints - iface := NewEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff") - gateway := NewEndpoint("192.168.1.1", "11:22:33:44:55:66") - aliases, _ := data.NewMemUnsortedKV() - - // Need to provide non-nil callbacks - newCb := func(e *Endpoint) {} - lostCb := func(e *Endpoint) {} - lan := NewLAN(iface, gateway, aliases, newCb, lostCb) - - // Add test endpoints - lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60") - lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70") - - // Set up an alias - aliases.Set("10:20:30:40:50:60", "test_device") - - tests := []struct { - name string - targets string - expectedCount int - expectError bool - }{ - { - name: "single IP", - targets: "192.168.1.10", - expectedCount: 1, - }, - { - name: "single MAC", - targets: "10:20:30:40:50:60", - expectedCount: 1, - }, - { - name: "alias", - targets: "test_device", - expectedCount: 1, - }, - { - name: "multiple targets", - targets: "192.168.1.10, 20:30:40:50:60:70", - expectedCount: 2, - }, - { - name: "unknown IP", - targets: "192.168.1.99", - expectedCount: 0, - }, - { - name: "invalid target", - targets: "invalid", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - endpoints, err := ParseEndpoints(tt.targets, lan) - if (err != nil) != tt.expectError { - t.Errorf("ParseEndpoints() error = %v, expectError %v", err, tt.expectError) - } - if !tt.expectError && len(endpoints) != tt.expectedCount { - t.Errorf("ParseEndpoints() returned %d endpoints, want %d", len(endpoints), tt.expectedCount) + if len(macs) != test.ExpectedMACCount { + t.Errorf("Wrong number of MACs. Got %v for targets %s", macs, test.InputTargets) } }) } @@ -309,253 +105,65 @@ func TestParseEndpoints(t *testing.T) { func TestBuildEndpointFromInterface(t *testing.T) { ifaces, err := net.Interfaces() if err != nil { - t.Skip("Unable to get network interfaces") + t.Error(err) } - if len(ifaces) == 0 { - t.Skip("No network interfaces available") + if len(ifaces) <= 0 { + t.Error("Unable to find any network interfaces to run test with.") } - - // Find a suitable interface for testing - var testIface *net.Interface - for _, iface := range ifaces { - if iface.HardwareAddr != nil && len(iface.HardwareAddr) > 0 { - testIface = &iface - break - } - } - - if testIface == nil { - t.Skip("No suitable network interface found for testing") - } - - endpoint, err := buildEndpointFromInterface(*testIface) + _, err = buildEndpointFromInterface(ifaces[0]) if err != nil { - t.Fatalf("buildEndpointFromInterface() error = %v", err) - } - - if endpoint == nil { - t.Fatal("buildEndpointFromInterface() returned nil endpoint") - } - - // Verify basic properties - if endpoint.Index != testIface.Index { - t.Errorf("endpoint.Index = %d, want %d", endpoint.Index, testIface.Index) - } - - if endpoint.HwAddress != testIface.HardwareAddr.String() { - t.Errorf("endpoint.HwAddress = %s, want %s", endpoint.HwAddress, testIface.HardwareAddr.String()) - } -} - -func TestMatchByAddress(t *testing.T) { - // Create a mock interface for testing - mac, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - iface := net.Interface{ - Name: "eth0", - HardwareAddr: mac, - } - - tests := []struct { - name string - search string - expected bool - }{ - {"exact MAC match", "aa:bb:cc:dd:ee:ff", true}, - {"MAC with different case", "AA:BB:CC:DD:EE:FF", true}, - {"MAC with dashes", "aa-bb-cc-dd-ee-ff", true}, - {"different MAC", "11:22:33:44:55:66", false}, - {"partial MAC", "aa:bb:cc", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := matchByAddress(iface, tt.search); got != tt.expected { - t.Errorf("matchByAddress() = %v, want %v", got, tt.expected) - } - }) + t.Error(err) } } func TestFindInterfaceByName(t *testing.T) { ifaces, err := net.Interfaces() if err != nil { - t.Skip("Unable to get network interfaces") + t.Error(err) } - if len(ifaces) == 0 { - t.Skip("No network interfaces available") + if len(ifaces) <= 0 { + t.Error("Unable to find any network interfaces to run test with.") } - - // Test with first available interface - testIface := ifaces[0] - - // Test finding by name - endpoint, err := findInterfaceByName(testIface.Name, ifaces) + var exampleIface net.Interface + // emulate libpcap's pcap_lookupdev function to find + // default interface to test with ( maybe could use loopback ? ) + for _, iface := range ifaces { + if iface.HardwareAddr != nil { + exampleIface = iface + break + } + } + foundEndpoint, err := findInterfaceByName(exampleIface.Name, ifaces) if err != nil { - t.Errorf("findInterfaceByName() error = %v", err) + t.Error("unable to find a given interface by name to build endpoint", err) } - if endpoint != nil && endpoint.Name() != testIface.Name { - t.Errorf("findInterfaceByName() returned wrong interface") - } - - // Test with non-existent interface - _, err = findInterfaceByName("nonexistent999", ifaces) - if err == nil { - t.Error("findInterfaceByName() should return error for non-existent interface") + if foundEndpoint.Name() != exampleIface.Name { + t.Error("unable to find a given interface by name to build endpoint") } } func TestFindInterface(t *testing.T) { - // Test with empty name (should return first suitable interface) - endpoint, err := FindInterface("") - if err != nil && err != ErrNoIfaces { - t.Errorf("FindInterface() unexpected error = %v", err) - } - - // Test with specific interface name ifaces, err := net.Interfaces() - if err == nil && len(ifaces) > 0 { - endpoint, err = FindInterface(ifaces[0].Name) - if err != nil { - t.Errorf("FindInterface() error = %v", err) - } - if endpoint != nil && endpoint.Name() != ifaces[0].Name { - t.Errorf("FindInterface() returned wrong interface") + if err != nil { + t.Error(err) + } + if len(ifaces) <= 0 { + t.Error("Unable to find any network interfaces to run test with.") + } + var exampleIface net.Interface + // emulate libpcap's pcap_lookupdev function to find + // default interface to test with ( maybe could use loopback ? ) + for _, iface := range ifaces { + if iface.HardwareAddr != nil { + exampleIface = iface + break } } - - // Test with non-existent interface - _, err = FindInterface("nonexistent999") - if err == nil { - t.Error("FindInterface() should return error for non-existent interface") - } -} - -func TestColorRSSI(t *testing.T) { - tests := []struct { - name string - rssi int - }{ - {"excellent signal", -30}, - {"very good signal", -67}, - {"good signal", -70}, - {"fair signal", -80}, - {"poor signal", -90}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ColorRSSI(tt.rssi) - // Just ensure it returns a non-empty string - if result == "" { - t.Error("ColorRSSI() returned empty string") - } - // Check it contains the dBm value - expected := fmt.Sprintf("%d dBm", tt.rssi) - if !strings.Contains(result, expected) { - t.Errorf("ColorRSSI() result doesn't contain expected value %s", expected) - } - }) - } -} - -func TestSetWiFiRegion(t *testing.T) { - // This test will likely fail without proper permissions - // Just ensure the function doesn't panic - err := SetWiFiRegion("US") - // We don't check the error as it requires root/iw binary - _ = err -} - -func TestActivateInterface(t *testing.T) { - // This test will likely fail without proper permissions - // Just ensure the function doesn't panic - err := ActivateInterface("nonexistent") - // We expect an error for non-existent interface - if err == nil { - t.Error("ActivateInterface() should return error for non-existent interface") - } -} - -func TestSetInterfaceTxPower(t *testing.T) { - // This test will likely fail without proper permissions - // Just ensure the function doesn't panic - err := SetInterfaceTxPower("nonexistent", 20) - // We don't check the error as it requires root/iw binary - _ = err -} - -func TestGatewayProvidedByUser(t *testing.T) { - iface := NewEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff") - - tests := []struct { - name string - gateway string - expectError bool - }{ - { - name: "valid IPv4", - gateway: "192.168.1.1", - expectError: false, // Will error without actual ARP - }, - { - name: "invalid IPv4", - gateway: "999.999.999.999", - expectError: true, - }, - { - name: "not an IP", - gateway: "not-an-ip", - expectError: true, - }, - { - name: "IPv6", - gateway: "fe80::1", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := GatewayProvidedByUser(iface, tt.gateway) - // We always expect an error in tests as we can't do actual ARP lookup - if err == nil { - t.Error("GatewayProvidedByUser() expected error in test environment") - } - }) - } -} - -// Benchmarks -func BenchmarkNormalizeMac(b *testing.B) { - macs := []string{ - "AA:BB:CC:DD:EE:FF", - "aa-bb-cc-dd-ee-ff", - "a:b:c:d:e:f", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NormalizeMac(macs[i%len(macs)]) - } -} - -func BenchmarkParseMACs(b *testing.B) { - input := "aa:bb:cc:dd:ee:ff, 11:22:33:44:55:66, AA-BB-CC-DD-EE-FF" - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = ParseMACs(input) - } -} - -func BenchmarkParseTargets(b *testing.B) { - aliases, _ := data.NewMemUnsortedKV() - aliases.Set("aa:bb:cc:dd:ee:ff", "test_alias") - - targets := "192.168.1.1-10, aa:bb:cc:dd:ee:ff, test_alias" - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _, _ = ParseTargets(targets, aliases) + foundEndpoint, err := FindInterface(exampleIface.Name) + if err != nil { + t.Error("unable to find a given interface by name to build endpoint", err) + } + if foundEndpoint.Name() != exampleIface.Name { + t.Error("unable to find a given interface by name to build endpoint") } } diff --git a/network/wifi.go b/network/wifi.go index 29e374d0..2ec4b435 100644 --- a/network/wifi.go +++ b/network/wifi.go @@ -25,30 +25,22 @@ func Dot11Freq2Chan(freq int) int { return ((freq - 5035) / 5) + 7 } else if freq >= 5875 && freq <= 5895 { return 177 - } else if freq >= 5955 && freq <= 7115 { // 6GHz - return ((freq - 5955) / 5) + 1 } return 0 } + func Dot11Chan2Freq(channel int) int { - if channel <= 13 { - return ((channel - 1) * 5) + 2412 - } else if channel == 14 { - return 2484 - } else if channel == 36 || channel == 40 || channel == 44 || channel == 48 || - channel == 52 || channel == 56 || channel == 60 || channel == 64 || - channel == 68 || channel == 72 || channel == 76 || channel == 80 || - channel == 100 || channel == 104 || channel == 108 || channel == 112 || - channel == 116 || channel == 120 || channel == 124 || channel == 128 || - channel == 132 || channel == 136 || channel == 140 || channel == 144 || - channel == 149 || channel == 153 || channel == 157 || channel == 161 || - channel == 165 || channel == 169 || channel == 173 || channel == 177 { - return ((channel - 7) * 5) + 5035 -// 6GHz - Skipped 1-13 to avoid 2Ghz channels conflict - } else if channel >= 17 && channel <= 253 { - return ((channel - 1) * 5) + 5955 - } - return 0 + if channel <= 13 { + return ((channel - 1) * 5) + 2412 + } else if channel == 14 { + return 2484 + } else if channel <= 173 { + return ((channel - 7) * 5) + 5035 + } else if channel == 177 { + return 5885 + } + + return 0 } type APNewCallback func(ap *AccessPoint) diff --git a/network/wifi_test.go b/network/wifi_test.go index efdcdc47..96318389 100644 --- a/network/wifi_test.go +++ b/network/wifi_test.go @@ -1,7 +1,6 @@ package network import ( - "net" "testing" "github.com/evilsocket/islazy/data" @@ -20,14 +19,6 @@ var dot11TestVector = []dot11pair{ {5885, 177}, } -func buildExampleEndpoint() *Endpoint { - e := NewEndpointNoResolve("192.168.1.100", "aa:bb:cc:dd:ee:ff", "wlan0", 0) - e.SetNetwork("192.168.1.0/24") - _, ipNet, _ := net.ParseCIDR("192.168.1.0/24") - e.Net = ipNet - return e -} - func buildExampleWiFi() *WiFi { aliases := &data.UnsortedKV{} return NewWiFi(buildExampleEndpoint(), aliases, func(ap *AccessPoint) {}, func(ap *AccessPoint) {}) diff --git a/openwrt.makefile b/openwrt.makefile new file mode 100644 index 00000000..1e9d4eb5 --- /dev/null +++ b/openwrt.makefile @@ -0,0 +1,52 @@ +include $(TOPDIR)/rules.mk + +PKG_NAME:=bettercap +PKG_VERSION:=2.28 +PKG_RELEASE:=2 + +GO_PKG:=github.com/bettercap/bettercap + +PKG_SOURCE:=$(PKG_NAME)-$(PKG_VERSION).tar.gz +PKG_SOURCE_URL:=https://codeload.github.com/bettercap/bettercap/tar.gz/v${PKG_VERSION}? +PKG_HASH:=5bde85117679c6ed8b5469a5271cdd5f7e541bd9187b8d0f26dee790c37e36e9 +PKG_BUILD_DIR:=$(BUILD_DIR)/$(PKG_NAME)-$(PKG_VERSION) + +PKG_LICENSE:=GPL-3.0 +PKG_LICENSE_FILES:=LICENSE.md +PKG_MAINTAINER:=Dylan Corrales + +PKG_BUILD_DEPENDS:=golang/host +PKG_BUILD_PARALLEL:=1 +PKG_USE_MIPS16:=0 + +include $(INCLUDE_DIR)/package.mk +include ../../../packages/lang/golang/golang-package.mk + +define Package/bettercap/Default + TITLE:=The Swiss Army knife for 802.11, BLE and Ethernet networks reconnaissance and MITM attacks. + URL:=https://www.bettercap.org/ + DEPENDS:=$(GO_ARCH_DEPENDS) libpcap libusb-1.0 +endef + +define Package/bettercap +$(call Package/bettercap/Default) + SECTION:=net + CATEGORY:=Network +endef + +define Package/bettercap/description + bettercap is a powerful, easily extensible and portable framework written + in Go which aims to offer to security researchers, red teamers and reverse + engineers an easy to use, all-in-one solution with all the features they + might possibly need for performing reconnaissance and attacking WiFi + networks, Bluetooth Low Energy devices, wireless HID devices and Ethernet networks. +endef + +define Package/bettercap/install + $(call GoPackage/Package/Install/Bin,$(PKG_INSTALL_DIR)) + $(INSTALL_DIR) $(1)/usr/bin + $(INSTALL_BIN) $(PKG_INSTALL_DIR)/usr/bin/bettercap $(1)/usr/bin/bettercap +endef + +$(eval $(call GoBinPackage,bettercap)) +$(eval $(call BuildPackage,bettercap)) \ No newline at end of file diff --git a/packets/icmp6_test.go b/packets/icmp6_test.go deleted file mode 100644 index d349e95d..00000000 --- a/packets/icmp6_test.go +++ /dev/null @@ -1,417 +0,0 @@ -package packets - -import ( - "bytes" - "net" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -func TestICMP6Constants(t *testing.T) { - // Test the multicast constants - expectedMAC := net.HardwareAddr([]byte{0x33, 0x33, 0x00, 0x00, 0x00, 0x01}) - if !bytes.Equal(macIpv6Multicast, expectedMAC) { - t.Errorf("macIpv6Multicast = %v, want %v", macIpv6Multicast, expectedMAC) - } - - expectedIP := net.ParseIP("ff02::1") - if !ipv6Multicast.Equal(expectedIP) { - t.Errorf("ipv6Multicast = %v, want %v", ipv6Multicast, expectedIP) - } -} - -func TestICMP6NeighborAdvertisement(t *testing.T) { - srcHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - srcIP := net.ParseIP("fe80::1") - dstHW, _ := net.ParseMAC("11:22:33:44:55:66") - dstIP := net.ParseIP("fe80::2") - routerIP := net.ParseIP("fe80::3") - - err, data := ICMP6NeighborAdvertisement(srcHW, srcIP, dstHW, dstIP, routerIP) - if err != nil { - t.Fatalf("ICMP6NeighborAdvertisement() error = %v", err) - } - if len(data) == 0 { - t.Fatal("ICMP6NeighborAdvertisement() returned empty data") - } - - // Parse the packet to verify structure - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // Check Ethernet layer - if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { - eth := ethLayer.(*layers.Ethernet) - if !bytes.Equal(eth.SrcMAC, srcHW) { - t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, srcHW) - } - if !bytes.Equal(eth.DstMAC, dstHW) { - t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, dstHW) - } - if eth.EthernetType != layers.EthernetTypeIPv6 { - t.Errorf("EthernetType = %v, want %v", eth.EthernetType, layers.EthernetTypeIPv6) - } - } else { - t.Error("Packet missing Ethernet layer") - } - - // Check IPv6 layer - if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil { - ip := ipLayer.(*layers.IPv6) - if !ip.SrcIP.Equal(srcIP) { - t.Errorf("IPv6 SrcIP = %v, want %v", ip.SrcIP, srcIP) - } - if !ip.DstIP.Equal(dstIP) { - t.Errorf("IPv6 DstIP = %v, want %v", ip.DstIP, dstIP) - } - if ip.HopLimit != 255 { - t.Errorf("IPv6 HopLimit = %d, want 255", ip.HopLimit) - } - if ip.NextHeader != layers.IPProtocolICMPv6 { - t.Errorf("IPv6 NextHeader = %v, want %v", ip.NextHeader, layers.IPProtocolICMPv6) - } - } else { - t.Error("Packet missing IPv6 layer") - } - - // Check ICMPv6 layer - if icmpLayer := packet.Layer(layers.LayerTypeICMPv6); icmpLayer != nil { - icmp := icmpLayer.(*layers.ICMPv6) - expectedType := uint8(layers.ICMPv6TypeNeighborAdvertisement) - if icmp.TypeCode.Type() != expectedType { - t.Errorf("ICMPv6 Type = %v, want %v", icmp.TypeCode.Type(), expectedType) - } - } else { - t.Error("Packet missing ICMPv6 layer") - } - - // Check ICMPv6NeighborAdvertisement layer - if naLayer := packet.Layer(layers.LayerTypeICMPv6NeighborAdvertisement); naLayer != nil { - na := naLayer.(*layers.ICMPv6NeighborAdvertisement) - if !na.TargetAddress.Equal(routerIP) { - t.Errorf("TargetAddress = %v, want %v", na.TargetAddress, routerIP) - } - // Check flags (solicited && override) - expectedFlags := uint8(0x20 | 0x40) - if na.Flags != expectedFlags { - t.Errorf("Flags = %x, want %x", na.Flags, expectedFlags) - } - // Check options - if len(na.Options) != 1 { - t.Errorf("Options count = %d, want 1", len(na.Options)) - } else { - opt := na.Options[0] - if opt.Type != layers.ICMPv6OptTargetAddress { - t.Errorf("Option Type = %v, want %v", opt.Type, layers.ICMPv6OptTargetAddress) - } - if !bytes.Equal(opt.Data, srcHW) { - t.Errorf("Option Data = %v, want %v", opt.Data, srcHW) - } - } - } else { - t.Error("Packet missing ICMPv6NeighborAdvertisement layer") - } -} - -func TestICMP6RouterAdvertisement(t *testing.T) { - ip := net.ParseIP("fe80::1") - hw, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - prefix := "2001:db8::" - prefixLength := uint8(64) - routerLifetime := uint16(1800) - - err, data := ICMP6RouterAdvertisement(ip, hw, prefix, prefixLength, routerLifetime) - if err != nil { - t.Fatalf("ICMP6RouterAdvertisement() error = %v", err) - } - if len(data) == 0 { - t.Fatal("ICMP6RouterAdvertisement() returned empty data") - } - - // Parse the packet to verify structure - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // Check Ethernet layer - if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { - eth := ethLayer.(*layers.Ethernet) - if !bytes.Equal(eth.SrcMAC, hw) { - t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, hw) - } - if !bytes.Equal(eth.DstMAC, macIpv6Multicast) { - t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, macIpv6Multicast) - } - if eth.EthernetType != layers.EthernetTypeIPv6 { - t.Errorf("EthernetType = %v, want %v", eth.EthernetType, layers.EthernetTypeIPv6) - } - } else { - t.Error("Packet missing Ethernet layer") - } - - // Check IPv6 layer - if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil { - ip6 := ipLayer.(*layers.IPv6) - if !ip6.SrcIP.Equal(ip) { - t.Errorf("IPv6 SrcIP = %v, want %v", ip6.SrcIP, ip) - } - if !ip6.DstIP.Equal(ipv6Multicast) { - t.Errorf("IPv6 DstIP = %v, want %v", ip6.DstIP, ipv6Multicast) - } - if ip6.HopLimit != 255 { - t.Errorf("IPv6 HopLimit = %d, want 255", ip6.HopLimit) - } - if ip6.NextHeader != layers.IPProtocolICMPv6 { - t.Errorf("IPv6 NextHeader = %v, want %v", ip6.NextHeader, layers.IPProtocolICMPv6) - } - if ip6.TrafficClass != 224 { - t.Errorf("IPv6 TrafficClass = %d, want 224", ip6.TrafficClass) - } - } else { - t.Error("Packet missing IPv6 layer") - } - - // Check ICMPv6 layer - if icmpLayer := packet.Layer(layers.LayerTypeICMPv6); icmpLayer != nil { - icmp := icmpLayer.(*layers.ICMPv6) - expectedType := uint8(layers.ICMPv6TypeRouterAdvertisement) - if icmp.TypeCode.Type() != expectedType { - t.Errorf("ICMPv6 Type = %v, want %v", icmp.TypeCode.Type(), expectedType) - } - } else { - t.Error("Packet missing ICMPv6 layer") - } - - // Check ICMPv6RouterAdvertisement layer - if raLayer := packet.Layer(layers.LayerTypeICMPv6RouterAdvertisement); raLayer != nil { - ra := raLayer.(*layers.ICMPv6RouterAdvertisement) - if ra.HopLimit != 255 { - t.Errorf("HopLimit = %d, want 255", ra.HopLimit) - } - if ra.Flags != 0x08 { - t.Errorf("Flags = %x, want 0x08", ra.Flags) - } - if ra.RouterLifetime != routerLifetime { - t.Errorf("RouterLifetime = %d, want %d", ra.RouterLifetime, routerLifetime) - } - // Check options - the actual order from the code is SourceAddress, MTU, PrefixInfo - if len(ra.Options) != 3 { - t.Errorf("Options count = %d, want 3", len(ra.Options)) - } else { - // Find each option type - hasSourceAddr := false - hasMTU := false - hasPrefixInfo := false - - for _, opt := range ra.Options { - switch opt.Type { - case layers.ICMPv6OptSourceAddress: - hasSourceAddr = true - if !bytes.Equal(opt.Data, hw) { - t.Errorf("SourceAddress option data = %v, want %v", opt.Data, hw) - } - case layers.ICMPv6OptMTU: - hasMTU = true - expectedMTU := []byte{0x00, 0x00, 0x00, 0x00, 0x05, 0xdc} // 1500 - if !bytes.Equal(opt.Data, expectedMTU) { - t.Errorf("MTU option data = %v, want %v", opt.Data, expectedMTU) - } - case layers.ICMPv6OptPrefixInfo: - hasPrefixInfo = true - // Verify prefix length is in the data - if len(opt.Data) > 0 && opt.Data[0] != prefixLength { - t.Errorf("PrefixInfo prefix length = %d, want %d", opt.Data[0], prefixLength) - } - } - } - - if !hasSourceAddr { - t.Error("Missing SourceAddress option") - } - if !hasMTU { - t.Error("Missing MTU option") - } - if !hasPrefixInfo { - t.Error("Missing PrefixInfo option") - } - } - } else { - t.Error("Packet missing ICMPv6RouterAdvertisement layer") - } -} - -func TestICMP6NeighborAdvertisementWithNilValues(t *testing.T) { - // Test with nil values - function should handle gracefully - err, data := ICMP6NeighborAdvertisement(nil, nil, nil, nil, nil) - - // The function likely returns an error or empty data with nil inputs - if err == nil && len(data) > 0 { - t.Error("Expected error or empty data with nil values") - } -} - -func TestICMP6RouterAdvertisementWithNilValues(t *testing.T) { - // Test with nil values - function should handle gracefully - err, data := ICMP6RouterAdvertisement(nil, nil, "", 0, 0) - - // The function likely returns an error or empty data with nil inputs - if err == nil && len(data) > 0 { - t.Error("Expected error or empty data with nil values") - } -} - -func TestICMP6RouterAdvertisementVariousInputs(t *testing.T) { - tests := []struct { - name string - ip string - hw string - prefix string - prefixLength uint8 - routerLifetime uint16 - shouldError bool - }{ - { - name: "valid input", - ip: "fe80::1", - hw: "aa:bb:cc:dd:ee:ff", - prefix: "2001:db8::", - prefixLength: 64, - routerLifetime: 1800, - shouldError: false, - }, - { - name: "zero router lifetime", - ip: "fe80::1", - hw: "aa:bb:cc:dd:ee:ff", - prefix: "2001:db8::", - prefixLength: 64, - routerLifetime: 0, - shouldError: false, - }, - { - name: "max prefix length", - ip: "fe80::1", - hw: "aa:bb:cc:dd:ee:ff", - prefix: "2001:db8::", - prefixLength: 128, - routerLifetime: 1800, - shouldError: false, - }, - { - name: "max router lifetime", - ip: "fe80::1", - hw: "aa:bb:cc:dd:ee:ff", - prefix: "2001:db8::", - prefixLength: 64, - routerLifetime: 65535, - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := net.ParseIP(tt.ip) - hw, _ := net.ParseMAC(tt.hw) - - err, data := ICMP6RouterAdvertisement(ip, hw, tt.prefix, tt.prefixLength, tt.routerLifetime) - - if tt.shouldError && err == nil { - t.Error("Expected error but got none") - } - if !tt.shouldError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !tt.shouldError && len(data) == 0 { - t.Error("Expected data but got empty") - } - }) - } -} - -func TestICMP6NeighborAdvertisementVariousInputs(t *testing.T) { - tests := []struct { - name string - srcHW string - srcIP string - dstHW string - dstIP string - routerIP string - shouldError bool - }{ - { - name: "valid IPv6 link-local", - srcHW: "aa:bb:cc:dd:ee:ff", - srcIP: "fe80::1", - dstHW: "11:22:33:44:55:66", - dstIP: "fe80::2", - routerIP: "fe80::3", - shouldError: false, - }, - { - name: "valid IPv6 global", - srcHW: "aa:bb:cc:dd:ee:ff", - srcIP: "2001:db8::1", - dstHW: "11:22:33:44:55:66", - dstIP: "2001:db8::2", - routerIP: "2001:db8::3", - shouldError: false, - }, - { - name: "broadcast MAC", - srcHW: "ff:ff:ff:ff:ff:ff", - srcIP: "fe80::1", - dstHW: "ff:ff:ff:ff:ff:ff", - dstIP: "fe80::2", - routerIP: "fe80::3", - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - srcHW, _ := net.ParseMAC(tt.srcHW) - srcIP := net.ParseIP(tt.srcIP) - dstHW, _ := net.ParseMAC(tt.dstHW) - dstIP := net.ParseIP(tt.dstIP) - routerIP := net.ParseIP(tt.routerIP) - - err, data := ICMP6NeighborAdvertisement(srcHW, srcIP, dstHW, dstIP, routerIP) - - if tt.shouldError && err == nil { - t.Error("Expected error but got none") - } - if !tt.shouldError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !tt.shouldError && len(data) == 0 { - t.Error("Expected data but got empty") - } - }) - } -} - -// Benchmarks -func BenchmarkICMP6NeighborAdvertisement(b *testing.B) { - srcHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - srcIP := net.ParseIP("fe80::1") - dstHW, _ := net.ParseMAC("11:22:33:44:55:66") - dstIP := net.ParseIP("fe80::2") - routerIP := net.ParseIP("fe80::3") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = ICMP6NeighborAdvertisement(srcHW, srcIP, dstHW, dstIP, routerIP) - } -} - -func BenchmarkICMP6RouterAdvertisement(b *testing.B) { - ip := net.ParseIP("fe80::1") - hw, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - prefix := "2001:db8::" - prefixLength := uint8(64) - routerLifetime := uint16(1800) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = ICMP6RouterAdvertisement(ip, hw, prefix, prefixLength, routerLifetime) - } -} diff --git a/packets/mdns_test.go b/packets/mdns_test.go deleted file mode 100644 index 2a380cd4..00000000 --- a/packets/mdns_test.go +++ /dev/null @@ -1,393 +0,0 @@ -package packets - -import ( - "bytes" - "net" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -func TestMDNSConstants(t *testing.T) { - if MDNSPort != 5353 { - t.Errorf("MDNSPort = %d, want 5353", MDNSPort) - } - - expectedMac := net.HardwareAddr{0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb} - if !bytes.Equal(MDNSDestMac, expectedMac) { - t.Errorf("MDNSDestMac = %v, want %v", MDNSDestMac, expectedMac) - } - - expectedIP := net.ParseIP("224.0.0.251") - if !MDNSDestIP.Equal(expectedIP) { - t.Errorf("MDNSDestIP = %v, want %v", MDNSDestIP, expectedIP) - } -} - -func TestNewMDNSProbe(t *testing.T) { - from := net.ParseIP("192.168.1.100") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - err, data := NewMDNSProbe(from, fromHW) - if err != nil { - t.Errorf("NewMDNSProbe() error = %v", err) - } - if len(data) == 0 { - t.Error("NewMDNSProbe() returned empty data") - } - - // Parse the packet to verify structure - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // Check Ethernet layer - if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { - eth := ethLayer.(*layers.Ethernet) - if !bytes.Equal(eth.SrcMAC, fromHW) { - t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, fromHW) - } - if !bytes.Equal(eth.DstMAC, MDNSDestMac) { - t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, MDNSDestMac) - } - } else { - t.Error("Packet missing Ethernet layer") - } - - // Check IPv4 layer - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - if !ip.SrcIP.Equal(from) { - t.Errorf("IPv4 SrcIP = %v, want %v", ip.SrcIP, from) - } - if !ip.DstIP.Equal(MDNSDestIP) { - t.Errorf("IPv4 DstIP = %v, want %v", ip.DstIP, MDNSDestIP) - } - } else { - t.Error("Packet missing IPv4 layer") - } - - // Check UDP layer - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp := udpLayer.(*layers.UDP) - if udp.DstPort != MDNSPort { - t.Errorf("UDP DstPort = %d, want %d", udp.DstPort, MDNSPort) - } - } else { - t.Error("Packet missing UDP layer") - } - - // The DNS layer is carried as payload in UDP, not a separate layer - // So we check the UDP payload instead - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp := udpLayer.(*layers.UDP) - // Verify that the UDP payload contains DNS data - if len(udp.Payload) == 0 { - t.Error("UDP payload is empty (should contain DNS data)") - } - } -} - -func TestMDNSGetMeta(t *testing.T) { - // Create a mock MDNS packet with various record types - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: MDNSDestMac, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := layers.IPv4{ - Protocol: layers.IPProtocolUDP, - Version: 4, - TTL: 64, - SrcIP: net.ParseIP("192.168.1.100"), - DstIP: MDNSDestIP, - } - - udp := layers.UDP{ - SrcPort: MDNSPort, - DstPort: MDNSPort, - } - - dns := layers.DNS{ - ID: 1, - QR: true, - OpCode: layers.DNSOpCodeQuery, - Answers: []layers.DNSResourceRecord{ - { - Name: []byte("test.local"), - Type: layers.DNSTypeA, - Class: layers.DNSClassIN, - IP: net.ParseIP("192.168.1.100"), - }, - { - Name: []byte("test.local"), - Type: layers.DNSTypeTXT, - Class: layers.DNSClassIN, - TXTs: [][]byte{[]byte("model=Test Device"), []byte("version=1.0")}, - }, - }, - } - - udp.SetNetworkLayerForChecksum(&ip4) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp, &dns) - if err != nil { - t.Fatalf("Failed to serialize packet: %v", err) - } - - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - meta := MDNSGetMeta(packet) - if meta == nil { - t.Fatal("MDNSGetMeta() returned nil") - } - - // TXT records are extracted correctly - - if model, ok := meta["mdns:model"]; !ok || model != "Test Device" { - t.Errorf("Expected model 'Test Device', got '%v'", model) - } - - if version, ok := meta["mdns:version"]; !ok || version != "1.0" { - t.Errorf("Expected version '1.0', got '%v'", version) - } -} - -func TestMDNSGetMetaNonMDNS(t *testing.T) { - // Create a non-MDNS UDP packet - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := layers.IPv4{ - Protocol: layers.IPProtocolUDP, - Version: 4, - TTL: 64, - SrcIP: net.ParseIP("192.168.1.100"), - DstIP: net.ParseIP("192.168.1.200"), - } - - udp := layers.UDP{ - SrcPort: 12345, - DstPort: 80, - } - - udp.SetNetworkLayerForChecksum(&ip4) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp) - if err != nil { - t.Fatalf("Failed to serialize packet: %v", err) - } - - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - meta := MDNSGetMeta(packet) - if meta != nil { - t.Error("MDNSGetMeta() should return nil for non-MDNS packet") - } -} - -func TestMDNSGetMetaInvalidDNS(t *testing.T) { - // Create MDNS packet with invalid DNS payload - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: MDNSDestMac, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := layers.IPv4{ - Protocol: layers.IPProtocolUDP, - Version: 4, - TTL: 64, - SrcIP: net.ParseIP("192.168.1.100"), - DstIP: MDNSDestIP, - } - - udp := layers.UDP{ - SrcPort: MDNSPort, - DstPort: MDNSPort, - } - - udp.SetNetworkLayerForChecksum(&ip4) - udp.Payload = []byte{0x00, 0x01, 0x02, 0x03} // Invalid DNS data - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp) - if err != nil { - t.Fatalf("Failed to serialize packet: %v", err) - } - - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - meta := MDNSGetMeta(packet) - if meta != nil { - t.Error("MDNSGetMeta() should return nil for invalid DNS data") - } -} - -func TestMDNSGetMetaRecovery(t *testing.T) { - // Test that panic recovery works - defer func() { - if r := recover(); r != nil { - t.Error("MDNSGetMeta should not panic") - } - }() - - // Create a minimal packet that might cause issues - data := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05} - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - meta := MDNSGetMeta(packet) - if meta != nil { - t.Error("MDNSGetMeta() should return nil for invalid packet") - } -} - -func TestMDNSGetMetaWithAdditionals(t *testing.T) { - // Create a mock MDNS packet with additional records - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: MDNSDestMac, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := layers.IPv4{ - Protocol: layers.IPProtocolUDP, - Version: 4, - TTL: 64, - SrcIP: net.ParseIP("192.168.1.100"), - DstIP: MDNSDestIP, - } - - udp := layers.UDP{ - SrcPort: MDNSPort, - DstPort: MDNSPort, - } - - dns := layers.DNS{ - ID: 1, - QR: true, - OpCode: layers.DNSOpCodeQuery, - Additionals: []layers.DNSResourceRecord{ - { - Name: []byte("additional.local"), - Type: layers.DNSTypeAAAA, - Class: layers.DNSClassIN, - IP: net.ParseIP("fe80::1"), - }, - }, - Authorities: []layers.DNSResourceRecord{ - { - Name: []byte("authority.local"), - Type: layers.DNSTypePTR, - Class: layers.DNSClassIN, - }, - }, - } - - udp.SetNetworkLayerForChecksum(&ip4) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp, &dns) - if err != nil { - t.Fatalf("Failed to serialize packet: %v", err) - } - - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - meta := MDNSGetMeta(packet) - if meta == nil { - t.Fatal("MDNSGetMeta() returned nil") - } - - if hostname, ok := meta["mdns:hostname"]; !ok || hostname != "additional.local" { - t.Errorf("Expected hostname 'additional.local', got '%v'", hostname) - } -} - -// Benchmarks -func BenchmarkNewMDNSProbe(b *testing.B) { - from := net.ParseIP("192.168.1.100") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = NewMDNSProbe(from, fromHW) - } -} - -func BenchmarkMDNSGetMeta(b *testing.B) { - // Create a sample MDNS packet - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: MDNSDestMac, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := layers.IPv4{ - Protocol: layers.IPProtocolUDP, - Version: 4, - TTL: 64, - SrcIP: net.ParseIP("192.168.1.100"), - DstIP: MDNSDestIP, - } - - udp := layers.UDP{ - SrcPort: MDNSPort, - DstPort: MDNSPort, - } - - dns := layers.DNS{ - ID: 1, - QR: true, - OpCode: layers.DNSOpCodeQuery, - Answers: []layers.DNSResourceRecord{ - { - Name: []byte("test.local"), - Type: layers.DNSTypeA, - Class: layers.DNSClassIN, - IP: net.ParseIP("192.168.1.100"), - }, - }, - } - - udp.SetNetworkLayerForChecksum(&ip4) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp, &dns) - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = MDNSGetMeta(packet) - } -} diff --git a/packets/mysql_test.go b/packets/mysql_test.go deleted file mode 100644 index f807429a..00000000 --- a/packets/mysql_test.go +++ /dev/null @@ -1,241 +0,0 @@ -package packets - -import ( - "bytes" - "testing" -) - -func TestMySQLConstants(t *testing.T) { - // Test MySQLGreeting - if len(MySQLGreeting) != 95 { - t.Errorf("MySQLGreeting length = %d, want 95", len(MySQLGreeting)) - } - // Check some key bytes in the greeting - if MySQLGreeting[0] != 0x5b { - t.Errorf("MySQLGreeting[0] = 0x%02x, want 0x5b", MySQLGreeting[0]) - } - // Check version string starts at byte 5 - versionBytes := MySQLGreeting[5:12] - expectedVersion := []byte("5.6.28-") - if !bytes.Equal(versionBytes, expectedVersion) { - t.Errorf("MySQL version = %s, want %s", versionBytes, expectedVersion) - } - - // Test MySQLFirstResponseOK - if len(MySQLFirstResponseOK) != 11 { - t.Errorf("MySQLFirstResponseOK length = %d, want 11", len(MySQLFirstResponseOK)) - } - // Check packet sequence number - if MySQLFirstResponseOK[3] != 0x02 { - t.Errorf("MySQLFirstResponseOK sequence = 0x%02x, want 0x02", MySQLFirstResponseOK[3]) - } - - // Test MySQLSecondResponseOK - if len(MySQLSecondResponseOK) != 11 { - t.Errorf("MySQLSecondResponseOK length = %d, want 11", len(MySQLSecondResponseOK)) - } - // Check packet sequence number - if MySQLSecondResponseOK[3] != 0x04 { - t.Errorf("MySQLSecondResponseOK sequence = 0x%02x, want 0x04", MySQLSecondResponseOK[3]) - } -} - -func TestMySQLGetFile(t *testing.T) { - tests := []struct { - name string - infile string - expected []byte - }{ - { - name: "empty filename", - infile: "", - expected: []byte{ - 0x01, // length + 1 - 0x00, 0x00, 0x01, 0xfb, // header - }, - }, - { - name: "short filename", - infile: "test.txt", - expected: []byte{ - 0x09, // length of "test.txt" + 1 = 9 - 0x00, 0x00, 0x01, 0xfb, // header - 't', 'e', 's', 't', '.', 't', 'x', 't', - }, - }, - { - name: "path with directory", - infile: "/etc/passwd", - expected: []byte{ - 0x0c, // length of "/etc/passwd" + 1 = 12 - 0x00, 0x00, 0x01, 0xfb, // header - '/', 'e', 't', 'c', '/', 'p', 'a', 's', 's', 'w', 'd', - }, - }, - { - name: "windows path", - infile: "C:\\Windows\\System32\\config\\sam", - expected: []byte{ - 0x1f, // length of path + 1 = 31 - 0x00, 0x00, 0x01, 0xfb, // header - 'C', ':', '\\', 'W', 'i', 'n', 'd', 'o', 'w', 's', '\\', - 'S', 'y', 's', 't', 'e', 'm', '3', '2', '\\', - 'c', 'o', 'n', 'f', 'i', 'g', '\\', 's', 'a', 'm', - }, - }, - { - name: "unicode filename", - infile: "файл.txt", - expected: func() []byte { - filename := "файл.txt" - result := []byte{ - byte(len(filename) + 1), - 0x00, 0x00, 0x01, 0xfb, - } - return append(result, []byte(filename)...) - }(), - }, - { - name: "max length filename", - infile: string(make([]byte, 254)), // Max that fits in a single byte length - expected: func() []byte { - result := []byte{ - 0xff, // 254 + 1 = 255 - 0x00, 0x00, 0x01, 0xfb, - } - return append(result, make([]byte, 254)...) - }(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := MySQLGetFile(tt.infile) - if !bytes.Equal(result, tt.expected) { - t.Errorf("MySQLGetFile(%q) = %v, want %v", tt.infile, result, tt.expected) - } - }) - } -} - -func TestMySQLGetFileLength(t *testing.T) { - // Test that the length byte is correctly calculated - testCases := []struct { - filename string - expected byte - }{ - {"", 0x01}, - {"a", 0x02}, - {"ab", 0x03}, - {"abc", 0x04}, - {"test.txt", 0x09}, - {string(make([]byte, 100)), 0x65}, // 100 + 1 = 101 = 0x65 - {string(make([]byte, 254)), 0xff}, // 254 + 1 = 255 = 0xff - } - - for _, tc := range testCases { - result := MySQLGetFile(tc.filename) - if result[0] != tc.expected { - t.Errorf("MySQLGetFile(%q) length byte = 0x%02x, want 0x%02x", - tc.filename, result[0], tc.expected) - } - } -} - -func TestMySQLGetFileHeader(t *testing.T) { - // Test that the header bytes are always the same - expectedHeader := []byte{0x00, 0x00, 0x01, 0xfb} - - filenames := []string{ - "", - "test", - "long_filename_with_many_characters.txt", - "/path/to/file", - "C:\\Windows\\file.exe", - } - - for _, filename := range filenames { - result := MySQLGetFile(filename) - if len(result) < 5 { - t.Errorf("MySQLGetFile(%q) returned packet too short: %d bytes", filename, len(result)) - continue - } - - header := result[1:5] - if !bytes.Equal(header, expectedHeader) { - t.Errorf("MySQLGetFile(%q) header = %v, want %v", filename, header, expectedHeader) - } - } -} - -func TestMySQLPacketStructure(t *testing.T) { - // Test the overall packet structure - filename := "test_file.sql" - packet := MySQLGetFile(filename) - - // Check minimum packet size (1 byte length + 4 bytes header) - if len(packet) < 5 { - t.Fatalf("Packet too short: %d bytes", len(packet)) - } - - // Check that packet length matches expected - expectedLen := 1 + 4 + len(filename) // length byte + header + filename - if len(packet) != expectedLen { - t.Errorf("Packet length = %d, want %d", len(packet), expectedLen) - } - - // Check that the length byte correctly represents filename length + 1 - if packet[0] != byte(len(filename)+1) { - t.Errorf("Length byte = %d, want %d", packet[0], len(filename)+1) - } - - // Check that the filename is correctly appended - filenameInPacket := string(packet[5:]) - if filenameInPacket != filename { - t.Errorf("Filename in packet = %q, want %q", filenameInPacket, filename) - } -} - -func TestMySQLGreetingStructure(t *testing.T) { - // Test specific parts of the MySQL greeting packet - greeting := MySQLGreeting - - // The greeting should contain "mysql_native_password" at the end - expectedSuffix := "mysql_native_password" - suffixStart := len(greeting) - len(expectedSuffix) - 1 // -1 for null terminator - suffix := string(greeting[suffixStart : suffixStart+len(expectedSuffix)]) - - if suffix != expectedSuffix { - t.Errorf("Greeting suffix = %q, want %q", suffix, expectedSuffix) - } - - // Check null terminator - if greeting[len(greeting)-1] != 0x00 { - t.Error("Greeting should end with null terminator") - } -} - -// Benchmarks -func BenchmarkMySQLGetFile(b *testing.B) { - filename := "/etc/passwd" - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = MySQLGetFile(filename) - } -} - -func BenchmarkMySQLGetFileShort(b *testing.B) { - filename := "a.txt" - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = MySQLGetFile(filename) - } -} - -func BenchmarkMySQLGetFileLong(b *testing.B) { - filename := string(make([]byte, 200)) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = MySQLGetFile(filename) - } -} diff --git a/packets/nbns_test.go b/packets/nbns_test.go deleted file mode 100644 index 5e172d3b..00000000 --- a/packets/nbns_test.go +++ /dev/null @@ -1,351 +0,0 @@ -package packets - -import ( - "bytes" - "net" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -func TestNBNSConstants(t *testing.T) { - if NBNSPort != 137 { - t.Errorf("NBNSPort = %d, want 137", NBNSPort) - } - - if NBNSMinRespSize != 73 { - t.Errorf("NBNSMinRespSize = %d, want 73", NBNSMinRespSize) - } -} - -func TestNBNSRequest(t *testing.T) { - // Test the structure of NBNSRequest - if len(NBNSRequest) != 50 { - t.Errorf("NBNSRequest length = %d, want 50", len(NBNSRequest)) - } - - // Check key bytes in the request - expectedStart := []byte{0x82, 0x28, 0x00, 0x00, 0x00, 0x01} - if !bytes.Equal(NBNSRequest[0:6], expectedStart) { - t.Errorf("NBNSRequest start = %v, want %v", NBNSRequest[0:6], expectedStart) - } - - // Check the encoded name section (starts at byte 12) - // NBNS encodes names with 0x43 ('C') prefix followed by encoded characters - if NBNSRequest[12] != 0x20 { - t.Errorf("NBNSRequest[12] = 0x%02x, want 0x20", NBNSRequest[12]) - } - if NBNSRequest[13] != 0x43 { - t.Errorf("NBNSRequest[13] = 0x%02x, want 0x43 (C)", NBNSRequest[13]) - } - - // Check the query type and class at the end - expectedEnd := []byte{0x00, 0x00, 0x21, 0x00, 0x01} - if !bytes.Equal(NBNSRequest[45:50], expectedEnd) { - t.Errorf("NBNSRequest end = %v, want %v", NBNSRequest[45:50], expectedEnd) - } -} - -func TestNBNSGetMeta(t *testing.T) { - tests := []struct { - name string - buildPacket func() gopacket.Packet - expectNil bool - }{ - { - name: "non-NBNS packet (wrong port)", - buildPacket: func() gopacket.Packet { - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - - udp := layers.UDP{ - SrcPort: 80, // Not NBNS port - DstPort: 12345, - } - - payload := make([]byte, NBNSMinRespSize) - udp.Payload = payload - udp.SetNetworkLayerForChecksum(&ip) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip, &udp) - return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - }, - expectNil: true, - }, - { - name: "NBNS packet with insufficient payload", - buildPacket: func() gopacket.Packet { - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - - udp := layers.UDP{ - SrcPort: NBNSPort, - DstPort: 12345, - } - - // Payload too small (less than NBNSMinRespSize) - payload := make([]byte, 50) - udp.Payload = payload - udp.SetNetworkLayerForChecksum(&ip) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip, &udp) - return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - }, - expectNil: true, - }, - { - name: "NBNS packet with non-printable hostname", - buildPacket: func() gopacket.Packet { - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - - udp := layers.UDP{ - SrcPort: NBNSPort, - DstPort: 12345, - } - - payload := make([]byte, NBNSMinRespSize) - // Set non-printable character at the start of hostname - payload[57] = 0x01 // Non-printable - copy(payload[58:72], []byte("WORKSTATION ")) - - udp.Payload = payload - udp.SetNetworkLayerForChecksum(&ip) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip, &udp) - return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - }, - expectNil: true, - }, - { - name: "packet without UDP layer", - buildPacket: func() gopacket.Packet { - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolTCP, // TCP instead of UDP - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip) - return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - }, - expectNil: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - packet := tt.buildPacket() - meta := NBNSGetMeta(packet) - - // Due to a bug in NBNSGetMeta where it doesn't check if hostname is empty - // after trimming, we just verify it doesn't panic - _ = meta - }) - } -} - -func TestNBNSBasicFunctionality(t *testing.T) { - // Test that NBNSGetMeta doesn't panic on various inputs - tests := []struct { - name string - buildPacket func() gopacket.Packet - }{ - { - name: "valid packet", - buildPacket: func() gopacket.Packet { - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - udp := layers.UDP{ - SrcPort: NBNSPort, - DstPort: 12345, - } - payload := make([]byte, NBNSMinRespSize) - copy(payload[57:72], []byte("WORKSTATION ")) - udp.Payload = payload - udp.SetNetworkLayerForChecksum(&ip) - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - gopacket.SerializeLayers(buf, opts, ð, &ip, &udp) - return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - }, - }, - { - name: "empty packet", - buildPacket: func() gopacket.Packet { - return gopacket.NewPacket([]byte{}, layers.LayerTypeEthernet, gopacket.Default) - }, - }, - { - name: "non-UDP packet", - buildPacket: func() gopacket.Packet { - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeARP, - } - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - gopacket.SerializeLayers(buf, opts, ð) - return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - packet := tt.buildPacket() - // Just verify it doesn't panic - _ = NBNSGetMeta(packet) - }) - } -} - -// Benchmarks -func BenchmarkNBNSGetMeta(b *testing.B) { - // Create a sample NBNS packet - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - - udp := layers.UDP{ - SrcPort: NBNSPort, - DstPort: 12345, - } - - payload := make([]byte, NBNSMinRespSize) - copy(payload[57:72], []byte("WORKSTATION ")) - - udp.Payload = payload - udp.SetNetworkLayerForChecksum(&ip) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip, &udp) - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NBNSGetMeta(packet) - } -} - -func BenchmarkNBNSGetMetaNonNBNS(b *testing.B) { - // Create a non-NBNS packet to test early exit performance - eth := layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolTCP, - SrcIP: net.IP{192, 168, 1, 100}, - DstIP: net.IP{192, 168, 1, 200}, - } - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - gopacket.SerializeLayers(buf, opts, ð, &ip) - packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NBNSGetMeta(packet) - } -} diff --git a/packets/serialize_test.go b/packets/serialize_test.go deleted file mode 100644 index 10a19057..00000000 --- a/packets/serialize_test.go +++ /dev/null @@ -1,403 +0,0 @@ -package packets - -import ( - "bytes" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -func TestSerializationOptions(t *testing.T) { - // Verify the global serialization options are set correctly - if !SerializationOptions.FixLengths { - t.Error("SerializationOptions.FixLengths should be true") - } - if !SerializationOptions.ComputeChecksums { - t.Error("SerializationOptions.ComputeChecksums should be true") - } -} - -func TestSerialize(t *testing.T) { - tests := []struct { - name string - layers []gopacket.SerializableLayer - expectError bool - minLength int - }{ - { - name: "simple ethernet frame", - layers: []gopacket.SerializableLayer{ - &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - }, - }, - expectError: false, - minLength: 14, // Ethernet header - }, - { - name: "ethernet with IPv4", - layers: []gopacket.SerializableLayer{ - &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - }, - &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolTCP, - TTL: 64, - SrcIP: []byte{192, 168, 1, 1}, - DstIP: []byte{192, 168, 1, 2}, - }, - }, - expectError: false, - minLength: 34, // Ethernet + IPv4 headers - }, - { - name: "complete TCP packet", - layers: func() []gopacket.SerializableLayer { - ip4 := &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolTCP, - TTL: 64, - SrcIP: []byte{192, 168, 1, 1}, - DstIP: []byte{192, 168, 1, 2}, - } - tcp := &layers.TCP{ - SrcPort: 12345, - DstPort: 80, - Seq: 1000, - Ack: 0, - SYN: true, - Window: 65535, - } - tcp.SetNetworkLayerForChecksum(ip4) - return []gopacket.SerializableLayer{ - &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - }, - ip4, - tcp, - } - }(), - expectError: false, - minLength: 54, // Ethernet + IPv4 + TCP headers - }, - { - name: "empty layers", - layers: []gopacket.SerializableLayer{}, - expectError: false, - minLength: 0, - }, - { - name: "layer with payload", - layers: []gopacket.SerializableLayer{ - &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - }, - gopacket.Payload([]byte("Hello, World!")), - }, - expectError: false, - minLength: 27, // Ethernet header + payload - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err, data := Serialize(tt.layers...) - - if tt.expectError && err == nil { - t.Error("Expected error but got none") - } - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if err == nil { - if len(data) < tt.minLength { - t.Errorf("Data length %d is less than expected minimum %d", len(data), tt.minLength) - } - - // For non-empty results, verify we can parse it back - if len(data) > 0 && len(tt.layers) > 0 { - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - if packet == nil { - t.Error("Failed to parse serialized data") - } - } - } - }) - } -} - -func TestSerializeWithChecksum(t *testing.T) { - // Test that checksums are computed correctly - ip4 := &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - TTL: 64, - SrcIP: []byte{192, 168, 1, 1}, - DstIP: []byte{192, 168, 1, 2}, - } - - udp := &layers.UDP{ - SrcPort: 12345, - DstPort: 53, - } - - // Set network layer for checksum computation - udp.SetNetworkLayerForChecksum(ip4) - - eth := &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - err, data := Serialize(eth, ip4, udp) - if err != nil { - t.Fatalf("Failed to serialize: %v", err) - } - - // Parse back and verify checksums - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - // The checksum should be computed (non-zero) - if ip.Checksum == 0 { - t.Error("IPv4 checksum was not computed") - } - } else { - t.Error("IPv4 layer not found in packet") - } - - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp := udpLayer.(*layers.UDP) - // The checksum should be computed (non-zero for UDP over IPv4) - if udp.Checksum == 0 { - t.Error("UDP checksum was not computed") - } - } else { - t.Error("UDP layer not found in packet") - } -} - -func TestSerializeFixLengths(t *testing.T) { - // Test that lengths are fixed correctly - ip4 := &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolTCP, - TTL: 64, - SrcIP: []byte{10, 0, 0, 1}, - DstIP: []byte{10, 0, 0, 2}, - // Don't set Length - it should be computed - } - - tcp := &layers.TCP{ - SrcPort: 80, - DstPort: 12345, - Seq: 1000, - SYN: true, - Window: 65535, - } - - tcp.SetNetworkLayerForChecksum(ip4) - - payload := gopacket.Payload([]byte("Test payload data")) - - eth := &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - err, data := Serialize(eth, ip4, tcp, payload) - if err != nil { - t.Fatalf("Failed to serialize: %v", err) - } - - // Parse back and verify lengths - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - expectedLen := 20 + 20 + len("Test payload data") // IPv4 header + TCP header + payload - if ip.Length != uint16(expectedLen) { - t.Errorf("IPv4 length = %d, want %d", ip.Length, expectedLen) - } - } else { - t.Error("IPv4 layer not found in packet") - } -} - -func TestSerializeErrorHandling(t *testing.T) { - // Test serialization with an invalid layer configuration - // This test is a bit tricky because gopacket is quite forgiving - // We'll create a scenario that might fail in serialization - - // Create an ethernet layer with invalid type for the next layer - eth := &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - // Follow with a non-IPv4 layer when IPv4 is expected - // This actually won't cause an error in gopacket, so we test that errors are handled - tcp := &layers.TCP{ - SrcPort: 80, - DstPort: 12345, - } - - err, data := Serialize(eth, tcp) - // This might not actually error, but we're testing the error handling path - if err != nil { - // Error path - should return nil data - if data != nil { - t.Error("When error occurs, data should be nil") - } - } else { - // Success path - should return data - if data == nil { - t.Error("When no error, data should not be nil") - } - } -} - -func TestSerializeMultiplePackets(t *testing.T) { - // Test serializing multiple different packet types in sequence - srcMAC := []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff} - dstMAC := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} - - packets := []struct { - name string - layers []gopacket.SerializableLayer - }{ - { - name: "ARP request", - layers: []gopacket.SerializableLayer{ - &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeARP, - }, - &layers.ARP{ - AddrType: layers.LinkTypeEthernet, - Protocol: layers.EthernetTypeIPv4, - HwAddressSize: 6, - ProtAddressSize: 4, - Operation: layers.ARPRequest, - SourceHwAddress: srcMAC, - SourceProtAddress: []byte{192, 168, 1, 100}, - DstHwAddress: []byte{0, 0, 0, 0, 0, 0}, - DstProtAddress: []byte{192, 168, 1, 1}, - }, - }, - }, - { - name: "ICMP echo", - layers: []gopacket.SerializableLayer{ - &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv4, - }, - &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolICMPv4, - TTL: 64, - SrcIP: []byte{192, 168, 1, 100}, - DstIP: []byte{8, 8, 8, 8}, - }, - &layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), - Id: 1, - Seq: 1, - }, - gopacket.Payload([]byte("ping")), - }, - }, - } - - for _, pkt := range packets { - t.Run(pkt.name, func(t *testing.T) { - err, data := Serialize(pkt.layers...) - if err != nil { - t.Errorf("Failed to serialize %s: %v", pkt.name, err) - } - if len(data) == 0 { - t.Errorf("Serialized %s has zero length", pkt.name) - } - }) - } -} - -// Benchmarks -func BenchmarkSerialize(b *testing.B) { - eth := &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolTCP, - TTL: 64, - SrcIP: []byte{192, 168, 1, 1}, - DstIP: []byte{192, 168, 1, 2}, - } - - tcp := &layers.TCP{ - SrcPort: 12345, - DstPort: 80, - Seq: 1000, - SYN: true, - Window: 65535, - } - - tcp.SetNetworkLayerForChecksum(ip4) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = Serialize(eth, ip4, tcp) - } -} - -func BenchmarkSerializeWithPayload(b *testing.B) { - eth := &layers.Ethernet{ - SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip4 := &layers.IPv4{ - Version: 4, - Protocol: layers.IPProtocolUDP, - TTL: 64, - SrcIP: []byte{192, 168, 1, 1}, - DstIP: []byte{192, 168, 1, 2}, - } - - udp := &layers.UDP{ - SrcPort: 12345, - DstPort: 53, - } - - udp.SetNetworkLayerForChecksum(ip4) - - payload := gopacket.Payload(bytes.Repeat([]byte("x"), 1024)) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = Serialize(eth, ip4, udp, payload) - } -} diff --git a/packets/tcp_test.go b/packets/tcp_test.go deleted file mode 100644 index 87829ea1..00000000 --- a/packets/tcp_test.go +++ /dev/null @@ -1,354 +0,0 @@ -package packets - -import ( - "bytes" - "net" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -func TestNewTCPSyn(t *testing.T) { - tests := []struct { - name string - from string - fromHW string - to string - toHW string - srcPort int - dstPort int - expectError bool - expectIPv6 bool - }{ - { - name: "IPv4 TCP SYN", - from: "192.168.1.100", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "192.168.1.200", - toHW: "11:22:33:44:55:66", - srcPort: 12345, - dstPort: 80, - expectError: false, - expectIPv6: false, - }, - { - name: "IPv6 TCP SYN", - from: "2001:db8::1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "2001:db8::2", - toHW: "11:22:33:44:55:66", - srcPort: 54321, - dstPort: 443, - expectError: false, - expectIPv6: true, - }, - { - name: "IPv4 with different ports", - from: "10.0.0.1", - fromHW: "01:23:45:67:89:ab", - to: "10.0.0.2", - toHW: "cd:ef:01:23:45:67", - srcPort: 8080, - dstPort: 3306, - expectError: false, - expectIPv6: false, - }, - { - name: "IPv6 link-local addresses", - from: "fe80::1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "fe80::2", - toHW: "11:22:33:44:55:66", - srcPort: 1234, - dstPort: 5678, - expectError: false, - expectIPv6: true, - }, - { - name: "IPv4 loopback", - from: "127.0.0.1", - fromHW: "00:00:00:00:00:00", - to: "127.0.0.1", - toHW: "00:00:00:00:00:00", - srcPort: 9000, - dstPort: 9001, - expectError: false, - expectIPv6: false, - }, - { - name: "IPv6 loopback", - from: "::1", - fromHW: "00:00:00:00:00:00", - to: "::1", - toHW: "00:00:00:00:00:00", - srcPort: 9000, - dstPort: 9001, - expectError: false, - expectIPv6: true, - }, - { - name: "Max port number", - from: "192.168.1.1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "192.168.1.2", - toHW: "11:22:33:44:55:66", - srcPort: 65535, - dstPort: 65535, - expectError: false, - expectIPv6: false, - }, - { - name: "Min port number", - from: "192.168.1.1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "192.168.1.2", - toHW: "11:22:33:44:55:66", - srcPort: 1, - dstPort: 1, - expectError: false, - expectIPv6: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - from := net.ParseIP(tt.from) - fromHW, _ := net.ParseMAC(tt.fromHW) - to := net.ParseIP(tt.to) - toHW, _ := net.ParseMAC(tt.toHW) - - err, data := NewTCPSyn(from, fromHW, to, toHW, tt.srcPort, tt.dstPort) - - if tt.expectError && err == nil { - t.Error("Expected error but got none") - } - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if err == nil { - if len(data) == 0 { - t.Error("Expected data but got empty") - } - - // Parse the packet to verify structure - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // Check Ethernet layer - if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { - eth := ethLayer.(*layers.Ethernet) - if !bytes.Equal(eth.SrcMAC, fromHW) { - t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, fromHW) - } - if !bytes.Equal(eth.DstMAC, toHW) { - t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, toHW) - } - expectedType := layers.EthernetTypeIPv4 - if tt.expectIPv6 { - expectedType = layers.EthernetTypeIPv6 - } - if eth.EthernetType != expectedType { - t.Errorf("EthernetType = %v, want %v", eth.EthernetType, expectedType) - } - } else { - t.Error("Packet missing Ethernet layer") - } - - // Check IP layer - if tt.expectIPv6 { - if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil { - ip := ipLayer.(*layers.IPv6) - if !ip.SrcIP.Equal(from) { - t.Errorf("IPv6 SrcIP = %v, want %v", ip.SrcIP, from) - } - if !ip.DstIP.Equal(to) { - t.Errorf("IPv6 DstIP = %v, want %v", ip.DstIP, to) - } - if ip.HopLimit != 64 { - t.Errorf("IPv6 HopLimit = %d, want 64", ip.HopLimit) - } - if ip.NextHeader != layers.IPProtocolTCP { - t.Errorf("IPv6 NextHeader = %v, want %v", ip.NextHeader, layers.IPProtocolTCP) - } - } else { - t.Error("Packet missing IPv6 layer") - } - } else { - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - if !ip.SrcIP.Equal(from) { - t.Errorf("IPv4 SrcIP = %v, want %v", ip.SrcIP, from) - } - if !ip.DstIP.Equal(to) { - t.Errorf("IPv4 DstIP = %v, want %v", ip.DstIP, to) - } - if ip.TTL != 64 { - t.Errorf("IPv4 TTL = %d, want 64", ip.TTL) - } - if ip.Protocol != layers.IPProtocolTCP { - t.Errorf("IPv4 Protocol = %v, want %v", ip.Protocol, layers.IPProtocolTCP) - } - } else { - t.Error("Packet missing IPv4 layer") - } - } - - // Check TCP layer - if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { - tcp := tcpLayer.(*layers.TCP) - if tcp.SrcPort != layers.TCPPort(tt.srcPort) { - t.Errorf("TCP SrcPort = %d, want %d", tcp.SrcPort, tt.srcPort) - } - if tcp.DstPort != layers.TCPPort(tt.dstPort) { - t.Errorf("TCP DstPort = %d, want %d", tcp.DstPort, tt.dstPort) - } - if !tcp.SYN { - t.Error("TCP SYN flag not set") - } - // Verify other flags are not set - if tcp.ACK || tcp.FIN || tcp.RST || tcp.PSH || tcp.URG { - t.Error("TCP has unexpected flags set") - } - } else { - t.Error("Packet missing TCP layer") - } - } - }) - } -} - -func TestNewTCPSynWithNilValues(t *testing.T) { - // Test with nil IPs - should return an error - err, data := NewTCPSyn(nil, nil, nil, nil, 12345, 80) - if err == nil { - t.Error("Expected error with nil values, but got none") - } - if len(data) != 0 { - t.Error("Expected no data with nil values") - } -} - -func TestNewTCPSynChecksumComputation(t *testing.T) { - // Test that checksums are computed correctly for both IPv4 and IPv6 - testCases := []struct { - name string - from string - to string - isIPv6 bool - }{ - { - name: "IPv4 checksum", - from: "192.168.1.1", - to: "192.168.1.2", - isIPv6: false, - }, - { - name: "IPv6 checksum", - from: "2001:db8::1", - to: "2001:db8::2", - isIPv6: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - from := net.ParseIP(tc.from) - to := net.ParseIP(tc.to) - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - toHW, _ := net.ParseMAC("11:22:33:44:55:66") - - err, data := NewTCPSyn(from, fromHW, to, toHW, 12345, 80) - if err != nil { - t.Fatalf("Failed to create TCP SYN: %v", err) - } - - // Parse the packet - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // Verify TCP checksum is non-zero (computed) - if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { - tcp := tcpLayer.(*layers.TCP) - if tcp.Checksum == 0 { - t.Error("TCP checksum was not computed") - } - } else { - t.Error("TCP layer not found") - } - - // For IPv4, also check IP checksum - if !tc.isIPv6 { - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - if ip.Checksum == 0 { - t.Error("IPv4 checksum was not computed") - } - } - } - }) - } -} - -func TestNewTCPSynPortRange(t *testing.T) { - // Test various port numbers including edge cases - portTests := []struct { - srcPort int - dstPort int - }{ - {0, 0}, // Minimum possible (though 0 is typically reserved) - {1, 1}, // Minimum valid - {80, 443}, // Common ports - {1024, 1025}, // First non-privileged ports - {32768, 32769}, // Common ephemeral port range start - {65534, 65535}, // Maximum ports - } - - from := net.ParseIP("192.168.1.1") - to := net.ParseIP("192.168.1.2") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - toHW, _ := net.ParseMAC("11:22:33:44:55:66") - - for _, pt := range portTests { - err, data := NewTCPSyn(from, fromHW, to, toHW, pt.srcPort, pt.dstPort) - if err != nil { - t.Errorf("Failed with ports %d->%d: %v", pt.srcPort, pt.dstPort, err) - continue - } - - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { - tcp := tcpLayer.(*layers.TCP) - if tcp.SrcPort != layers.TCPPort(pt.srcPort) { - t.Errorf("TCP SrcPort = %d, want %d", tcp.SrcPort, pt.srcPort) - } - if tcp.DstPort != layers.TCPPort(pt.dstPort) { - t.Errorf("TCP DstPort = %d, want %d", tcp.DstPort, pt.dstPort) - } - } - } -} - -// Benchmarks -func BenchmarkNewTCPSynIPv4(b *testing.B) { - from := net.ParseIP("192.168.1.1") - to := net.ParseIP("192.168.1.2") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - toHW, _ := net.ParseMAC("11:22:33:44:55:66") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = NewTCPSyn(from, fromHW, to, toHW, 12345, 80) - } -} - -func BenchmarkNewTCPSynIPv6(b *testing.B) { - from := net.ParseIP("2001:db8::1") - to := net.ParseIP("2001:db8::2") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - toHW, _ := net.ParseMAC("11:22:33:44:55:66") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = NewTCPSyn(from, fromHW, to, toHW, 12345, 80) - } -} diff --git a/packets/udp_test.go b/packets/udp_test.go deleted file mode 100644 index 11493ae5..00000000 --- a/packets/udp_test.go +++ /dev/null @@ -1,366 +0,0 @@ -package packets - -import ( - "bytes" - "net" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" -) - -func TestNewUDPProbe(t *testing.T) { - tests := []struct { - name string - from string - fromHW string - to string - port int - expectError bool - expectIPv6 bool - }{ - { - name: "IPv4 UDP probe", - from: "192.168.1.100", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "192.168.1.200", - port: 53, - expectError: false, - expectIPv6: false, - }, - { - name: "IPv6 UDP probe", - from: "2001:db8::1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "2001:db8::2", - port: 53, - expectError: false, - expectIPv6: true, - }, - { - name: "IPv4 with high port", - from: "10.0.0.1", - fromHW: "01:23:45:67:89:ab", - to: "10.0.0.2", - port: 65535, - expectError: false, - expectIPv6: false, - }, - { - name: "IPv6 link-local", - from: "fe80::1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "fe80::2", - port: 123, - expectError: false, - expectIPv6: true, - }, - { - name: "IPv4 loopback", - from: "127.0.0.1", - fromHW: "00:00:00:00:00:00", - to: "127.0.0.1", - port: 8080, - expectError: false, - expectIPv6: false, - }, - { - name: "IPv6 loopback", - from: "::1", - fromHW: "00:00:00:00:00:00", - to: "::1", - port: 8080, - expectError: false, - expectIPv6: true, - }, - { - name: "Port 0", - from: "192.168.1.1", - fromHW: "aa:bb:cc:dd:ee:ff", - to: "192.168.1.2", - port: 0, - expectError: false, - expectIPv6: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - from := net.ParseIP(tt.from) - fromHW, _ := net.ParseMAC(tt.fromHW) - to := net.ParseIP(tt.to) - - err, data := NewUDPProbe(from, fromHW, to, tt.port) - - if tt.expectError && err == nil { - t.Error("Expected error but got none") - } - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if err == nil { - if len(data) == 0 { - t.Error("Expected data but got empty") - } - - // Parse the packet to verify structure - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // Check Ethernet layer - if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { - eth := ethLayer.(*layers.Ethernet) - if !bytes.Equal(eth.SrcMAC, fromHW) { - t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, fromHW) - } - // Check broadcast destination MAC - expectedDstMAC := net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} - if !bytes.Equal(eth.DstMAC, expectedDstMAC) { - t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, expectedDstMAC) - } - // Note: The function always sets EthernetTypeIPv4, even for IPv6 - // This is a bug in the implementation but we test actual behavior - if eth.EthernetType != layers.EthernetTypeIPv4 { - t.Errorf("EthernetType = %v, want %v", eth.EthernetType, layers.EthernetTypeIPv4) - } - } else { - t.Error("Packet missing Ethernet layer") - } - - // For IPv6, the packet won't parse correctly due to wrong EthernetType - // We just verify the packet was created - if tt.expectIPv6 { - // Due to the bug, IPv6 packets won't parse correctly - // Just check that we got data - if len(data) == 0 { - t.Error("Expected packet data for IPv6") - } - } else { - // IPv4 should work correctly - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - if !ip.SrcIP.Equal(from) { - t.Errorf("IPv4 SrcIP = %v, want %v", ip.SrcIP, from) - } - if !ip.DstIP.Equal(to) { - t.Errorf("IPv4 DstIP = %v, want %v", ip.DstIP, to) - } - if ip.TTL != 64 { - t.Errorf("IPv4 TTL = %d, want 64", ip.TTL) - } - if ip.Protocol != layers.IPProtocolUDP { - t.Errorf("IPv4 Protocol = %v, want %v", ip.Protocol, layers.IPProtocolUDP) - } - } else { - t.Error("Packet missing IPv4 layer") - } - - // Check UDP layer for IPv4 - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp := udpLayer.(*layers.UDP) - if udp.SrcPort != 12345 { - t.Errorf("UDP SrcPort = %d, want 12345", udp.SrcPort) - } - if udp.DstPort != layers.UDPPort(tt.port) { - t.Errorf("UDP DstPort = %d, want %d", udp.DstPort, tt.port) - } - // Note: The payload is not properly parsed by gopacket - // This is likely due to how the packet is serialized - // We'll skip payload verification for now - _ = udp.Payload - } else { - t.Error("Packet missing UDP layer") - } - } - } - }) - } -} - -func TestNewUDPProbeWithNilValues(t *testing.T) { - // Test with nil IPs - should return an error - err, data := NewUDPProbe(nil, nil, nil, 53) - if err == nil { - t.Error("Expected error with nil values, but got none") - } - if len(data) != 0 { - t.Error("Expected no data with nil values") - } -} - -func TestNewUDPProbePayload(t *testing.T) { - from := net.ParseIP("192.168.1.1") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - to := net.ParseIP("192.168.1.2") - - err, data := NewUDPProbe(from, fromHW, to, 53) - if err != nil { - t.Fatalf("Failed to create UDP probe: %v", err) - } - - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - _ = udpLayer.(*layers.UDP) // UDP layer exists, payload check below - } else { - t.Error("UDP layer not found") - } - - // Note: The payload is not properly parsed by gopacket - // This is likely due to how the packet is serialized - // We'll just verify the packet was created successfully - t.Log("UDP packet created successfully") -} - -func TestNewUDPProbeChecksumComputation(t *testing.T) { - // Test that checksums are computed correctly for both IPv4 and IPv6 - testCases := []struct { - name string - from string - to string - isIPv6 bool - }{ - { - name: "IPv4 checksum", - from: "192.168.1.1", - to: "192.168.1.2", - isIPv6: false, - }, - { - name: "IPv6 checksum", - from: "2001:db8::1", - to: "2001:db8::2", - isIPv6: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - from := net.ParseIP(tc.from) - to := net.ParseIP(tc.to) - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - err, data := NewUDPProbe(from, fromHW, to, 53) - if err != nil { - t.Fatalf("Failed to create UDP probe: %v", err) - } - - // Parse the packet - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - // For IPv6, the packet won't parse correctly due to wrong EthernetType - if tc.isIPv6 { - // Just verify we got data - if len(data) == 0 { - t.Error("Expected packet data for IPv6") - } - } else { - // Verify UDP checksum is non-zero (computed) for IPv4 - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp := udpLayer.(*layers.UDP) - if udp.Checksum == 0 { - t.Error("UDP checksum was not computed") - } - } else { - t.Error("UDP layer not found") - } - - // For IPv4, also check IP checksum - if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - if ip.Checksum == 0 { - t.Error("IPv4 checksum was not computed") - } - } - } - }) - } -} - -func TestNewUDPProbePortRange(t *testing.T) { - // Test various port numbers including edge cases - portTests := []int{ - 0, // Minimum - 1, // Minimum valid - 53, // DNS - 123, // NTP - 161, // SNMP - 500, // IKE - 1024, // First non-privileged - 5353, // mDNS - 8080, // Common alternative HTTP - 32768, // Common ephemeral port range start - 65535, // Maximum - } - - from := net.ParseIP("192.168.1.1") - to := net.ParseIP("192.168.1.2") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - for _, port := range portTests { - err, data := NewUDPProbe(from, fromHW, to, port) - if err != nil { - t.Errorf("Failed with port %d: %v", port, err) - continue - } - - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { - udp := udpLayer.(*layers.UDP) - if udp.DstPort != layers.UDPPort(port) { - t.Errorf("UDP DstPort = %d, want %d", udp.DstPort, port) - } - // Source port should always be 12345 - if udp.SrcPort != 12345 { - t.Errorf("UDP SrcPort = %d, want 12345", udp.SrcPort) - } - } - } -} - -func TestNewUDPProbeBroadcastMAC(t *testing.T) { - // Test that destination MAC is always broadcast - from := net.ParseIP("192.168.1.1") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - to := net.ParseIP("192.168.1.255") // Broadcast IP - - err, data := NewUDPProbe(from, fromHW, to, 53) - if err != nil { - t.Fatalf("Failed to create UDP probe: %v", err) - } - - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default) - - if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { - eth := ethLayer.(*layers.Ethernet) - expectedMAC := net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} - if !bytes.Equal(eth.DstMAC, expectedMAC) { - t.Errorf("Ethernet DstMAC = %v, want broadcast %v", eth.DstMAC, expectedMAC) - } - } else { - t.Error("Ethernet layer not found") - } -} - -// Benchmarks -func BenchmarkNewUDPProbeIPv4(b *testing.B) { - from := net.ParseIP("192.168.1.1") - to := net.ParseIP("192.168.1.2") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = NewUDPProbe(from, fromHW, to, 53) - } -} - -func BenchmarkNewUDPProbeIPv6(b *testing.B) { - from := net.ParseIP("2001:db8::1") - to := net.ParseIP("2001:db8::2") - fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = NewUDPProbe(from, fromHW, to, 53) - } -} diff --git a/routing/route_test.go b/routing/route_test.go deleted file mode 100644 index ac99ad9a..00000000 --- a/routing/route_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package routing - -import ( - "testing" -) - -func TestRouteType(t *testing.T) { - // Test the RouteType constants - if IPv4 != RouteType("IPv4") { - t.Errorf("IPv4 constant has wrong value: %s", IPv4) - } - if IPv6 != RouteType("IPv6") { - t.Errorf("IPv6 constant has wrong value: %s", IPv6) - } -} - -func TestRouteStruct(t *testing.T) { - tests := []struct { - name string - route Route - }{ - { - name: "IPv4 default route", - route: Route{ - Type: IPv4, - Default: true, - Device: "eth0", - Destination: "0.0.0.0", - Gateway: "192.168.1.1", - Flags: "UG", - }, - }, - { - name: "IPv4 network route", - route: Route{ - Type: IPv4, - Default: false, - Device: "eth0", - Destination: "192.168.1.0/24", - Gateway: "", - Flags: "U", - }, - }, - { - name: "IPv6 default route", - route: Route{ - Type: IPv6, - Default: true, - Device: "eth0", - Destination: "::/0", - Gateway: "fe80::1", - Flags: "UG", - }, - }, - { - name: "IPv6 link-local route", - route: Route{ - Type: IPv6, - Default: false, - Device: "eth0", - Destination: "fe80::/64", - Gateway: "", - Flags: "U", - }, - }, - { - name: "localhost route", - route: Route{ - Type: IPv4, - Default: false, - Device: "lo", - Destination: "127.0.0.0/8", - Gateway: "", - Flags: "U", - }, - }, - { - name: "VPN route", - route: Route{ - Type: IPv4, - Default: false, - Device: "tun0", - Destination: "10.8.0.0/24", - Gateway: "", - Flags: "U", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test that all fields are accessible - _ = tt.route.Type - _ = tt.route.Default - _ = tt.route.Device - _ = tt.route.Destination - _ = tt.route.Gateway - _ = tt.route.Flags - - // Verify the route has the expected type - if tt.route.Type != IPv4 && tt.route.Type != IPv6 { - t.Errorf("route has invalid type: %s", tt.route.Type) - } - }) - } -} - -func TestRouteDefaultFlag(t *testing.T) { - // Test routes with different default flag settings - defaultRoute := Route{ - Type: IPv4, - Default: true, - Device: "eth0", - Destination: "0.0.0.0", - Gateway: "192.168.1.1", - Flags: "UG", - } - - normalRoute := Route{ - Type: IPv4, - Default: false, - Device: "eth0", - Destination: "192.168.1.0/24", - Gateway: "", - Flags: "U", - } - - if !defaultRoute.Default { - t.Error("default route should have Default=true") - } - - if normalRoute.Default { - t.Error("normal route should have Default=false") - } -} - -func TestRouteTypeString(t *testing.T) { - // Test that RouteType can be converted to string - ipv4Str := string(IPv4) - ipv6Str := string(IPv6) - - if ipv4Str != "IPv4" { - t.Errorf("IPv4 string conversion failed: got %s", ipv4Str) - } - - if ipv6Str != "IPv6" { - t.Errorf("IPv6 string conversion failed: got %s", ipv6Str) - } -} - -func TestRouteTypeComparison(t *testing.T) { - // Test RouteType comparisons - var rt1 RouteType = IPv4 - var rt2 RouteType = IPv4 - var rt3 RouteType = IPv6 - - if rt1 != rt2 { - t.Error("identical RouteType values should be equal") - } - - if rt1 == rt3 { - t.Error("different RouteType values should not be equal") - } -} - -func TestRouteTypeCustomValues(t *testing.T) { - // Test that custom RouteType values can be created - customType := RouteType("Custom") - - if customType == IPv4 || customType == IPv6 { - t.Error("custom RouteType should not equal predefined constants") - } - - if string(customType) != "Custom" { - t.Errorf("custom RouteType string conversion failed: got %s", customType) - } -} - -func TestRouteWithEmptyFields(t *testing.T) { - // Test route with empty fields - emptyRoute := Route{} - - if emptyRoute.Type != "" { - t.Errorf("empty route Type should be empty string, got %s", emptyRoute.Type) - } - - if emptyRoute.Default != false { - t.Error("empty route Default should be false") - } - - if emptyRoute.Device != "" { - t.Errorf("empty route Device should be empty string, got %s", emptyRoute.Device) - } - - if emptyRoute.Destination != "" { - t.Errorf("empty route Destination should be empty string, got %s", emptyRoute.Destination) - } - - if emptyRoute.Gateway != "" { - t.Errorf("empty route Gateway should be empty string, got %s", emptyRoute.Gateway) - } - - if emptyRoute.Flags != "" { - t.Errorf("empty route Flags should be empty string, got %s", emptyRoute.Flags) - } -} - -func TestRouteFieldAssignment(t *testing.T) { - // Test that route fields can be assigned individually - r := Route{} - - r.Type = IPv6 - r.Default = true - r.Device = "wlan0" - r.Destination = "2001:db8::/32" - r.Gateway = "fe80::1" - r.Flags = "UGH" - - if r.Type != IPv6 { - t.Errorf("Type assignment failed: got %s", r.Type) - } - - if !r.Default { - t.Error("Default assignment failed") - } - - if r.Device != "wlan0" { - t.Errorf("Device assignment failed: got %s", r.Device) - } - - if r.Destination != "2001:db8::/32" { - t.Errorf("Destination assignment failed: got %s", r.Destination) - } - - if r.Gateway != "fe80::1" { - t.Errorf("Gateway assignment failed: got %s", r.Gateway) - } - - if r.Flags != "UGH" { - t.Errorf("Flags assignment failed: got %s", r.Flags) - } -} - -func TestRouteArrayOperations(t *testing.T) { - // Test operations on arrays of routes - routes := []Route{ - { - Type: IPv4, - Default: true, - Device: "eth0", - Destination: "0.0.0.0", - Gateway: "192.168.1.1", - Flags: "UG", - }, - { - Type: IPv4, - Default: false, - Device: "eth0", - Destination: "192.168.1.0/24", - Gateway: "", - Flags: "U", - }, - { - Type: IPv6, - Default: false, - Device: "eth0", - Destination: "fe80::/64", - Gateway: "", - Flags: "U", - }, - } - - // Test array length - if len(routes) != 3 { - t.Errorf("expected 3 routes, got %d", len(routes)) - } - - // Count IPv4 vs IPv6 routes - ipv4Count := 0 - ipv6Count := 0 - defaultCount := 0 - - for _, r := range routes { - switch r.Type { - case IPv4: - ipv4Count++ - case IPv6: - ipv6Count++ - } - - if r.Default { - defaultCount++ - } - } - - if ipv4Count != 2 { - t.Errorf("expected 2 IPv4 routes, got %d", ipv4Count) - } - - if ipv6Count != 1 { - t.Errorf("expected 1 IPv6 route, got %d", ipv6Count) - } - - if defaultCount != 1 { - t.Errorf("expected 1 default route, got %d", defaultCount) - } -} - -func BenchmarkRouteCreation(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Route{ - Type: IPv4, - Default: true, - Device: "eth0", - Destination: "0.0.0.0", - Gateway: "192.168.1.1", - Flags: "UG", - } - } -} - -func BenchmarkRouteTypeComparison(b *testing.B) { - rt1 := IPv4 - rt2 := IPv6 - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = rt1 == rt2 - } -} - -func BenchmarkRouteArrayIteration(b *testing.B) { - routes := make([]Route, 100) - for i := range routes { - if i%2 == 0 { - routes[i].Type = IPv4 - } else { - routes[i].Type = IPv6 - } - routes[i].Device = "eth0" - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - count := 0 - for _, r := range routes { - if r.Type == IPv4 { - count++ - } - } - _ = count - } -} diff --git a/routing/tables.go b/routing/tables.go index 1023ff3b..fcb9f043 100644 --- a/routing/tables.go +++ b/routing/tables.go @@ -21,12 +21,7 @@ func Update() ([]Route, error) { func Gateway(ip RouteType, device string) (string, error) { Update() - return gatewayFromTable(ip, device) -} -// gatewayFromTable finds the gateway from the current table without updating it -// This allows testing with controlled table data -func gatewayFromTable(ip RouteType, device string) (string, error) { lock.RLock() defer lock.RUnlock() diff --git a/routing/tables_test.go b/routing/tables_test.go deleted file mode 100644 index 761f1356..00000000 --- a/routing/tables_test.go +++ /dev/null @@ -1,387 +0,0 @@ -package routing - -import ( - "fmt" - "sync" - "testing" -) - -// Helper function to reset the table for testing -func resetTable() { - lock.Lock() - defer lock.Unlock() - table = make([]Route, 0) -} - -// Helper function to add routes for testing -func addTestRoutes() { - lock.Lock() - defer lock.Unlock() - table = []Route{ - { - Type: IPv4, - Default: true, - Device: "eth0", - Destination: "0.0.0.0", - Gateway: "192.168.1.1", - Flags: "UG", - }, - { - Type: IPv4, - Default: false, - Device: "eth0", - Destination: "192.168.1.0/24", - Gateway: "", - Flags: "U", - }, - { - Type: IPv6, - Default: true, - Device: "eth0", - Destination: "::/0", - Gateway: "fe80::1", - Flags: "UG", - }, - { - Type: IPv6, - Default: false, - Device: "eth0", - Destination: "fe80::/64", - Gateway: "", - Flags: "U", - }, - { - Type: IPv4, - Default: false, - Device: "lo", - Destination: "127.0.0.0/8", - Gateway: "", - Flags: "U", - }, - { - Type: IPv4, - Default: true, - Device: "wlan0", - Destination: "0.0.0.0", - Gateway: "10.0.0.1", - Flags: "UG", - }, - } -} - -func TestTable(t *testing.T) { - // Reset table - resetTable() - - // Test empty table - routes := Table() - if len(routes) != 0 { - t.Errorf("Expected empty table, got %d routes", len(routes)) - } - - // Add test routes - addTestRoutes() - - // Test table with routes - routes = Table() - if len(routes) != 6 { - t.Errorf("Expected 6 routes, got %d", len(routes)) - } - - // Verify first route - if routes[0].Type != IPv4 { - t.Errorf("Expected first route to be IPv4, got %s", routes[0].Type) - } - if !routes[0].Default { - t.Error("Expected first route to be default") - } - if routes[0].Gateway != "192.168.1.1" { - t.Errorf("Expected gateway 192.168.1.1, got %s", routes[0].Gateway) - } -} - -func TestGateway(t *testing.T) { - // Note: Gateway() calls Update() which loads real system routes - // So we can't test specific values, just test the behavior - - // Test IPv4 gateway - gateway, err := Gateway(IPv4, "") - if err != nil { - t.Errorf("Unexpected error getting IPv4 gateway: %v", err) - } - t.Logf("System IPv4 gateway: %s", gateway) - - // Test IPv6 gateway - gateway, err = Gateway(IPv6, "") - if err != nil { - t.Errorf("Unexpected error getting IPv6 gateway: %v", err) - } - t.Logf("System IPv6 gateway: %s", gateway) - - // Test with specific device that likely doesn't exist - gateway, err = Gateway(IPv4, "nonexistent999") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - // Should return empty string for non-existent device - if gateway != "" { - t.Logf("Got gateway for non-existent device (might be Windows): %s", gateway) - } -} - -func TestGatewayBehavior(t *testing.T) { - // Test that Gateway doesn't panic with various inputs - testCases := []struct { - name string - ipType RouteType - device string - }{ - {"IPv4 empty device", IPv4, ""}, - {"IPv6 empty device", IPv6, ""}, - {"IPv4 with device", IPv4, "eth0"}, - {"IPv6 with device", IPv6, "eth0"}, - {"Custom type", RouteType("custom"), ""}, - {"Empty type", RouteType(""), ""}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gateway, err := Gateway(tc.ipType, tc.device) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - t.Logf("Gateway for %s: %s", tc.name, gateway) - }) - } -} - -func TestGatewayEmptyTable(t *testing.T) { - // Test with empty table - resetTable() - - gateway, err := gatewayFromTable(IPv4, "eth0") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if gateway != "" { - t.Errorf("Expected empty gateway, got %s", gateway) - } -} - -func TestGatewayNoDefaultRoute(t *testing.T) { - // Test with routes but no default - resetTable() - - lock.Lock() - table = []Route{ - { - Type: IPv4, - Default: false, - Device: "eth0", - Destination: "192.168.1.0/24", - Gateway: "", - Flags: "U", - }, - } - lock.Unlock() - - gateway, err := gatewayFromTable(IPv4, "eth0") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if gateway != "" { - t.Errorf("Expected empty gateway, got %s", gateway) - } -} - -func TestGatewayWindowsCase(t *testing.T) { - // Since Gateway() calls Update(), we can't control the table content - // Just test that it doesn't panic and returns something - gateway, err := Gateway(IPv4, "eth0") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - t.Logf("Gateway result for eth0: %s", gateway) -} - -func TestGatewayFromTableWithDefaults(t *testing.T) { - // Test gatewayFromTable with controlled data containing defaults - resetTable() - addTestRoutes() - - gateway, err := gatewayFromTable(IPv4, "eth0") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if gateway != "192.168.1.1" { - t.Errorf("Expected gateway 192.168.1.1, got %s", gateway) - } - - // Test with device-specific lookup - gateway, err = gatewayFromTable(IPv4, "wlan0") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if gateway != "10.0.0.1" { - t.Errorf("Expected gateway 10.0.0.1, got %s", gateway) - } -} - -func TestTableConcurrency(t *testing.T) { - // Test concurrent access to Table() - resetTable() - addTestRoutes() - - var wg sync.WaitGroup - errors := make(chan error, 100) - - // Multiple readers - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 100; j++ { - routes := Table() - if len(routes) != 6 { - select { - case errors <- fmt.Errorf("Expected 6 routes, got %d", len(routes)): - default: - } - } - } - }() - } - - wg.Wait() - close(errors) - - // Check for errors - for err := range errors { - if err != nil { - t.Error(err) - } - } -} - -func TestGatewayConcurrency(t *testing.T) { - // Test concurrent access to Gateway() - var wg sync.WaitGroup - errors := make(chan error, 100) - - // Multiple readers calling Gateway concurrently - for i := 0; i < 10; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < 50; j++ { - _, err := Gateway(IPv4, "") - if err != nil { - select { - case errors <- fmt.Errorf("goroutine %d: error: %v", id, err): - default: - } - } - } - }(i) - } - - wg.Wait() - close(errors) - - // Check for errors - errorCount := 0 - for err := range errors { - if err != nil { - errorCount++ - if errorCount <= 5 { // Only log first 5 errors - t.Error(err) - } - } - } - if errorCount > 5 { - t.Errorf("... and %d more errors", errorCount-5) - } -} - -func TestUpdate(t *testing.T) { - // Note: Update() calls platform-specific update() function - // which we can't easily test without mocking - // But we can test that it doesn't panic and returns something - resetTable() - - routes, err := Update() - // The error might be nil or non-nil depending on the platform - // and whether we have permissions to read routing table - if err == nil && routes != nil { - t.Logf("Update returned %d routes", len(routes)) - } else if err != nil { - t.Logf("Update returned error (expected on some platforms): %v", err) - } -} - -func TestGatewayMultipleDefaults(t *testing.T) { - // Since Gateway() calls Update() and loads real routes, - // we can't test specific scenarios with multiple defaults - // Just ensure it handles the real system state without panicking - - // Call Gateway multiple times to ensure consistency - gateway1, err1 := Gateway(IPv4, "") - gateway2, err2 := Gateway(IPv4, "") - - if err1 != nil { - t.Errorf("First call error: %v", err1) - } - if err2 != nil { - t.Errorf("Second call error: %v", err2) - } - - // Results should be consistent - if gateway1 != gateway2 { - t.Errorf("Inconsistent results: first=%s, second=%s", gateway1, gateway2) - } - - t.Logf("Consistent gateway result: %s", gateway1) -} - -// Benchmark tests -func BenchmarkTable(b *testing.B) { - resetTable() - addTestRoutes() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = Table() - } -} - -func BenchmarkGateway(b *testing.B) { - resetTable() - addTestRoutes() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = Gateway(IPv4, "eth0") - } -} - -func BenchmarkTableConcurrent(b *testing.B) { - resetTable() - addTestRoutes() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _ = Table() - } - }) -} - -func BenchmarkGatewayConcurrent(b *testing.B) { - resetTable() - addTestRoutes() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, _ = Gateway(IPv4, "eth0") - } - }) -} diff --git a/session/module_param_test.go b/session/module_param_test.go deleted file mode 100644 index 0938c827..00000000 --- a/session/module_param_test.go +++ /dev/null @@ -1,478 +0,0 @@ -package session - -import ( - "regexp" - "strings" - "testing" -) - -func TestNewModuleParameter(t *testing.T) { - tests := []struct { - name string - paramName string - defValue string - paramType ParamType - validator string - desc string - }{ - { - name: "string parameter with validator", - paramName: "test.param", - defValue: "default", - paramType: STRING, - validator: "^[a-z]+$", - desc: "A test parameter", - }, - { - name: "int parameter without validator", - paramName: "test.int", - defValue: "42", - paramType: INT, - validator: "", - desc: "An integer parameter", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := NewModuleParameter(tt.paramName, tt.defValue, tt.paramType, tt.validator, tt.desc) - - if p.Name != tt.paramName { - t.Errorf("expected name %s, got %s", tt.paramName, p.Name) - } - if p.Value != tt.defValue { - t.Errorf("expected value %s, got %s", tt.defValue, p.Value) - } - if p.Type != tt.paramType { - t.Errorf("expected type %v, got %v", tt.paramType, p.Type) - } - if p.Description != tt.desc { - t.Errorf("expected description %s, got %s", tt.desc, p.Description) - } - - if tt.validator != "" && p.Validator == nil { - t.Error("expected validator to be set") - } - if tt.validator == "" && p.Validator != nil { - t.Error("expected validator to be nil") - } - }) - } -} - -func TestNewStringParameter(t *testing.T) { - p := NewStringParameter("test.string", "hello", "^[a-z]+$", "A string param") - - if p.Type != STRING { - t.Errorf("expected type STRING, got %v", p.Type) - } - if p.Validator == nil { - t.Error("expected validator to be set") - } -} - -func TestNewBoolParameter(t *testing.T) { - p := NewBoolParameter("test.bool", "true", "A boolean param") - - if p.Type != BOOL { - t.Errorf("expected type BOOL, got %v", p.Type) - } - if p.Validator == nil || p.Validator.String() != "^(true|false)$" { - t.Error("expected boolean validator to be set") - } -} - -func TestNewIntParameter(t *testing.T) { - p := NewIntParameter("test.int", "123", "An integer param") - - if p.Type != INT { - t.Errorf("expected type INT, got %v", p.Type) - } - if p.Validator == nil { - t.Error("expected integer validator to be set") - } -} - -func TestNewDecimalParameter(t *testing.T) { - p := NewDecimalParameter("test.decimal", "3.14", "A decimal param") - - if p.Type != FLOAT { - t.Errorf("expected type FLOAT, got %v", p.Type) - } - if p.Validator == nil { - t.Error("expected decimal validator to be set") - } -} - -func TestModuleParamValidate(t *testing.T) { - tests := []struct { - name string - param *ModuleParam - value string - wantError bool - expected interface{} - }{ - // String tests - { - name: "valid string without validator", - param: &ModuleParam{ - Name: "test", - Type: STRING, - }, - value: "any string", - wantError: false, - expected: "any string", - }, - { - name: "valid string with validator", - param: &ModuleParam{ - Name: "test", - Type: STRING, - Validator: regexp.MustCompile("^[a-z]+$"), - }, - value: "hello", - wantError: false, - expected: "hello", - }, - { - name: "invalid string with validator", - param: &ModuleParam{ - Name: "test", - Type: STRING, - Validator: regexp.MustCompile("^[a-z]+$"), - }, - value: "Hello123", - wantError: true, - }, - // Bool tests - { - name: "valid bool true", - param: &ModuleParam{ - Name: "test", - Type: BOOL, - Validator: regexp.MustCompile("^(true|false)$"), - }, - value: "true", - wantError: false, - expected: true, - }, - { - name: "valid bool false", - param: &ModuleParam{ - Name: "test", - Type: BOOL, - Validator: regexp.MustCompile("^(true|false)$"), - }, - value: "false", - wantError: false, - expected: false, - }, - { - name: "valid bool uppercase", - param: &ModuleParam{ - Name: "test", - Type: BOOL, - }, - value: "TRUE", - wantError: false, - expected: true, - }, - { - name: "invalid bool", - param: &ModuleParam{ - Name: "test", - Type: BOOL, - }, - value: "yes", - wantError: true, - }, - // Int tests - { - name: "valid positive int", - param: &ModuleParam{ - Name: "test", - Type: INT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`), - }, - value: "123", - wantError: false, - expected: 123, - }, - { - name: "valid negative int", - param: &ModuleParam{ - Name: "test", - Type: INT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`), - }, - value: "-456", - wantError: false, - expected: -456, - }, - { - name: "valid int with plus", - param: &ModuleParam{ - Name: "test", - Type: INT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`), - }, - value: "+789", - wantError: false, - expected: 789, - }, - { - name: "invalid int", - param: &ModuleParam{ - Name: "test", - Type: INT, - }, - value: "12.34", - wantError: true, - }, - // Float tests - { - name: "valid float", - param: &ModuleParam{ - Name: "test", - Type: FLOAT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+(\.\d+)?$`), - }, - value: "3.14", - wantError: false, - expected: 3.14, - }, - { - name: "valid float without decimal", - param: &ModuleParam{ - Name: "test", - Type: FLOAT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+(\.\d+)?$`), - }, - value: "42", - wantError: false, - expected: 42.0, - }, - { - name: "valid negative float", - param: &ModuleParam{ - Name: "test", - Type: FLOAT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+(\.\d+)?$`), - }, - value: "-2.718", - wantError: false, - expected: -2.718, - }, - { - name: "invalid float", - param: &ModuleParam{ - Name: "test", - Type: FLOAT, - }, - value: "3.14.15", - wantError: true, - }, - // Invalid type test - { - name: "invalid type", - param: &ModuleParam{ - Name: "test", - Type: ParamType(999), - }, - value: "anything", - wantError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err, result := tt.param.validate(tt.value) - - if tt.wantError { - if err == nil { - t.Error("expected error but got none") - } - } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if result != tt.expected { - t.Errorf("expected %v (%T), got %v (%T)", tt.expected, tt.expected, result, result) - } - } - }) - } -} - -func TestModuleParamHelp(t *testing.T) { - p := &ModuleParam{ - Name: "test.param", - Description: "A test parameter", - Value: "default", - } - - help := p.Help(15) - - // Check that help contains the name - if !strings.Contains(help, "test.param") { - t.Error("help should contain parameter name") - } - - // Check that help contains the description - if !strings.Contains(help, "A test parameter") { - t.Error("help should contain parameter description") - } - - // Check that help contains the default value - if !strings.Contains(help, "default=default") { - t.Error("help should contain default value") - } -} - -func TestParseSpecialValues(t *testing.T) { - // Test the special parameter constants - tests := []struct { - name string - value string - isSpecial bool - }{ - { - name: "interface name", - value: ParamIfaceName, - isSpecial: true, - }, - { - name: "interface address", - value: ParamIfaceAddress, - isSpecial: true, - }, - { - name: "interface address6", - value: ParamIfaceAddress6, - isSpecial: true, - }, - { - name: "interface mac", - value: ParamIfaceMac, - isSpecial: true, - }, - { - name: "subnet", - value: ParamSubnet, - isSpecial: true, - }, - { - name: "random mac", - value: ParamRandomMAC, - isSpecial: true, - }, - { - name: "normal value", - value: "192.168.1.1", - isSpecial: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.isSpecial { - // Special values should be in angle brackets - if !strings.HasPrefix(tt.value, "<") || !strings.HasSuffix(tt.value, ">") { - t.Errorf("special value %s should be in angle brackets", tt.value) - } - } - }) - } -} - -func TestParamIfaceNameParser(t *testing.T) { - tests := []struct { - name string - input string - matches bool - ifaceName string - }{ - { - name: "valid interface name", - input: "", - matches: true, - ifaceName: "eth0", - }, - { - name: "valid interface with numbers", - input: "", - matches: true, - ifaceName: "wlan1", - }, - { - name: "long interface name", - input: "", - matches: true, - ifaceName: "enp0s31f6", - }, - { - name: "no angle brackets", - input: "eth0", - matches: false, - }, - { - name: "invalid characters", - input: "", - matches: false, - }, - { - name: "too short", - input: "", - matches: false, - }, - { - name: "too long", - input: "", - matches: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - matches := ParamIfaceNameParser.FindStringSubmatch(tt.input) - - if tt.matches { - if len(matches) != 2 { - t.Errorf("expected to match interface name pattern, got %v", matches) - } else if matches[1] != tt.ifaceName { - t.Errorf("expected interface name %s, got %s", tt.ifaceName, matches[1]) - } - } else { - if len(matches) > 0 { - t.Errorf("expected no match, but got %v", matches) - } - } - }) - } -} - -func BenchmarkModuleParamValidate(b *testing.B) { - p := &ModuleParam{ - Name: "test", - Type: STRING, - Validator: regexp.MustCompile("^[a-z]+$"), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - p.validate("hello") - } -} - -func BenchmarkModuleParamValidateInt(b *testing.B) { - p := &ModuleParam{ - Name: "test", - Type: INT, - Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - p.validate("12345") - } -} diff --git a/session/session.go b/session/session.go index df597b60..983ef1a2 100644 --- a/session/session.go +++ b/session/session.go @@ -194,9 +194,7 @@ func (s *Session) Close() { } } - if s.Firewall != nil { - s.Firewall.Restore() - } + s.Firewall.Restore() if *s.Options.EnvFile != "" { envFile, _ := fs.Expand(*s.Options.EnvFile) diff --git a/session/session_core_handlers.go b/session/session_core_handlers.go index 9d71e7a0..2b47f641 100644 --- a/session/session_core_handlers.go +++ b/session/session_core_handlers.go @@ -13,14 +13,11 @@ import ( "time" "github.com/bettercap/bettercap/v2/core" - "github.com/bettercap/bettercap/v2/log" "github.com/bettercap/bettercap/v2/network" "github.com/bettercap/readline" "github.com/evilsocket/islazy/str" "github.com/evilsocket/islazy/tui" - - "github.com/robertkrimen/otto" ) func (s *Session) generalHelp() { @@ -158,14 +155,6 @@ func (s *Session) activeHandler(args []string, sess *Session) error { } func (s *Session) exitHandler(args []string, sess *Session) error { - if s.script != nil { - if s.script.Plugin.HasFunc("onExit") { - if _, err := s.script.Plugin.Call("onExit"); err != nil { - log.Error("Error while executing onExit callback: %s", "\nTraceback:\n "+err.(*otto.Error).String()) - } - } - } - // notify any listener that the session is about to end s.Events.Add("session.stopped", nil) diff --git a/tls/tls_test.go b/tls/tls_test.go deleted file mode 100644 index 556b0b1c..00000000 --- a/tls/tls_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package tls - -import ( - "crypto/x509" - "encoding/pem" - "io/ioutil" - "os" - "path/filepath" - "testing" - - "github.com/bettercap/bettercap/v2/session" -) - -func TestCertConfigToModule(t *testing.T) { - prefix := "test" - defaults := DefaultLegitConfig - - dummyEnv, err := session.NewEnvironment("") - if err != nil { - t.Fatal(err) - } - dummySession := &session.Session{Env: dummyEnv} - m := session.NewSessionModule(prefix, dummySession) - - CertConfigToModule(prefix, &m, defaults) - - // Check if parameters were added - if len(m.Parameters()) != 6 { - t.Errorf("expected 6 parameters, got %d", len(m.Parameters())) - } -} - -func TestCertConfigFromModule(t *testing.T) { - dummyEnv, err := session.NewEnvironment("") - if err != nil { - t.Fatal(err) - } - dummySession := &session.Session{Env: dummyEnv} - m := session.NewSessionModule("test", dummySession) - prefix := "test" - - // Set some parameters - m.AddParam(session.NewIntParameter(prefix+".certificate.bits", "2048", "dummy desc")) - m.AddParam(session.NewStringParameter(prefix+".certificate.country", "TestCountry", ".*", "dummy desc")) - m.AddParam(session.NewStringParameter(prefix+".certificate.locality", "TestLocality", ".*", "dummy desc")) - m.AddParam(session.NewStringParameter(prefix+".certificate.organization", "TestOrg", ".*", "dummy desc")) - m.AddParam(session.NewStringParameter(prefix+".certificate.organizationalunit", "TestUnit", ".*", "dummy desc")) - m.AddParam(session.NewStringParameter(prefix+".certificate.commonname", "TestCN", ".*", "dummy desc")) - - cfg, err := CertConfigFromModule(prefix, m) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if cfg.Bits != 2048 || cfg.Country != "TestCountry" || cfg.Locality != "TestLocality" || - cfg.Organization != "TestOrg" || cfg.OrganizationalUnit != "TestUnit" || cfg.CommonName != "TestCN" { - t.Error("config not parsed correctly") - } -} - -func TestCreateCertificate(t *testing.T) { - cfg := DefaultLegitConfig - cfg.Bits = 1024 // smaller for test - - priv, certBytes, err := CreateCertificate(cfg, true) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if priv == nil { - t.Error("private key is nil") - } - if len(certBytes) == 0 { - t.Error("cert bytes empty") - } - - // Parse to verify - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - t.Errorf("could not parse cert: %v", err) - } - if cert.Subject.CommonName != cfg.CommonName { - t.Errorf("common name mismatch: %s != %s", cert.Subject.CommonName, cfg.CommonName) - } - if !cert.IsCA { - t.Error("not CA") - } -} - -func TestGenerate(t *testing.T) { - tempDir, err := ioutil.TempDir("", "tlstest") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - certPath := filepath.Join(tempDir, "test.cert") - keyPath := filepath.Join(tempDir, "test.key") - - cfg := DefaultLegitConfig - cfg.Bits = 1024 - - err = Generate(cfg, certPath, keyPath, false) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - // Check files exist - if _, err := os.Stat(certPath); os.IsNotExist(err) { - t.Error("cert file not created") - } - if _, err := os.Stat(keyPath); os.IsNotExist(err) { - t.Error("key file not created") - } - - // Load and verify - certBytes, _ := ioutil.ReadFile(certPath) - keyBytes, _ := ioutil.ReadFile(keyPath) - - certBlock, _ := pem.Decode(certBytes) - if certBlock == nil || certBlock.Type != "CERTIFICATE" { - t.Error("invalid cert PEM") - } - - keyBlock, _ := pem.Decode(keyBytes) - if keyBlock == nil || keyBlock.Type != "RSA PRIVATE KEY" { - t.Error("invalid key PEM") - } - - priv, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes) - if err != nil { - t.Errorf("invalid private key: %v", err) - } - if priv.N.BitLen() != 1024 { - t.Errorf("key bits mismatch: %d", priv.N.BitLen()) - } -}