diff --git a/.gitattributes b/.gitattributes
deleted file mode 100644
index e236489d..00000000
--- a/.gitattributes
+++ /dev/null
@@ -1,4 +0,0 @@
-*.js linguist-vendored
-/Dockerfile linguist-vendored
-/release.py linguist-vendored
-/**/*.js linguist-vendored
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
deleted file mode 100644
index 05551636..00000000
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ /dev/null
@@ -1,5 +0,0 @@
-blank_issues_enabled: false
-contact_links:
- - name: Bettercap Documentation
- url: https://www.bettercap.org/
- about: Please read the instructions before asking for help.
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
deleted file mode 100644
index c78a0857..00000000
--- a/.github/dependabot.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-version: 2
-updates:
- # GitHub Actions
- - package-ecosystem: github-actions
- directory: /
- schedule:
- interval: daily
diff --git a/.github/workflows/build-and-deploy.yml b/.github/workflows/build-and-deploy.yml
index a9a770f0..a8f72dbd 100644
--- a/.github/workflows/build-and-deploy.yml
+++ b/.github/workflows/build-and-deploy.yml
@@ -8,57 +8,56 @@ on:
jobs:
build:
- name: ${{ matrix.os.pretty }} ${{ matrix.arch }}
- runs-on: ${{ matrix.os.runs-on }}
+ runs-on: ${{ matrix.os }}
strategy:
matrix:
- os:
- - name: darwin
- runs-on: [macos-latest]
- pretty: 🍎 macOS
- - name: linux
- runs-on: [ubuntu-latest]
- pretty: 🐧 Linux
- - name: windows
- runs-on: [windows-latest]
- pretty: 🪟 Windows
- output: bettercap.exe
- arch: [amd64, arm64]
- go: [1.24.x]
- exclude:
- - os:
- name: darwin
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ go-version: ['1.22.x']
+ include:
+ - os: ubuntu-latest
arch: amd64
- # Linux ARM64 images are not yet publicly available (https://github.com/actions/runner-images)
- - os:
- name: linux
+ target_os: linux
+ target_arch: amd64
+ - os: ubuntu-latest
arch: arm64
- - os:
- name: windows
+ target_os: linux
+ target_arch: aarch64
+ - os: macos-latest
arch: arm64
+ target_os: darwin
+ target_arch: arm64
+ - os: windows-latest
+ arch: amd64
+ target_os: windows
+ target_arch: amd64
+ output: bettercap.exe
env:
- OUTPUT: ${{ matrix.os.output || 'bettercap' }}
+ TARGET_OS: ${{ matrix.target_os }}
+ TARGET_ARCH: ${{ matrix.target_arch }}
+ GO_VERSION: ${{ matrix.go-version }}
+ OUTPUT: ${{ matrix.output || 'bettercap' }}
steps:
- name: Checkout Code
- uses: actions/checkout@v4
+ uses: actions/checkout@v2
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v2
with:
- go-version: ${{ matrix.go }}
+ go-version: ${{ matrix.go-version }}
- name: Install Dependencies
- if: ${{ matrix.os.name == 'linux' }}
+ if: ${{ matrix.os == 'ubuntu-latest' }}
run: sudo apt-get update && sudo apt-get install -y p7zip-full libpcap-dev libnetfilter-queue-dev libusb-1.0-0-dev
- name: Install Dependencies (macOS)
- if: ${{ matrix.os.name == 'macos' }}
+ if: ${{ matrix.os == 'macos-latest' }}
run: brew install libpcap libusb p7zip
+
- name: Install libusb via mingw (Windows)
- if: ${{ matrix.os.name == 'windows' }}
+ if: ${{ matrix.os == 'windows-latest' }}
uses: msys2/setup-msys2@v2
with:
install: |-
@@ -66,7 +65,7 @@ jobs:
mingw64/mingw-w64-x86_64-pkg-config
- name: Install other Dependencies (Windows)
- if: ${{ matrix.os.name == 'windows' }}
+ if: ${{ matrix.os == 'windows-latest' }}
run: |
choco install openssl.light -y
choco install make -y
@@ -82,36 +81,25 @@ jobs:
- name: Verify Build
run: |
file "${{ env.OUTPUT }}"
- openssl dgst -sha256 "${{ env.OUTPUT }}" | tee bettercap_${{ matrix.os.name }}_${{ matrix.arch }}.sha256
- 7z a "bettercap_${{ matrix.os.name }}_${{ matrix.arch }}.zip" "${{ env.OUTPUT }}" "bettercap_${{ matrix.os.name }}_${{ matrix.arch }}.sha256"
-
- - name: Upload Artifacts
- uses: actions/upload-artifact@v4
- with:
- name: release-artifacts-${{ matrix.os.name }}-${{ matrix.arch }}
- path: |
- bettercap_*.zip
- bettercap_*.sha256
+ openssl dgst -sha256 "${{ env.OUTPUT }}" | tee bettercap_${{ matrix.target_os }}_${{ matrix.target_arch }}_${{ env.VERSION }}.sha256
+ 7z a "bettercap_${{ matrix.target_os }}_${{ matrix.target_arch }}_${{ env.VERSION }}.zip" "${{ env.OUTPUT }}" "bettercap_${{ matrix.target_os }}_${{ matrix.target_arch }}_${{ env.VERSION }}.sha256"
deploy:
needs: [build]
+ if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
name: Release
runs-on: ubuntu-latest
steps:
- - name: Download Artifacts
- uses: actions/download-artifact@v5
+ - name: Checkout Code
+ uses: actions/checkout@v2
with:
- pattern: release-artifacts-*
- merge-multiple: true
- path: dist/
-
- - name: Release Assets
- run: ls -l dist
+ submodules: true
- name: Upload Release Assets
- uses: softprops/action-gh-release@v2
- if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
+ uses: softprops/action-gh-release@v1
with:
- files: dist/bettercap_*
+ files: |
+ bettercap_*.zip
+ bettercap_*.sha256
env:
- GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
+ GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
\ No newline at end of file
diff --git a/.github/workflows/build-and-push-docker.yml b/.github/workflows/build-and-push-docker.yml
index c9ad06f1..c6ef89c2 100644
--- a/.github/workflows/build-and-push-docker.yml
+++ b/.github/workflows/build-and-push-docker.yml
@@ -23,7 +23,7 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
- uses: docker/build-push-action@v6
+ uses: docker/build-push-action@v5
with:
platforms: linux/amd64,linux/arm64
push: true
diff --git a/.github/workflows/test-on-linux.yml b/.github/workflows/test-on-linux.yml
index e920f281..665c1bd4 100644
--- a/.github/workflows/test-on-linux.yml
+++ b/.github/workflows/test-on-linux.yml
@@ -13,14 +13,14 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
- go-version: ['1.24.x']
+ go-version: ['1.22.x']
steps:
- name: Checkout Code
- uses: actions/checkout@v4
+ uses: actions/checkout@v2
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
diff --git a/.github/workflows/test-on-macos.yml b/.github/workflows/test-on-macos.yml
index b48c57cd..278689ef 100644
--- a/.github/workflows/test-on-macos.yml
+++ b/.github/workflows/test-on-macos.yml
@@ -13,14 +13,14 @@ jobs:
strategy:
matrix:
os: [macos-latest]
- go-version: ['1.24.x']
+ go-version: ['1.22.x']
steps:
- name: Checkout Code
- uses: actions/checkout@v4
+ uses: actions/checkout@v2
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
diff --git a/.github/workflows/test-on-windows.yml b/.github/workflows/test-on-windows.yml
index b5e6a6e2..08ea79da 100644
--- a/.github/workflows/test-on-windows.yml
+++ b/.github/workflows/test-on-windows.yml
@@ -13,14 +13,14 @@ jobs:
strategy:
matrix:
os: [windows-latest]
- go-version: ['1.24.x']
+ go-version: ['1.22.x']
steps:
- name: Checkout Code
- uses: actions/checkout@v4
+ uses: actions/checkout@v2
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
diff --git a/Dockerfile b/Dockerfile
index 362ff471..414cc8c4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# build stage
-FROM golang:1.24-alpine AS build-env
+FROM golang:1.22-alpine3.20 AS build-env
RUN apk add --no-cache ca-certificates
RUN apk add --no-cache bash gcc g++ binutils-gold iptables wireless-tools build-base libpcap-dev libusb-dev linux-headers libnetfilter_queue-dev git
@@ -13,9 +13,9 @@ RUN mkdir -p /usr/local/share/bettercap
RUN git clone https://github.com/bettercap/caplets /usr/local/share/bettercap/caplets
# final stage
-FROM alpine
+FROM alpine:3.20
RUN apk add --no-cache ca-certificates
-RUN apk add --no-cache bash iproute2 libpcap libusb-dev libnetfilter_queue wireless-tools iw
+RUN apk add --no-cache bash iproute2 libpcap libusb-dev libnetfilter_queue wireless-tools
COPY --from=build-env /go/src/github.com/bettercap/bettercap/bettercap /app/
COPY --from=build-env /usr/local/share/bettercap/caplets /app/
WORKDIR /app
diff --git a/.github/ISSUE_TEMPLATE/default_issue.md b/ISSUE_TEMPLATE.md
similarity index 94%
rename from .github/ISSUE_TEMPLATE/default_issue.md
rename to ISSUE_TEMPLATE.md
index 8fc3c85c..5c23a58c 100644
--- a/.github/ISSUE_TEMPLATE/default_issue.md
+++ b/ISSUE_TEMPLATE.md
@@ -1,8 +1,3 @@
----
-name: General Issue
-about: Write a general issue or bug report.
----
-
# Prerequisites
Please, before creating this issue make sure that you read the [README](https://github.com/bettercap/bettercap/blob/master/README.md), that you are running the [latest stable version](https://github.com/bettercap/bettercap/releases) and that you already searched [other issues](https://github.com/bettercap/bettercap/issues?q=is%3Aopen+is%3Aissue+label%3Abug) to see if your problem or request was already reported.
diff --git a/Makefile b/Makefile
index 3ec8e6cc..65a2e917 100644
--- a/Makefile
+++ b/Makefile
@@ -6,10 +6,10 @@ GO ?= go
all: build
build: resources
- $(GO) build $(GOFLAGS) -o $(TARGET) .
+ $(GOFLAGS) $(GO) build -o $(TARGET) .
build_with_race_detector: resources
- $(GO) build $(GOFLAGS) -race -o $(TARGET) .
+ $(GOFLAGS) $(GO) build -race -o $(TARGET) .
resources: network/manuf.go
@@ -24,13 +24,13 @@ docker:
@docker build -t bettercap:latest .
test:
- $(GO) test -covermode=atomic -coverprofile=cover.out ./...
+ $(GOFLAGS) $(GO) test -covermode=atomic -coverprofile=cover.out ./...
html_coverage: test
- $(GO) tool cover -html=cover.out -o cover.out.html
+ $(GOFLAGS) $(GO) tool cover -html=cover.out -o cover.out.html
benchmark: server_deps
- $(GO) test -v -run=doNotRunTests -bench=. -benchmem ./...
+ $(GOFLAGS) $(GO) test -v -run=doNotRunTests -bench=. -benchmem ./...
fmt:
$(GO) fmt -s -w $(PACKAGES)
diff --git a/README.md b/README.md
index 299e1d78..4a27f1cd 100644
--- a/README.md
+++ b/README.md
@@ -38,15 +38,9 @@ bettercap is a powerful, easily extensible and portable framework written in Go
* **A very convenient [web UI](https://www.bettercap.org/usage/#web-ui).**
* [More!](https://www.bettercap.org/modules/)
-## Contributors
-
-
-
-
-
## License
-`bettercap` is made with ♥ and released under the GPL 3 license.
+`bettercap` is made with ♥ by [the dev team](https://github.com/orgs/bettercap/people) and it's released under the GPL 3 license.
## Stargazers over time
diff --git a/caplets/caplet_test.go b/caplets/caplet_test.go
deleted file mode 100644
index dee5d9ff..00000000
--- a/caplets/caplet_test.go
+++ /dev/null
@@ -1,378 +0,0 @@
-package caplets
-
-import (
- "errors"
- "io/ioutil"
- "os"
- "strings"
- "testing"
-)
-
-func TestNewCaplet(t *testing.T) {
- name := "test-caplet"
- path := "/path/to/caplet.cap"
- size := int64(1024)
-
- cap := NewCaplet(name, path, size)
-
- if cap.Name != name {
- t.Errorf("expected name %s, got %s", name, cap.Name)
- }
- if cap.Path != path {
- t.Errorf("expected path %s, got %s", path, cap.Path)
- }
- if cap.Size != size {
- t.Errorf("expected size %d, got %d", size, cap.Size)
- }
- if cap.Code == nil {
- t.Error("Code should not be nil")
- }
- if cap.Scripts == nil {
- t.Error("Scripts should not be nil")
- }
-}
-
-func TestCapletEval(t *testing.T) {
- tests := []struct {
- name string
- code []string
- argv []string
- wantLines []string
- wantErr bool
- }{
- {
- name: "empty code",
- code: []string{},
- argv: nil,
- wantLines: []string{},
- wantErr: false,
- },
- {
- name: "skip comments and empty lines",
- code: []string{
- "# this is a comment",
- "",
- "set test value",
- "# another comment",
- "set another value",
- },
- argv: nil,
- wantLines: []string{
- "set test value",
- "set another value",
- },
- wantErr: false,
- },
- {
- name: "variable substitution",
- code: []string{
- "set param $0",
- "set value $1",
- "run $0 $1 $2",
- },
- argv: []string{"arg0", "arg1", "arg2"},
- wantLines: []string{
- "set param arg0",
- "set value arg1",
- "run arg0 arg1 arg2",
- },
- wantErr: false,
- },
- {
- name: "multiple occurrences of same variable",
- code: []string{
- "$0 $0 $1 $0",
- },
- argv: []string{"foo", "bar"},
- wantLines: []string{
- "foo foo bar foo",
- },
- wantErr: false,
- },
- {
- name: "missing argv values",
- code: []string{
- "set $0 $1 $2",
- },
- argv: []string{"only_one"},
- wantLines: []string{
- "set only_one $1 $2",
- },
- wantErr: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- tempFile, err := ioutil.TempFile("", "test-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
- tempFile.Close()
-
- cap := NewCaplet("test", tempFile.Name(), 100)
- cap.Code = tt.code
-
- var gotLines []string
- err = cap.Eval(tt.argv, func(line string) error {
- gotLines = append(gotLines, line)
- return nil
- })
-
- if (err != nil) != tt.wantErr {
- t.Errorf("Eval() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
-
- if len(gotLines) != len(tt.wantLines) {
- t.Errorf("got %d lines, want %d", len(gotLines), len(tt.wantLines))
- return
- }
-
- for i, line := range gotLines {
- if line != tt.wantLines[i] {
- t.Errorf("line %d: got %q, want %q", i, line, tt.wantLines[i])
- }
- }
- })
- }
-}
-
-func TestCapletEvalError(t *testing.T) {
- tempFile, err := ioutil.TempFile("", "test-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
- tempFile.Close()
-
- cap := NewCaplet("test", tempFile.Name(), 100)
- cap.Code = []string{
- "first line",
- "error line",
- "third line",
- }
-
- expectedErr := errors.New("test error")
- var executedLines []string
-
- err = cap.Eval(nil, func(line string) error {
- executedLines = append(executedLines, line)
- if line == "error line" {
- return expectedErr
- }
- return nil
- })
-
- if err != expectedErr {
- t.Errorf("expected error %v, got %v", expectedErr, err)
- }
-
- // Should have executed first two lines before error
- if len(executedLines) != 2 {
- t.Errorf("expected 2 executed lines, got %d", len(executedLines))
- }
-}
-
-func TestCapletEvalWithChdirPath(t *testing.T) {
- // Create a temporary caplet file to test with
- tempFile, err := ioutil.TempFile("", "test-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
- tempFile.Close()
-
- cap := NewCaplet("test", tempFile.Name(), 100)
- cap.Code = []string{"test command"}
-
- executed := false
- err = cap.Eval(nil, func(line string) error {
- executed = true
- return nil
- })
-
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if !executed {
- t.Error("callback was not executed")
- }
-}
-
-func TestNewScript(t *testing.T) {
- path := "/path/to/script.js"
- size := int64(2048)
-
- script := newScript(path, size)
-
- if script.Path != path {
- t.Errorf("expected path %s, got %s", path, script.Path)
- }
- if script.Size != size {
- t.Errorf("expected size %d, got %d", size, script.Size)
- }
- if script.Code == nil {
- t.Error("Code should not be nil")
- }
- if len(script.Code) != 0 {
- t.Errorf("expected empty Code slice, got %d elements", len(script.Code))
- }
-}
-
-func TestCapletEvalCommentAtStartOfLine(t *testing.T) {
- tempFile, err := ioutil.TempFile("", "test-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
- tempFile.Close()
-
- cap := NewCaplet("test", tempFile.Name(), 100)
- cap.Code = []string{
- "# comment",
- " # not a comment (has space before #)",
- " # not a comment (has tab before #)",
- "command # inline comment",
- }
-
- var gotLines []string
- err = cap.Eval(nil, func(line string) error {
- gotLines = append(gotLines, line)
- return nil
- })
-
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- expectedLines := []string{
- " # not a comment (has space before #)",
- " # not a comment (has tab before #)",
- "command # inline comment",
- }
-
- if len(gotLines) != len(expectedLines) {
- t.Errorf("got %d lines, want %d", len(gotLines), len(expectedLines))
- return
- }
-
- for i, line := range gotLines {
- if line != expectedLines[i] {
- t.Errorf("line %d: got %q, want %q", i, line, expectedLines[i])
- }
- }
-}
-
-func TestCapletEvalArgvSubstitutionEdgeCases(t *testing.T) {
- tests := []struct {
- name string
- code string
- argv []string
- wantLine string
- }{
- {
- name: "double digit substitution $10",
- code: "$1$0",
- argv: []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"},
- wantLine: "10",
- },
- {
- name: "no space between variables",
- code: "$0$1$2",
- argv: []string{"a", "b", "c"},
- wantLine: "abc",
- },
- {
- name: "variables in quotes",
- code: `"$0" '$1'`,
- argv: []string{"foo", "bar"},
- wantLine: `"foo" 'bar'`,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- tempFile, err := ioutil.TempFile("", "test-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
- tempFile.Close()
-
- cap := NewCaplet("test", tempFile.Name(), 100)
- cap.Code = []string{tt.code}
-
- var gotLine string
- err = cap.Eval(tt.argv, func(line string) error {
- gotLine = line
- return nil
- })
-
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- if gotLine != tt.wantLine {
- t.Errorf("got line %q, want %q", gotLine, tt.wantLine)
- }
- })
- }
-}
-
-func TestCapletStructFields(t *testing.T) {
- // Test that Caplet properly embeds Script
- tempFile, err := ioutil.TempFile("", "test-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
- tempFile.Close()
-
- cap := NewCaplet("test", tempFile.Name(), 100)
-
- // These fields should be accessible due to embedding
- _ = cap.Path
- _ = cap.Size
- _ = cap.Code
-
- // And these are Caplet's own fields
- _ = cap.Name
- _ = cap.Scripts
-}
-
-func BenchmarkCapletEval(b *testing.B) {
- cap := NewCaplet("bench", "/tmp/bench.cap", 100)
- cap.Code = []string{
- "set param1 $0",
- "set param2 $1",
- "# comment line",
- "",
- "run command $0 $1 $2",
- "another command",
- }
- argv := []string{"arg0", "arg1", "arg2"}
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = cap.Eval(argv, func(line string) error {
- // Do nothing, just measure evaluation overhead
- return nil
- })
- }
-}
-
-func BenchmarkVariableSubstitution(b *testing.B) {
- line := "command $0 $1 $2 $3 $4 $5 $6 $7 $8 $9"
- argv := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- result := line
- for j, arg := range argv {
- what := "$" + string(rune('0'+j))
- result = strings.Replace(result, what, arg, -1)
- }
- }
-}
diff --git a/caplets/env_test.go b/caplets/env_test.go
deleted file mode 100644
index c1087216..00000000
--- a/caplets/env_test.go
+++ /dev/null
@@ -1,308 +0,0 @@
-package caplets
-
-import (
- "os"
- "path/filepath"
- "runtime"
- "strings"
- "testing"
-)
-
-func TestGetDefaultInstallBase(t *testing.T) {
- base := getDefaultInstallBase()
-
- if runtime.GOOS == "windows" {
- expected := filepath.Join(os.Getenv("ALLUSERSPROFILE"), "bettercap")
- if base != expected {
- t.Errorf("on windows, expected %s, got %s", expected, base)
- }
- } else {
- expected := "/usr/local/share/bettercap/"
- if base != expected {
- t.Errorf("on non-windows, expected %s, got %s", expected, base)
- }
- }
-}
-
-func TestGetUserHomeDir(t *testing.T) {
- home := getUserHomeDir()
-
- // Should return a non-empty string
- if home == "" {
- t.Error("getUserHomeDir returned empty string")
- }
-
- // Should be an absolute path
- if !filepath.IsAbs(home) {
- t.Errorf("expected absolute path, got %s", home)
- }
-}
-
-func TestSetup(t *testing.T) {
- // Save original values
- origInstallBase := InstallBase
- origInstallPathArchive := InstallPathArchive
- origInstallPath := InstallPath
- origArchivePath := ArchivePath
- origLoadPaths := LoadPaths
-
- // Test with custom base
- testBase := "/custom/base"
- err := Setup(testBase)
-
- if err != nil {
- t.Errorf("Setup returned error: %v", err)
- }
-
- // Check that paths are set correctly
- if InstallBase != testBase {
- t.Errorf("expected InstallBase %s, got %s", testBase, InstallBase)
- }
-
- expectedArchivePath := filepath.Join(testBase, "caplets-master")
- if InstallPathArchive != expectedArchivePath {
- t.Errorf("expected InstallPathArchive %s, got %s", expectedArchivePath, InstallPathArchive)
- }
-
- expectedInstallPath := filepath.Join(testBase, "caplets")
- if InstallPath != expectedInstallPath {
- t.Errorf("expected InstallPath %s, got %s", expectedInstallPath, InstallPath)
- }
-
- expectedTempPath := filepath.Join(os.TempDir(), "caplets.zip")
- if ArchivePath != expectedTempPath {
- t.Errorf("expected ArchivePath %s, got %s", expectedTempPath, ArchivePath)
- }
-
- // Check LoadPaths contains expected paths
- expectedInLoadPaths := []string{
- "./",
- "./caplets/",
- InstallPath,
- filepath.Join(getUserHomeDir(), "caplets"),
- }
-
- for _, expected := range expectedInLoadPaths {
- absExpected, _ := filepath.Abs(expected)
- found := false
- for _, path := range LoadPaths {
- if path == absExpected {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("expected path %s not found in LoadPaths", absExpected)
- }
- }
-
- // All paths should be absolute
- for _, path := range LoadPaths {
- if !filepath.IsAbs(path) {
- t.Errorf("LoadPath %s is not absolute", path)
- }
- }
-
- // Restore original values
- InstallBase = origInstallBase
- InstallPathArchive = origInstallPathArchive
- InstallPath = origInstallPath
- ArchivePath = origArchivePath
- LoadPaths = origLoadPaths
-}
-
-func TestSetupWithEnvironmentVariable(t *testing.T) {
- // Save original values
- origEnv := os.Getenv(EnvVarName)
- origLoadPaths := LoadPaths
-
- // Set environment variable with multiple paths
- testPaths := []string{"/path1", "/path2", "/path3"}
- os.Setenv(EnvVarName, strings.Join(testPaths, string(os.PathListSeparator)))
-
- // Run setup
- err := Setup("/test/base")
- if err != nil {
- t.Errorf("Setup returned error: %v", err)
- }
-
- // Check that custom paths from env var are in LoadPaths
- for _, testPath := range testPaths {
- absTestPath, _ := filepath.Abs(testPath)
- found := false
- for _, path := range LoadPaths {
- if path == absTestPath {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("expected env path %s not found in LoadPaths", absTestPath)
- }
- }
-
- // Restore original values
- if origEnv == "" {
- os.Unsetenv(EnvVarName)
- } else {
- os.Setenv(EnvVarName, origEnv)
- }
- LoadPaths = origLoadPaths
-}
-
-func TestSetupWithEmptyEnvironmentVariable(t *testing.T) {
- // Save original values
- origEnv := os.Getenv(EnvVarName)
- origLoadPaths := LoadPaths
-
- // Set empty environment variable
- os.Setenv(EnvVarName, "")
-
- // Count LoadPaths before setup
- err := Setup("/test/base")
- if err != nil {
- t.Errorf("Setup returned error: %v", err)
- }
-
- // Should have only the default paths (4)
- if len(LoadPaths) != 4 {
- t.Errorf("expected 4 default LoadPaths, got %d", len(LoadPaths))
- }
-
- // Restore original values
- if origEnv == "" {
- os.Unsetenv(EnvVarName)
- } else {
- os.Setenv(EnvVarName, origEnv)
- }
- LoadPaths = origLoadPaths
-}
-
-func TestSetupWithWhitespaceInEnvironmentVariable(t *testing.T) {
- // Save original values
- origEnv := os.Getenv(EnvVarName)
- origLoadPaths := LoadPaths
-
- // Set environment variable with whitespace
- testPaths := []string{" /path1 ", " ", "/path2 "}
- os.Setenv(EnvVarName, strings.Join(testPaths, string(os.PathListSeparator)))
-
- // Run setup
- err := Setup("/test/base")
- if err != nil {
- t.Errorf("Setup returned error: %v", err)
- }
-
- // Should have added only non-empty paths after trimming
- expectedPaths := []string{"/path1", "/path2"}
- foundCount := 0
- for _, expectedPath := range expectedPaths {
- absExpected, _ := filepath.Abs(expectedPath)
- for _, path := range LoadPaths {
- if path == absExpected {
- foundCount++
- break
- }
- }
- }
-
- if foundCount != len(expectedPaths) {
- t.Errorf("expected to find %d paths from env, found %d", len(expectedPaths), foundCount)
- }
-
- // Restore original values
- if origEnv == "" {
- os.Unsetenv(EnvVarName)
- } else {
- os.Setenv(EnvVarName, origEnv)
- }
- LoadPaths = origLoadPaths
-}
-
-func TestConstants(t *testing.T) {
- // Test that constants have expected values
- if EnvVarName != "CAPSPATH" {
- t.Errorf("expected EnvVarName to be 'CAPSPATH', got %s", EnvVarName)
- }
-
- if Suffix != ".cap" {
- t.Errorf("expected Suffix to be '.cap', got %s", Suffix)
- }
-
- if InstallArchive != "https://github.com/bettercap/caplets/archive/master.zip" {
- t.Errorf("unexpected InstallArchive value: %s", InstallArchive)
- }
-}
-
-func TestInit(t *testing.T) {
- // The init function should have been called already
- // Check that paths are initialized
- if InstallBase == "" {
- t.Error("InstallBase not initialized")
- }
-
- if InstallPath == "" {
- t.Error("InstallPath not initialized")
- }
-
- if InstallPathArchive == "" {
- t.Error("InstallPathArchive not initialized")
- }
-
- if ArchivePath == "" {
- t.Error("ArchivePath not initialized")
- }
-
- if LoadPaths == nil || len(LoadPaths) == 0 {
- t.Error("LoadPaths not initialized")
- }
-}
-
-func TestSetupMultipleTimes(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
-
- // Setup multiple times with different bases
- bases := []string{"/base1", "/base2", "/base3"}
-
- for _, base := range bases {
- err := Setup(base)
- if err != nil {
- t.Errorf("Setup(%s) returned error: %v", base, err)
- }
-
- // Check that InstallBase is updated
- if InstallBase != base {
- t.Errorf("expected InstallBase %s, got %s", base, InstallBase)
- }
-
- // LoadPaths should be recreated each time
- if len(LoadPaths) < 4 {
- t.Errorf("LoadPaths should have at least 4 entries, got %d", len(LoadPaths))
- }
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
-}
-
-func BenchmarkSetup(b *testing.B) {
- // Save original values
- origEnv := os.Getenv(EnvVarName)
-
- // Set a complex environment
- paths := []string{"/p1", "/p2", "/p3", "/p4", "/p5"}
- os.Setenv(EnvVarName, strings.Join(paths, string(os.PathListSeparator)))
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- Setup("/benchmark/base")
- }
-
- // Restore
- if origEnv == "" {
- os.Unsetenv(EnvVarName)
- } else {
- os.Setenv(EnvVarName, origEnv)
- }
-}
diff --git a/caplets/manager_test.go b/caplets/manager_test.go
deleted file mode 100644
index 0392a12b..00000000
--- a/caplets/manager_test.go
+++ /dev/null
@@ -1,511 +0,0 @@
-package caplets
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path/filepath"
- "sort"
- "strings"
- "sync"
- "testing"
-)
-
-func createTestCaplet(t testing.TB, dir string, name string, content []string) string {
- filename := filepath.Join(dir, name)
- data := strings.Join(content, "\n")
- err := ioutil.WriteFile(filename, []byte(data), 0644)
- if err != nil {
- t.Fatalf("failed to create test caplet: %v", err)
- }
- return filename
-}
-
-func TestList(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directories
- tempDir, err := ioutil.TempDir("", "caplets-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- // Create subdirectories
- dir1 := filepath.Join(tempDir, "dir1")
- dir2 := filepath.Join(tempDir, "dir2")
- subdir := filepath.Join(dir1, "subdir")
-
- os.Mkdir(dir1, 0755)
- os.Mkdir(dir2, 0755)
- os.Mkdir(subdir, 0755)
-
- // Create test caplets
- createTestCaplet(t, dir1, "test1.cap", []string{"# Test caplet 1", "set test 1"})
- createTestCaplet(t, dir1, "test2.cap", []string{"# Test caplet 2", "set test 2"})
- createTestCaplet(t, dir2, "test3.cap", []string{"# Test caplet 3", "set test 3"})
- createTestCaplet(t, subdir, "nested.cap", []string{"# Nested caplet", "set nested test"})
-
- // Also create a non-caplet file
- ioutil.WriteFile(filepath.Join(dir1, "notacaplet.txt"), []byte("not a caplet"), 0644)
-
- // Set LoadPaths
- LoadPaths = []string{dir1, dir2}
-
- // Call List()
- caplets := List()
-
- // Check results
- if len(caplets) != 4 {
- t.Errorf("expected 4 caplets, got %d", len(caplets))
- }
-
- // Check names (should be sorted)
- expectedNames := []string{filepath.Join("subdir", "nested"), "test1", "test2", "test3"}
- sort.Strings(expectedNames)
-
- gotNames := make([]string, len(caplets))
- for i, cap := range caplets {
- gotNames[i] = cap.Name
- }
-
- for i, expected := range expectedNames {
- if i >= len(gotNames) || gotNames[i] != expected {
- t.Errorf("expected caplet %d to be %s, got %s", i, expected, gotNames[i])
- }
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func TestListEmptyDirectories(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directory
- tempDir, err := ioutil.TempDir("", "caplets-empty-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- // Set LoadPaths to empty directory
- LoadPaths = []string{tempDir}
-
- // Call List()
- caplets := List()
-
- // Should return empty list
- if len(caplets) != 0 {
- t.Errorf("expected 0 caplets, got %d", len(caplets))
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func TestLoad(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directory
- tempDir, err := ioutil.TempDir("", "caplets-load-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- // Create test caplet
- capletContent := []string{
- "# Test caplet",
- "set param value",
- "",
- "# Another comment",
- "run command",
- }
- createTestCaplet(t, tempDir, "test.cap", capletContent)
-
- // Set LoadPaths
- LoadPaths = []string{tempDir}
-
- // Test loading without .cap extension
- cap, err := Load("test")
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if cap == nil {
- t.Error("caplet is nil")
- } else {
- if cap.Name != "test" {
- t.Errorf("expected name 'test', got %s", cap.Name)
- }
- if len(cap.Code) != len(capletContent) {
- t.Errorf("expected %d lines, got %d", len(capletContent), len(cap.Code))
- }
- }
-
- // Test loading from cache
- // Note: The Load function caches with the suffix, so we need to use the same name with suffix
- cap2, err := Load("test.cap")
- if err != nil {
- t.Errorf("unexpected error on cache hit: %v", err)
- }
- if cap2 == nil {
- t.Error("caplet is nil on cache hit")
- }
-
- // Test loading with .cap extension
- // Note: Load caches by the name parameter, so "test.cap" is a different cache key
- cap3, err := Load("test.cap")
- if err != nil {
- t.Errorf("unexpected error with .cap extension: %v", err)
- }
- if cap3 == nil {
- t.Error("caplet is nil")
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func TestLoadAbsolutePath(t *testing.T) {
- // Save original values
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp file
- tempFile, err := ioutil.TempFile("", "test-absolute-*.cap")
- if err != nil {
- t.Fatal(err)
- }
- defer os.Remove(tempFile.Name())
-
- // Write content
- content := "# Absolute path test\nset test absolute"
- tempFile.WriteString(content)
- tempFile.Close()
-
- // Load with absolute path
- cap, err := Load(tempFile.Name())
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if cap == nil {
- t.Error("caplet is nil")
- } else {
- if cap.Path != tempFile.Name() {
- t.Errorf("expected path %s, got %s", tempFile.Name(), cap.Path)
- }
- }
-
- // Restore original values
- cache = origCache
-}
-
-func TestLoadNotFound(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Set empty LoadPaths
- LoadPaths = []string{}
-
- // Try to load non-existent caplet
- cap, err := Load("nonexistent")
- if err == nil {
- t.Error("expected error for non-existent caplet")
- }
- if cap != nil {
- t.Error("expected nil caplet for non-existent file")
- }
- if !strings.Contains(err.Error(), "not found") {
- t.Errorf("expected 'not found' error, got: %v", err)
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func TestLoadWithFolder(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directory structure
- tempDir, err := ioutil.TempDir("", "caplets-folder-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- // Create a caplet folder
- capletDir := filepath.Join(tempDir, "mycaplet")
- os.Mkdir(capletDir, 0755)
-
- // Create main caplet file
- mainContent := []string{"# Main caplet", "set main test"}
- createTestCaplet(t, capletDir, "mycaplet.cap", mainContent)
-
- // Create additional files
- jsContent := []string{"// JavaScript file", "console.log('test');"}
- createTestCaplet(t, capletDir, "script.js", jsContent)
-
- capContent := []string{"# Sub caplet", "set sub test"}
- createTestCaplet(t, capletDir, "sub.cap", capContent)
-
- // Create a file that should be ignored
- ioutil.WriteFile(filepath.Join(capletDir, "readme.txt"), []byte("readme"), 0644)
-
- // Set LoadPaths
- LoadPaths = []string{tempDir}
-
- // Load the caplet
- cap, err := Load("mycaplet/mycaplet")
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if cap == nil {
- t.Fatal("caplet is nil")
- }
-
- // Check main caplet
- if cap.Name != "mycaplet/mycaplet" {
- t.Errorf("expected name 'mycaplet/mycaplet', got %s", cap.Name)
- }
- if len(cap.Code) != len(mainContent) {
- t.Errorf("expected %d lines in main, got %d", len(mainContent), len(cap.Code))
- }
-
- // Check additional scripts
- if len(cap.Scripts) != 2 {
- t.Errorf("expected 2 additional scripts, got %d", len(cap.Scripts))
- }
-
- // Find and check the .js file
- foundJS := false
- foundCap := false
- for _, script := range cap.Scripts {
- if strings.HasSuffix(script.Path, "script.js") {
- foundJS = true
- if len(script.Code) != len(jsContent) {
- t.Errorf("expected %d lines in JS, got %d", len(jsContent), len(script.Code))
- }
- }
- if strings.HasSuffix(script.Path, "sub.cap") {
- foundCap = true
- if len(script.Code) != len(capContent) {
- t.Errorf("expected %d lines in sub.cap, got %d", len(capContent), len(script.Code))
- }
- }
- }
-
- if !foundJS {
- t.Error("script.js not found in Scripts")
- }
- if !foundCap {
- t.Error("sub.cap not found in Scripts")
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func TestCacheConcurrency(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directory
- tempDir, err := ioutil.TempDir("", "caplets-concurrent-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- // Create test caplets
- for i := 0; i < 5; i++ {
- name := fmt.Sprintf("test%d.cap", i)
- content := []string{fmt.Sprintf("# Test %d", i)}
- createTestCaplet(t, tempDir, name, content)
- }
-
- // Set LoadPaths
- LoadPaths = []string{tempDir}
-
- // Run concurrent loads
- var wg sync.WaitGroup
- errors := make(chan error, 50)
-
- for i := 0; i < 10; i++ {
- for j := 0; j < 5; j++ {
- wg.Add(1)
- go func(idx int) {
- defer wg.Done()
- name := fmt.Sprintf("test%d", idx)
- _, err := Load(name)
- if err != nil {
- errors <- err
- }
- }(j)
- }
- }
-
- wg.Wait()
- close(errors)
-
- // Check for errors
- for err := range errors {
- t.Errorf("concurrent load error: %v", err)
- }
-
- // Verify cache has all entries
- if len(cache) != 5 {
- t.Errorf("expected 5 cached entries, got %d", len(cache))
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func TestLoadPathPriority(t *testing.T) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directories
- tempDir1, _ := ioutil.TempDir("", "caplets-priority1-")
- tempDir2, _ := ioutil.TempDir("", "caplets-priority2-")
- defer os.RemoveAll(tempDir1)
- defer os.RemoveAll(tempDir2)
-
- // Create same-named caplet in both directories
- createTestCaplet(t, tempDir1, "test.cap", []string{"# From dir1"})
- createTestCaplet(t, tempDir2, "test.cap", []string{"# From dir2"})
-
- // Set LoadPaths with tempDir1 first
- LoadPaths = []string{tempDir1, tempDir2}
-
- // Load caplet
- cap, err := Load("test")
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- // Should load from first directory
- if cap != nil && len(cap.Code) > 0 {
- if cap.Code[0] != "# From dir1" {
- t.Error("caplet not loaded from first directory in LoadPaths")
- }
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func BenchmarkLoad(b *testing.B) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
-
- // Create temp directory
- tempDir, _ := ioutil.TempDir("", "caplets-bench-")
- defer os.RemoveAll(tempDir)
-
- // Create test caplet
- content := make([]string, 100)
- for i := range content {
- content[i] = fmt.Sprintf("command %d", i)
- }
- createTestCaplet(b, tempDir, "bench.cap", content)
-
- // Set LoadPaths
- LoadPaths = []string{tempDir}
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- // Clear cache to measure loading time
- cache = make(map[string]*Caplet)
- Load("bench")
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func BenchmarkLoadFromCache(b *testing.B) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
- cache = make(map[string]*Caplet)
-
- // Create temp directory
- tempDir, _ := ioutil.TempDir("", "caplets-bench-cache-")
- defer os.RemoveAll(tempDir)
-
- // Create test caplet
- createTestCaplet(b, tempDir, "bench.cap", []string{"# Benchmark"})
-
- // Set LoadPaths
- LoadPaths = []string{tempDir}
-
- // Pre-load into cache
- Load("bench")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- Load("bench")
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
-
-func BenchmarkList(b *testing.B) {
- // Save original values
- origLoadPaths := LoadPaths
- origCache := cache
-
- // Create temp directory
- tempDir, _ := ioutil.TempDir("", "caplets-bench-list-")
- defer os.RemoveAll(tempDir)
-
- // Create multiple caplets
- for i := 0; i < 20; i++ {
- name := fmt.Sprintf("test%d.cap", i)
- createTestCaplet(b, tempDir, name, []string{fmt.Sprintf("# Test %d", i)})
- }
-
- // Set LoadPaths
- LoadPaths = []string{tempDir}
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- cache = make(map[string]*Caplet)
- List()
- }
-
- // Restore original values
- LoadPaths = origLoadPaths
- cache = origCache
-}
diff --git a/core/banner.go b/core/banner.go
index 1a63f0c8..1df1aafa 100644
--- a/core/banner.go
+++ b/core/banner.go
@@ -2,7 +2,7 @@ package core
const (
Name = "bettercap"
- Version = "2.41.4"
+ Version = "2.41.0"
Author = "Simone 'evilsocket' Margaritelli"
Website = "https://bettercap.org/"
)
diff --git a/core/core_test.go b/core/core_test.go
index 057e5b21..2dc77c49 100644
--- a/core/core_test.go
+++ b/core/core_test.go
@@ -97,144 +97,3 @@ func TestCoreExists(t *testing.T) {
}
}
}
-
-func TestHasBinary(t *testing.T) {
- tests := []struct {
- name string
- executable string
- expected bool
- }{
- {
- name: "common shell",
- executable: "sh",
- expected: true,
- },
- {
- name: "echo command",
- executable: "echo",
- expected: true,
- },
- {
- name: "non-existent binary",
- executable: "this-binary-definitely-does-not-exist-12345",
- expected: false,
- },
- {
- name: "empty string",
- executable: "",
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := HasBinary(tt.executable)
- if got != tt.expected {
- t.Errorf("HasBinary(%q) = %v, want %v", tt.executable, got, tt.expected)
- }
- })
- }
-}
-
-func TestExec(t *testing.T) {
- tests := []struct {
- name string
- executable string
- args []string
- wantError bool
- contains string
- }{
- {
- name: "echo with args",
- executable: "echo",
- args: []string{"hello", "world"},
- wantError: false,
- contains: "hello world",
- },
- {
- name: "echo empty",
- executable: "echo",
- args: []string{},
- wantError: false,
- contains: "",
- },
- {
- name: "non-existent command",
- executable: "this-command-does-not-exist-12345",
- args: []string{},
- wantError: true,
- contains: "",
- },
- {
- name: "true command",
- executable: "true",
- args: []string{},
- wantError: false,
- contains: "",
- },
- {
- name: "false command",
- executable: "false",
- args: []string{},
- wantError: true,
- contains: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Skip platform-specific commands if not available
- if !HasBinary(tt.executable) && !tt.wantError {
- t.Skipf("%s not found in PATH", tt.executable)
- }
-
- output, err := Exec(tt.executable, tt.args)
-
- if tt.wantError {
- if err == nil {
- t.Errorf("Exec(%q, %v) expected error but got none", tt.executable, tt.args)
- }
- } else {
- if err != nil {
- t.Errorf("Exec(%q, %v) unexpected error: %v", tt.executable, tt.args, err)
- }
- if tt.contains != "" && output != tt.contains {
- t.Errorf("Exec(%q, %v) = %q, want %q", tt.executable, tt.args, output, tt.contains)
- }
- }
- })
- }
-}
-
-func TestExecWithOutput(t *testing.T) {
- // Test that Exec properly captures and trims output
- if HasBinary("printf") {
- output, err := Exec("printf", []string{" hello world \n"})
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
- if output != "hello world" {
- t.Errorf("expected trimmed output 'hello world', got %q", output)
- }
- }
-}
-
-func BenchmarkUniqueInts(b *testing.B) {
- // Create a slice with duplicates
- input := make([]int, 1000)
- for i := 0; i < 1000; i++ {
- input[i] = i % 100 // This creates 10 duplicates of each number 0-99
- }
-
- b.Run("unsorted", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = UniqueInts(input, false)
- }
- })
-
- b.Run("sorted", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = UniqueInts(input, true)
- }
- })
-}
diff --git a/firewall/redirection_test.go b/firewall/redirection_test.go
deleted file mode 100644
index 050590b2..00000000
--- a/firewall/redirection_test.go
+++ /dev/null
@@ -1,268 +0,0 @@
-package firewall
-
-import (
- "testing"
-)
-
-func TestNewRedirection(t *testing.T) {
- iface := "eth0"
- proto := "tcp"
- portFrom := 8080
- addrTo := "192.168.1.100"
- portTo := 9090
-
- r := NewRedirection(iface, proto, portFrom, addrTo, portTo)
-
- if r == nil {
- t.Fatal("NewRedirection returned nil")
- }
-
- if r.Interface != iface {
- t.Errorf("expected Interface %s, got %s", iface, r.Interface)
- }
-
- if r.Protocol != proto {
- t.Errorf("expected Protocol %s, got %s", proto, r.Protocol)
- }
-
- if r.SrcAddress != "" {
- t.Errorf("expected empty SrcAddress, got %s", r.SrcAddress)
- }
-
- if r.SrcPort != portFrom {
- t.Errorf("expected SrcPort %d, got %d", portFrom, r.SrcPort)
- }
-
- if r.DstAddress != addrTo {
- t.Errorf("expected DstAddress %s, got %s", addrTo, r.DstAddress)
- }
-
- if r.DstPort != portTo {
- t.Errorf("expected DstPort %d, got %d", portTo, r.DstPort)
- }
-}
-
-func TestRedirectionString(t *testing.T) {
- tests := []struct {
- name string
- r Redirection
- want string
- }{
- {
- name: "basic redirection",
- r: Redirection{
- Interface: "eth0",
- Protocol: "tcp",
- SrcAddress: "",
- SrcPort: 8080,
- DstAddress: "192.168.1.100",
- DstPort: 9090,
- },
- want: "[eth0] (tcp) :8080 -> 192.168.1.100:9090",
- },
- {
- name: "with source address",
- r: Redirection{
- Interface: "wlan0",
- Protocol: "udp",
- SrcAddress: "192.168.1.50",
- SrcPort: 53,
- DstAddress: "8.8.8.8",
- DstPort: 53,
- },
- want: "[wlan0] (udp) 192.168.1.50:53 -> 8.8.8.8:53",
- },
- {
- name: "localhost redirection",
- r: Redirection{
- Interface: "lo",
- Protocol: "tcp",
- SrcAddress: "127.0.0.1",
- SrcPort: 80,
- DstAddress: "127.0.0.1",
- DstPort: 8080,
- },
- want: "[lo] (tcp) 127.0.0.1:80 -> 127.0.0.1:8080",
- },
- {
- name: "high port numbers",
- r: Redirection{
- Interface: "eth1",
- Protocol: "tcp",
- SrcAddress: "",
- SrcPort: 65535,
- DstAddress: "10.0.0.1",
- DstPort: 65534,
- },
- want: "[eth1] (tcp) :65535 -> 10.0.0.1:65534",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.r.String()
- if got != tt.want {
- t.Errorf("String() = %q, want %q", got, tt.want)
- }
- })
- }
-}
-
-func TestNewRedirectionVariousProtocols(t *testing.T) {
- protocols := []string{"tcp", "udp", "icmp", "any"}
-
- for _, proto := range protocols {
- t.Run(proto, func(t *testing.T) {
- r := NewRedirection("eth0", proto, 1234, "10.0.0.1", 5678)
- if r.Protocol != proto {
- t.Errorf("expected protocol %s, got %s", proto, r.Protocol)
- }
- })
- }
-}
-
-func TestNewRedirectionVariousInterfaces(t *testing.T) {
- interfaces := []string{"eth0", "wlan0", "lo", "docker0", "br0", "tun0"}
-
- for _, iface := range interfaces {
- t.Run(iface, func(t *testing.T) {
- r := NewRedirection(iface, "tcp", 80, "192.168.1.1", 8080)
- if r.Interface != iface {
- t.Errorf("expected interface %s, got %s", iface, r.Interface)
- }
- })
- }
-}
-
-func TestRedirectionStringEmptyFields(t *testing.T) {
- tests := []struct {
- name string
- r Redirection
- want string
- }{
- {
- name: "empty interface",
- r: Redirection{
- Interface: "",
- Protocol: "tcp",
- SrcAddress: "",
- SrcPort: 80,
- DstAddress: "192.168.1.1",
- DstPort: 8080,
- },
- want: "[] (tcp) :80 -> 192.168.1.1:8080",
- },
- {
- name: "empty protocol",
- r: Redirection{
- Interface: "eth0",
- Protocol: "",
- SrcAddress: "",
- SrcPort: 80,
- DstAddress: "192.168.1.1",
- DstPort: 8080,
- },
- want: "[eth0] () :80 -> 192.168.1.1:8080",
- },
- {
- name: "empty destination",
- r: Redirection{
- Interface: "eth0",
- Protocol: "tcp",
- SrcAddress: "",
- SrcPort: 80,
- DstAddress: "",
- DstPort: 8080,
- },
- want: "[eth0] (tcp) :80 -> :8080",
- },
- {
- name: "all empty strings",
- r: Redirection{
- Interface: "",
- Protocol: "",
- SrcAddress: "",
- SrcPort: 0,
- DstAddress: "",
- DstPort: 0,
- },
- want: "[] () :0 -> :0",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.r.String()
- if got != tt.want {
- t.Errorf("String() = %q, want %q", got, tt.want)
- }
- })
- }
-}
-
-func TestRedirectionStructCopy(t *testing.T) {
- // Test that Redirection can be safely copied
- original := NewRedirection("eth0", "tcp", 80, "192.168.1.1", 8080)
- original.SrcAddress = "10.0.0.1"
-
- // Create a copy
- copy := *original
-
- // Modify the copy
- copy.Interface = "wlan0"
- copy.SrcPort = 443
-
- // Verify original is unchanged
- if original.Interface != "eth0" {
- t.Error("original Interface was modified")
- }
- if original.SrcPort != 80 {
- t.Error("original SrcPort was modified")
- }
-
- // Verify copy has new values
- if copy.Interface != "wlan0" {
- t.Error("copy Interface was not set correctly")
- }
- if copy.SrcPort != 443 {
- t.Error("copy SrcPort was not set correctly")
- }
-}
-
-func BenchmarkNewRedirection(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = NewRedirection("eth0", "tcp", 80, "192.168.1.1", 8080)
- }
-}
-
-func BenchmarkRedirectionString(b *testing.B) {
- r := Redirection{
- Interface: "eth0",
- Protocol: "tcp",
- SrcAddress: "192.168.1.50",
- SrcPort: 8080,
- DstAddress: "192.168.1.100",
- DstPort: 9090,
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = r.String()
- }
-}
-
-func BenchmarkRedirectionStringEmpty(b *testing.B) {
- r := Redirection{
- Interface: "eth0",
- Protocol: "tcp",
- SrcAddress: "",
- SrcPort: 8080,
- DstAddress: "192.168.1.100",
- DstPort: 9090,
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = r.String()
- }
-}
diff --git a/go.mod b/go.mod
index 0cbddafa..b1b2dfc3 100644
--- a/go.mod
+++ b/go.mod
@@ -1,20 +1,20 @@
module github.com/bettercap/bettercap/v2
-go 1.23.0
+go 1.21
-toolchain go1.24.4
+toolchain go1.22.6
require (
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d
- github.com/adrianmo/go-nmea v1.10.0
- github.com/antchfx/jsonquery v1.3.6
+ github.com/adrianmo/go-nmea v1.9.0
+ github.com/antchfx/jsonquery v1.3.5
github.com/bettercap/gatt v0.0.0-20240808115956-ec4935e8c4a0
github.com/bettercap/nrf24 v0.0.0-20190219153547-aa37e6d0e0eb
github.com/bettercap/readline v0.0.0-20210228151553-655e48bcb7bf
github.com/bettercap/recording v0.0.0-20190408083647-3ce1dcf032e3
github.com/cenkalti/backoff v2.2.1+incompatible
github.com/dustin/go-humanize v1.0.1
- github.com/elazarl/goproxy v1.7.2
+ github.com/elazarl/goproxy v0.0.0-20240726154733-8b0c20506380
github.com/evilsocket/islazy v1.11.0
github.com/florianl/go-nfqueue/v2 v2.0.0
github.com/gobwas/glob v0.0.0-20181002190808-e7a84e9525fe
@@ -23,45 +23,47 @@ require (
github.com/google/gousb v1.1.3
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.3
+ github.com/grandcat/zeroconf v1.0.0
github.com/hashicorp/go-bexpr v0.1.14
github.com/inconshreveable/go-vhost v1.0.0
github.com/jpillora/go-tld v1.2.1
github.com/malfunkt/iprange v0.9.0
github.com/mdlayher/dhcp6 v0.0.0-20190311162359-2a67805d7d0b
- github.com/miekg/dns v1.1.67
+ github.com/miekg/dns v1.1.61
github.com/mitchellh/go-homedir v1.1.0
github.com/phin1x/go-ipp v1.6.1
- github.com/robertkrimen/otto v0.5.1
+ github.com/robertkrimen/otto v0.4.0
github.com/stratoberry/go-gpsd v1.3.0
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07
github.com/thoj/go-ircevent v0.0.0-20210723090443-73e444401d64
- go.einride.tech/can v0.14.0
- golang.org/x/net v0.42.0
+ go.einride.tech/can v0.12.0
+ golang.org/x/net v0.28.0
gopkg.in/yaml.v3 v3.0.1
)
require (
- github.com/antchfx/xpath v1.3.4 // indirect
+ github.com/antchfx/xpath v1.3.1 // indirect
github.com/chzyer/logex v1.2.1 // indirect
- github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
+ github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e // indirect
+ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/mock v1.6.0 // indirect
- github.com/google/go-cmp v0.7.0 // indirect
+ github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/kr/binarydist v0.1.0 // indirect
- github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
- github.com/mdlayher/socket v0.5.1 // indirect
+ github.com/mdlayher/socket v0.4.1 // indirect
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
github.com/mgutz/logxi v0.0.0-20161027140823-aebf8a7d67ab // indirect
- github.com/mitchellh/mapstructure v1.5.0 // indirect
+ github.com/mitchellh/mapstructure v1.4.1 // indirect
github.com/mitchellh/pointerstructure v1.2.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
- golang.org/x/mod v0.26.0 // indirect
- golang.org/x/sync v0.16.0 // indirect
- golang.org/x/sys v0.34.0 // indirect
- golang.org/x/text v0.27.0 // indirect
- golang.org/x/tools v0.35.0 // indirect
+ golang.org/x/mod v0.20.0 // indirect
+ golang.org/x/sync v0.8.0 // indirect
+ golang.org/x/sys v0.23.0 // indirect
+ golang.org/x/text v0.17.0 // indirect
+ golang.org/x/tools v0.24.0 // indirect
gopkg.in/sourcemap.v1 v1.0.5 // indirect
)
diff --git a/go.sum b/go.sum
index f9a5d6ad..a2930b76 100644
--- a/go.sum
+++ b/go.sum
@@ -1,12 +1,11 @@
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8=
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo=
-github.com/adrianmo/go-nmea v1.10.0 h1:L1aYaebZ4cXFCoXNSeDeQa0tApvSKvIbqMsK+iaRiCo=
-github.com/adrianmo/go-nmea v1.10.0/go.mod h1:u8bPnpKt/D/5rll/5l9f6iDfeq5WZW0+/SXdkwix6Tg=
-github.com/antchfx/jsonquery v1.3.6 h1:TaSfeAh7n6T11I74bsZ1FswreIfrbJ0X+OyLflx6mx4=
-github.com/antchfx/jsonquery v1.3.6/go.mod h1:fGzSGJn9Y826Qd3pC8Wx45avuUwpkePsACQJYy+58BU=
-github.com/antchfx/xpath v1.3.2/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs=
-github.com/antchfx/xpath v1.3.4 h1:1ixrW1VnXd4HurCj7qnqnR0jo14g8JMe20Fshg1Vgz4=
-github.com/antchfx/xpath v1.3.4/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs=
+github.com/adrianmo/go-nmea v1.9.0 h1:kCuerWLDIppltHNZ2HGdCGkqbmupYJYfE6indcGkcp8=
+github.com/adrianmo/go-nmea v1.9.0/go.mod h1:u8bPnpKt/D/5rll/5l9f6iDfeq5WZW0+/SXdkwix6Tg=
+github.com/antchfx/jsonquery v1.3.5 h1:243OSaQh02EfmASa3w3weKC9UaiD8RRzJhgfvq3q408=
+github.com/antchfx/jsonquery v1.3.5/go.mod h1:qH23yX2Jsj1/k378Yu/EOgPCNgJ35P9tiGOeQdt/GWc=
+github.com/antchfx/xpath v1.3.1 h1:PNbFuUqHwWl0xRjvUPjJ95Agbmdj2uzzIwmQKgu4oCk=
+github.com/antchfx/xpath v1.3.1/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs=
github.com/bettercap/gatt v0.0.0-20240808115956-ec4935e8c4a0 h1:HiFUGV/7eGWG/YJAf9HcKOUmxIj+7LVzC8zD57VX1qo=
github.com/bettercap/gatt v0.0.0-20240808115956-ec4935e8c4a0/go.mod h1:oafnPgaBI4gqJiYkueCyR4dqygiWGXTGOE0gmmAVeeQ=
github.com/bettercap/nrf24 v0.0.0-20190219153547-aa37e6d0e0eb h1:JWAAJk4ny+bT3VrtcX+e7mcmWtWUeUM0xVcocSAUuWc=
@@ -27,22 +26,23 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
-github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
-github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
+github.com/elazarl/goproxy v0.0.0-20240726154733-8b0c20506380 h1:1NyRx2f4W4WBRyg0Kys0ZbaNmDDzZ2R/C7DTi+bbsJ0=
+github.com/elazarl/goproxy v0.0.0-20240726154733-8b0c20506380/go.mod h1:thX175TtLTzLj3p7N/Q9IiKZ7NF+p72cvL91emV0hzo=
+github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e h1:CQn2/8fi3kmpT9BTiHEELgdxAOQNVZc9GoPA4qnQzrs=
+github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8=
github.com/evilsocket/islazy v1.11.0 h1:B5w6uuS6ki6iDG+aH/RFeoMb8ijQh/pGabewqp2UeJ0=
github.com/evilsocket/islazy v1.11.0/go.mod h1:muYH4x5MB5YRdkxnrOtrXLIBX6LySj1uFIqys94LKdo=
github.com/florianl/go-nfqueue/v2 v2.0.0 h1:NTCxS9b0GSbHkWv1a7oOvZn679fsyDkaSkRvOYpQ9Oo=
github.com/florianl/go-nfqueue/v2 v2.0.0/go.mod h1:M2tBLIj62QpwqjwV0qfcjqGOqP3qiTuXr2uSRBXH9Qk=
github.com/gobwas/glob v0.0.0-20181002190808-e7a84e9525fe h1:8P+/htb3mwwpeGdJg69yBF/RofK7c6Fjz5Ypa/bTqbY=
github.com/gobwas/glob v0.0.0-20181002190808-e7a84e9525fe/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
+github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
-github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
-github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
@@ -55,6 +55,8 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE=
+github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs=
github.com/hashicorp/go-bexpr v0.1.14 h1:uKDeyuOhWhT1r5CiMTjdVY4Aoxdxs6EtwgTGnlosyp4=
github.com/hashicorp/go-bexpr v0.1.14/go.mod h1:gN7hRKB3s7yT+YvTdnhZVLTENejvhlkZ8UE4YVBS+Q8=
github.com/inconshreveable/go-vhost v1.0.0 h1:IK4VZTlXL4l9vz2IZoiSFbYaaqUW7dXJAiPriUN5Ur8=
@@ -74,28 +76,29 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/malfunkt/iprange v0.9.0 h1:VCs0PKLUPotNVQTpVNszsut4lP7OCGNBwX+lOYBrnVQ=
github.com/malfunkt/iprange v0.9.0/go.mod h1:TRGqO/f95gh3LOndUGTL46+W0GXA91WTqyZ0Quwvt4U=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
-github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
-github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
+github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/dhcp6 v0.0.0-20190311162359-2a67805d7d0b h1:r12blE3QRYlW1WBiBEe007O6NrTb/P54OjR5d4WLEGk=
github.com/mdlayher/dhcp6 v0.0.0-20190311162359-2a67805d7d0b/go.mod h1:p4K2+UAoap8Jzsadsxc0KG0OZjmmCthTPUyZqAVkjBY=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
-github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
-github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
+github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
+github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/mgutz/logxi v0.0.0-20161027140823-aebf8a7d67ab h1:n8cgpHzJ5+EDyDri2s/GC7a9+qK3/YEGnBsd0uS/8PY=
github.com/mgutz/logxi v0.0.0-20161027140823-aebf8a7d67ab/go.mod h1:y1pL58r5z2VvAjeG1VLGc8zOQgSOzbKN7kMHPvFXJ+8=
-github.com/miekg/dns v1.1.67 h1:kg0EHj0G4bfT5/oOys6HhZw4vmMlnoZ+gDu8tJ/AlI0=
-github.com/miekg/dns v1.1.67/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
+github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
+github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
+github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
+github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag=
github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
-github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
-github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/pointerstructure v1.2.1 h1:ZhBBeX8tSlRpu/FFhXH4RC4OJzFlqsQhoHZAz4x7TIw=
github.com/mitchellh/pointerstructure v1.2.1/go.mod h1:BRAsLI5zgXmw97Lf6s25bs8ohIXc3tViBH44KcwB2g4=
github.com/phin1x/go-ipp v1.6.1 h1:oxJXi92BO2FZhNcG3twjnxKFH1liTQ46vbbZx+IN/80=
@@ -104,8 +107,9 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/robertkrimen/otto v0.5.1 h1:avDI4ToRk8k1hppLdYFTuuzND41n37vPGJU7547dGf0=
-github.com/robertkrimen/otto v0.5.1/go.mod h1:bS433I4Q9p+E5pZLu7r17vP6FkE6/wLxBdmKjoqJXF8=
+github.com/robertkrimen/otto v0.4.0 h1:/c0GRrK1XDPcgIasAsnlpBT5DelIeB9U/Z/JCQsgr7E=
+github.com/robertkrimen/otto v0.4.0/go.mod h1:uW9yN1CYflmUQYvAMS0m+ZiNo3dMzRUDQJX0jWbzgxw=
+github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc=
github.com/stratoberry/go-gpsd v1.3.0 h1:JxJOEC4SgD0QY65AE7B1CtJtweP73nqJghZeLNU9J+c=
github.com/stratoberry/go-gpsd v1.3.0/go.mod h1:nVf/vTgfYxOMxiQdy9BtJjojbFRtG8H3wNula++VgkU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -115,16 +119,15 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 h1:UyzmZLoiDWMRywV4DUYb9Fbt8uiOSooupjTq10vpvnU=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/thoj/go-ircevent v0.0.0-20210723090443-73e444401d64 h1:l/T7dYuJEQZOwVOpjIXr1180aM9PZL/d1MnMVIxefX4=
github.com/thoj/go-ircevent v0.0.0-20210723090443-73e444401d64/go.mod h1:Q1NAJOuRdQCqN/VIWdnaaEhV8LpeO2rtlBP7/iDJNII=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
-go.einride.tech/can v0.14.0 h1:OkQ0jsjCk4ijgTMjD43V1NKQyDztpX7Vo/NrvmnsAXE=
-go.einride.tech/can v0.14.0/go.mod h1:615YuRGnWfndMGD+f3Ud1sp1xJLP1oj14dKRtb2CXDQ=
+go.einride.tech/can v0.12.0 h1:6MW9TKycSovWqJxcYHpZEiuFCGuAfpqApCzTS15KrPk=
+go.einride.tech/can v0.12.0/go.mod h1:5n3+AonCfUso6PfjD9l2d0W2LxTFjjHOnHAm+UMS9Ws=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -132,22 +135,25 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
-golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
+golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
+golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190310074541-c10a0554eabf/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
-golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
-golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
+golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
+golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
-golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
+golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -157,23 +163,25 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
-golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
+golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
-golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
+golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
+golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
-golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
-golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
+golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
+golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/js/crypto.go b/js/crypto.go
deleted file mode 100644
index 7128b965..00000000
--- a/js/crypto.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package js
-
-import (
- "crypto/sha1"
-
- "github.com/robertkrimen/otto"
-)
-
-func cryptoSha1(call otto.FunctionCall) otto.Value {
- argv := call.ArgumentList
- argc := len(argv)
- if argc != 1 {
- return ReportError("Crypto.sha1: expected 1 argument, %d given instead.", argc)
- }
-
- arg := argv[0]
- if (!arg.IsString()) {
- return ReportError("Crypto.sha1: single argument must be a string.")
- }
-
- hasher := sha1.New()
- hasher.Write([]byte(arg.String()))
- v, err := otto.ToValue(string(hasher.Sum(nil)))
- if err != nil {
- return ReportError("Crypto.sha1: could not convert to string: %s", err)
- }
-
- return v
-}
diff --git a/js/data.go b/js/data.go
index 6fe48f22..e2bfe5b0 100644
--- a/js/data.go
+++ b/js/data.go
@@ -8,94 +8,25 @@ import (
"github.com/robertkrimen/otto"
)
-func textEncode(call otto.FunctionCall) otto.Value {
- argv := call.ArgumentList
- argc := len(argv)
- if argc != 1 {
- return ReportError("textEncode: expected 1 argument, %d given instead.", argc)
- }
-
- arg := argv[0]
- if (!arg.IsString()) {
- return ReportError("textEncode: single argument must be a string.")
- }
-
- encoded := []byte(arg.String())
- vm := otto.New()
- v, err := vm.ToValue(encoded)
- if err != nil {
- return ReportError("textEncode: could not convert to []uint8: %s", err.Error())
- }
-
- return v
-}
-
-func textDecode(call otto.FunctionCall) otto.Value {
- argv := call.ArgumentList
- argc := len(argv)
- if argc != 1 {
- return ReportError("textDecode: expected 1 argument, %d given instead.", argc)
- }
-
- arg, err := argv[0].Export()
- if err != nil {
- return ReportError("textDecode: could not export argument value: %s", err.Error())
- }
- byteArr, ok := arg.([]uint8)
- if !ok {
- return ReportError("textDecode: single argument must be of type []uint8.")
- }
-
- decoded := string(byteArr)
- v, err := otto.ToValue(decoded)
- if err != nil {
- return ReportError("textDecode: could not convert to string: %s", err.Error())
- }
-
- return v
-}
-
func btoa(call otto.FunctionCall) otto.Value {
- argv := call.ArgumentList
- argc := len(argv)
- if argc != 1 {
- return ReportError("btoa: expected 1 argument, %d given instead.", argc)
- }
-
- arg := argv[0]
- if (!arg.IsString()) {
- return ReportError("btoa: single argument must be a string.")
- }
-
- encoded := base64.StdEncoding.EncodeToString([]byte(arg.String()))
- v, err := otto.ToValue(encoded)
+ varValue := base64.StdEncoding.EncodeToString([]byte(call.Argument(0).String()))
+ v, err := otto.ToValue(varValue)
if err != nil {
- return ReportError("btoa: could not convert to string: %s", err.Error())
+ return ReportError("Could not convert to string: %s", varValue)
}
return v
}
func atob(call otto.FunctionCall) otto.Value {
- argv := call.ArgumentList
- argc := len(argv)
- if argc != 1 {
- return ReportError("atob: expected 1 argument, %d given instead.", argc)
- }
-
- arg := argv[0]
- if (!arg.IsString()) {
- return ReportError("atob: single argument must be a string.")
- }
-
- decoded, err := base64.StdEncoding.DecodeString(arg.String())
+ varValue, err := base64.StdEncoding.DecodeString(call.Argument(0).String())
if err != nil {
- return ReportError("atob: could not decode string: %s", err.Error())
+ return ReportError("Could not decode string: %s", call.Argument(0).String())
}
- v, err := otto.ToValue(string(decoded))
+ v, err := otto.ToValue(string(varValue))
if err != nil {
- return ReportError("atob: could not convert to string: %s", err.Error())
+ return ReportError("Could not convert to string: %s", varValue)
}
return v
@@ -108,12 +39,7 @@ func gzipCompress(call otto.FunctionCall) otto.Value {
return ReportError("gzipCompress: expected 1 argument, %d given instead.", argc)
}
- arg := argv[0]
- if (!arg.IsString()) {
- return ReportError("gzipCompress: single argument must be a string.")
- }
-
- uncompressedBytes := []byte(arg.String())
+ uncompressedBytes := []byte(argv[0].String())
var writerBuffer bytes.Buffer
gzipWriter := gzip.NewWriter(&writerBuffer)
@@ -127,7 +53,7 @@ func gzipCompress(call otto.FunctionCall) otto.Value {
v, err := otto.ToValue(string(compressedBytes))
if err != nil {
- return ReportError("gzipCompress: could not convert to string: %s", err.Error())
+ return ReportError("Could not convert to string: %s", err.Error())
}
return v
@@ -157,7 +83,7 @@ func gzipDecompress(call otto.FunctionCall) otto.Value {
decompressedBytes := decompressedBuffer.Bytes()
v, err := otto.ToValue(string(decompressedBytes))
if err != nil {
- return ReportError("gzipDecompress: could not convert to string: %s", err.Error())
+ return ReportError("Could not convert to string: %s", err.Error())
}
return v
diff --git a/js/data_test.go b/js/data_test.go
deleted file mode 100644
index 64326418..00000000
--- a/js/data_test.go
+++ /dev/null
@@ -1,514 +0,0 @@
-package js
-
-import (
- "encoding/base64"
- "strings"
- "testing"
-
- "github.com/robertkrimen/otto"
-)
-
-func TestBtoa(t *testing.T) {
- vm := otto.New()
-
- tests := []struct {
- name string
- input string
- expected string
- }{
- {
- name: "simple string",
- input: "hello world",
- expected: base64.StdEncoding.EncodeToString([]byte("hello world")),
- },
- {
- name: "empty string",
- input: "",
- expected: base64.StdEncoding.EncodeToString([]byte("")),
- },
- {
- name: "special characters",
- input: "!@#$%^&*()_+-=[]{}|;:,.<>?",
- expected: base64.StdEncoding.EncodeToString([]byte("!@#$%^&*()_+-=[]{}|;:,.<>?")),
- },
- {
- name: "unicode string",
- input: "Hello 世界 🌍",
- expected: base64.StdEncoding.EncodeToString([]byte("Hello 世界 🌍")),
- },
- {
- name: "newlines and tabs",
- input: "line1\nline2\ttab",
- expected: base64.StdEncoding.EncodeToString([]byte("line1\nline2\ttab")),
- },
- {
- name: "long string",
- input: strings.Repeat("a", 1000),
- expected: base64.StdEncoding.EncodeToString([]byte(strings.Repeat("a", 1000))),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create call with argument
- arg, _ := vm.ToValue(tt.input)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := btoa(call)
-
- // Check if result is an error
- if result.IsUndefined() {
- t.Fatal("btoa returned undefined")
- }
-
- // Get string value
- resultStr, err := result.ToString()
- if err != nil {
- t.Fatalf("failed to convert result to string: %v", err)
- }
-
- if resultStr != tt.expected {
- t.Errorf("btoa(%q) = %q, want %q", tt.input, resultStr, tt.expected)
- }
- })
- }
-}
-
-func TestAtob(t *testing.T) {
- vm := otto.New()
-
- tests := []struct {
- name string
- input string
- expected string
- wantError bool
- }{
- {
- name: "simple base64",
- input: base64.StdEncoding.EncodeToString([]byte("hello world")),
- expected: "hello world",
- },
- {
- name: "empty base64",
- input: base64.StdEncoding.EncodeToString([]byte("")),
- expected: "",
- },
- {
- name: "special characters base64",
- input: base64.StdEncoding.EncodeToString([]byte("!@#$%^&*()_+-=[]{}|;:,.<>?")),
- expected: "!@#$%^&*()_+-=[]{}|;:,.<>?",
- },
- {
- name: "unicode base64",
- input: base64.StdEncoding.EncodeToString([]byte("Hello 世界 🌍")),
- expected: "Hello 世界 🌍",
- },
- {
- name: "invalid base64",
- input: "not valid base64!",
- wantError: true,
- },
- {
- name: "invalid padding",
- input: "SGVsbG8gV29ybGQ", // Missing padding
- wantError: true,
- },
- {
- name: "long base64",
- input: base64.StdEncoding.EncodeToString([]byte(strings.Repeat("a", 1000))),
- expected: strings.Repeat("a", 1000),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create call with argument
- arg, _ := vm.ToValue(tt.input)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := atob(call)
-
- // Get string value
- resultStr, err := result.ToString()
- if err != nil && !tt.wantError {
- t.Fatalf("failed to convert result to string: %v", err)
- }
-
- if tt.wantError {
- // Should return undefined (NullValue) on error
- if !result.IsUndefined() {
- t.Errorf("expected undefined for error case, got %q", resultStr)
- }
- } else {
- if resultStr != tt.expected {
- t.Errorf("atob(%q) = %q, want %q", tt.input, resultStr, tt.expected)
- }
- }
- })
- }
-}
-
-func TestGzipCompress(t *testing.T) {
- vm := otto.New()
-
- tests := []struct {
- name string
- input string
- }{
- {
- name: "simple string",
- input: "hello world",
- },
- {
- name: "empty string",
- input: "",
- },
- {
- name: "repeated pattern",
- input: strings.Repeat("abcd", 100),
- },
- {
- name: "random text",
- input: "The quick brown fox jumps over the lazy dog. " + strings.Repeat("Lorem ipsum dolor sit amet. ", 10),
- },
- {
- name: "unicode text",
- input: "Hello 世界 🌍 " + strings.Repeat("测试数据 ", 50),
- },
- {
- name: "binary-like data",
- input: string([]byte{0, 1, 2, 3, 255, 254, 253, 252}),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create call with argument
- arg, _ := vm.ToValue(tt.input)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := gzipCompress(call)
-
- // Get compressed data
- compressed, err := result.ToString()
- if err != nil {
- t.Fatalf("failed to convert result to string: %v", err)
- }
-
- // Verify it's actually compressed (for non-empty strings, compressed should be different)
- if tt.input != "" && compressed == tt.input {
- t.Error("compressed data is same as input")
- }
-
- // Verify gzip header (should start with 0x1f, 0x8b)
- if len(compressed) >= 2 {
- if compressed[0] != 0x1f || compressed[1] != 0x8b {
- t.Error("compressed data doesn't have valid gzip header")
- }
- }
-
- // Now decompress to verify
- argCompressed, _ := vm.ToValue(compressed)
- callDecompress := otto.FunctionCall{
- ArgumentList: []otto.Value{argCompressed},
- }
-
- resultDecompressed := gzipDecompress(callDecompress)
- decompressed, err := resultDecompressed.ToString()
- if err != nil {
- t.Fatalf("failed to decompress: %v", err)
- }
-
- if decompressed != tt.input {
- t.Errorf("round-trip failed: got %q, want %q", decompressed, tt.input)
- }
- })
- }
-}
-
-func TestGzipCompressInvalidArgs(t *testing.T) {
- vm := otto.New()
-
- tests := []struct {
- name string
- args []otto.Value
- }{
- {
- name: "no arguments",
- args: []otto.Value{},
- },
- {
- name: "too many arguments",
- args: func() []otto.Value {
- arg1, _ := vm.ToValue("test")
- arg2, _ := vm.ToValue("extra")
- return []otto.Value{arg1, arg2}
- }(),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- call := otto.FunctionCall{
- ArgumentList: tt.args,
- }
-
- result := gzipCompress(call)
-
- // Should return undefined (NullValue) on error
- if !result.IsUndefined() {
- resultStr, _ := result.ToString()
- t.Errorf("expected undefined for error case, got %q", resultStr)
- }
- })
- }
-}
-
-func TestGzipDecompress(t *testing.T) {
- vm := otto.New()
-
- // First compress some data
- originalData := "This is test data for decompression"
- arg, _ := vm.ToValue(originalData)
- compressCall := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
- compressedResult := gzipCompress(compressCall)
- compressedData, _ := compressedResult.ToString()
-
- t.Run("valid decompression", func(t *testing.T) {
- argCompressed, _ := vm.ToValue(compressedData)
- decompressCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argCompressed},
- }
-
- result := gzipDecompress(decompressCall)
- decompressed, err := result.ToString()
- if err != nil {
- t.Fatalf("failed to convert result to string: %v", err)
- }
-
- if decompressed != originalData {
- t.Errorf("decompressed data doesn't match original: got %q, want %q", decompressed, originalData)
- }
- })
-
- t.Run("invalid gzip data", func(t *testing.T) {
- argInvalid, _ := vm.ToValue("not gzip data")
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argInvalid},
- }
-
- result := gzipDecompress(call)
-
- // Should return undefined (NullValue) on error
- if !result.IsUndefined() {
- resultStr, _ := result.ToString()
- t.Errorf("expected undefined for error case, got %q", resultStr)
- }
- })
-
- t.Run("corrupted gzip data", func(t *testing.T) {
- // Create corrupted gzip by taking valid gzip and modifying it
- corruptedData := compressedData[:len(compressedData)/2] + "corrupted"
-
- argCorrupted, _ := vm.ToValue(corruptedData)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argCorrupted},
- }
-
- result := gzipDecompress(call)
-
- // Should return undefined (NullValue) on error
- if !result.IsUndefined() {
- resultStr, _ := result.ToString()
- t.Errorf("expected undefined for error case, got %q", resultStr)
- }
- })
-}
-
-func TestGzipDecompressInvalidArgs(t *testing.T) {
- vm := otto.New()
-
- tests := []struct {
- name string
- args []otto.Value
- }{
- {
- name: "no arguments",
- args: []otto.Value{},
- },
- {
- name: "too many arguments",
- args: func() []otto.Value {
- arg1, _ := vm.ToValue("test")
- arg2, _ := vm.ToValue("extra")
- return []otto.Value{arg1, arg2}
- }(),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- call := otto.FunctionCall{
- ArgumentList: tt.args,
- }
-
- result := gzipDecompress(call)
-
- // Should return undefined (NullValue) on error
- if !result.IsUndefined() {
- resultStr, _ := result.ToString()
- t.Errorf("expected undefined for error case, got %q", resultStr)
- }
- })
- }
-}
-
-func TestBtoaAtobRoundTrip(t *testing.T) {
- vm := otto.New()
-
- testStrings := []string{
- "simple",
- "",
- "with spaces and\nnewlines\ttabs",
- "special!@#$%^&*()_+-=[]{}|;:,.<>?",
- "unicode 世界 🌍",
- strings.Repeat("long string ", 100),
- }
-
- for _, original := range testStrings {
- t.Run(original, func(t *testing.T) {
- // Encode with btoa
- argOriginal, _ := vm.ToValue(original)
- encodeCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argOriginal},
- }
-
- encoded := btoa(encodeCall)
- encodedStr, _ := encoded.ToString()
-
- // Decode with atob
- argEncoded, _ := vm.ToValue(encodedStr)
- decodeCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argEncoded},
- }
-
- decoded := atob(decodeCall)
- decodedStr, _ := decoded.ToString()
-
- if decodedStr != original {
- t.Errorf("round-trip failed: got %q, want %q", decodedStr, original)
- }
- })
- }
-}
-
-func TestGzipCompressDecompressRoundTrip(t *testing.T) {
- vm := otto.New()
-
- testData := []string{
- "simple",
- "",
- strings.Repeat("repetitive data ", 100),
- "unicode 世界 🌍 " + strings.Repeat("测试 ", 50),
- string([]byte{0, 1, 2, 3, 255, 254, 253, 252}),
- }
-
- for _, original := range testData {
- t.Run(original, func(t *testing.T) {
- // Compress
- argOriginal, _ := vm.ToValue(original)
- compressCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argOriginal},
- }
-
- compressed := gzipCompress(compressCall)
- compressedStr, _ := compressed.ToString()
-
- // Decompress
- argCompressed, _ := vm.ToValue(compressedStr)
- decompressCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argCompressed},
- }
-
- decompressed := gzipDecompress(decompressCall)
- decompressedStr, _ := decompressed.ToString()
-
- if decompressedStr != original {
- t.Errorf("round-trip failed: got %q, want %q", decompressedStr, original)
- }
- })
- }
-}
-
-func BenchmarkBtoa(b *testing.B) {
- vm := otto.New()
- arg, _ := vm.ToValue("The quick brown fox jumps over the lazy dog")
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = btoa(call)
- }
-}
-
-func BenchmarkAtob(b *testing.B) {
- vm := otto.New()
- encoded := base64.StdEncoding.EncodeToString([]byte("The quick brown fox jumps over the lazy dog"))
- arg, _ := vm.ToValue(encoded)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = atob(call)
- }
-}
-
-func BenchmarkGzipCompress(b *testing.B) {
- vm := otto.New()
- data := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 10)
- arg, _ := vm.ToValue(data)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = gzipCompress(call)
- }
-}
-
-func BenchmarkGzipDecompress(b *testing.B) {
- vm := otto.New()
-
- // First compress some data
- data := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 10)
- argData, _ := vm.ToValue(data)
- compressCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argData},
- }
- compressed := gzipCompress(compressCall)
- compressedStr, _ := compressed.ToString()
-
- // Benchmark decompression
- argCompressed, _ := vm.ToValue(compressedStr)
- decompressCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argCompressed},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = gzipDecompress(decompressCall)
- }
-}
diff --git a/js/fs_test.go b/js/fs_test.go
deleted file mode 100644
index fd089d28..00000000
--- a/js/fs_test.go
+++ /dev/null
@@ -1,684 +0,0 @@
-package js
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path/filepath"
- "runtime"
- "strings"
- "testing"
-
- "github.com/robertkrimen/otto"
-)
-
-func TestReadDir(t *testing.T) {
- vm := otto.New()
-
- // Create a temporary directory for testing
- tmpDir, err := ioutil.TempDir("", "js_test_readdir_*")
- if err != nil {
- t.Fatalf("failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- // Create some test files and subdirectories
- testFiles := []string{"file1.txt", "file2.log", ".hidden"}
- testDirs := []string{"subdir1", "subdir2"}
-
- for _, name := range testFiles {
- if err := ioutil.WriteFile(filepath.Join(tmpDir, name), []byte("test"), 0644); err != nil {
- t.Fatalf("failed to create test file %s: %v", name, err)
- }
- }
-
- for _, name := range testDirs {
- if err := os.Mkdir(filepath.Join(tmpDir, name), 0755); err != nil {
- t.Fatalf("failed to create test dir %s: %v", name, err)
- }
- }
-
- t.Run("valid directory", func(t *testing.T) {
- arg, _ := vm.ToValue(tmpDir)
- call := otto.FunctionCall{
- Otto: vm,
- ArgumentList: []otto.Value{arg},
- }
-
- result := readDir(call)
-
- // Check if result is not undefined
- if result.IsUndefined() {
- t.Fatal("readDir returned undefined")
- }
-
- // Convert to Go slice
- export, err := result.Export()
- if err != nil {
- t.Fatalf("failed to export result: %v", err)
- }
-
- entries, ok := export.([]string)
- if !ok {
- t.Fatalf("expected []string, got %T", export)
- }
-
- // Check all expected entries are present
- expectedEntries := append(testFiles, testDirs...)
- if len(entries) != len(expectedEntries) {
- t.Errorf("expected %d entries, got %d", len(expectedEntries), len(entries))
- }
-
- // Check each entry exists
- for _, expected := range expectedEntries {
- found := false
- for _, entry := range entries {
- if entry == expected {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("expected entry %s not found", expected)
- }
- }
- })
-
- t.Run("non-existent directory", func(t *testing.T) {
- arg, _ := vm.ToValue("/path/that/does/not/exist")
- call := otto.FunctionCall{
- Otto: vm,
- ArgumentList: []otto.Value{arg},
- }
-
- result := readDir(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined for non-existent directory")
- }
- })
-
- t.Run("file instead of directory", func(t *testing.T) {
- // Create a file
- testFile := filepath.Join(tmpDir, "notadir.txt")
- ioutil.WriteFile(testFile, []byte("test"), 0644)
-
- arg, _ := vm.ToValue(testFile)
- call := otto.FunctionCall{
- Otto: vm,
- ArgumentList: []otto.Value{arg},
- }
-
- result := readDir(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined when passing file instead of directory")
- }
- })
-
- t.Run("invalid arguments", func(t *testing.T) {
- tests := []struct {
- name string
- args []otto.Value
- }{
- {
- name: "no arguments",
- args: []otto.Value{},
- },
- {
- name: "too many arguments",
- args: func() []otto.Value {
- arg1, _ := vm.ToValue(tmpDir)
- arg2, _ := vm.ToValue("extra")
- return []otto.Value{arg1, arg2}
- }(),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- call := otto.FunctionCall{
- Otto: vm,
- ArgumentList: tt.args,
- }
-
- result := readDir(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined for invalid arguments")
- }
- })
- }
- })
-
- t.Run("empty directory", func(t *testing.T) {
- emptyDir := filepath.Join(tmpDir, "empty")
- os.Mkdir(emptyDir, 0755)
-
- arg, _ := vm.ToValue(emptyDir)
- call := otto.FunctionCall{
- Otto: vm,
- ArgumentList: []otto.Value{arg},
- }
-
- result := readDir(call)
-
- if result.IsUndefined() {
- t.Fatal("readDir returned undefined for empty directory")
- }
-
- export, _ := result.Export()
- entries, _ := export.([]string)
-
- if len(entries) != 0 {
- t.Errorf("expected 0 entries for empty directory, got %d", len(entries))
- }
- })
-}
-
-func TestReadFile(t *testing.T) {
- vm := otto.New()
-
- // Create a temporary directory for testing
- tmpDir, err := ioutil.TempDir("", "js_test_readfile_*")
- if err != nil {
- t.Fatalf("failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- t.Run("valid file", func(t *testing.T) {
- testContent := "Hello, World!\nThis is a test file.\n特殊字符测试 🌍"
- testFile := filepath.Join(tmpDir, "test.txt")
- ioutil.WriteFile(testFile, []byte(testContent), 0644)
-
- arg, _ := vm.ToValue(testFile)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := readFile(call)
-
- if result.IsUndefined() {
- t.Fatal("readFile returned undefined")
- }
-
- content, err := result.ToString()
- if err != nil {
- t.Fatalf("failed to convert result to string: %v", err)
- }
-
- if content != testContent {
- t.Errorf("expected content %q, got %q", testContent, content)
- }
- })
-
- t.Run("non-existent file", func(t *testing.T) {
- arg, _ := vm.ToValue("/path/that/does/not/exist.txt")
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := readFile(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined for non-existent file")
- }
- })
-
- t.Run("directory instead of file", func(t *testing.T) {
- arg, _ := vm.ToValue(tmpDir)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := readFile(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined when passing directory instead of file")
- }
- })
-
- t.Run("empty file", func(t *testing.T) {
- emptyFile := filepath.Join(tmpDir, "empty.txt")
- ioutil.WriteFile(emptyFile, []byte(""), 0644)
-
- arg, _ := vm.ToValue(emptyFile)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := readFile(call)
-
- if result.IsUndefined() {
- t.Fatal("readFile returned undefined for empty file")
- }
-
- content, _ := result.ToString()
- if content != "" {
- t.Errorf("expected empty string, got %q", content)
- }
- })
-
- t.Run("binary file", func(t *testing.T) {
- binaryContent := []byte{0, 1, 2, 3, 255, 254, 253, 252}
- binaryFile := filepath.Join(tmpDir, "binary.bin")
- ioutil.WriteFile(binaryFile, binaryContent, 0644)
-
- arg, _ := vm.ToValue(binaryFile)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := readFile(call)
-
- if result.IsUndefined() {
- t.Fatal("readFile returned undefined for binary file")
- }
-
- content, _ := result.ToString()
- if content != string(binaryContent) {
- t.Error("binary content mismatch")
- }
- })
-
- t.Run("invalid arguments", func(t *testing.T) {
- tests := []struct {
- name string
- args []otto.Value
- }{
- {
- name: "no arguments",
- args: []otto.Value{},
- },
- {
- name: "too many arguments",
- args: func() []otto.Value {
- arg1, _ := vm.ToValue("file.txt")
- arg2, _ := vm.ToValue("extra")
- return []otto.Value{arg1, arg2}
- }(),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- call := otto.FunctionCall{
- ArgumentList: tt.args,
- }
-
- result := readFile(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined for invalid arguments")
- }
- })
- }
- })
-
- t.Run("large file", func(t *testing.T) {
- // Create a 1MB file
- largeContent := strings.Repeat("A", 1024*1024)
- largeFile := filepath.Join(tmpDir, "large.txt")
- ioutil.WriteFile(largeFile, []byte(largeContent), 0644)
-
- arg, _ := vm.ToValue(largeFile)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- result := readFile(call)
-
- if result.IsUndefined() {
- t.Fatal("readFile returned undefined for large file")
- }
-
- content, _ := result.ToString()
- if len(content) != len(largeContent) {
- t.Errorf("expected content length %d, got %d", len(largeContent), len(content))
- }
- })
-}
-
-func TestWriteFile(t *testing.T) {
- vm := otto.New()
-
- // Create a temporary directory for testing
- tmpDir, err := ioutil.TempDir("", "js_test_writefile_*")
- if err != nil {
- t.Fatalf("failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- t.Run("write new file", func(t *testing.T) {
- testFile := filepath.Join(tmpDir, "new_file.txt")
- testContent := "Hello, World!\nThis is a new file.\n特殊字符测试 🌍"
-
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue(testContent)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
-
- result := writeFile(call)
-
- // writeFile returns null on success
- if !result.IsNull() {
- t.Error("expected null return value for successful write")
- }
-
- // Verify file was created with correct content
- content, err := ioutil.ReadFile(testFile)
- if err != nil {
- t.Fatalf("failed to read written file: %v", err)
- }
-
- if string(content) != testContent {
- t.Errorf("expected content %q, got %q", testContent, string(content))
- }
-
- // Check file permissions
- info, _ := os.Stat(testFile)
- if runtime.GOOS == "windows" {
- // On Windows, permissions are different - just check that file exists and is readable
- if info.Mode()&0400 == 0 {
- t.Error("expected file to be readable on Windows")
- }
- } else {
- // On Unix-like systems, check exact permissions
- if info.Mode().Perm() != 0644 {
- t.Errorf("expected permissions 0644, got %v", info.Mode().Perm())
- }
- }
- })
-
- t.Run("overwrite existing file", func(t *testing.T) {
- testFile := filepath.Join(tmpDir, "existing.txt")
- oldContent := "Old content"
- newContent := "New content that is longer than the old content"
-
- // Create initial file
- ioutil.WriteFile(testFile, []byte(oldContent), 0644)
-
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue(newContent)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
-
- result := writeFile(call)
-
- if !result.IsNull() {
- t.Error("expected null return value for successful write")
- }
-
- // Verify file was overwritten
- content, _ := ioutil.ReadFile(testFile)
- if string(content) != newContent {
- t.Errorf("expected content %q, got %q", newContent, string(content))
- }
- })
-
- t.Run("write to non-existent directory", func(t *testing.T) {
- testFile := filepath.Join(tmpDir, "nonexistent", "subdir", "file.txt")
- testContent := "test"
-
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue(testContent)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
-
- result := writeFile(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined when writing to non-existent directory")
- }
- })
-
- t.Run("write empty content", func(t *testing.T) {
- testFile := filepath.Join(tmpDir, "empty.txt")
-
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue("")
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
-
- result := writeFile(call)
-
- if !result.IsNull() {
- t.Error("expected null return value for successful write")
- }
-
- // Verify empty file was created
- content, _ := ioutil.ReadFile(testFile)
- if len(content) != 0 {
- t.Errorf("expected empty file, got %d bytes", len(content))
- }
- })
-
- t.Run("invalid arguments", func(t *testing.T) {
- tests := []struct {
- name string
- args []otto.Value
- }{
- {
- name: "no arguments",
- args: []otto.Value{},
- },
- {
- name: "one argument",
- args: func() []otto.Value {
- arg, _ := vm.ToValue("file.txt")
- return []otto.Value{arg}
- }(),
- },
- {
- name: "too many arguments",
- args: func() []otto.Value {
- arg1, _ := vm.ToValue("file.txt")
- arg2, _ := vm.ToValue("content")
- arg3, _ := vm.ToValue("extra")
- return []otto.Value{arg1, arg2, arg3}
- }(),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- call := otto.FunctionCall{
- ArgumentList: tt.args,
- }
-
- result := writeFile(call)
-
- // Should return undefined (error)
- if !result.IsUndefined() {
- t.Error("expected undefined for invalid arguments")
- }
- })
- }
- })
-
- t.Run("write binary content", func(t *testing.T) {
- testFile := filepath.Join(tmpDir, "binary.bin")
- binaryContent := string([]byte{0, 1, 2, 3, 255, 254, 253, 252})
-
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue(binaryContent)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
-
- result := writeFile(call)
-
- if !result.IsNull() {
- t.Error("expected null return value for successful write")
- }
-
- // Verify binary content
- content, _ := ioutil.ReadFile(testFile)
- if string(content) != binaryContent {
- t.Error("binary content mismatch")
- }
- })
-}
-
-func TestFileSystemIntegration(t *testing.T) {
- vm := otto.New()
-
- // Create a temporary directory for testing
- tmpDir, err := ioutil.TempDir("", "js_test_integration_*")
- if err != nil {
- t.Fatalf("failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- t.Run("write then read file", func(t *testing.T) {
- testFile := filepath.Join(tmpDir, "roundtrip.txt")
- testContent := "Round-trip test content\nLine 2\nLine 3"
-
- // Write file
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue(testContent)
- writeCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
-
- writeResult := writeFile(writeCall)
- if !writeResult.IsNull() {
- t.Fatal("write failed")
- }
-
- // Read file back
- readCall := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile},
- }
-
- readResult := readFile(readCall)
- if readResult.IsUndefined() {
- t.Fatal("read failed")
- }
-
- readContent, _ := readResult.ToString()
- if readContent != testContent {
- t.Errorf("round-trip failed: expected %q, got %q", testContent, readContent)
- }
- })
-
- t.Run("create files then list directory", func(t *testing.T) {
- // Create multiple files
- files := []string{"file1.txt", "file2.txt", "file3.txt"}
- for _, name := range files {
- path := filepath.Join(tmpDir, name)
- argFile, _ := vm.ToValue(path)
- argContent, _ := vm.ToValue("content of " + name)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
- writeFile(call)
- }
-
- // List directory
- argDir, _ := vm.ToValue(tmpDir)
- listCall := otto.FunctionCall{
- Otto: vm,
- ArgumentList: []otto.Value{argDir},
- }
-
- listResult := readDir(listCall)
- if listResult.IsUndefined() {
- t.Fatal("readDir failed")
- }
-
- export, _ := listResult.Export()
- entries, _ := export.([]string)
-
- // Check all files are listed
- for _, expected := range files {
- found := false
- for _, entry := range entries {
- if entry == expected {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("expected file %s not found in directory listing", expected)
- }
- }
- })
-}
-
-func BenchmarkReadFile(b *testing.B) {
- vm := otto.New()
-
- // Create test file
- tmpFile, _ := ioutil.TempFile("", "bench_readfile_*")
- defer os.Remove(tmpFile.Name())
-
- content := strings.Repeat("Benchmark test content line\n", 100)
- ioutil.WriteFile(tmpFile.Name(), []byte(content), 0644)
-
- arg, _ := vm.ToValue(tmpFile.Name())
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{arg},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = readFile(call)
- }
-}
-
-func BenchmarkWriteFile(b *testing.B) {
- vm := otto.New()
-
- tmpDir, _ := ioutil.TempDir("", "bench_writefile_*")
- defer os.RemoveAll(tmpDir)
-
- content := strings.Repeat("Benchmark test content line\n", 100)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- testFile := filepath.Join(tmpDir, fmt.Sprintf("bench_%d.txt", i))
- argFile, _ := vm.ToValue(testFile)
- argContent, _ := vm.ToValue(content)
- call := otto.FunctionCall{
- ArgumentList: []otto.Value{argFile, argContent},
- }
- _ = writeFile(call)
- }
-}
-
-func BenchmarkReadDir(b *testing.B) {
- vm := otto.New()
-
- // Create test directory with files
- tmpDir, _ := ioutil.TempDir("", "bench_readdir_*")
- defer os.RemoveAll(tmpDir)
-
- // Create 100 files
- for i := 0; i < 100; i++ {
- name := filepath.Join(tmpDir, fmt.Sprintf("file_%d.txt", i))
- ioutil.WriteFile(name, []byte("test"), 0644)
- }
-
- arg, _ := vm.ToValue(tmpDir)
- call := otto.FunctionCall{
- Otto: vm,
- ArgumentList: []otto.Value{arg},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = readDir(call)
- }
-}
diff --git a/js/http.go b/js/http.go
index 685f8ec0..615928cb 100644
--- a/js/http.go
+++ b/js/http.go
@@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
+ "io/ioutil"
"net/http"
"net/url"
"strings"
@@ -63,7 +64,7 @@ func (c httpPackage) Request(method string, uri string,
}
defer resp.Body.Close()
- raw, err := io.ReadAll(resp.Body)
+ raw, err := ioutil.ReadAll(resp.Body)
if err != nil {
return httpResponse{Error: err}
}
@@ -132,7 +133,7 @@ func httpRequest(call otto.FunctionCall) otto.Value {
}
defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
+ body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return ReportError("Could not read response: %s", err)
}
diff --git a/js/init.go b/js/init.go
index 1aaa52cd..6415dd88 100644
--- a/js/init.go
+++ b/js/init.go
@@ -27,16 +27,10 @@ func init() {
plugin.Defines["log_error"] = log_error
plugin.Defines["log_fatal"] = log_fatal
- plugin.Defines["Crypto"] = map[string]interface{}{
- "sha1": cryptoSha1,
- }
-
plugin.Defines["btoa"] = btoa
plugin.Defines["atob"] = atob
plugin.Defines["gzipCompress"] = gzipCompress
plugin.Defines["gzipDecompress"] = gzipDecompress
- plugin.Defines["textEncode"] = textEncode
- plugin.Defines["textDecode"] = textDecode
plugin.Defines["httpRequest"] = httpRequest
plugin.Defines["http"] = httpPackage{}
diff --git a/js/random_test.go b/js/random_test.go
deleted file mode 100644
index 594a16ad..00000000
--- a/js/random_test.go
+++ /dev/null
@@ -1,307 +0,0 @@
-package js
-
-import (
- "net"
- "regexp"
- "strings"
- "testing"
-)
-
-func TestRandomString(t *testing.T) {
- r := randomPackage{}
-
- tests := []struct {
- name string
- size int
- charset string
- }{
- {
- name: "alphanumeric",
- size: 10,
- charset: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
- },
- {
- name: "numbers only",
- size: 20,
- charset: "0123456789",
- },
- {
- name: "lowercase letters",
- size: 15,
- charset: "abcdefghijklmnopqrstuvwxyz",
- },
- {
- name: "uppercase letters",
- size: 8,
- charset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
- },
- {
- name: "special characters",
- size: 12,
- charset: "!@#$%^&*()_+-=[]{}|;:,.<>?",
- },
- {
- name: "unicode characters",
- size: 5,
- charset: "αβγδεζηθικλμνξοπρστυφχψω",
- },
- {
- name: "mixed unicode and ascii",
- size: 10,
- charset: "abc123αβγ",
- },
- {
- name: "single character",
- size: 100,
- charset: "a",
- },
- {
- name: "empty size",
- size: 0,
- charset: "abcdef",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := r.String(tt.size, tt.charset)
-
- // Check length
- if len([]rune(result)) != tt.size {
- t.Errorf("expected length %d, got %d", tt.size, len([]rune(result)))
- }
-
- // Check that all characters are from the charset
- for _, char := range result {
- if !strings.ContainsRune(tt.charset, char) {
- t.Errorf("character %c not in charset %s", char, tt.charset)
- }
- }
- })
- }
-}
-
-func TestRandomStringDistribution(t *testing.T) {
- r := randomPackage{}
- charset := "ab"
- size := 1000
-
- // Generate many single-character strings
- counts := make(map[rune]int)
- for i := 0; i < size; i++ {
- result := r.String(1, charset)
- if len(result) == 1 {
- counts[rune(result[0])]++
- }
- }
-
- // Check that both characters appear (very high probability)
- if len(counts) != 2 {
- t.Errorf("expected both characters to appear, got %d unique characters", len(counts))
- }
-
- // Check distribution is reasonable (not perfect due to randomness)
- for char, count := range counts {
- ratio := float64(count) / float64(size)
- if ratio < 0.3 || ratio > 0.7 {
- t.Errorf("character %c appeared %d times (%.2f%%), expected around 50%%",
- char, count, ratio*100)
- }
- }
-}
-
-func TestRandomMac(t *testing.T) {
- r := randomPackage{}
- macRegex := regexp.MustCompile(`^([0-9a-f]{2}:){5}[0-9a-f]{2}$`)
-
- // Generate multiple MAC addresses
- macs := make(map[string]bool)
- for i := 0; i < 100; i++ {
- mac := r.Mac()
-
- // Check format
- if !macRegex.MatchString(mac) {
- t.Errorf("invalid MAC format: %s", mac)
- }
-
- // Check it's a valid MAC
- _, err := net.ParseMAC(mac)
- if err != nil {
- t.Errorf("invalid MAC address: %s, error: %v", mac, err)
- }
-
- // Store for uniqueness check
- macs[mac] = true
- }
-
- // Check that we get different MACs (very high probability)
- if len(macs) < 95 {
- t.Errorf("expected at least 95 unique MACs out of 100, got %d", len(macs))
- }
-}
-
-func TestRandomMacNormalization(t *testing.T) {
- r := randomPackage{}
-
- // Generate several MACs and check they're normalized
- for i := 0; i < 10; i++ {
- mac := r.Mac()
-
- // Check lowercase
- if mac != strings.ToLower(mac) {
- t.Errorf("MAC not normalized to lowercase: %s", mac)
- }
-
- // Check separator is colon
- if strings.Contains(mac, "-") {
- t.Errorf("MAC contains hyphen instead of colon: %s", mac)
- }
-
- // Check length
- if len(mac) != 17 { // 6 bytes * 2 chars + 5 colons
- t.Errorf("MAC has wrong length: %s (len=%d)", mac, len(mac))
- }
- }
-}
-
-func TestRandomStringEdgeCases(t *testing.T) {
- r := randomPackage{}
-
- // Test with various edge cases
- tests := []struct {
- name string
- size int
- charset string
- }{
- {
- name: "zero size",
- size: 0,
- charset: "abc",
- },
- {
- name: "very large size",
- size: 10000,
- charset: "abc",
- },
- {
- name: "size larger than charset",
- size: 10,
- charset: "ab",
- },
- {
- name: "single char charset with large size",
- size: 1000,
- charset: "x",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := r.String(tt.size, tt.charset)
-
- if len([]rune(result)) != tt.size {
- t.Errorf("expected length %d, got %d", tt.size, len([]rune(result)))
- }
-
- // Check all characters are from charset
- for _, c := range result {
- if !strings.ContainsRune(tt.charset, c) {
- t.Errorf("character %c not in charset %s", c, tt.charset)
- }
- }
- })
- }
-}
-
-func TestRandomStringNegativeSize(t *testing.T) {
- r := randomPackage{}
-
- // Test that negative size causes panic
- defer func() {
- if r := recover(); r == nil {
- t.Error("expected panic for negative size but didn't get one")
- }
- }()
-
- // This should panic
- _ = r.String(-1, "abc")
-}
-
-func TestRandomPackageInstance(t *testing.T) {
- // Test that we can create multiple instances
- r1 := randomPackage{}
- r2 := randomPackage{}
-
- // Both should work independently
- s1 := r1.String(5, "abc")
- s2 := r2.String(5, "xyz")
-
- if len(s1) != 5 {
- t.Errorf("r1.String returned wrong length: %d", len(s1))
- }
- if len(s2) != 5 {
- t.Errorf("r2.String returned wrong length: %d", len(s2))
- }
-
- // Check correct charset usage
- for _, c := range s1 {
- if !strings.ContainsRune("abc", c) {
- t.Errorf("r1 produced character outside charset: %c", c)
- }
- }
- for _, c := range s2 {
- if !strings.ContainsRune("xyz", c) {
- t.Errorf("r2 produced character outside charset: %c", c)
- }
- }
-}
-
-func BenchmarkRandomString(b *testing.B) {
- r := randomPackage{}
- charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
-
- b.Run("size-10", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = r.String(10, charset)
- }
- })
-
- b.Run("size-100", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = r.String(100, charset)
- }
- })
-
- b.Run("size-1000", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = r.String(1000, charset)
- }
- })
-}
-
-func BenchmarkRandomMac(b *testing.B) {
- r := randomPackage{}
-
- for i := 0; i < b.N; i++ {
- _ = r.Mac()
- }
-}
-
-func BenchmarkRandomStringCharsets(b *testing.B) {
- r := randomPackage{}
-
- charsets := map[string]string{
- "small": "abc",
- "medium": "abcdefghijklmnopqrstuvwxyz",
- "large": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:,.<>?",
- "unicode": "αβγδεζηθικλμνξοπρστυφχψωABCDEFGHIJKLMNOPQRSTUVWXYZ",
- }
-
- for name, charset := range charsets {
- b.Run(name, func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = r.String(20, charset)
- }
- })
- }
-}
diff --git a/log/log_test.go b/log/log_test.go
deleted file mode 100644
index af696d19..00000000
--- a/log/log_test.go
+++ /dev/null
@@ -1,106 +0,0 @@
-package log
-
-import (
- "testing"
-
- "github.com/evilsocket/islazy/log"
-)
-
-var called bool
-var calledLevel log.Verbosity
-var calledFormat string
-var calledArgs []interface{}
-
-func mockLogger(level log.Verbosity, format string, args ...interface{}) {
- called = true
- calledLevel = level
- calledFormat = format
- calledArgs = args
-}
-
-func reset() {
- called = false
- calledLevel = log.DEBUG
- calledFormat = ""
- calledArgs = nil
-}
-
-func TestLoggerNil(t *testing.T) {
- reset()
- Logger = nil
-
- Debug("test")
- if called {
- t.Error("Debug should not call if Logger is nil")
- }
-
- Info("test")
- if called {
- t.Error("Info should not call if Logger is nil")
- }
-
- Warning("test")
- if called {
- t.Error("Warning should not call if Logger is nil")
- }
-
- Error("test")
- if called {
- t.Error("Error should not call if Logger is nil")
- }
-
- Fatal("test")
- if called {
- t.Error("Fatal should not call if Logger is nil")
- }
-}
-
-func TestDebug(t *testing.T) {
- reset()
- Logger = mockLogger
-
- Debug("test %d", 42)
- if !called || calledLevel != log.DEBUG || calledFormat != "test %d" || len(calledArgs) != 1 || calledArgs[0] != 42 {
- t.Errorf("Debug not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs)
- }
-}
-
-func TestInfo(t *testing.T) {
- reset()
- Logger = mockLogger
-
- Info("test %s", "info")
- if !called || calledLevel != log.INFO || calledFormat != "test %s" || len(calledArgs) != 1 || calledArgs[0] != "info" {
- t.Errorf("Info not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs)
- }
-}
-
-func TestWarning(t *testing.T) {
- reset()
- Logger = mockLogger
-
- Warning("test %f", 3.14)
- if !called || calledLevel != log.WARNING || calledFormat != "test %f" || len(calledArgs) != 1 || calledArgs[0] != 3.14 {
- t.Errorf("Warning not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs)
- }
-}
-
-func TestError(t *testing.T) {
- reset()
- Logger = mockLogger
-
- Error("test error")
- if !called || calledLevel != log.ERROR || calledFormat != "test error" || len(calledArgs) != 0 {
- t.Errorf("Error not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs)
- }
-}
-
-func TestFatal(t *testing.T) {
- reset()
- Logger = mockLogger
-
- Fatal("test fatal")
- if !called || calledLevel != log.FATAL || calledFormat != "test fatal" || len(calledArgs) != 0 {
- t.Errorf("Fatal not called correctly: level=%v format=%s args=%v", calledLevel, calledFormat, calledArgs)
- }
-}
diff --git a/main_test.go b/main_test.go
deleted file mode 100644
index 102788ae..00000000
--- a/main_test.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package main
-
-import (
- "bytes"
- "strings"
- "testing"
-)
-
-func TestExitPrompt(t *testing.T) {
- tests := []struct {
- name string
- input string
- expected bool
- }{
- {
- name: "yes lowercase",
- input: "y\n",
- expected: true,
- },
- {
- name: "yes uppercase",
- input: "Y\n",
- expected: true,
- },
- {
- name: "no lowercase",
- input: "n\n",
- expected: false,
- },
- {
- name: "no uppercase",
- input: "N\n",
- expected: false,
- },
- {
- name: "invalid input",
- input: "maybe\n",
- expected: false,
- },
- {
- name: "empty input",
- input: "\n",
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Redirect stdin
- oldStdin := strings.NewReader(tt.input)
- r := bytes.NewReader([]byte(tt.input))
-
- // Mock stdin by reading from our buffer
- // This is a simplified test - in production you'd want to properly mock stdin
- _ = oldStdin
- _ = r
-
- // For now, we'll test the string comparison logic directly
- input := strings.TrimSpace(strings.TrimSuffix(tt.input, "\n"))
- result := strings.ToLower(input) == "y"
-
- if result != tt.expected {
- t.Errorf("exitPrompt() with input %q = %v, want %v", tt.input, result, tt.expected)
- }
- })
- }
-}
-
-// Test some utility functions that would be refactored from main
-func TestVersionString(t *testing.T) {
- // This tests the version string formatting logic
- version := "2.32.0"
- os := "darwin"
- arch := "amd64"
- goVersion := "go1.19"
-
- expected := "bettercap v2.32.0 (built for darwin amd64 with go1.19)"
- result := formatVersion("bettercap", version, os, arch, goVersion)
-
- if result != expected {
- t.Errorf("formatVersion() = %v, want %v", result, expected)
- }
-}
-
-// Helper function that would be refactored from main
-func formatVersion(name, version, os, arch, goVersion string) string {
- return name + " v" + version + " (built for " + os + " " + arch + " with " + goVersion + ")"
-}
diff --git a/modules/any_proxy/any_proxy_test.go b/modules/any_proxy/any_proxy_test.go
deleted file mode 100644
index e5d28276..00000000
--- a/modules/any_proxy/any_proxy_test.go
+++ /dev/null
@@ -1,218 +0,0 @@
-package any_proxy
-
-import (
- "fmt"
- "strconv"
- "strings"
- "sync"
- "testing"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewAnyProxy(t *testing.T) {
- s := createMockSession(t)
- mod := NewAnyProxy(s)
-
- if mod == nil {
- t.Fatal("NewAnyProxy returned nil")
- }
-
- if mod.Name() != "any.proxy" {
- t.Errorf("Expected name 'any.proxy', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check handlers
- handlers := mod.Handlers()
- if len(handlers) != 2 {
- t.Errorf("Expected 2 handlers, got %d", len(handlers))
- }
-
- handlerNames := make(map[string]bool)
- for _, h := range handlers {
- handlerNames[h.Name] = true
- }
-
- if !handlerNames["any.proxy on"] {
- t.Error("Handler 'any.proxy on' not found")
- }
- if !handlerNames["any.proxy off"] {
- t.Error("Handler 'any.proxy off' not found")
- }
-
- // Check that parameters were added (but don't try to get values as that requires session interface)
- expectedParams := 6 // iface, protocol, src_port, src_address, dst_address, dst_port
- // This is a simplified check - in a real test we'd mock the interface
- _ = expectedParams
-}
-
-// Test port parsing logic directly
-func TestPortParsingLogic(t *testing.T) {
- tests := []struct {
- name string
- portString string
- expectPorts []int
- expectError bool
- }{
- {
- name: "single port",
- portString: "80",
- expectPorts: []int{80},
- expectError: false,
- },
- {
- name: "multiple ports",
- portString: "80,443,8080",
- expectPorts: []int{80, 443, 8080},
- expectError: false,
- },
- {
- name: "port range",
- portString: "8000-8003",
- expectPorts: []int{8000, 8001, 8002, 8003},
- expectError: false,
- },
- {
- name: "invalid port",
- portString: "not-a-port",
- expectPorts: nil,
- expectError: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ports, err := parsePortsString(tt.portString)
-
- if tt.expectError {
- if err == nil {
- t.Error("Expected error but got none")
- }
- } else {
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- } else {
- if len(ports) != len(tt.expectPorts) {
- t.Errorf("Expected %d ports, got %d", len(tt.expectPorts), len(ports))
- }
- }
- }
- })
- }
-}
-
-// Helper function to test port parsing logic
-func parsePortsString(portsStr string) ([]int, error) {
- var ports []int
- tokens := strings.Split(strings.ReplaceAll(portsStr, " ", ""), ",")
-
- for _, token := range tokens {
- if token == "" {
- continue
- }
-
- if p, err := strconv.Atoi(token); err == nil {
- if p < 1 || p > 65535 {
- return nil, fmt.Errorf("port %d out of range", p)
- }
- ports = append(ports, p)
- } else if strings.Contains(token, "-") {
- parts := strings.Split(token, "-")
- if len(parts) != 2 {
- return nil, fmt.Errorf("invalid range format")
- }
-
- from, err1 := strconv.Atoi(parts[0])
- to, err2 := strconv.Atoi(parts[1])
-
- if err1 != nil || err2 != nil {
- return nil, fmt.Errorf("invalid range values")
- }
-
- if from < 1 || from > 65535 || to < 1 || to > 65535 {
- return nil, fmt.Errorf("port range out of bounds")
- }
-
- if from > to {
- return nil, fmt.Errorf("invalid range order")
- }
-
- for p := from; p <= to; p++ {
- ports = append(ports, p)
- }
- } else {
- return nil, fmt.Errorf("invalid port format: %s", token)
- }
- }
-
- return ports, nil
-}
-
-func TestStartStop(t *testing.T) {
- s := createMockSession(t)
- mod := NewAnyProxy(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-
- // Note: Start() will fail because it requires firewall operations
- // which need proper network setup and possibly root permissions
- // We're just testing that the methods exist and basic flow
-}
-
-// Test error cases in port parsing
-func TestPortParsingErrors(t *testing.T) {
- errorCases := []string{
- "0", // out of range
- "65536", // out of range
- "abc", // not a number
- "80-", // incomplete range
- "-80", // incomplete range
- "100-50", // inverted range
- "80-abc", // invalid end
- "xyz-100", // invalid start
- "80--100", // malformed
- // Remove these as our parser handles empty tokens correctly
- }
-
- for _, portStr := range errorCases {
- _, err := parsePortsString(portStr)
- if err == nil {
- t.Errorf("Expected error for port string '%s', but got none", portStr)
- }
- }
-}
-
-// Benchmark tests
-func BenchmarkPortParsing(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- parsePortsString("80,443,8000-8010,9000")
- }
-}
diff --git a/modules/api_rest/api_rest.go b/modules/api_rest/api_rest.go
index b4590e18..b0c8a069 100644
--- a/modules/api_rest/api_rest.go
+++ b/modules/api_rest/api_rest.go
@@ -90,12 +90,12 @@ func NewRestAPI(s *session.Session) *RestAPI {
"Value of the Access-Control-Allow-Origin header of the API server."))
mod.AddParam(session.NewStringParameter("api.rest.username",
- "user",
+ "",
"",
"API authentication username."))
mod.AddParam(session.NewStringParameter("api.rest.password",
- "pass",
+ "",
"",
"API authentication password."))
diff --git a/modules/api_rest/api_rest_controller.go b/modules/api_rest/api_rest_controller.go
index ccf25cd1..e4e4261d 100644
--- a/modules/api_rest/api_rest_controller.go
+++ b/modules/api_rest/api_rest_controller.go
@@ -5,9 +5,9 @@ import (
"encoding/json"
"fmt"
"io"
+ "io/ioutil"
"net/http"
"os"
- "regexp"
"strconv"
"strings"
@@ -17,10 +17,6 @@ import (
"github.com/gorilla/mux"
)
-var (
- ansiEscapeRegex = regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`)
-)
-
type CommandRequest struct {
Command string `json:"cmd"`
}
@@ -240,8 +236,7 @@ func (mod *RestAPI) runSessionCommand(w http.ResponseWriter, r *http.Request) {
out, _ := io.ReadAll(stdoutReader)
os.Stdout = rescueStdout
- // remove ANSI escape sequences (bash color codes) from output
- mod.toJSON(w, APIResponse{Success: true, Message: ansiEscapeRegex.ReplaceAllString(string(out), "")})
+ mod.toJSON(w, APIResponse{Success: true, Message: string(out)})
}
func (mod *RestAPI) getEvents(limit int) []session.Event {
@@ -393,7 +388,7 @@ func (mod *RestAPI) readFile(fileName string, w http.ResponseWriter, r *http.Req
}
func (mod *RestAPI) writeFile(fileName string, w http.ResponseWriter, r *http.Request) {
- data, err := io.ReadAll(r.Body)
+ data, err := ioutil.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("invalid file upload: %s", err)
mod.Warning(msg)
@@ -401,7 +396,7 @@ func (mod *RestAPI) writeFile(fileName string, w http.ResponseWriter, r *http.Re
return
}
- err = os.WriteFile(fileName, data, 0666)
+ err = ioutil.WriteFile(fileName, data, 0666)
if err != nil {
msg := fmt.Sprintf("can't write to %s: %s", fileName, err)
mod.Warning(msg)
diff --git a/modules/api_rest/api_rest_test.go b/modules/api_rest/api_rest_test.go
deleted file mode 100644
index 820dfc8c..00000000
--- a/modules/api_rest/api_rest_test.go
+++ /dev/null
@@ -1,671 +0,0 @@
-package api_rest
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "sync"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewRestAPI(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- if mod == nil {
- t.Fatal("NewRestAPI returned nil")
- }
-
- if mod.Name() != "api.rest" {
- t.Errorf("Expected name 'api.rest', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{
- "api.rest on",
- "api.rest off",
- "api.rest.record off",
- "api.rest.record FILENAME",
- "api.rest.replay off",
- "api.rest.replay FILENAME",
- }
-
- if len(handlers) != len(expectedHandlers) {
- t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers))
- }
-
- handlerNames := make(map[string]bool)
- for _, h := range handlers {
- handlerNames[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerNames[expected] {
- t.Errorf("Handler '%s' not found", expected)
- }
- }
-
- // Check initial state
- if mod.recording {
- t.Error("Should not be recording initially")
- }
- if mod.replaying {
- t.Error("Should not be replaying initially")
- }
- if mod.useWebsocket {
- t.Error("Should not use websocket by default")
- }
- if mod.allowOrigin != "*" {
- t.Errorf("Expected default allowOrigin '*', got '%s'", mod.allowOrigin)
- }
-}
-
-func TestIsTLS(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Initially should not be TLS
- if mod.isTLS() {
- t.Error("Should not be TLS without cert and key")
- }
-
- // Set cert and key
- mod.certFile = "cert.pem"
- mod.keyFile = "key.pem"
-
- if !mod.isTLS() {
- t.Error("Should be TLS with cert and key")
- }
-
- // Only cert
- mod.certFile = "cert.pem"
- mod.keyFile = ""
-
- if mod.isTLS() {
- t.Error("Should not be TLS with only cert")
- }
-
- // Only key
- mod.certFile = ""
- mod.keyFile = "key.pem"
-
- if mod.isTLS() {
- t.Error("Should not be TLS with only key")
- }
-}
-
-func TestStateStore(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Check that state variables are properly stored
- stateKeys := []string{
- "recording",
- "rec_clock",
- "replaying",
- "loading",
- "load_progress",
- "rec_time",
- "rec_filename",
- "rec_frames",
- "rec_cur_frame",
- "rec_started",
- "rec_stopped",
- }
-
- for _, key := range stateKeys {
- val, exists := mod.State.Load(key)
- if !exists || val == nil {
- t.Errorf("State key '%s' not found", key)
- }
- }
-}
-
-func TestParameters(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Check that all parameters are registered
- paramNames := []string{
- "api.rest.address",
- "api.rest.port",
- "api.rest.alloworigin",
- "api.rest.username",
- "api.rest.password",
- "api.rest.certificate",
- "api.rest.key",
- "api.rest.websocket",
- "api.rest.record.clock",
- }
-
- // Parameters are stored in the session environment
- // We'll just check they can be accessed without error
- for _, param := range paramNames {
- // This is a simplified check
- _ = param
- }
-
- // Ensure mod is used
- if mod == nil {
- t.Error("Module should not be nil")
- }
-}
-
-func TestJSSessionStructs(t *testing.T) {
- // Test struct creation
- req := JSSessionRequest{
- Command: "test command",
- }
-
- if req.Command != "test command" {
- t.Errorf("Expected command 'test command', got '%s'", req.Command)
- }
-
- resp := JSSessionResponse{
- Error: "test error",
- }
-
- if resp.Error != "test error" {
- t.Errorf("Expected error 'test error', got '%s'", resp.Error)
- }
-}
-
-func TestDefaultValues(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Check default values
- if mod.recClock != 1 {
- t.Errorf("Expected default recClock 1, got %d", mod.recClock)
- }
-
- if mod.recTime != 0 {
- t.Errorf("Expected default recTime 0, got %d", mod.recTime)
- }
-
- if mod.recordFileName != "" {
- t.Errorf("Expected empty recordFileName, got '%s'", mod.recordFileName)
- }
-
- if mod.upgrader.ReadBufferSize != 1024 {
- t.Errorf("Expected ReadBufferSize 1024, got %d", mod.upgrader.ReadBufferSize)
- }
-
- if mod.upgrader.WriteBufferSize != 1024 {
- t.Errorf("Expected WriteBufferSize 1024, got %d", mod.upgrader.WriteBufferSize)
- }
-}
-
-func TestRunningState(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-
- // Note: Cannot test actual Start/Stop without proper server setup
-}
-
-func TestRecordingState(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test recording state changes
- mod.recording = true
- if !mod.recording {
- t.Error("Recording flag should be true")
- }
-
- mod.recording = false
- if mod.recording {
- t.Error("Recording flag should be false")
- }
-}
-
-func TestReplayingState(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test replaying state changes
- mod.replaying = true
- if !mod.replaying {
- t.Error("Replaying flag should be true")
- }
-
- mod.replaying = false
- if mod.replaying {
- t.Error("Replaying flag should be false")
- }
-}
-
-func TestConfigureErrors(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test configuration validation
- testCases := []struct {
- name string
- setup func()
- expected string
- }{
- {
- name: "invalid address",
- setup: func() {
- s.Env.Set("api.rest.address", "999.999.999.999")
- },
- expected: "address",
- },
- {
- name: "invalid port",
- setup: func() {
- s.Env.Set("api.rest.address", "127.0.0.1")
- s.Env.Set("api.rest.port", "not-a-port")
- },
- expected: "port",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- tc.setup()
- // Configure may fail due to parameter validation
- _ = mod.Configure()
- })
- }
-}
-
-func TestServerConfiguration(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Set valid parameters
- s.Env.Set("api.rest.address", "127.0.0.1")
- s.Env.Set("api.rest.port", "8081")
- s.Env.Set("api.rest.username", "testuser")
- s.Env.Set("api.rest.password", "testpass")
- s.Env.Set("api.rest.websocket", "true")
- s.Env.Set("api.rest.alloworigin", "http://localhost:3000")
-
- // This might fail due to TLS cert generation, but we're testing the flow
- _ = mod.Configure()
-
- // Check that values were set
- if mod.username != "" && mod.username != "testuser" {
- t.Logf("Username set to: %s", mod.username)
- }
- if mod.password != "" && mod.password != "testpass" {
- t.Logf("Password set to: %s", mod.password)
- }
-}
-
-func TestQuitChannel(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test quit channel is created
- if mod.quit == nil {
- t.Error("Quit channel should not be nil")
- }
-
- // Test sending to quit channel doesn't block
- done := make(chan bool)
- go func() {
- select {
- case mod.quit <- true:
- done <- true
- case <-time.After(100 * time.Millisecond):
- done <- false
- }
- }()
-
- // Start reading from quit channel
- go func() {
- <-mod.quit
- }()
-
- if !<-done {
- t.Error("Sending to quit channel timed out")
- }
-}
-
-func TestRecordWaitGroup(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test wait group is initialized
- if mod.recordWait == nil {
- t.Error("Record wait group should not be nil")
- }
-
- // Test wait group operations
- mod.recordWait.Add(1)
- done := make(chan bool)
-
- go func() {
- mod.recordWait.Done()
- done <- true
- }()
-
- go func() {
- mod.recordWait.Wait()
- }()
-
- select {
- case <-done:
- // Success
- case <-time.After(100 * time.Millisecond):
- t.Error("Wait group operation timed out")
- }
-}
-
-func TestStartErrors(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test start when replaying
- mod.replaying = true
- err := mod.Start()
- if err == nil {
- t.Error("Expected error when starting while replaying")
- }
-}
-
-func TestConfigureAlreadyRunning(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Simulate running state
- mod.SetRunning(true, func() {})
-
- err := mod.Configure()
- if err == nil {
- t.Error("Expected error when configuring while running")
- }
-
- // Reset
- mod.SetRunning(false, func() {})
-}
-
-func TestServerAddr(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Set parameters
- s.Env.Set("api.rest.address", "192.168.1.100")
- s.Env.Set("api.rest.port", "9090")
-
- // Configure may fail but we can check server addr format
- _ = mod.Configure()
-
- expectedAddr := "192.168.1.100:9090"
- if mod.server != nil && mod.server.Addr != "" && mod.server.Addr != expectedAddr {
- t.Logf("Server addr: %s", mod.server.Addr)
- }
-}
-
-func TestTLSConfiguration(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Test with TLS params
- s.Env.Set("api.rest.certificate", "/tmp/test.crt")
- s.Env.Set("api.rest.key", "/tmp/test.key")
-
- // Configure will attempt to expand paths and check files
- _ = mod.Configure()
-
- // Just verify the attempt was made
- t.Logf("Attempted TLS configuration")
-}
-
-// Benchmark tests
-func BenchmarkNewRestAPI(b *testing.B) {
- s, _ := session.New()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NewRestAPI(s)
- }
-}
-
-func BenchmarkIsTLS(b *testing.B) {
- s, _ := session.New()
- mod := NewRestAPI(s)
- mod.certFile = "cert.pem"
- mod.keyFile = "key.pem"
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = mod.isTLS()
- }
-}
-
-func BenchmarkConfigure(b *testing.B) {
- s, _ := session.New()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- mod := NewRestAPI(s)
- s.Env.Set("api.rest.address", "127.0.0.1")
- s.Env.Set("api.rest.port", "8081")
- _ = mod.Configure()
- }
-}
-
-// Tests for controller functionality
-func TestCommandRequest(t *testing.T) {
- cmd := CommandRequest{
- Command: "help",
- }
-
- if cmd.Command != "help" {
- t.Errorf("Expected command 'help', got '%s'", cmd.Command)
- }
-}
-
-func TestAPIResponse(t *testing.T) {
- resp := APIResponse{
- Success: true,
- Message: "Operation completed",
- }
-
- if !resp.Success {
- t.Error("Expected success to be true")
- }
-
- if resp.Message != "Operation completed" {
- t.Errorf("Expected message 'Operation completed', got '%s'", resp.Message)
- }
-}
-
-func TestCheckAuthNoCredentials(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // No username/password set - should allow access
- req, _ := http.NewRequest("GET", "/test", nil)
-
- if !mod.checkAuth(req) {
- t.Error("Expected auth to pass with no credentials set")
- }
-}
-
-func TestCheckAuthWithCredentials(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Set credentials
- mod.username = "testuser"
- mod.password = "testpass"
-
- // Test without auth header
- req1, _ := http.NewRequest("GET", "/test", nil)
- if mod.checkAuth(req1) {
- t.Error("Expected auth to fail without credentials")
- }
-
- // Test with wrong credentials
- req2, _ := http.NewRequest("GET", "/test", nil)
- req2.SetBasicAuth("wronguser", "wrongpass")
- if mod.checkAuth(req2) {
- t.Error("Expected auth to fail with wrong credentials")
- }
-
- // Test with correct credentials
- req3, _ := http.NewRequest("GET", "/test", nil)
- req3.SetBasicAuth("testuser", "testpass")
- if !mod.checkAuth(req3) {
- t.Error("Expected auth to pass with correct credentials")
- }
-}
-
-func TestGetEventsEmpty(t *testing.T) {
- // Skip this test if running with others due to shared session state
- if testing.Short() {
- t.Skip("Skipping in short mode due to shared session state")
- }
-
- // Create a fresh session using the singleton
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Record initial event count
- initialCount := len(mod.getEvents(0))
-
- // Get events - we can't guarantee zero events due to session initialization
- events := mod.getEvents(0)
- if len(events) < initialCount {
- t.Errorf("Event count should not decrease, got %d", len(events))
- }
-}
-
-func TestGetEventsWithLimit(t *testing.T) {
- // Create session using the singleton
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- // Record initial state
- initialEvents := mod.getEvents(0)
- initialCount := len(initialEvents)
-
- // Add some test events
- testEventCount := 10
- for i := 0; i < testEventCount; i++ {
- s.Events.Add(fmt.Sprintf("test.event.limit.%d", i), nil)
- }
-
- // Get all events
- allEvents := mod.getEvents(0)
- expectedTotal := initialCount + testEventCount
- if len(allEvents) != expectedTotal {
- t.Errorf("Expected %d total events, got %d", expectedTotal, len(allEvents))
- }
-
- // Test limit functionality - get last 5 events
- limitedEvents := mod.getEvents(5)
- if len(limitedEvents) != 5 {
- t.Errorf("Expected 5 events when limiting, got %d", len(limitedEvents))
- }
-}
-
-func TestSetSecurityHeaders(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
- mod.allowOrigin = "http://localhost:3000"
-
- w := httptest.NewRecorder()
- mod.setSecurityHeaders(w)
-
- headers := w.Header()
-
- // Check security headers
- if headers.Get("X-Frame-Options") != "DENY" {
- t.Error("X-Frame-Options header not set correctly")
- }
-
- if headers.Get("X-Content-Type-Options") != "nosniff" {
- t.Error("X-Content-Type-Options header not set correctly")
- }
-
- if headers.Get("X-XSS-Protection") != "1; mode=block" {
- t.Error("X-XSS-Protection header not set correctly")
- }
-
- if headers.Get("Access-Control-Allow-Origin") != "http://localhost:3000" {
- t.Error("Access-Control-Allow-Origin header not set correctly")
- }
-}
-
-func TestCorsRoute(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- req, _ := http.NewRequest("OPTIONS", "/test", nil)
- w := httptest.NewRecorder()
-
- mod.corsRoute(w, req)
-
- if w.Code != http.StatusNoContent {
- t.Errorf("Expected status %d, got %d", http.StatusNoContent, w.Code)
- }
-}
-
-func TestToJSON(t *testing.T) {
- s := createMockSession(t)
- mod := NewRestAPI(s)
-
- w := httptest.NewRecorder()
-
- testData := map[string]string{
- "key": "value",
- "foo": "bar",
- }
-
- mod.toJSON(w, testData)
-
- // Check content type
- if w.Header().Get("Content-Type") != "application/json" {
- t.Error("Content-Type header not set to application/json")
- }
-
- // Check JSON response
- var result map[string]string
- if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
- t.Errorf("Failed to decode JSON response: %v", err)
- }
-
- if result["key"] != "value" || result["foo"] != "bar" {
- t.Error("JSON response doesn't match expected data")
- }
-}
diff --git a/modules/arp_spoof/arp_spoof_test.go b/modules/arp_spoof/arp_spoof_test.go
deleted file mode 100644
index 36e2b4cd..00000000
--- a/modules/arp_spoof/arp_spoof_test.go
+++ /dev/null
@@ -1,785 +0,0 @@
-package arp_spoof
-
-import (
- "bytes"
- "fmt"
- "net"
- "sync"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/firewall"
- "github.com/bettercap/bettercap/v2/network"
- "github.com/bettercap/bettercap/v2/packets"
- "github.com/bettercap/bettercap/v2/session"
- "github.com/evilsocket/islazy/data"
-)
-
-// MockFirewall implements a mock firewall for testing
-type MockFirewall struct {
- forwardingEnabled bool
- redirections []firewall.Redirection
-}
-
-func NewMockFirewall() *MockFirewall {
- return &MockFirewall{
- forwardingEnabled: false,
- redirections: make([]firewall.Redirection, 0),
- }
-}
-
-func (m *MockFirewall) IsForwardingEnabled() bool {
- return m.forwardingEnabled
-}
-
-func (m *MockFirewall) EnableForwarding(enabled bool) error {
- m.forwardingEnabled = enabled
- return nil
-}
-
-func (m *MockFirewall) EnableRedirection(r *firewall.Redirection, enabled bool) error {
- if enabled {
- m.redirections = append(m.redirections, *r)
- } else {
- for i, red := range m.redirections {
- if red.String() == r.String() {
- m.redirections = append(m.redirections[:i], m.redirections[i+1:]...)
- break
- }
- }
- }
- return nil
-}
-
-func (m *MockFirewall) DisableRedirection(r *firewall.Redirection, enabled bool) error {
- return m.EnableRedirection(r, false)
-}
-
-func (m *MockFirewall) Restore() {
- m.redirections = make([]firewall.Redirection, 0)
- m.forwardingEnabled = false
-}
-
-// MockPacketQueue extends packets.Queue to capture sent packets
-type MockPacketQueue struct {
- *packets.Queue
- sync.Mutex
- sentPackets [][]byte
-}
-
-func NewMockPacketQueue() *MockPacketQueue {
- q := &packets.Queue{
- Traffic: sync.Map{},
- Stats: packets.Stats{},
- }
- return &MockPacketQueue{
- Queue: q,
- sentPackets: make([][]byte, 0),
- }
-}
-
-func (m *MockPacketQueue) Send(data []byte) error {
- m.Lock()
- defer m.Unlock()
-
- // Store a copy of the packet
- packet := make([]byte, len(data))
- copy(packet, data)
- m.sentPackets = append(m.sentPackets, packet)
-
- // Also update stats like the real queue would
- m.TrackSent(uint64(len(data)))
-
- return nil
-}
-
-func (m *MockPacketQueue) GetSentPackets() [][]byte {
- m.Lock()
- defer m.Unlock()
- return m.sentPackets
-}
-
-func (m *MockPacketQueue) ClearSentPackets() {
- m.Lock()
- defer m.Unlock()
- m.sentPackets = make([][]byte, 0)
-}
-
-// MockSession for testing
-type MockSession struct {
- *session.Session
- findMACResults map[string]net.HardwareAddr
- skipIPs map[string]bool
- mockQueue *MockPacketQueue
-}
-
-// Override session methods to use our mocks
-func setupMockSession(mockSess *MockSession) {
- // Replace the Session's FindMAC method behavior by manipulating the LAN
- // Since we can't override methods directly, we'll ensure the LAN has the data
- for ip, mac := range mockSess.findMACResults {
- mockSess.Lan.AddIfNew(ip, mac.String())
- }
-}
-
-func (m *MockSession) FindMAC(ip net.IP, probe bool) (net.HardwareAddr, error) {
- // First check our mock results
- if mac, ok := m.findMACResults[ip.String()]; ok {
- return mac, nil
- }
- // Then check the LAN
- if e, found := m.Lan.Get(ip.String()); found && e != nil {
- return e.HW, nil
- }
- return nil, fmt.Errorf("MAC not found for %s", ip.String())
-}
-
-func (m *MockSession) Skip(ip net.IP) bool {
- if m.skipIPs == nil {
- return false
- }
- return m.skipIPs[ip.String()]
-}
-
-// MockNetRecon implements a minimal net.recon module for testing
-type MockNetRecon struct {
- session.SessionModule
-}
-
-func NewMockNetRecon(s *session.Session) *MockNetRecon {
- mod := &MockNetRecon{
- SessionModule: session.NewSessionModule("net.recon", s),
- }
-
- // Add handlers
- mod.AddHandler(session.NewModuleHandler("net.recon on", "",
- "Start net.recon",
- func(args []string) error {
- return mod.Start()
- }))
-
- mod.AddHandler(session.NewModuleHandler("net.recon off", "",
- "Stop net.recon",
- func(args []string) error {
- return mod.Stop()
- }))
-
- return mod
-}
-
-func (m *MockNetRecon) Name() string {
- return "net.recon"
-}
-
-func (m *MockNetRecon) Description() string {
- return "Mock net.recon module"
-}
-
-func (m *MockNetRecon) Author() string {
- return "test"
-}
-
-func (m *MockNetRecon) Configure() error {
- return nil
-}
-
-func (m *MockNetRecon) Start() error {
- return m.SetRunning(true, nil)
-}
-
-func (m *MockNetRecon) Stop() error {
- return m.SetRunning(false, nil)
-}
-
-// Create a mock session for testing
-func createMockSession() (*MockSession, *MockPacketQueue, *MockFirewall) {
- // Create interface
- iface := &network.Endpoint{
- IpAddress: "192.168.1.100",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "eth0",
- }
- iface.SetIP("192.168.1.100")
- iface.SetBits(24)
-
- // Parse interface addresses
- ifaceIP := net.ParseIP("192.168.1.100")
- ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- iface.IP = ifaceIP
- iface.HW = ifaceHW
-
- // Create gateway
- gateway := &network.Endpoint{
- IpAddress: "192.168.1.1",
- HwAddress: "11:22:33:44:55:66",
- }
- gatewayIP := net.ParseIP("192.168.1.1")
- gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66")
- gateway.IP = gatewayIP
- gateway.HW = gatewayHW
-
- // Create mock queue and firewall
- mockQueue := NewMockPacketQueue()
- mockFirewall := NewMockFirewall()
-
- // Create environment
- env, _ := session.NewEnvironment("")
-
- // Create LAN
- aliases, _ := data.NewUnsortedKV("", 0)
- lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- // Create session
- sess := &session.Session{
- Interface: iface,
- Gateway: gateway,
- Lan: lan,
- StartedAt: time.Now(),
- Active: true,
- Env: env,
- Queue: mockQueue.Queue,
- Firewall: mockFirewall,
- Modules: make(session.ModuleList, 0),
- }
-
- // Initialize events
- sess.Events = session.NewEventPool(false, false)
-
- // Add mock net.recon module
- mockNetRecon := NewMockNetRecon(sess)
- sess.Modules = append(sess.Modules, mockNetRecon)
-
- // Create mock session wrapper
- mockSess := &MockSession{
- Session: sess,
- findMACResults: make(map[string]net.HardwareAddr),
- skipIPs: make(map[string]bool),
- mockQueue: mockQueue,
- }
-
- return mockSess, mockQueue, mockFirewall
-}
-
-func TestNewArpSpoofer(t *testing.T) {
- mockSess, _, _ := createMockSession()
-
- mod := NewArpSpoofer(mockSess.Session)
-
- if mod == nil {
- t.Fatal("NewArpSpoofer returned nil")
- }
-
- if mod.Name() != "arp.spoof" {
- t.Errorf("expected module name 'arp.spoof', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("unexpected author: %s", mod.Author())
- }
-
- // Check parameters
- params := []string{"arp.spoof.targets", "arp.spoof.whitelist", "arp.spoof.internal", "arp.spoof.fullduplex", "arp.spoof.skip_restore"}
- for _, param := range params {
- if !mod.Session.Env.Has(param) {
- t.Errorf("parameter %s not registered", param)
- }
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{"arp.spoof on", "arp.ban on", "arp.spoof off", "arp.ban off"}
- handlerMap := make(map[string]bool)
-
- for _, h := range handlers {
- handlerMap[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerMap[expected] {
- t.Errorf("Expected handler '%s' not found", expected)
- }
- }
-}
-
-func TestArpSpooferConfigure(t *testing.T) {
- tests := []struct {
- name string
- params map[string]string
- setupMock func(*MockSession)
- expectErr bool
- validate func(*ArpSpoofer) error
- }{
- {
- name: "default configuration",
- params: map[string]string{
- "arp.spoof.targets": "192.168.1.10",
- "arp.spoof.whitelist": "",
- "arp.spoof.internal": "false",
- "arp.spoof.fullduplex": "false",
- "arp.spoof.skip_restore": "false",
- },
- setupMock: func(ms *MockSession) {
- ms.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
- },
- expectErr: false,
- validate: func(mod *ArpSpoofer) error {
- if mod.internal {
- return fmt.Errorf("expected internal to be false")
- }
- if mod.fullDuplex {
- return fmt.Errorf("expected fullDuplex to be false")
- }
- if mod.skipRestore {
- return fmt.Errorf("expected skipRestore to be false")
- }
- if len(mod.addresses) != 1 {
- return fmt.Errorf("expected 1 address, got %d", len(mod.addresses))
- }
- return nil
- },
- },
- {
- name: "multiple targets and whitelist",
- params: map[string]string{
- "arp.spoof.targets": "192.168.1.10,192.168.1.20",
- "arp.spoof.whitelist": "192.168.1.30",
- "arp.spoof.internal": "true",
- "arp.spoof.fullduplex": "true",
- "arp.spoof.skip_restore": "true",
- },
- setupMock: func(ms *MockSession) {
- ms.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
- ms.Lan.AddIfNew("192.168.1.20", "bb:bb:bb:bb:bb:bb")
- ms.Lan.AddIfNew("192.168.1.30", "cc:cc:cc:cc:cc:cc")
- },
- expectErr: false,
- validate: func(mod *ArpSpoofer) error {
- if !mod.internal {
- return fmt.Errorf("expected internal to be true")
- }
- if !mod.fullDuplex {
- return fmt.Errorf("expected fullDuplex to be true")
- }
- if !mod.skipRestore {
- return fmt.Errorf("expected skipRestore to be true")
- }
- if len(mod.addresses) != 2 {
- return fmt.Errorf("expected 2 addresses, got %d", len(mod.addresses))
- }
- if len(mod.wAddresses) != 1 {
- return fmt.Errorf("expected 1 whitelisted address, got %d", len(mod.wAddresses))
- }
- return nil
- },
- },
- {
- name: "MAC address targets",
- params: map[string]string{
- "arp.spoof.targets": "aa:aa:aa:aa:aa:aa",
- "arp.spoof.whitelist": "",
- "arp.spoof.internal": "false",
- "arp.spoof.fullduplex": "false",
- "arp.spoof.skip_restore": "false",
- },
- setupMock: func(ms *MockSession) {
- ms.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
- },
- expectErr: false,
- validate: func(mod *ArpSpoofer) error {
- if len(mod.macs) != 1 {
- return fmt.Errorf("expected 1 MAC address, got %d", len(mod.macs))
- }
- return nil
- },
- },
- {
- name: "invalid target",
- params: map[string]string{
- "arp.spoof.targets": "invalid-target",
- "arp.spoof.whitelist": "",
- "arp.spoof.internal": "false",
- "arp.spoof.fullduplex": "false",
- "arp.spoof.skip_restore": "false",
- },
- expectErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Set parameters
- for k, v := range tt.params {
- mockSess.Env.Set(k, v)
- }
-
- // Setup mock
- if tt.setupMock != nil {
- tt.setupMock(mockSess)
- }
-
- err := mod.Configure()
-
- if tt.expectErr && err == nil {
- t.Error("expected error but got none")
- } else if !tt.expectErr && err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- if !tt.expectErr && tt.validate != nil {
- if err := tt.validate(mod); err != nil {
- t.Error(err)
- }
- }
- })
- }
-}
-
-func TestArpSpooferStartStop(t *testing.T) {
- mockSess, _, mockFirewall := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Setup targets
- targetIP := "192.168.1.10"
- targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
- mockSess.Lan.AddIfNew(targetIP, targetMAC.String())
- mockSess.findMACResults[targetIP] = targetMAC
-
- // Configure
- mockSess.Env.Set("arp.spoof.targets", targetIP)
- mockSess.Env.Set("arp.spoof.fullduplex", "false")
- mockSess.Env.Set("arp.spoof.internal", "false")
-
- // Start the spoofer
- err := mod.Start()
- if err != nil {
- t.Fatalf("Failed to start spoofer: %v", err)
- }
-
- if !mod.Running() {
- t.Error("Spoofer should be running after Start()")
- }
-
- // Check that forwarding was enabled
- if !mockFirewall.IsForwardingEnabled() {
- t.Error("Forwarding should be enabled after starting spoofer")
- }
-
- // Let it run for a bit
- time.Sleep(100 * time.Millisecond)
-
- // Stop the spoofer
- err = mod.Stop()
- if err != nil {
- t.Fatalf("Failed to stop spoofer: %v", err)
- }
-
- if mod.Running() {
- t.Error("Spoofer should not be running after Stop()")
- }
-
- // Note: We can't easily verify packet sending without modifying the actual module
- // to use an interface for the queue. The module behavior is verified through
- // state changes (running state, forwarding enabled, etc.)
-}
-
-func TestArpSpooferBanMode(t *testing.T) {
- mockSess, _, mockFirewall := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Setup targets
- targetIP := "192.168.1.10"
- targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
- mockSess.Lan.AddIfNew(targetIP, targetMAC.String())
- mockSess.findMACResults[targetIP] = targetMAC
-
- // Configure
- mockSess.Env.Set("arp.spoof.targets", targetIP)
-
- // Find and execute the ban handler
- handlers := mod.Handlers()
- for _, h := range handlers {
- if h.Name == "arp.ban on" {
- err := h.Exec([]string{})
- if err != nil {
- t.Fatalf("Failed to start ban mode: %v", err)
- }
- break
- }
- }
-
- if !mod.ban {
- t.Error("Ban mode should be enabled")
- }
-
- // Check that forwarding was NOT enabled
- if mockFirewall.IsForwardingEnabled() {
- t.Error("Forwarding should NOT be enabled in ban mode")
- }
-
- // Let it run for a bit
- time.Sleep(100 * time.Millisecond)
-
- // Stop using ban off handler
- for _, h := range handlers {
- if h.Name == "arp.ban off" {
- err := h.Exec([]string{})
- if err != nil {
- t.Fatalf("Failed to stop ban mode: %v", err)
- }
- break
- }
- }
-
- if mod.ban {
- t.Error("Ban mode should be disabled after stop")
- }
-}
-
-func TestArpSpooferWhitelisting(t *testing.T) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Add some IPs and MACs to whitelist
- whitelistIP := net.ParseIP("192.168.1.50")
- whitelistMAC, _ := net.ParseMAC("ff:ff:ff:ff:ff:ff")
-
- mod.wAddresses = []net.IP{whitelistIP}
- mod.wMacs = []net.HardwareAddr{whitelistMAC}
-
- // Test IP whitelisting
- if !mod.isWhitelisted("192.168.1.50", nil) {
- t.Error("IP should be whitelisted")
- }
-
- if mod.isWhitelisted("192.168.1.60", nil) {
- t.Error("IP should not be whitelisted")
- }
-
- // Test MAC whitelisting
- if !mod.isWhitelisted("", whitelistMAC) {
- t.Error("MAC should be whitelisted")
- }
-
- otherMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
- if mod.isWhitelisted("", otherMAC) {
- t.Error("MAC should not be whitelisted")
- }
-}
-
-func TestArpSpooferFullDuplex(t *testing.T) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Setup targets
- targetIP := "192.168.1.10"
- targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
- mockSess.Lan.AddIfNew(targetIP, targetMAC.String())
- mockSess.findMACResults[targetIP] = targetMAC
-
- // Configure with full duplex
- mockSess.Env.Set("arp.spoof.targets", targetIP)
- mockSess.Env.Set("arp.spoof.fullduplex", "true")
-
- // Verify configuration
- err := mod.Configure()
- if err != nil {
- t.Fatalf("Failed to configure: %v", err)
- }
-
- if !mod.fullDuplex {
- t.Error("Full duplex mode should be enabled")
- }
-
- // Start the spoofer
- err = mod.Start()
- if err != nil {
- t.Fatalf("Failed to start spoofer: %v", err)
- }
-
- if !mod.Running() {
- t.Error("Module should be running")
- }
-
- // Let it run for a bit
- time.Sleep(150 * time.Millisecond)
-
- // Stop
- mod.Stop()
-}
-
-func TestArpSpooferInternalMode(t *testing.T) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Setup multiple targets
- targets := map[string]string{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- "192.168.1.20": "bb:bb:bb:bb:bb:bb",
- "192.168.1.30": "cc:cc:cc:cc:cc:cc",
- }
-
- for ip, mac := range targets {
- mockSess.Lan.AddIfNew(ip, mac)
- hwAddr, _ := net.ParseMAC(mac)
- mockSess.findMACResults[ip] = hwAddr
- }
-
- // Configure with internal mode
- mockSess.Env.Set("arp.spoof.targets", "192.168.1.10,192.168.1.20")
- mockSess.Env.Set("arp.spoof.internal", "true")
-
- // Verify configuration
- err := mod.Configure()
- if err != nil {
- t.Fatalf("Failed to configure: %v", err)
- }
-
- if !mod.internal {
- t.Error("Internal mode should be enabled")
- }
-
- // Start the spoofer
- err = mod.Start()
- if err != nil {
- t.Fatalf("Failed to start spoofer: %v", err)
- }
-
- if !mod.Running() {
- t.Error("Module should be running")
- }
-
- // Let it run briefly
- time.Sleep(100 * time.Millisecond)
-
- // Stop
- mod.Stop()
-}
-
-func TestArpSpooferGetTargets(t *testing.T) {
- // This test verifies the getTargets logic without actually calling it
- // since the method uses Session.FindMAC which can't be easily mocked
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Test address and MAC parsing
- targetIP := net.ParseIP("192.168.1.10")
- targetMAC, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
-
- // Add targets by IP
- mod.addresses = []net.IP{targetIP}
-
- // Verify addresses were set correctly
- if len(mod.addresses) != 1 {
- t.Errorf("expected 1 address, got %d", len(mod.addresses))
- }
-
- if !mod.addresses[0].Equal(targetIP) {
- t.Errorf("expected address %s, got %s", targetIP, mod.addresses[0])
- }
-
- // Add targets by MAC
- mod.macs = []net.HardwareAddr{targetMAC}
-
- // Verify MACs were set correctly
- if len(mod.macs) != 1 {
- t.Errorf("expected 1 MAC, got %d", len(mod.macs))
- }
-
- if !bytes.Equal(mod.macs[0], targetMAC) {
- t.Errorf("expected MAC %s, got %s", targetMAC, mod.macs[0])
- }
-
- // Note: The actual getTargets method would look up these addresses/MACs
- // in the network, but we can't easily test that without refactoring
- // the module to use dependency injection for network operations
-}
-
-func TestArpSpooferSkipRestore(t *testing.T) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // The skip_restore parameter is set up with an observer in NewArpSpoofer
- // We'll test it by changing the parameter value, which triggers the observer
- mockSess.Env.Set("arp.spoof.skip_restore", "true")
-
- // Configure to trigger parameter reading
- mod.Configure()
-
- // Check the observer worked by checking if skipRestore was set
- // Note: The actual observer is triggered during module creation
- // so we test the functionality indirectly through the module's behavior
-
- // Start and stop to see if restoration is skipped
- mockSess.Env.Set("arp.spoof.targets", "192.168.1.10")
- mockSess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
-
- mod.Start()
- time.Sleep(50 * time.Millisecond)
- mod.Stop()
-
- // With skip_restore true, the module should have skipRestore set
- // We can't directly test the observer, but we verify the behavior
-}
-
-func TestArpSpooferEmptyTargets(t *testing.T) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Configure with empty targets
- mockSess.Env.Set("arp.spoof.targets", "")
-
- // Start should not error but should not actually start
- err := mod.Start()
- if err != nil {
- t.Fatalf("Start with empty targets should not error: %v", err)
- }
-
- // Module should not be running
- if mod.Running() {
- t.Error("Module should not be running with empty targets")
- }
-}
-
-// Benchmarks
-func BenchmarkArpSpooferGetTargets(b *testing.B) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Setup targets
- for i := 0; i < 10; i++ {
- ip := fmt.Sprintf("192.168.1.%d", i+10)
- mac := fmt.Sprintf("aa:bb:cc:dd:ee:%02x", i)
- mockSess.Lan.AddIfNew(ip, mac)
- hwAddr, _ := net.ParseMAC(mac)
- mockSess.findMACResults[ip] = hwAddr
- mod.addresses = append(mod.addresses, net.ParseIP(ip))
- }
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- _ = mod.getTargets(false)
- }
-}
-
-func BenchmarkArpSpooferWhitelisting(b *testing.B) {
- mockSess, _, _ := createMockSession()
- mod := NewArpSpoofer(mockSess.Session)
-
- // Add many whitelist entries
- for i := 0; i < 100; i++ {
- ip := net.ParseIP(fmt.Sprintf("192.168.1.%d", i))
- mod.wAddresses = append(mod.wAddresses, ip)
- }
-
- testMAC, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- _ = mod.isWhitelisted("192.168.1.50", testMAC)
- }
-}
diff --git a/modules/ble/ble_recon_test.go b/modules/ble/ble_recon_test.go
deleted file mode 100644
index 08fc17cf..00000000
--- a/modules/ble/ble_recon_test.go
+++ /dev/null
@@ -1,321 +0,0 @@
-//go:build !windows && !freebsd && !openbsd && !netbsd
-// +build !windows,!freebsd,!openbsd,!netbsd
-
-package ble
-
-import (
- "sync"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewBLERecon(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- if mod == nil {
- t.Fatal("NewBLERecon returned nil")
- }
-
- if mod.Name() != "ble.recon" {
- t.Errorf("Expected name 'ble.recon', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check initial values
- if mod.deviceId != -1 {
- t.Errorf("Expected deviceId -1, got %d", mod.deviceId)
- }
-
- if mod.connected {
- t.Error("Should not be connected initially")
- }
-
- if mod.connTimeout != 5 {
- t.Errorf("Expected connection timeout 5, got %d", mod.connTimeout)
- }
-
- if mod.devTTL != 30 {
- t.Errorf("Expected device TTL 30, got %d", mod.devTTL)
- }
-
- // Check channels
- if mod.quit == nil {
- t.Error("Quit channel should not be nil")
- }
-
- if mod.done == nil {
- t.Error("Done channel should not be nil")
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{
- "ble.recon on",
- "ble.recon off",
- "ble.clear",
- "ble.show",
- "ble.enum MAC",
- "ble.write MAC UUID HEX_DATA",
- }
-
- if len(handlers) != len(expectedHandlers) {
- t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers))
- }
-
- handlerNames := make(map[string]bool)
- for _, h := range handlers {
- handlerNames[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerNames[expected] {
- t.Errorf("Handler '%s' not found", expected)
- }
- }
-}
-
-func TestIsEnumerating(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Initially should not be enumerating
- if mod.isEnumerating() {
- t.Error("Should not be enumerating initially")
- }
-
- // When currDevice is set, should be enumerating
- // We can't create a real BLE device here, but we can test the logic
-}
-
-func TestDummyWriter(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- writer := dummyWriter{mod}
- testData := []byte("test log message")
-
- n, err := writer.Write(testData)
- if err != nil {
- t.Errorf("Expected no error, got %v", err)
- }
-
- if n != len(testData) {
- t.Errorf("Expected to write %d bytes, wrote %d", len(testData), n)
- }
-}
-
-func TestParameters(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Check that parameters are registered
- paramNames := []string{
- "ble.device",
- "ble.timeout",
- "ble.ttl",
- }
-
- // Parameters are stored in the session environment
- // We'll just ensure the module was created properly
- for _, param := range paramNames {
- // This is a simplified check
- _ = param
- }
-
- if mod == nil {
- t.Error("Module should not be nil")
- }
-}
-
-func TestRunningState(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-
- // Note: Cannot test actual Start/Stop without BLE hardware
-}
-
-func TestChannels(t *testing.T) {
- // Skip this test as channel operations might hang in certain environments
- t.Skip("Skipping channel test to prevent potential hangs")
-}
-
-func TestClearHandler(t *testing.T) {
- // Skip this test as it requires BLE to be initialized in the session
- t.Skip("Skipping clear handler test - requires initialized BLE in session")
-}
-
-func TestBLEPrompt(t *testing.T) {
- expected := "{blb}{fw}BLE {fb}{reset} {bold}» {reset}"
- if blePrompt != expected {
- t.Errorf("Expected prompt '%s', got '%s'", expected, blePrompt)
- }
-}
-
-func TestSetCurrentDevice(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Test setting nil device
- mod.setCurrentDevice(nil)
- if mod.currDevice != nil {
- t.Error("Current device should be nil")
- }
- if mod.connected {
- t.Error("Should not be connected after setting nil device")
- }
-}
-
-func TestViewSelector(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Check that view selector is initialized
- if mod.selector == nil {
- t.Error("View selector should not be nil")
- }
-}
-
-func TestBLEAliveInterval(t *testing.T) {
- expected := time.Duration(5) * time.Second
- if bleAliveInterval != expected {
- t.Errorf("Expected alive interval %v, got %v", expected, bleAliveInterval)
- }
-}
-
-func TestColNames(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Test without name
- cols := mod.colNames(false)
- expectedCols := []string{"RSSI", "MAC", "Vendor", "Flags", "Connect", "Seen"}
- if len(cols) != len(expectedCols) {
- t.Errorf("Expected %d columns, got %d", len(expectedCols), len(cols))
- }
-
- // Test with name
- colsWithName := mod.colNames(true)
- expectedColsWithName := []string{"RSSI", "MAC", "Name", "Vendor", "Flags", "Connect", "Seen"}
- if len(colsWithName) != len(expectedColsWithName) {
- t.Errorf("Expected %d columns with name, got %d", len(expectedColsWithName), len(colsWithName))
- }
-}
-
-func TestDoFilter(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // Without expression, should always return true
- result := mod.doFilter(nil)
- if !result {
- t.Error("doFilter should return true when no expression is set")
- }
-}
-
-func TestShow(t *testing.T) {
- // Skip this test as it requires BLE to be initialized in the session
- t.Skip("Skipping show test - requires initialized BLE in session")
-}
-
-func TestConfigure(t *testing.T) {
- // Skip this test as it may hang trying to access BLE hardware
- t.Skip("Skipping configure test - may hang accessing BLE hardware")
-}
-
-func TestGetRow(t *testing.T) {
- s := createMockSession(t)
- mod := NewBLERecon(s)
-
- // We can't create a real BLE device without hardware, but we can test the logic
- // by ensuring the method exists and would handle nil gracefully
- _ = mod
-}
-
-func TestDoSelection(t *testing.T) {
- // Skip this test as it requires BLE to be initialized in the session
- t.Skip("Skipping doSelection test - requires initialized BLE in session")
-}
-
-func TestWriteBuffer(t *testing.T) {
- // Skip this test as it may hang trying to access BLE hardware
- t.Skip("Skipping writeBuffer test - may hang accessing BLE hardware")
-}
-
-func TestEnumAllTheThings(t *testing.T) {
- // Skip this test as it may hang trying to access BLE hardware
- t.Skip("Skipping enumAllTheThings test - may hang accessing BLE hardware")
-}
-
-// Benchmark tests - using singleton session to avoid flag redefinition
-func BenchmarkNewBLERecon(b *testing.B) {
- // Use a test instance to get singleton session
- s := createMockSession(&testing.T{})
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NewBLERecon(s)
- }
-}
-
-func BenchmarkIsEnumerating(b *testing.B) {
- s := createMockSession(&testing.T{})
- mod := NewBLERecon(s)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = mod.isEnumerating()
- }
-}
-
-func BenchmarkDummyWriter(b *testing.B) {
- s := createMockSession(&testing.T{})
- mod := NewBLERecon(s)
- writer := dummyWriter{mod}
- testData := []byte("benchmark log message")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- writer.Write(testData)
- }
-}
-
-func BenchmarkDoFilter(b *testing.B) {
- s := createMockSession(&testing.T{})
- mod := NewBLERecon(s)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- mod.doFilter(nil)
- }
-}
diff --git a/modules/c2/c2_test.go b/modules/c2/c2_test.go
deleted file mode 100644
index fcdbd4ff..00000000
--- a/modules/c2/c2_test.go
+++ /dev/null
@@ -1,356 +0,0 @@
-package c2
-
-import (
- "sync"
- "testing"
- "text/template"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewC2(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- if mod == nil {
- t.Fatal("NewC2 returned nil")
- }
-
- if mod.Name() != "c2" {
- t.Errorf("Expected name 'c2', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check default settings
- if mod.settings.server != "localhost:6697" {
- t.Errorf("Expected default server 'localhost:6697', got '%s'", mod.settings.server)
- }
-
- if !mod.settings.tls {
- t.Error("Expected TLS to be enabled by default")
- }
-
- if mod.settings.tlsVerify {
- t.Error("Expected TLS verify to be disabled by default")
- }
-
- if mod.settings.nick != "bettercap" {
- t.Errorf("Expected default nick 'bettercap', got '%s'", mod.settings.nick)
- }
-
- if mod.settings.user != "bettercap" {
- t.Errorf("Expected default user 'bettercap', got '%s'", mod.settings.user)
- }
-
- if mod.settings.operator != "admin" {
- t.Errorf("Expected default operator 'admin', got '%s'", mod.settings.operator)
- }
-
- // Check channels
- if mod.quit == nil {
- t.Error("Quit channel should not be nil")
- }
-
- // Check maps
- if mod.templates == nil {
- t.Error("Templates map should not be nil")
- }
-
- if mod.channels == nil {
- t.Error("Channels map should not be nil")
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{
- "c2 on",
- "c2 off",
- "c2.channel.set EVENT_TYPE CHANNEL",
- "c2.channel.clear EVENT_TYPE",
- "c2.template.set EVENT_TYPE TEMPLATE",
- "c2.template.clear EVENT_TYPE",
- }
-
- if len(handlers) != len(expectedHandlers) {
- t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers))
- }
-
- handlerNames := make(map[string]bool)
- for _, h := range handlers {
- handlerNames[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerNames[expected] {
- t.Errorf("Handler '%s' not found", expected)
- }
- }
-}
-
-func TestDefaultSettings(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- // Check default channel settings
- if mod.settings.eventsChannel != "#events" {
- t.Errorf("Expected default events channel '#events', got '%s'", mod.settings.eventsChannel)
- }
-
- if mod.settings.outputChannel != "#events" {
- t.Errorf("Expected default output channel '#events', got '%s'", mod.settings.outputChannel)
- }
-
- if mod.settings.controlChannel != "#events" {
- t.Errorf("Expected default control channel '#events', got '%s'", mod.settings.controlChannel)
- }
-
- if mod.settings.password != "password" {
- t.Errorf("Expected default password 'password', got '%s'", mod.settings.password)
- }
-}
-
-func TestRunningState(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-
- // Note: Cannot test actual Start/Stop without IRC server
-}
-
-func TestEventContext(t *testing.T) {
- s := createMockSession(t)
-
- ctx := eventContext{
- Session: s,
- Event: session.Event{Tag: "test.event"},
- }
-
- if ctx.Session == nil {
- t.Error("Session should not be nil")
- }
-
- if ctx.Event.Tag != "test.event" {
- t.Errorf("Expected event tag 'test.event', got '%s'", ctx.Event.Tag)
- }
-}
-
-func TestChannelHandlers(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- // Test channel.set handler
- for _, h := range mod.Handlers() {
- if h.Name == "c2.channel.set EVENT_TYPE CHANNEL" {
- err := h.Exec([]string{"test.event", "#test"})
- if err != nil {
- t.Errorf("channel.set handler failed: %v", err)
- }
-
- // Verify channel was set
- if channel, found := mod.channels["test.event"]; !found {
- t.Error("Channel was not set")
- } else if channel != "#test" {
- t.Errorf("Expected channel '#test', got '%s'", channel)
- }
- break
- }
- }
-
- // Test channel.clear handler
- for _, h := range mod.Handlers() {
- if h.Name == "c2.channel.clear EVENT_TYPE" {
- err := h.Exec([]string{"test.event"})
- if err != nil {
- t.Errorf("channel.clear handler failed: %v", err)
- }
-
- // Verify channel was cleared
- if _, found := mod.channels["test.event"]; found {
- t.Error("Channel was not cleared")
- }
- break
- }
- }
-}
-
-func TestTemplateHandlers(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- // Test template.set handler
- for _, h := range mod.Handlers() {
- if h.Name == "c2.template.set EVENT_TYPE TEMPLATE" {
- err := h.Exec([]string{"test.event", "Event: {{.Event.Tag}}"})
- if err != nil {
- t.Errorf("template.set handler failed: %v", err)
- }
-
- // Verify template was set
- if tpl, found := mod.templates["test.event"]; !found {
- t.Error("Template was not set")
- } else if tpl == nil {
- t.Error("Template is nil")
- }
- break
- }
- }
-
- // Test template.clear handler
- for _, h := range mod.Handlers() {
- if h.Name == "c2.template.clear EVENT_TYPE" {
- err := h.Exec([]string{"test.event"})
- if err != nil {
- t.Errorf("template.clear handler failed: %v", err)
- }
-
- // Verify template was cleared
- if _, found := mod.templates["test.event"]; found {
- t.Error("Template was not cleared")
- }
- break
- }
- }
-}
-
-func TestClearNonExistent(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- // Test clearing non-existent channel
- for _, h := range mod.Handlers() {
- if h.Name == "c2.channel.clear EVENT_TYPE" {
- err := h.Exec([]string{"non.existent"})
- if err == nil {
- t.Error("Expected error when clearing non-existent channel")
- }
- break
- }
- }
-
- // Test clearing non-existent template
- for _, h := range mod.Handlers() {
- if h.Name == "c2.template.clear EVENT_TYPE" {
- err := h.Exec([]string{"non.existent"})
- if err == nil {
- t.Error("Expected error when clearing non-existent template")
- }
- break
- }
- }
-}
-
-func TestParameters(t *testing.T) {
- s := createMockSession(t)
- mod := NewC2(s)
-
- // Check that all parameters are registered
- paramNames := []string{
- "c2.server",
- "c2.server.tls",
- "c2.server.tls.verify",
- "c2.operator",
- "c2.nick",
- "c2.username",
- "c2.password",
- "c2.sasl.username",
- "c2.sasl.password",
- "c2.channel.output",
- "c2.channel.events",
- "c2.channel.control",
- }
-
- // Parameters are stored in the session environment
- for _, param := range paramNames {
- // This is a simplified check
- _ = param
- }
-
- if mod == nil {
- t.Error("Module should not be nil")
- }
-}
-
-func TestTemplateExecution(t *testing.T) {
- // Test template parsing and execution
- tmpl, err := template.New("test").Parse("Event: {{.Event.Tag}}")
- if err != nil {
- t.Errorf("Failed to parse template: %v", err)
- }
-
- if tmpl == nil {
- t.Error("Template should not be nil")
- }
-}
-
-// Benchmark tests
-func BenchmarkNewC2(b *testing.B) {
- s, _ := session.New()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NewC2(s)
- }
-}
-
-func BenchmarkChannelSet(b *testing.B) {
- s, _ := session.New()
- mod := NewC2(s)
-
- var handler *session.ModuleHandler
- for _, h := range mod.Handlers() {
- if h.Name == "c2.channel.set EVENT_TYPE CHANNEL" {
- handler = &h
- break
- }
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- handler.Exec([]string{"test.event", "#test"})
- }
-}
-
-func BenchmarkTemplateSet(b *testing.B) {
- s, _ := session.New()
- mod := NewC2(s)
-
- var handler *session.ModuleHandler
- for _, h := range mod.Handlers() {
- if h.Name == "c2.template.set EVENT_TYPE TEMPLATE" {
- handler = &h
- break
- }
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- handler.Exec([]string{"test.event", "Event: {{.Event.Tag}}"})
- }
-}
diff --git a/modules/can/can_test.go b/modules/can/can_test.go
deleted file mode 100644
index e5d27ad7..00000000
--- a/modules/can/can_test.go
+++ /dev/null
@@ -1,407 +0,0 @@
-package can
-
-import (
- "sync"
- "testing"
-
- "github.com/bettercap/bettercap/v2/session"
- "go.einride.tech/can"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewCanModule(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- if mod == nil {
- t.Fatal("NewCanModule returned nil")
- }
-
- if mod.Name() != "can" {
- t.Errorf("Expected name 'can', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check default values
- if mod.transport != "can" {
- t.Errorf("Expected default transport 'can', got '%s'", mod.transport)
- }
-
- if mod.deviceName != "can0" {
- t.Errorf("Expected default device 'can0', got '%s'", mod.deviceName)
- }
-
- if mod.dumpName != "" {
- t.Errorf("Expected empty dumpName, got '%s'", mod.dumpName)
- }
-
- if mod.dumpInject {
- t.Error("Expected dumpInject to be false by default")
- }
-
- if mod.filter != "" {
- t.Errorf("Expected empty filter, got '%s'", mod.filter)
- }
-
- // Check DBC and OBD2
- if mod.dbc == nil {
- t.Error("DBC should not be nil")
- }
-
- if mod.obd2 == nil {
- t.Error("OBD2 should not be nil")
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{
- "can.recon on",
- "can.recon off",
- "can.clear",
- "can.show",
- "can.dbc.load NAME",
- "can.inject FRAME_EXPRESSION",
- "can.fuzz ID_OR_NODE_NAME OPTIONAL_SIZE",
- }
-
- if len(handlers) != len(expectedHandlers) {
- t.Errorf("Expected %d handlers, got %d", len(expectedHandlers), len(handlers))
- }
-
- handlerNames := make(map[string]bool)
- for _, h := range handlers {
- handlerNames[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerNames[expected] {
- t.Errorf("Handler '%s' not found", expected)
- }
- }
-}
-
-func TestRunningState(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-
- // Note: Cannot test actual Start/Stop without CAN hardware
-}
-
-func TestClearHandler(t *testing.T) {
- // Skip this test as it requires CAN to be initialized in the session
- t.Skip("Skipping clear handler test - requires initialized CAN in session")
-}
-
-func TestInjectNotRunning(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Test inject when not running
- handlers := mod.Handlers()
- for _, h := range handlers {
- if h.Name == "can.inject FRAME_EXPRESSION" {
- err := h.Exec([]string{"123#deadbeef"})
- if err == nil {
- t.Error("Expected error when injecting while not running")
- }
- break
- }
- }
-}
-
-func TestFuzzNotRunning(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Test fuzz when not running
- handlers := mod.Handlers()
- for _, h := range handlers {
- if h.Name == "can.fuzz ID_OR_NODE_NAME OPTIONAL_SIZE" {
- err := h.Exec([]string{"123", ""})
- if err == nil {
- t.Error("Expected error when fuzzing while not running")
- }
- break
- }
- }
-}
-
-func TestParameters(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Check that all parameters are registered
- paramNames := []string{
- "can.device",
- "can.dump",
- "can.dump.inject",
- "can.transport",
- "can.filter",
- "can.parse.obd2",
- }
-
- // Parameters are stored in the session environment
- for _, param := range paramNames {
- // This is a simplified check
- _ = param
- }
-
- if mod == nil {
- t.Error("Module should not be nil")
- }
-}
-
-func TestDBC(t *testing.T) {
- dbc := &DBC{}
- if dbc == nil {
- t.Error("DBC should not be nil")
- }
-}
-
-func TestOBD2(t *testing.T) {
- obd2 := &OBD2{}
- if obd2 == nil {
- t.Error("OBD2 should not be nil")
- }
-}
-
-func TestShowHandler(t *testing.T) {
- // Skip this test as it requires CAN to be initialized in the session
- t.Skip("Skipping show handler test - requires initialized CAN in session")
-}
-
-func TestDefaultTransport(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- if mod.transport != "can" {
- t.Errorf("Expected transport 'can', got '%s'", mod.transport)
- }
-}
-
-func TestDefaultDevice(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- if mod.deviceName != "can0" {
- t.Errorf("Expected device 'can0', got '%s'", mod.deviceName)
- }
-}
-
-func TestFilterExpression(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Initially filter should be empty
- if mod.filter != "" {
- t.Errorf("Expected empty filter, got '%s'", mod.filter)
- }
-
- // filterExpr should be nil initially
- if mod.filterExpr != nil {
- t.Error("Expected filterExpr to be nil initially")
- }
-}
-
-func TestDBCStruct(t *testing.T) {
- // Test DBC struct initialization
- dbc := &DBC{}
- if dbc == nil {
- t.Error("DBC should not be nil")
- }
-}
-
-func TestOBD2Struct(t *testing.T) {
- // Test OBD2 struct initialization
- obd2 := &OBD2{}
- if obd2 == nil {
- t.Error("OBD2 should not be nil")
- }
-}
-
-func TestCANMessage(t *testing.T) {
- // Test CAN message creation using NewCanMessage
- frame := can.Frame{}
- frame.ID = 0x123
- frame.Data = [8]byte{0x01, 0x02, 0x03, 0x04, 0x00, 0x00, 0x00, 0x00}
- frame.Length = 4
-
- msg := NewCanMessage(frame)
-
- if msg.Frame.ID != 0x123 {
- t.Errorf("Expected ID 0x123, got 0x%x", msg.Frame.ID)
- }
-
- if msg.Frame.Length != 4 {
- t.Errorf("Expected frame length 4, got %d", msg.Frame.Length)
- }
-
- if msg.Signals == nil {
- t.Error("Signals map should not be nil")
- }
-}
-
-func TestDefaultParameters(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Test default parameter values exist
- expectedParams := []string{
- "can.device",
- "can.transport",
- "can.dump",
- "can.filter",
- "can.dump.inject",
- "can.parse.obd2",
- }
-
- // Check that parameters are defined
- params := mod.Parameters()
- if params == nil {
- t.Error("Parameters should not be nil")
- }
-
- // Just verify we have the expected number of parameters
- if len(expectedParams) != 6 {
- t.Error("Expected 6 parameters")
- }
-}
-
-func TestHandlerExecution(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Test that we can find all expected handlers
- handlerTests := []struct {
- name string
- args []string
- shouldFail bool
- }{
- {"can.inject FRAME_EXPRESSION", []string{"123#deadbeef"}, true}, // Should fail when not running
- {"can.fuzz ID_OR_NODE_NAME OPTIONAL_SIZE", []string{"123", "8"}, true}, // Should fail when not running
- {"can.dbc.load NAME", []string{"test.dbc"}, true}, // Will fail without actual file
- }
-
- handlers := mod.Handlers()
- for _, test := range handlerTests {
- found := false
- for _, h := range handlers {
- if h.Name == test.name {
- found = true
- err := h.Exec(test.args)
- if test.shouldFail && err == nil {
- t.Errorf("Handler %s should have failed but didn't", test.name)
- } else if !test.shouldFail && err != nil {
- t.Errorf("Handler %s failed unexpectedly: %v", test.name, err)
- }
- break
- }
- }
- if !found {
- t.Errorf("Handler %s not found", test.name)
- }
- }
-}
-
-func TestModuleFields(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Test various fields are initialized correctly
- if mod.conn != nil {
- t.Error("conn should be nil initially")
- }
-
- if mod.recv != nil {
- t.Error("recv should be nil initially")
- }
-
- if mod.send != nil {
- t.Error("send should be nil initially")
- }
-}
-
-func TestDBCLoadHandler(t *testing.T) {
- s := createMockSession(t)
- mod := NewCanModule(s)
-
- // Find dbc.load handler
- var dbcHandler *session.ModuleHandler
- for _, h := range mod.Handlers() {
- if h.Name == "can.dbc.load NAME" {
- dbcHandler = &h
- break
- }
- }
-
- if dbcHandler == nil {
- t.Fatal("DBC load handler not found")
- }
-
- // Test with non-existent file
- err := dbcHandler.Exec([]string{"non_existent.dbc"})
- if err == nil {
- t.Error("Expected error when loading non-existent DBC file")
- }
-}
-
-// Benchmark tests
-func BenchmarkNewCanModule(b *testing.B) {
- s, _ := session.New()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NewCanModule(s)
- }
-}
-
-func BenchmarkClearHandler(b *testing.B) {
- // Skip this benchmark as it requires CAN to be initialized
- b.Skip("Skipping clear handler benchmark - requires initialized CAN in session")
-}
-
-func BenchmarkInjectHandler(b *testing.B) {
- s, _ := session.New()
- mod := NewCanModule(s)
-
- var handler *session.ModuleHandler
- for _, h := range mod.Handlers() {
- if h.Name == "can.inject FRAME_EXPRESSION" {
- handler = &h
- break
- }
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- // This will fail since module is not running, but we're benchmarking the handler
- _ = handler.Exec([]string{"123#deadbeef"})
- }
-}
diff --git a/modules/dns_proxy/dns_proxy_base.go b/modules/dns_proxy/dns_proxy_base.go
index fe1b84af..f8c17445 100644
--- a/modules/dns_proxy/dns_proxy_base.go
+++ b/modules/dns_proxy/dns_proxy_base.go
@@ -14,8 +14,6 @@ import (
"github.com/evilsocket/islazy/log"
"github.com/miekg/dns"
-
- "github.com/robertkrimen/otto"
)
const (
@@ -227,14 +225,6 @@ func (p *DNSProxy) Start() {
}
func (p *DNSProxy) Stop() error {
- if p.Script != nil {
- if p.Script.Plugin.HasFunc("onExit") {
- if _, err := p.Script.Call("onExit"); err != nil {
- log.Error("Error while executing onExit callback: %s", "\nTraceback:\n "+err.(*otto.Error).String())
- }
- }
- }
-
if p.doRedirect && p.Redirection != nil {
p.Debug("disabling redirection %s", p.Redirection.String())
if err := p.Sess.Firewall.EnableRedirection(p.Redirection, false); err != nil {
diff --git a/modules/dns_proxy/dns_proxy_js_query.go b/modules/dns_proxy/dns_proxy_js_query.go
index bae57ad2..cd38f01f 100644
--- a/modules/dns_proxy/dns_proxy_js_query.go
+++ b/modules/dns_proxy/dns_proxy_js_query.go
@@ -3,9 +3,6 @@ package dns_proxy
import (
"encoding/json"
"fmt"
- "math"
- "math/big"
- "reflect"
"github.com/bettercap/bettercap/v2/log"
"github.com/bettercap/bettercap/v2/session"
@@ -43,7 +40,7 @@ func jsPropToMap(obj map[string]interface{}, key string) map[string]interface{}
if v, ok := obj[key].(map[string]interface{}); ok {
return v
}
- log.Error("error converting JS property to map[string]interface{} where key is: %s", key)
+ log.Debug("error converting JS property to map[string]interface{} where key is: %s", key)
return map[string]interface{}{}
}
@@ -51,7 +48,7 @@ func jsPropToMapArray(obj map[string]interface{}, key string) []map[string]inter
if v, ok := obj[key].([]map[string]interface{}); ok {
return v
}
- log.Error("error converting JS property to []map[string]interface{} where key is: %s", key)
+ log.Debug("error converting JS property to []map[string]interface{} where key is: %s", key)
return []map[string]interface{}{}
}
@@ -59,7 +56,7 @@ func jsPropToString(obj map[string]interface{}, key string) string {
if v, ok := obj[key].(string); ok {
return v
}
- log.Error("error converting JS property to string where key is: %s", key)
+ log.Debug("error converting JS property to string where key is: %s", key)
return ""
}
@@ -67,115 +64,56 @@ func jsPropToStringArray(obj map[string]interface{}, key string) []string {
if v, ok := obj[key].([]string); ok {
return v
}
- log.Error("error converting JS property to []string where key is: %s", key)
+ log.Debug("error converting JS property to []string where key is: %s", key)
return []string{}
}
func jsPropToUint8(obj map[string]interface{}, key string) uint8 {
- if v, ok := obj[key].(int64); ok {
- if v >= 0 && v <= math.MaxUint8 {
- return uint8(v)
- }
+ if v, ok := obj[key].(uint8); ok {
+ return v
}
- log.Error("error converting JS property to uint8 where key is: %s", key)
- return uint8(0)
+ log.Debug("error converting JS property to uint8 where key is: %s", key)
+ return 0
}
func jsPropToUint8Array(obj map[string]interface{}, key string) []uint8 {
- if arr, ok := obj[key].([]interface{}); ok {
- vArr := make([]uint8, 0, len(arr))
- for _, item := range arr {
- if v, ok := item.(int64); ok {
- if v >= 0 && v <= math.MaxUint8 {
- vArr = append(vArr, uint8(v))
- } else {
- log.Error("error converting JS property to []uint8 where key is: %s", key)
- return []uint8{}
- }
- }
- }
- return vArr
+ if v, ok := obj[key].([]uint8); ok {
+ return v
}
- log.Error("error converting JS property to []uint8 where key is: %s", key)
+ log.Debug("error converting JS property to []uint8 where key is: %s", key)
return []uint8{}
}
func jsPropToUint16(obj map[string]interface{}, key string) uint16 {
- if v, ok := obj[key].(int64); ok {
- if v >= 0 && v <= math.MaxUint16 {
- return uint16(v)
- }
+ if v, ok := obj[key].(uint16); ok {
+ return v
}
- log.Error("error converting JS property to uint16 where key is: %s", key)
- return uint16(0)
+ log.Debug("error converting JS property to uint16 where key is: %s", key)
+ return 0
}
func jsPropToUint16Array(obj map[string]interface{}, key string) []uint16 {
- if arr, ok := obj[key].([]interface{}); ok {
- vArr := make([]uint16, 0, len(arr))
- for _, item := range arr {
- if v, ok := item.(int64); ok {
- if v >= 0 && v <= math.MaxUint16 {
- vArr = append(vArr, uint16(v))
- } else {
- log.Error("error converting JS property to []uint16 where key is: %s", key)
- return []uint16{}
- }
- }
- }
- return vArr
+ if v, ok := obj[key].([]uint16); ok {
+ return v
}
- log.Error("error converting JS property to []uint16 where key is: %s", key)
+ log.Debug("error converting JS property to []uint16 where key is: %s", key)
return []uint16{}
}
func jsPropToUint32(obj map[string]interface{}, key string) uint32 {
- if v, ok := obj[key].(int64); ok {
- if v >= 0 && v <= math.MaxUint32 {
- return uint32(v)
- }
+ if v, ok := obj[key].(uint32); ok {
+ return v
}
- log.Error("error converting JS property to uint32 where key is: %s", key)
- return uint32(0)
+ log.Debug("error converting JS property to uint32 where key is: %s", key)
+ return 0
}
func jsPropToUint64(obj map[string]interface{}, key string) uint64 {
- prop, found := obj[key]
- if found {
- switch reflect.TypeOf(prop).String() {
- case "float64":
- if f, ok := prop.(float64); ok {
- bigInt := new(big.Float).SetFloat64(f)
- v, _ := bigInt.Uint64()
- if v >= 0 {
- return v
- }
- }
- break
- case "int64":
- if v, ok := prop.(int64); ok {
- if v >= 0 {
- return uint64(v)
- }
- }
- break
- case "uint64":
- if v, ok := prop.(uint64); ok {
- return v
- }
- break
- }
+ if v, ok := obj[key].(uint64); ok {
+ return v
}
- log.Error("error converting JS property to uint64 where key is: %s", key)
- return uint64(0)
-}
-
-func uint16ArrayToInt64Array(arr []uint16) []int64 {
- vArr := make([]int64, 0, len(arr))
- for _, item := range arr {
- vArr = append(vArr, int64(item))
- }
- return vArr
+ log.Debug("error converting JS property to uint64 where key is: %s", key)
+ return 0
}
func (j *JSQuery) NewHash() string {
@@ -245,8 +183,8 @@ func NewJSQuery(query *dns.Msg, clientIP string) (jsQuery *JSQuery) {
for i, question := range query.Question {
questions[i] = map[string]interface{}{
"Name": question.Name,
- "Qtype": int64(question.Qtype),
- "Qclass": int64(question.Qclass),
+ "Qtype": question.Qtype,
+ "Qclass": question.Qclass,
}
}
@@ -355,11 +293,3 @@ func (j *JSQuery) WasModified() bool {
// check if any of the fields has been changed
return j.NewHash() != j.refHash
}
-
-func (j *JSQuery) CheckIfModifiedAndUpdateHash() bool {
- // check if query was changed and update its hash
- newHash := j.NewHash()
- wasModified := j.refHash != newHash
- j.refHash = newHash
- return wasModified
-}
diff --git a/modules/dns_proxy/dns_proxy_js_record.go b/modules/dns_proxy/dns_proxy_js_record.go
index 49553ad8..55832d69 100644
--- a/modules/dns_proxy/dns_proxy_js_record.go
+++ b/modules/dns_proxy/dns_proxy_js_record.go
@@ -13,10 +13,10 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
jsRecord = map[string]interface{}{
"Header": map[string]interface{}{
- "Class": int64(header.Class),
+ "Class": header.Class,
"Name": header.Name,
- "Rrtype": int64(header.Rrtype),
- "Ttl": int64(header.Ttl),
+ "Rrtype": header.Rrtype,
+ "Ttl": header.Ttl,
},
}
@@ -48,24 +48,24 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
jsRecord["Mr"] = rr.Mr
case *dns.MX:
jsRecord["Mx"] = rr.Mx
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.NULL:
jsRecord["Data"] = rr.Data
case *dns.SOA:
- jsRecord["Expire"] = int64(rr.Expire)
- jsRecord["Minttl"] = int64(rr.Minttl)
+ jsRecord["Expire"] = rr.Expire
+ jsRecord["Minttl"] = rr.Minttl
jsRecord["Ns"] = rr.Ns
- jsRecord["Refresh"] = int64(rr.Refresh)
- jsRecord["Retry"] = int64(rr.Retry)
+ jsRecord["Refresh"] = rr.Refresh
+ jsRecord["Retry"] = rr.Retry
jsRecord["Mbox"] = rr.Mbox
- jsRecord["Serial"] = int64(rr.Serial)
+ jsRecord["Serial"] = rr.Serial
case *dns.TXT:
jsRecord["Txt"] = rr.Txt
case *dns.SRV:
- jsRecord["Port"] = int64(rr.Port)
- jsRecord["Priority"] = int64(rr.Priority)
+ jsRecord["Port"] = rr.Port
+ jsRecord["Priority"] = rr.Priority
jsRecord["Target"] = rr.Target
- jsRecord["Weight"] = int64(rr.Weight)
+ jsRecord["Weight"] = rr.Weight
case *dns.PTR:
jsRecord["Ptr"] = rr.Ptr
case *dns.NS:
@@ -73,10 +73,10 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
case *dns.DNAME:
jsRecord["Target"] = rr.Target
case *dns.AFSDB:
- jsRecord["Subtype"] = int64(rr.Subtype)
+ jsRecord["Subtype"] = rr.Subtype
jsRecord["Hostname"] = rr.Hostname
case *dns.CAA:
- jsRecord["Flag"] = int64(rr.Flag)
+ jsRecord["Flag"] = rr.Flag
jsRecord["Tag"] = rr.Tag
jsRecord["Value"] = rr.Value
case *dns.HINFO:
@@ -90,123 +90,123 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
jsRecord["SubAddress"] = rr.SubAddress
case *dns.KX:
jsRecord["Exchanger"] = rr.Exchanger
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.LOC:
- jsRecord["Altitude"] = int64(rr.Altitude)
- jsRecord["HorizPre"] = int64(rr.HorizPre)
- jsRecord["Latitude"] = int64(rr.Latitude)
- jsRecord["Longitude"] = int64(rr.Longitude)
- jsRecord["Size"] = int64(rr.Size)
- jsRecord["Version"] = int64(rr.Version)
- jsRecord["VertPre"] = int64(rr.VertPre)
+ jsRecord["Altitude"] = rr.Altitude
+ jsRecord["HorizPre"] = rr.HorizPre
+ jsRecord["Latitude"] = rr.Latitude
+ jsRecord["Longitude"] = rr.Longitude
+ jsRecord["Size"] = rr.Size
+ jsRecord["Version"] = rr.Version
+ jsRecord["VertPre"] = rr.VertPre
case *dns.SSHFP:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
+ jsRecord["Algorithm"] = rr.Algorithm
jsRecord["FingerPrint"] = rr.FingerPrint
- jsRecord["Type"] = int64(rr.Type)
+ jsRecord["Type"] = rr.Type
case *dns.TLSA:
jsRecord["Certificate"] = rr.Certificate
- jsRecord["MatchingType"] = int64(rr.MatchingType)
- jsRecord["Selector"] = int64(rr.Selector)
- jsRecord["Usage"] = int64(rr.Usage)
+ jsRecord["MatchingType"] = rr.MatchingType
+ jsRecord["Selector"] = rr.Selector
+ jsRecord["Usage"] = rr.Usage
case *dns.CERT:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
+ jsRecord["Algorithm"] = rr.Algorithm
jsRecord["Certificate"] = rr.Certificate
- jsRecord["KeyTag"] = int64(rr.KeyTag)
- jsRecord["Type"] = int64(rr.Type)
+ jsRecord["KeyTag"] = rr.KeyTag
+ jsRecord["Type"] = rr.Type
case *dns.DS:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
+ jsRecord["Algorithm"] = rr.Algorithm
jsRecord["Digest"] = rr.Digest
- jsRecord["DigestType"] = int64(rr.DigestType)
- jsRecord["KeyTag"] = int64(rr.KeyTag)
+ jsRecord["DigestType"] = rr.DigestType
+ jsRecord["KeyTag"] = rr.KeyTag
case *dns.NAPTR:
- jsRecord["Order"] = int64(rr.Order)
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Order"] = rr.Order
+ jsRecord["Preference"] = rr.Preference
jsRecord["Flags"] = rr.Flags
jsRecord["Service"] = rr.Service
jsRecord["Regexp"] = rr.Regexp
jsRecord["Replacement"] = rr.Replacement
case *dns.RRSIG:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
- jsRecord["Expiration"] = int64(rr.Expiration)
- jsRecord["Inception"] = int64(rr.Inception)
- jsRecord["KeyTag"] = int64(rr.KeyTag)
- jsRecord["Labels"] = int64(rr.Labels)
- jsRecord["OrigTtl"] = int64(rr.OrigTtl)
+ jsRecord["Algorithm"] = rr.Algorithm
+ jsRecord["Expiration"] = rr.Expiration
+ jsRecord["Inception"] = rr.Inception
+ jsRecord["KeyTag"] = rr.KeyTag
+ jsRecord["Labels"] = rr.Labels
+ jsRecord["OrigTtl"] = rr.OrigTtl
jsRecord["Signature"] = rr.Signature
jsRecord["SignerName"] = rr.SignerName
- jsRecord["TypeCovered"] = int64(rr.TypeCovered)
+ jsRecord["TypeCovered"] = rr.TypeCovered
case *dns.NSEC:
jsRecord["NextDomain"] = rr.NextDomain
- jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap)
+ jsRecord["TypeBitMap"] = rr.TypeBitMap
case *dns.NSEC3:
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Hash"] = int64(rr.Hash)
- jsRecord["HashLength"] = int64(rr.HashLength)
- jsRecord["Iterations"] = int64(rr.Iterations)
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Hash"] = rr.Hash
+ jsRecord["HashLength"] = rr.HashLength
+ jsRecord["Iterations"] = rr.Iterations
jsRecord["NextDomain"] = rr.NextDomain
jsRecord["Salt"] = rr.Salt
- jsRecord["SaltLength"] = int64(rr.SaltLength)
- jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap)
+ jsRecord["SaltLength"] = rr.SaltLength
+ jsRecord["TypeBitMap"] = rr.TypeBitMap
case *dns.NSEC3PARAM:
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Hash"] = int64(rr.Hash)
- jsRecord["Iterations"] = int64(rr.Iterations)
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Hash"] = rr.Hash
+ jsRecord["Iterations"] = rr.Iterations
jsRecord["Salt"] = rr.Salt
- jsRecord["SaltLength"] = int64(rr.SaltLength)
+ jsRecord["SaltLength"] = rr.SaltLength
case *dns.TKEY:
jsRecord["Algorithm"] = rr.Algorithm
- jsRecord["Error"] = int64(rr.Error)
- jsRecord["Expiration"] = int64(rr.Expiration)
- jsRecord["Inception"] = int64(rr.Inception)
+ jsRecord["Error"] = rr.Error
+ jsRecord["Expiration"] = rr.Expiration
+ jsRecord["Inception"] = rr.Inception
jsRecord["Key"] = rr.Key
- jsRecord["KeySize"] = int64(rr.KeySize)
- jsRecord["Mode"] = int64(rr.Mode)
+ jsRecord["KeySize"] = rr.KeySize
+ jsRecord["Mode"] = rr.Mode
jsRecord["OtherData"] = rr.OtherData
- jsRecord["OtherLen"] = int64(rr.OtherLen)
+ jsRecord["OtherLen"] = rr.OtherLen
case *dns.TSIG:
jsRecord["Algorithm"] = rr.Algorithm
- jsRecord["Error"] = int64(rr.Error)
- jsRecord["Fudge"] = int64(rr.Fudge)
- jsRecord["MACSize"] = int64(rr.MACSize)
+ jsRecord["Error"] = rr.Error
+ jsRecord["Fudge"] = rr.Fudge
+ jsRecord["MACSize"] = rr.MACSize
jsRecord["MAC"] = rr.MAC
- jsRecord["OrigId"] = int64(rr.OrigId)
+ jsRecord["OrigId"] = rr.OrigId
jsRecord["OtherData"] = rr.OtherData
- jsRecord["OtherLen"] = int64(rr.OtherLen)
- jsRecord["TimeSigned"] = int64(rr.TimeSigned)
+ jsRecord["OtherLen"] = rr.OtherLen
+ jsRecord["TimeSigned"] = rr.TimeSigned
case *dns.IPSECKEY:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
+ jsRecord["Algorithm"] = rr.Algorithm
jsRecord["GatewayAddr"] = rr.GatewayAddr.String()
jsRecord["GatewayHost"] = rr.GatewayHost
- jsRecord["GatewayType"] = int64(rr.GatewayType)
- jsRecord["Precedence"] = int64(rr.Precedence)
+ jsRecord["GatewayType"] = rr.GatewayType
+ jsRecord["Precedence"] = rr.Precedence
jsRecord["PublicKey"] = rr.PublicKey
case *dns.KEY:
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Protocol"] = int64(rr.Protocol)
- jsRecord["Algorithm"] = int64(rr.Algorithm)
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Protocol"] = rr.Protocol
+ jsRecord["Algorithm"] = rr.Algorithm
jsRecord["PublicKey"] = rr.PublicKey
case *dns.CDS:
- jsRecord["KeyTag"] = int64(rr.KeyTag)
- jsRecord["Algorithm"] = int64(rr.Algorithm)
- jsRecord["DigestType"] = int64(rr.DigestType)
+ jsRecord["KeyTag"] = rr.KeyTag
+ jsRecord["Algorithm"] = rr.Algorithm
+ jsRecord["DigestType"] = rr.DigestType
jsRecord["Digest"] = rr.Digest
case *dns.CDNSKEY:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Protocol"] = int64(rr.Protocol)
+ jsRecord["Algorithm"] = rr.Algorithm
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Protocol"] = rr.Protocol
jsRecord["PublicKey"] = rr.PublicKey
case *dns.NID:
jsRecord["NodeID"] = rr.NodeID
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.L32:
jsRecord["Locator32"] = rr.Locator32.String()
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.L64:
jsRecord["Locator64"] = rr.Locator64
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.LP:
jsRecord["Fqdn"] = rr.Fqdn
- jsRecord["Preference"] = int16(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.GPOS:
jsRecord["Altitude"] = rr.Altitude
jsRecord["Latitude"] = rr.Latitude
@@ -215,40 +215,40 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
jsRecord["Mbox"] = rr.Mbox
jsRecord["Txt"] = rr.Txt
case *dns.RKEY:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Protocol"] = int64(rr.Protocol)
+ jsRecord["Algorithm"] = rr.Algorithm
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Protocol"] = rr.Protocol
jsRecord["PublicKey"] = rr.PublicKey
case *dns.SMIMEA:
jsRecord["Certificate"] = rr.Certificate
- jsRecord["MatchingType"] = int64(rr.MatchingType)
- jsRecord["Selector"] = int64(rr.Selector)
- jsRecord["Usage"] = int64(rr.Usage)
+ jsRecord["MatchingType"] = rr.MatchingType
+ jsRecord["Selector"] = rr.Selector
+ jsRecord["Usage"] = rr.Usage
case *dns.AMTRELAY:
jsRecord["GatewayAddr"] = rr.GatewayAddr.String()
jsRecord["GatewayHost"] = rr.GatewayHost
- jsRecord["GatewayType"] = int64(rr.GatewayType)
- jsRecord["Precedence"] = int64(rr.Precedence)
+ jsRecord["GatewayType"] = rr.GatewayType
+ jsRecord["Precedence"] = rr.Precedence
case *dns.AVC:
jsRecord["Txt"] = rr.Txt
case *dns.URI:
- jsRecord["Priority"] = int64(rr.Priority)
- jsRecord["Weight"] = int64(rr.Weight)
+ jsRecord["Priority"] = rr.Priority
+ jsRecord["Weight"] = rr.Weight
jsRecord["Target"] = rr.Target
case *dns.EUI48:
jsRecord["Address"] = rr.Address
case *dns.EUI64:
jsRecord["Address"] = rr.Address
case *dns.GID:
- jsRecord["Gid"] = int64(rr.Gid)
+ jsRecord["Gid"] = rr.Gid
case *dns.UID:
- jsRecord["Uid"] = int64(rr.Uid)
+ jsRecord["Uid"] = rr.Uid
case *dns.UINFO:
jsRecord["Uinfo"] = rr.Uinfo
case *dns.SPF:
jsRecord["Txt"] = rr.Txt
case *dns.HTTPS:
- jsRecord["Priority"] = int64(rr.Priority)
+ jsRecord["Priority"] = rr.Priority
jsRecord["Target"] = rr.Target
kvs := rr.Value
var jsKvs []map[string]interface{}
@@ -262,7 +262,7 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
}
jsRecord["Value"] = jsKvs
case *dns.SVCB:
- jsRecord["Priority"] = int64(rr.Priority)
+ jsRecord["Priority"] = rr.Priority
jsRecord["Target"] = rr.Target
kvs := rr.Value
jsKvs := make([]map[string]interface{}, len(kvs))
@@ -277,13 +277,13 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
jsRecord["Value"] = jsKvs
case *dns.ZONEMD:
jsRecord["Digest"] = rr.Digest
- jsRecord["Hash"] = int64(rr.Hash)
- jsRecord["Scheme"] = int64(rr.Scheme)
- jsRecord["Serial"] = int64(rr.Serial)
+ jsRecord["Hash"] = rr.Hash
+ jsRecord["Scheme"] = rr.Scheme
+ jsRecord["Serial"] = rr.Serial
case *dns.CSYNC:
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Serial"] = int64(rr.Serial)
- jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap)
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Serial"] = rr.Serial
+ jsRecord["TypeBitMap"] = rr.TypeBitMap
case *dns.OPENPGPKEY:
jsRecord["PublicKey"] = rr.PublicKey
case *dns.TALINK:
@@ -294,53 +294,43 @@ func NewJSResourceRecord(rr dns.RR) (jsRecord map[string]interface{}, err error)
case *dns.DHCID:
jsRecord["Digest"] = rr.Digest
case *dns.DNSKEY:
- jsRecord["Flags"] = int64(rr.Flags)
- jsRecord["Protocol"] = int64(rr.Protocol)
- jsRecord["Algorithm"] = int64(rr.Algorithm)
+ jsRecord["Flags"] = rr.Flags
+ jsRecord["Protocol"] = rr.Protocol
+ jsRecord["Algorithm"] = rr.Algorithm
jsRecord["PublicKey"] = rr.PublicKey
case *dns.HIP:
jsRecord["Hit"] = rr.Hit
- jsRecord["HitLength"] = int64(rr.HitLength)
+ jsRecord["HitLength"] = rr.HitLength
jsRecord["PublicKey"] = rr.PublicKey
- jsRecord["PublicKeyAlgorithm"] = int64(rr.PublicKeyAlgorithm)
- jsRecord["PublicKeyLength"] = int64(rr.PublicKeyLength)
+ jsRecord["PublicKeyAlgorithm"] = rr.PublicKeyAlgorithm
+ jsRecord["PublicKeyLength"] = rr.PublicKeyLength
jsRecord["RendezvousServers"] = rr.RendezvousServers
case *dns.OPT:
- options := rr.Option
- jsOptions := make([]map[string]interface{}, len(options))
- for i, option := range options {
- jsOption, err := NewJSEDNS0(option)
- if err != nil {
- log.Error(err.Error())
- continue
- }
- jsOptions[i] = jsOption
- }
- jsRecord["Option"] = jsOptions
+ jsRecord["Option"] = rr.Option
case *dns.NIMLOC:
jsRecord["Locator"] = rr.Locator
case *dns.EID:
jsRecord["Endpoint"] = rr.Endpoint
case *dns.NXT:
jsRecord["NextDomain"] = rr.NextDomain
- jsRecord["TypeBitMap"] = uint16ArrayToInt64Array(rr.TypeBitMap)
+ jsRecord["TypeBitMap"] = rr.TypeBitMap
case *dns.PX:
jsRecord["Mapx400"] = rr.Mapx400
jsRecord["Map822"] = rr.Map822
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.SIG:
- jsRecord["Algorithm"] = int64(rr.Algorithm)
- jsRecord["Expiration"] = int64(rr.Expiration)
- jsRecord["Inception"] = int64(rr.Inception)
- jsRecord["KeyTag"] = int64(rr.KeyTag)
- jsRecord["Labels"] = int64(rr.Labels)
- jsRecord["OrigTtl"] = int64(rr.OrigTtl)
+ jsRecord["Algorithm"] = rr.Algorithm
+ jsRecord["Expiration"] = rr.Expiration
+ jsRecord["Inception"] = rr.Inception
+ jsRecord["KeyTag"] = rr.KeyTag
+ jsRecord["Labels"] = rr.Labels
+ jsRecord["OrigTtl"] = rr.OrigTtl
jsRecord["Signature"] = rr.Signature
jsRecord["SignerName"] = rr.SignerName
- jsRecord["TypeCovered"] = int64(rr.TypeCovered)
+ jsRecord["TypeCovered"] = rr.TypeCovered
case *dns.RT:
jsRecord["Host"] = rr.Host
- jsRecord["Preference"] = int64(rr.Preference)
+ jsRecord["Preference"] = rr.Preference
case *dns.NSAPPTR:
jsRecord["Ptr"] = rr.Ptr
case *dns.X25:
diff --git a/modules/dns_proxy/dns_proxy_script.go b/modules/dns_proxy/dns_proxy_script.go
index 83dd6777..4a608168 100644
--- a/modules/dns_proxy/dns_proxy_script.go
+++ b/modules/dns_proxy/dns_proxy_script.go
@@ -84,9 +84,11 @@ func (s *DnsProxyScript) OnRequest(req *dns.Msg, clientIP string) (jsreq, jsres
if _, err := s.Call("onRequest", jsreq, jsres); err != nil {
log.Error("%s", err)
return nil, nil
- } else if jsreq.CheckIfModifiedAndUpdateHash() {
+ } else if jsreq.WasModified() {
+ jsreq.UpdateHash()
return jsreq, nil
- } else if jsres.CheckIfModifiedAndUpdateHash() {
+ } else if jsres.WasModified() {
+ jsres.UpdateHash()
return nil, jsres
}
}
@@ -102,7 +104,8 @@ func (s *DnsProxyScript) OnResponse(req, res *dns.Msg, clientIP string) (jsreq,
if _, err := s.Call("onResponse", jsreq, jsres); err != nil {
log.Error("%s", err)
return nil, nil
- } else if jsres.CheckIfModifiedAndUpdateHash() {
+ } else if jsres.WasModified() {
+ jsres.UpdateHash()
return nil, jsres
}
}
diff --git a/modules/events_stream/events_view.go b/modules/events_stream/events_view.go
index f06d8dae..56d0e10d 100644
--- a/modules/events_stream/events_view.go
+++ b/modules/events_stream/events_view.go
@@ -137,7 +137,7 @@ func (mod *EventsStream) Render(output io.Writer, e session.Event) {
} else if strings.HasPrefix(e.Tag, "zeroconf.") {
mod.viewZeroConfEvent(output, e)
} else if !strings.HasPrefix(e.Tag, "tick") && e.Tag != "session.started" && e.Tag != "session.stopped" {
- fmt.Fprintf(output, "[%s] [%s] %v\n", e.Time.Format(mod.timeFormat), tui.Green(e.Tag), e.Data)
+ fmt.Fprintf(output, "[%s] [%s] %v\n", e.Time.Format(mod.timeFormat), tui.Green(e.Tag), e)
}
}
diff --git a/modules/http_proxy/http_proxy_base.go b/modules/http_proxy/http_proxy_base.go
index 7ace2122..5d4eebef 100644
--- a/modules/http_proxy/http_proxy_base.go
+++ b/modules/http_proxy/http_proxy_base.go
@@ -27,8 +27,6 @@ import (
"github.com/evilsocket/islazy/log"
"github.com/evilsocket/islazy/str"
"github.com/evilsocket/islazy/tui"
-
- "github.com/robertkrimen/otto"
)
const (
@@ -434,14 +432,6 @@ func (p *HTTPProxy) Start() {
}
func (p *HTTPProxy) Stop() error {
- if p.Script != nil {
- if p.Script.Plugin.HasFunc("onExit") {
- if _, err := p.Script.Call("onExit"); err != nil {
- log.Error("Error while executing onExit callback: %s", "\nTraceback:\n "+err.(*otto.Error).String())
- }
- }
- }
-
if p.doRedirect && p.Redirection != nil {
p.Debug("disabling redirection %s", p.Redirection.String())
if err := p.Sess.Firewall.EnableRedirection(p.Redirection, false); err != nil {
diff --git a/modules/http_proxy/http_proxy_base_filters.go b/modules/http_proxy/http_proxy_base_filters.go
index 988807f2..017fc0c3 100644
--- a/modules/http_proxy/http_proxy_base_filters.go
+++ b/modules/http_proxy/http_proxy_base_filters.go
@@ -1,10 +1,10 @@
package http_proxy
import (
- "io"
+ "io/ioutil"
"net/http"
- "strconv"
"strings"
+ "strconv"
"github.com/elazarl/goproxy"
@@ -74,10 +74,10 @@ func (p *HTTPProxy) isScriptInjectable(res *http.Response) (bool, string) {
return false, ""
}
-func (p *HTTPProxy) doScriptInjection(res *http.Response, cType string) error {
+func (p *HTTPProxy) doScriptInjection(res *http.Response, cType string) (error) {
defer res.Body.Close()
- raw, err := io.ReadAll(res.Body)
+ raw, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
} else if html := string(raw); strings.Contains(html, "") {
@@ -91,7 +91,7 @@ func (p *HTTPProxy) doScriptInjection(res *http.Response, cType string) error {
res.Header.Set("Content-Length", strconv.Itoa(len(html)))
// reset the response body to the original unread state
- res.Body = io.NopCloser(strings.NewReader(html))
+ res.Body = ioutil.NopCloser(strings.NewReader(html))
return nil
}
diff --git a/modules/http_proxy/http_proxy_base_sslstriper.go b/modules/http_proxy/http_proxy_base_sslstriper.go
index e3331b18..d2fd0f4f 100644
--- a/modules/http_proxy/http_proxy_base_sslstriper.go
+++ b/modules/http_proxy/http_proxy_base_sslstriper.go
@@ -1,7 +1,7 @@
package http_proxy
import (
- "io"
+ "io/ioutil"
"net/http"
"net/url"
"regexp"
@@ -253,7 +253,7 @@ func (s *SSLStripper) Process(res *http.Response, ctx *goproxy.ProxyCtx) {
// if we have a text or html content type, fetch the body
// and perform sslstripping
if s.isContentStrippable(res) {
- raw, err := io.ReadAll(res.Body)
+ raw, err := ioutil.ReadAll(res.Body)
if err != nil {
log.Error("Could not read response body: %s", err)
return
@@ -297,9 +297,9 @@ func (s *SSLStripper) Process(res *http.Response, ctx *goproxy.ProxyCtx) {
// reset the response body to the original unread state
// but with just a string reader, this way further calls
- // to ui.ReadAll(res.Body) will just return the content
+ // to ioutil.ReadAll(res.Body) will just return the content
// we stripped without downloading anything again.
- res.Body = io.NopCloser(strings.NewReader(body))
+ res.Body = ioutil.NopCloser(strings.NewReader(body))
}
// fix cookies domain + strip "secure" + "httponly" flags
diff --git a/modules/http_proxy/http_proxy_js_request.go b/modules/http_proxy/http_proxy_js_request.go
index 859526e4..a3c6a1da 100644
--- a/modules/http_proxy/http_proxy_js_request.go
+++ b/modules/http_proxy/http_proxy_js_request.go
@@ -3,7 +3,7 @@ package http_proxy
import (
"bytes"
"fmt"
- "io"
+ "io/ioutil"
"net/http"
"net/url"
"regexp"
@@ -103,21 +103,7 @@ func (j *JSRequest) WasModified() bool {
return j.NewHash() != j.refHash
}
-func (j *JSRequest) CheckIfModifiedAndUpdateHash() bool {
- newHash := j.NewHash()
- // body was read
- if j.bodyRead {
- j.refHash = newHash
- return true
- }
- // check if req was changed and update its hash
- wasModified := j.refHash != newHash
- j.refHash = newHash
- return wasModified
-}
-
func (j *JSRequest) GetHeader(name, deflt string) string {
- name = strings.ToLower(name)
headers := strings.Split(j.Headers, "\r\n")
for i := 0; i < len(headers); i++ {
if headers[i] != "" {
@@ -125,7 +111,8 @@ func (j *JSRequest) GetHeader(name, deflt string) string {
if len(header_parts) != 0 && len(header_parts[0]) == 3 {
header_name := string(header_parts[0][1])
header_value := string(header_parts[0][2])
- if name == strings.ToLower(header_name) {
+
+ if strings.ToLower(name) == strings.ToLower(header_name) {
return header_value
}
}
@@ -134,25 +121,6 @@ func (j *JSRequest) GetHeader(name, deflt string) string {
return deflt
}
-func (j *JSRequest) GetHeaders(name string) []string {
- name = strings.ToLower(name)
- headers := strings.Split(j.Headers, "\r\n")
- header_values := make([]string, 0, len(headers))
- for i := 0; i < len(headers); i++ {
- if headers[i] != "" {
- header_parts := header_regexp.FindAllSubmatch([]byte(headers[i]), 1)
- if len(header_parts) != 0 && len(header_parts[0]) == 3 {
- header_name := string(header_parts[0][1])
- header_value := string(header_parts[0][2])
- if name == strings.ToLower(header_name) {
- header_values = append(header_values, header_value)
- }
- }
- }
- }
- return header_values
-}
-
func (j *JSRequest) SetHeader(name, value string) {
name = strings.TrimSpace(name)
value = strings.TrimSpace(value)
@@ -201,7 +169,7 @@ func (j *JSRequest) RemoveHeader(name string) {
}
func (j *JSRequest) ReadBody() string {
- raw, err := io.ReadAll(j.req.Body)
+ raw, err := ioutil.ReadAll(j.req.Body)
if err != nil {
return ""
}
@@ -209,7 +177,7 @@ func (j *JSRequest) ReadBody() string {
j.Body = string(raw)
j.bodyRead = true
// reset the request body to the original unread state
- j.req.Body = io.NopCloser(bytes.NewBuffer(raw))
+ j.req.Body = ioutil.NopCloser(bytes.NewBuffer(raw))
return j.Body
}
diff --git a/modules/http_proxy/http_proxy_js_response.go b/modules/http_proxy/http_proxy_js_response.go
index c1bb98bf..051812ef 100644
--- a/modules/http_proxy/http_proxy_js_response.go
+++ b/modules/http_proxy/http_proxy_js_response.go
@@ -3,7 +3,7 @@ package http_proxy
import (
"bytes"
"fmt"
- "io"
+ "io/ioutil"
"net/http"
"strings"
@@ -76,29 +76,7 @@ func (j *JSResponse) WasModified() bool {
return j.NewHash() != j.refHash
}
-func (j *JSResponse) CheckIfModifiedAndUpdateHash() bool {
- newHash := j.NewHash()
- if j.bodyRead {
- // body was read
- j.refHash = newHash
- return true
- } else if j.bodyClear {
- // body was cleared manually
- j.refHash = newHash
- return true
- } else if j.Body != "" {
- // body was not read but just set
- j.refHash = newHash
- return true
- }
- // check if res was changed and update its hash
- wasModified := j.refHash != newHash
- j.refHash = newHash
- return wasModified
-}
-
func (j *JSResponse) GetHeader(name, deflt string) string {
- name = strings.ToLower(name)
headers := strings.Split(j.Headers, "\r\n")
for i := 0; i < len(headers); i++ {
if headers[i] != "" {
@@ -106,7 +84,8 @@ func (j *JSResponse) GetHeader(name, deflt string) string {
if len(header_parts) != 0 && len(header_parts[0]) == 3 {
header_name := string(header_parts[0][1])
header_value := string(header_parts[0][2])
- if name == strings.ToLower(header_name) {
+
+ if strings.ToLower(name) == strings.ToLower(header_name) {
return header_value
}
}
@@ -115,25 +94,6 @@ func (j *JSResponse) GetHeader(name, deflt string) string {
return deflt
}
-func (j *JSResponse) GetHeaders(name string) []string {
- name = strings.ToLower(name)
- headers := strings.Split(j.Headers, "\r\n")
- header_values := make([]string, 0, len(headers))
- for i := 0; i < len(headers); i++ {
- if headers[i] != "" {
- header_parts := header_regexp.FindAllSubmatch([]byte(headers[i]), 1)
- if len(header_parts) != 0 && len(header_parts[0]) == 3 {
- header_name := string(header_parts[0][1])
- header_value := string(header_parts[0][2])
- if name == strings.ToLower(header_name) {
- header_values = append(header_values, header_value)
- }
- }
- }
- }
- return header_values
-}
-
func (j *JSResponse) SetHeader(name, value string) {
name = strings.TrimSpace(name)
value = strings.TrimSpace(value)
@@ -208,7 +168,7 @@ func (j *JSResponse) ToResponse(req *http.Request) (resp *http.Response) {
func (j *JSResponse) ReadBody() string {
defer j.resp.Body.Close()
- raw, err := io.ReadAll(j.resp.Body)
+ raw, err := ioutil.ReadAll(j.resp.Body)
if err != nil {
return ""
}
@@ -217,7 +177,7 @@ func (j *JSResponse) ReadBody() string {
j.bodyRead = true
j.bodyClear = false
// reset the response body to the original unread state
- j.resp.Body = io.NopCloser(bytes.NewBuffer(raw))
+ j.resp.Body = ioutil.NopCloser(bytes.NewBuffer(raw))
return j.Body
}
diff --git a/modules/http_proxy/http_proxy_script.go b/modules/http_proxy/http_proxy_script.go
index 446f61da..070f7e24 100644
--- a/modules/http_proxy/http_proxy_script.go
+++ b/modules/http_proxy/http_proxy_script.go
@@ -84,9 +84,11 @@ func (s *HttpProxyScript) OnRequest(original *http.Request) (jsreq *JSRequest, j
if _, err := s.Call("onRequest", jsreq, jsres); err != nil {
log.Error("%s", err)
return nil, nil
- } else if jsreq.CheckIfModifiedAndUpdateHash() {
+ } else if jsreq.WasModified() {
+ jsreq.UpdateHash()
return jsreq, nil
- } else if jsres.CheckIfModifiedAndUpdateHash() {
+ } else if jsres.WasModified() {
+ jsres.UpdateHash()
return nil, jsres
}
}
@@ -102,7 +104,8 @@ func (s *HttpProxyScript) OnResponse(res *http.Response) (jsreq *JSRequest, jsre
if _, err := s.Call("onResponse", jsreq, jsres); err != nil {
log.Error("%s", err)
return nil, nil
- } else if jsres.CheckIfModifiedAndUpdateHash() {
+ } else if jsres.WasModified() {
+ jsres.UpdateHash()
return nil, jsres
}
}
diff --git a/modules/http_proxy/http_proxy_test.go b/modules/http_proxy/http_proxy_test.go
deleted file mode 100644
index d05d046e..00000000
--- a/modules/http_proxy/http_proxy_test.go
+++ /dev/null
@@ -1,706 +0,0 @@
-package http_proxy
-
-import (
- "fmt"
- "io/ioutil"
- "net"
- "net/http"
- "net/http/httptest"
- "os"
- "runtime"
- "strings"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/firewall"
- "github.com/bettercap/bettercap/v2/network"
- "github.com/bettercap/bettercap/v2/packets"
- "github.com/bettercap/bettercap/v2/session"
- "github.com/evilsocket/islazy/data"
-)
-
-// MockFirewall implements a mock firewall for testing
-type MockFirewall struct {
- forwardingEnabled bool
- redirections []firewall.Redirection
-}
-
-func NewMockFirewall() *MockFirewall {
- return &MockFirewall{
- forwardingEnabled: false,
- redirections: make([]firewall.Redirection, 0),
- }
-}
-
-func (m *MockFirewall) IsForwardingEnabled() bool {
- return m.forwardingEnabled
-}
-
-func (m *MockFirewall) EnableForwarding(enabled bool) error {
- m.forwardingEnabled = enabled
- return nil
-}
-
-func (m *MockFirewall) EnableRedirection(r *firewall.Redirection, enabled bool) error {
- if enabled {
- m.redirections = append(m.redirections, *r)
- } else {
- for i, red := range m.redirections {
- if red.String() == r.String() {
- m.redirections = append(m.redirections[:i], m.redirections[i+1:]...)
- break
- }
- }
- }
- return nil
-}
-
-func (m *MockFirewall) DisableRedirection(r *firewall.Redirection, enabled bool) error {
- return m.EnableRedirection(r, false)
-}
-
-func (m *MockFirewall) Restore() {
- m.redirections = make([]firewall.Redirection, 0)
- m.forwardingEnabled = false
-}
-
-// Create a mock session for testing
-func createMockSession() (*session.Session, *MockFirewall) {
- // Create interface
- iface := &network.Endpoint{
- IpAddress: "192.168.1.100",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "eth0",
- }
- iface.SetIP("192.168.1.100")
- iface.SetBits(24)
-
- // Parse interface addresses
- ifaceIP := net.ParseIP("192.168.1.100")
- ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- iface.IP = ifaceIP
- iface.HW = ifaceHW
-
- // Create gateway
- gateway := &network.Endpoint{
- IpAddress: "192.168.1.1",
- HwAddress: "11:22:33:44:55:66",
- }
- gatewayIP := net.ParseIP("192.168.1.1")
- gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66")
- gateway.IP = gatewayIP
- gateway.HW = gatewayHW
-
- // Create mock firewall
- mockFirewall := NewMockFirewall()
-
- // Create environment
- env, _ := session.NewEnvironment("")
-
- // Create LAN
- aliases, _ := data.NewUnsortedKV("", 0)
- lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- // Create session
- sess := &session.Session{
- Interface: iface,
- Gateway: gateway,
- Lan: lan,
- StartedAt: time.Now(),
- Active: true,
- Env: env,
- Queue: &packets.Queue{},
- Firewall: mockFirewall,
- Modules: make(session.ModuleList, 0),
- }
-
- // Initialize events
- sess.Events = session.NewEventPool(false, false)
-
- return sess, mockFirewall
-}
-
-func TestNewHttpProxy(t *testing.T) {
- sess, _ := createMockSession()
-
- mod := NewHttpProxy(sess)
-
- if mod == nil {
- t.Fatal("NewHttpProxy returned nil")
- }
-
- if mod.Name() != "http.proxy" {
- t.Errorf("expected module name 'http.proxy', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("unexpected author: %s", mod.Author())
- }
-
- // Check parameters
- params := []string{
- "http.port",
- "http.proxy.address",
- "http.proxy.port",
- "http.proxy.redirect",
- "http.proxy.script",
- "http.proxy.injectjs",
- "http.proxy.blacklist",
- "http.proxy.whitelist",
- "http.proxy.sslstrip",
- }
- for _, param := range params {
- if !mod.Session.Env.Has(param) {
- t.Errorf("parameter %s not registered", param)
- }
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{"http.proxy on", "http.proxy off"}
- handlerMap := make(map[string]bool)
-
- for _, h := range handlers {
- handlerMap[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerMap[expected] {
- t.Errorf("Expected handler '%s' not found", expected)
- }
- }
-}
-
-func TestHttpProxyConfigure(t *testing.T) {
- tests := []struct {
- name string
- params map[string]string
- expectErr bool
- validate func(*HttpProxy) error
- }{
- {
- name: "default configuration",
- params: map[string]string{
- "http.port": "80",
- "http.proxy.address": "192.168.1.100",
- "http.proxy.port": "8080",
- "http.proxy.redirect": "true",
- "http.proxy.script": "",
- "http.proxy.injectjs": "",
- "http.proxy.blacklist": "",
- "http.proxy.whitelist": "",
- "http.proxy.sslstrip": "false",
- },
- expectErr: false,
- validate: func(mod *HttpProxy) error {
- if mod.proxy == nil {
- return fmt.Errorf("proxy not initialized")
- }
- if mod.proxy.Address != "192.168.1.100" {
- return fmt.Errorf("expected address 192.168.1.100, got %s", mod.proxy.Address)
- }
- if !mod.proxy.doRedirect {
- return fmt.Errorf("expected redirect to be true")
- }
- if mod.proxy.Stripper == nil {
- return fmt.Errorf("SSL stripper not initialized")
- }
- if mod.proxy.Stripper.Enabled() {
- return fmt.Errorf("SSL stripper should be disabled")
- }
- return nil
- },
- },
- // Note: SSL stripping test removed as it requires elevated permissions
- // to create network capture handles
- {
- name: "with blacklist and whitelist",
- params: map[string]string{
- "http.port": "80",
- "http.proxy.address": "192.168.1.100",
- "http.proxy.port": "8080",
- "http.proxy.redirect": "false",
- "http.proxy.script": "",
- "http.proxy.injectjs": "",
- "http.proxy.blacklist": "*.evil.com,bad.site.org",
- "http.proxy.whitelist": "*.good.com,safe.site.org",
- "http.proxy.sslstrip": "false",
- },
- expectErr: false,
- validate: func(mod *HttpProxy) error {
- if len(mod.proxy.Blacklist) != 2 {
- return fmt.Errorf("expected 2 blacklist entries, got %d", len(mod.proxy.Blacklist))
- }
- if len(mod.proxy.Whitelist) != 2 {
- return fmt.Errorf("expected 2 whitelist entries, got %d", len(mod.proxy.Whitelist))
- }
- if mod.proxy.doRedirect {
- return fmt.Errorf("expected redirect to be false")
- }
- return nil
- },
- },
- {
- name: "JavaScript injection with inline code",
- params: map[string]string{
- "http.port": "80",
- "http.proxy.address": "192.168.1.100",
- "http.proxy.port": "8080",
- "http.proxy.redirect": "true",
- "http.proxy.script": "",
- "http.proxy.injectjs": "alert('injected');",
- "http.proxy.blacklist": "",
- "http.proxy.whitelist": "",
- "http.proxy.sslstrip": "false",
- },
- expectErr: false,
- validate: func(mod *HttpProxy) error {
- if mod.proxy.jsHook == "" {
- return fmt.Errorf("jsHook should be set")
- }
- if !strings.Contains(mod.proxy.jsHook, "alert('injected');") {
- return fmt.Errorf("jsHook should contain injected code")
- }
- return nil
- },
- },
- {
- name: "JavaScript injection with URL",
- params: map[string]string{
- "http.port": "80",
- "http.proxy.address": "192.168.1.100",
- "http.proxy.port": "8080",
- "http.proxy.redirect": "true",
- "http.proxy.script": "",
- "http.proxy.injectjs": "http://evil.com/hook.js",
- "http.proxy.blacklist": "",
- "http.proxy.whitelist": "",
- "http.proxy.sslstrip": "false",
- },
- expectErr: false,
- validate: func(mod *HttpProxy) error {
- if mod.proxy.jsHook == "" {
- return fmt.Errorf("jsHook should be set")
- }
- if !strings.Contains(mod.proxy.jsHook, "http://evil.com/hook.js") {
- return fmt.Errorf("jsHook should contain script URL")
- }
- return nil
- },
- },
- {
- name: "invalid address",
- params: map[string]string{
- "http.port": "80",
- "http.proxy.address": "invalid-address",
- "http.proxy.port": "8080",
- "http.proxy.redirect": "true",
- "http.proxy.script": "",
- "http.proxy.injectjs": "",
- "http.proxy.blacklist": "",
- "http.proxy.whitelist": "",
- "http.proxy.sslstrip": "false",
- },
- expectErr: true,
- },
- {
- name: "invalid port",
- params: map[string]string{
- "http.port": "80",
- "http.proxy.address": "192.168.1.100",
- "http.proxy.port": "invalid-port",
- "http.proxy.redirect": "true",
- "http.proxy.script": "",
- "http.proxy.injectjs": "",
- "http.proxy.blacklist": "",
- "http.proxy.whitelist": "",
- "http.proxy.sslstrip": "false",
- },
- expectErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- sess, _ := createMockSession()
- mod := NewHttpProxy(sess)
-
- // Set parameters
- for k, v := range tt.params {
- sess.Env.Set(k, v)
- }
-
- err := mod.Configure()
-
- if tt.expectErr && err == nil {
- t.Error("expected error but got none")
- } else if !tt.expectErr && err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- if !tt.expectErr && tt.validate != nil {
- if err := tt.validate(mod); err != nil {
- t.Error(err)
- }
- }
- })
- }
-}
-
-func TestHttpProxyStartStop(t *testing.T) {
- sess, mockFirewall := createMockSession()
- mod := NewHttpProxy(sess)
-
- // Configure with test parameters
- sess.Env.Set("http.port", "80")
- sess.Env.Set("http.proxy.address", "127.0.0.1")
- sess.Env.Set("http.proxy.port", "0") // Use port 0 to get a random available port
- sess.Env.Set("http.proxy.redirect", "true")
- sess.Env.Set("http.proxy.sslstrip", "false")
-
- // Start the proxy
- err := mod.Start()
- if err != nil {
- t.Fatalf("Failed to start proxy: %v", err)
- }
-
- if !mod.Running() {
- t.Error("Proxy should be running after Start()")
- }
-
- // Check that forwarding was enabled
- if !mockFirewall.IsForwardingEnabled() {
- t.Error("Forwarding should be enabled after starting proxy")
- }
-
- // Check that redirection was added
- if len(mockFirewall.redirections) != 1 {
- t.Errorf("Expected 1 redirection, got %d", len(mockFirewall.redirections))
- }
-
- // Give the server time to start
- time.Sleep(100 * time.Millisecond)
-
- // Stop the proxy
- err = mod.Stop()
- if err != nil {
- t.Fatalf("Failed to stop proxy: %v", err)
- }
-
- if mod.Running() {
- t.Error("Proxy should not be running after Stop()")
- }
-
- // Check that redirection was removed
- if len(mockFirewall.redirections) != 0 {
- t.Errorf("Expected 0 redirections after stop, got %d", len(mockFirewall.redirections))
- }
-}
-
-func TestHttpProxyAlreadyStarted(t *testing.T) {
- sess, _ := createMockSession()
- mod := NewHttpProxy(sess)
-
- // Configure
- sess.Env.Set("http.port", "80")
- sess.Env.Set("http.proxy.address", "127.0.0.1")
- sess.Env.Set("http.proxy.port", "0")
- sess.Env.Set("http.proxy.redirect", "false")
-
- // Start the proxy
- err := mod.Start()
- if err != nil {
- t.Fatalf("Failed to start proxy: %v", err)
- }
-
- // Try to configure while running
- err = mod.Configure()
- if err == nil {
- t.Error("Configure should fail when proxy is already running")
- }
-
- // Stop the proxy
- mod.Stop()
-}
-
-func TestHTTPProxyDoProxy(t *testing.T) {
- sess, _ := createMockSession()
- proxy := NewHTTPProxy(sess, "test")
-
- tests := []struct {
- name string
- request *http.Request
- expected bool
- }{
- {
- name: "valid request",
- request: &http.Request{
- Host: "example.com",
- },
- expected: true,
- },
- {
- name: "empty host",
- request: &http.Request{
- Host: "",
- },
- expected: false,
- },
- {
- name: "localhost request",
- request: &http.Request{
- Host: "localhost:8080",
- },
- expected: false,
- },
- {
- name: "127.0.0.1 request",
- request: &http.Request{
- Host: "127.0.0.1:8080",
- },
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := proxy.doProxy(tt.request)
- if result != tt.expected {
- t.Errorf("doProxy(%v) = %v, expected %v", tt.request.Host, result, tt.expected)
- }
- })
- }
-}
-
-func TestHTTPProxyShouldProxy(t *testing.T) {
- sess, _ := createMockSession()
- proxy := NewHTTPProxy(sess, "test")
-
- tests := []struct {
- name string
- blacklist []string
- whitelist []string
- host string
- expected bool
- }{
- {
- name: "no filters",
- blacklist: []string{},
- whitelist: []string{},
- host: "example.com",
- expected: true,
- },
- {
- name: "blacklisted exact match",
- blacklist: []string{"evil.com"},
- whitelist: []string{},
- host: "evil.com",
- expected: false,
- },
- {
- name: "blacklisted wildcard match",
- blacklist: []string{"*.evil.com"},
- whitelist: []string{},
- host: "sub.evil.com",
- expected: false,
- },
- {
- name: "whitelisted exact match",
- blacklist: []string{"*"},
- whitelist: []string{"good.com"},
- host: "good.com",
- expected: true,
- },
- {
- name: "not blacklisted",
- blacklist: []string{"evil.com"},
- whitelist: []string{},
- host: "good.com",
- expected: true,
- },
- {
- name: "whitelist takes precedence",
- blacklist: []string{"*"},
- whitelist: []string{"good.com"},
- host: "good.com",
- expected: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- proxy.Blacklist = tt.blacklist
- proxy.Whitelist = tt.whitelist
-
- req := &http.Request{
- Host: tt.host,
- }
-
- result := proxy.shouldProxy(req)
- if result != tt.expected {
- t.Errorf("shouldProxy(%v) = %v, expected %v", tt.host, result, tt.expected)
- }
- })
- }
-}
-
-func TestHTTPProxyStripPort(t *testing.T) {
- tests := []struct {
- input string
- expected string
- }{
- {"example.com:8080", "example.com"},
- {"example.com", "example.com"},
- {"192.168.1.1:443", "192.168.1.1"},
- {"[::1]:8080", "["}, // stripPort splits on first colon, so IPv6 addresses don't work correctly
- }
-
- for _, tt := range tests {
- t.Run(tt.input, func(t *testing.T) {
- result := stripPort(tt.input)
- if result != tt.expected {
- t.Errorf("stripPort(%s) = %s, expected %s", tt.input, result, tt.expected)
- }
- })
- }
-}
-
-func TestHTTPProxyJavaScriptInjection(t *testing.T) {
- sess, _ := createMockSession()
- proxy := NewHTTPProxy(sess, "test")
-
- tests := []struct {
- name string
- jsToInject string
- expectedHook string
- }{
- {
- name: "inline JavaScript",
- jsToInject: "console.log('test');",
- expectedHook: ``,
- },
- {
- name: "script tag",
- jsToInject: ``,
- expectedHook: ``, // script tags get wrapped
- },
- {
- name: "external URL",
- jsToInject: "http://example.com/script.js",
- expectedHook: ``,
- },
- {
- name: "HTTPS URL",
- jsToInject: "https://example.com/script.js",
- expectedHook: ``,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Skip test with invalid filename characters on Windows
- if runtime.GOOS == "windows" && strings.ContainsAny(tt.jsToInject, "<>:\"|?*") {
- t.Skip("Skipping test with invalid filename characters on Windows")
- }
-
- err := proxy.Configure("127.0.0.1", 8080, 80, false, "", tt.jsToInject, false)
- if err != nil {
- t.Fatalf("Configure failed: %v", err)
- }
-
- if proxy.jsHook != tt.expectedHook {
- t.Errorf("jsHook = %q, expected %q", proxy.jsHook, tt.expectedHook)
- }
- })
- }
-}
-
-func TestHTTPProxyWithTestServer(t *testing.T) {
- // Create a test HTTP server
- testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("Test Page"))
- }))
- defer testServer.Close()
-
- sess, _ := createMockSession()
- proxy := NewHTTPProxy(sess, "test")
-
- // Configure proxy with JS injection
- err := proxy.Configure("127.0.0.1", 0, 80, false, "", "console.log('injected');", false)
- if err != nil {
- t.Fatalf("Configure failed: %v", err)
- }
-
- // Create a simple test to verify proxy is initialized
- if proxy.Proxy == nil {
- t.Error("Proxy not initialized")
- }
-
- if proxy.jsHook == "" {
- t.Error("JavaScript hook not set")
- }
-
- // Note: Testing actual proxy behavior would require setting up the proxy server
- // and making HTTP requests through it, which is complex in a unit test environment
-}
-
-func TestHTTPProxyScriptLoading(t *testing.T) {
- sess, _ := createMockSession()
- proxy := NewHTTPProxy(sess, "test")
-
- // Create a temporary script file
- scriptContent := `
-function onRequest(req, res) {
- console.log("Request intercepted");
-}
-`
- tmpFile, err := ioutil.TempFile("", "proxy_script_*.js")
- if err != nil {
- t.Fatalf("Failed to create temp file: %v", err)
- }
- defer os.Remove(tmpFile.Name())
-
- if _, err := tmpFile.Write([]byte(scriptContent)); err != nil {
- t.Fatalf("Failed to write script: %v", err)
- }
- tmpFile.Close()
-
- // Try to configure with non-existent script
- err = proxy.Configure("127.0.0.1", 8080, 80, false, "non_existent_script.js", "", false)
- if err == nil {
- t.Error("Configure should fail with non-existent script")
- }
-
- // Note: Actual script loading would require proper JS engine setup
- // which is complex to mock. This test verifies the error handling.
-}
-
-// Benchmarks
-func BenchmarkHTTPProxyShouldProxy(b *testing.B) {
- sess, _ := createMockSession()
- proxy := NewHTTPProxy(sess, "test")
-
- proxy.Blacklist = []string{"*.evil.com", "bad.site.org", "*.malicious.net"}
- proxy.Whitelist = []string{"*.good.com", "safe.site.org"}
-
- req := &http.Request{
- Host: "example.com",
- }
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- _ = proxy.shouldProxy(req)
- }
-}
-
-func BenchmarkHTTPProxyStripPort(b *testing.B) {
- testHost := "example.com:8080"
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- _ = stripPort(testHost)
- }
-}
diff --git a/modules/http_server/http_server.go b/modules/http_server/http_server.go
index da309d3d..25cd7802 100644
--- a/modules/http_server/http_server.go
+++ b/modules/http_server/http_server.go
@@ -31,20 +31,20 @@ func NewHttpServer(s *session.Session) *HttpServer {
mod.AddParam(session.NewStringParameter("http.server.address",
session.ParamIfaceAddress,
session.IPv4Validator,
- "Address to bind the HTTP server to."))
+ "Address to bind the http server to."))
mod.AddParam(session.NewIntParameter("http.server.port",
"80",
- "Port to bind the HTTP server to."))
+ "Port to bind the http server to."))
mod.AddHandler(session.NewModuleHandler("http.server on", "",
- "Start HTTP server.",
+ "Start httpd server.",
func(args []string) error {
return mod.Start()
}))
mod.AddHandler(session.NewModuleHandler("http.server off", "",
- "Stop HTTP server.",
+ "Stop httpd server.",
func(args []string) error {
return mod.Stop()
}))
diff --git a/modules/https_server/https_server.go b/modules/https_server/https_server.go
index 2f3fd0a6..8e547fa7 100644
--- a/modules/https_server/https_server.go
+++ b/modules/https_server/https_server.go
@@ -35,11 +35,11 @@ func NewHttpsServer(s *session.Session) *HttpsServer {
mod.AddParam(session.NewStringParameter("https.server.address",
session.ParamIfaceAddress,
session.IPv4Validator,
- "Address to bind the HTTPS server to."))
+ "Address to bind the http server to."))
mod.AddParam(session.NewIntParameter("https.server.port",
"443",
- "Port to bind the HTTPS server to."))
+ "Port to bind the http server to."))
mod.AddParam(session.NewStringParameter("https.server.certificate",
"~/.bettercap-httpd.cert.pem",
@@ -54,13 +54,13 @@ func NewHttpsServer(s *session.Session) *HttpsServer {
tls.CertConfigToModule("https.server", &mod.SessionModule, tls.DefaultLegitConfig)
mod.AddHandler(session.NewModuleHandler("https.server on", "",
- "Start HTTPS server.",
+ "Start https server.",
func(args []string) error {
return mod.Start()
}))
mod.AddHandler(session.NewModuleHandler("https.server off", "",
- "Stop HTTPS server.",
+ "Stop https server.",
func(args []string) error {
return mod.Stop()
}))
diff --git a/modules/modules_test.go b/modules/modules_test.go
deleted file mode 100644
index 3cde11cd..00000000
--- a/modules/modules_test.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package modules
-
-import (
- "testing"
-)
-
-func TestLoadModulesWithNilSession(t *testing.T) {
- // This test verifies that LoadModules handles nil session gracefully
- // In the actual implementation, this would panic, which is expected behavior
- defer func() {
- if r := recover(); r == nil {
- t.Error("expected panic when loading modules with nil session, but didn't get one")
- }
- }()
-
- LoadModules(nil)
-}
-
-// Since LoadModules requires a fully initialized session with command-line flags,
-// which conflicts with the test runner, we can't easily test the actual module loading.
-// The main functionality is tested through integration tests and the actual application.
-// This test file at least provides some coverage for the package and demonstrates
-// the expected behavior with invalid input.
diff --git a/modules/net_probe/net_probe_test.go b/modules/net_probe/net_probe_test.go
deleted file mode 100644
index 7013dd23..00000000
--- a/modules/net_probe/net_probe_test.go
+++ /dev/null
@@ -1,610 +0,0 @@
-package net_probe
-
-import (
- "fmt"
- "net"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/network"
- "github.com/bettercap/bettercap/v2/packets"
- "github.com/bettercap/bettercap/v2/session"
- "github.com/malfunkt/iprange"
-)
-
-// MockQueue implements a mock packet queue for testing
-type MockQueue struct {
- sync.Mutex
- sentPackets [][]byte
- sendError error
- active bool
-}
-
-func NewMockQueue() *MockQueue {
- return &MockQueue{
- sentPackets: make([][]byte, 0),
- active: true,
- }
-}
-
-func (m *MockQueue) Send(data []byte) error {
- m.Lock()
- defer m.Unlock()
-
- if m.sendError != nil {
- return m.sendError
- }
-
- // Store a copy of the packet
- packet := make([]byte, len(data))
- copy(packet, data)
- m.sentPackets = append(m.sentPackets, packet)
- return nil
-}
-
-func (m *MockQueue) GetSentPackets() [][]byte {
- m.Lock()
- defer m.Unlock()
- return m.sentPackets
-}
-
-func (m *MockQueue) ClearSentPackets() {
- m.Lock()
- defer m.Unlock()
- m.sentPackets = make([][]byte, 0)
-}
-
-func (m *MockQueue) Stop() {
- m.Lock()
- defer m.Unlock()
- m.active = false
-}
-
-// MockSession for testing
-type MockSession struct {
- *session.Session
- runCommands []string
- skipIPs map[string]bool
-}
-
-func (m *MockSession) Run(cmd string) error {
- m.runCommands = append(m.runCommands, cmd)
-
- // Handle module commands
- if cmd == "net.recon on" {
- // Find and start the net.recon module
- for _, mod := range m.Modules {
- if mod.Name() == "net.recon" {
- if !mod.Running() {
- return mod.Start()
- }
- return nil
- }
- }
- } else if cmd == "net.recon off" {
- // Find and stop the net.recon module
- for _, mod := range m.Modules {
- if mod.Name() == "net.recon" {
- if mod.Running() {
- return mod.Stop()
- }
- return nil
- }
- }
- } else if cmd == "zerogod.discovery on" || cmd == "zerogod.discovery off" {
- // Mock zerogod.discovery commands
- return nil
- }
-
- return nil
-}
-
-func (m *MockSession) Skip(ip net.IP) bool {
- if m.skipIPs == nil {
- return false
- }
- return m.skipIPs[ip.String()]
-}
-
-// MockNetRecon implements a minimal net.recon module for testing
-type MockNetRecon struct {
- session.SessionModule
-}
-
-func NewMockNetRecon(s *session.Session) *MockNetRecon {
- mod := &MockNetRecon{
- SessionModule: session.NewSessionModule("net.recon", s),
- }
-
- // Add handlers so the module can be started/stopped via commands
- mod.AddHandler(session.NewModuleHandler("net.recon on", "",
- "Start net.recon",
- func(args []string) error {
- return mod.Start()
- }))
-
- mod.AddHandler(session.NewModuleHandler("net.recon off", "",
- "Stop net.recon",
- func(args []string) error {
- return mod.Stop()
- }))
-
- return mod
-}
-
-func (m *MockNetRecon) Name() string {
- return "net.recon"
-}
-
-func (m *MockNetRecon) Description() string {
- return "Mock net.recon module"
-}
-
-func (m *MockNetRecon) Author() string {
- return "test"
-}
-
-func (m *MockNetRecon) Configure() error {
- return nil
-}
-
-func (m *MockNetRecon) Start() error {
- return m.SetRunning(true, nil)
-}
-
-func (m *MockNetRecon) Stop() error {
- return m.SetRunning(false, nil)
-}
-
-// Create a mock session for testing
-func createMockSession() (*MockSession, *MockQueue) {
- // Create interface
- iface := &network.Endpoint{
- IpAddress: "192.168.1.100",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "eth0",
- }
- iface.SetIP("192.168.1.100")
- iface.SetBits(24)
-
- // Parse interface addresses
- ifaceIP := net.ParseIP("192.168.1.100")
- ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- iface.IP = ifaceIP
- iface.HW = ifaceHW
-
- // Create gateway
- gateway := &network.Endpoint{
- IpAddress: "192.168.1.1",
- HwAddress: "11:22:33:44:55:66",
- }
-
- // Create mock queue
- mockQueue := NewMockQueue()
-
- // Create environment
- env, _ := session.NewEnvironment("")
-
- // Create session
- sess := &session.Session{
- Interface: iface,
- Gateway: gateway,
- StartedAt: time.Now(),
- Active: true,
- Env: env,
- Queue: &packets.Queue{
- Traffic: sync.Map{},
- Stats: packets.Stats{},
- },
- Modules: make(session.ModuleList, 0),
- }
-
- // Initialize events
- sess.Events = session.NewEventPool(false, false)
-
- // Add mock net.recon module
- mockNetRecon := NewMockNetRecon(sess)
- sess.Modules = append(sess.Modules, mockNetRecon)
-
- // Create mock session wrapper
- mockSess := &MockSession{
- Session: sess,
- runCommands: make([]string, 0),
- skipIPs: make(map[string]bool),
- }
-
- return mockSess, mockQueue
-}
-
-func TestNewProber(t *testing.T) {
- mockSess, _ := createMockSession()
-
- mod := NewProber(mockSess.Session)
-
- if mod == nil {
- t.Fatal("NewProber returned nil")
- }
-
- if mod.Name() != "net.probe" {
- t.Errorf("expected module name 'net.probe', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("unexpected author: %s", mod.Author())
- }
-
- // Check parameters
- params := []string{"net.probe.nbns", "net.probe.mdns", "net.probe.upnp", "net.probe.wsd", "net.probe.throttle"}
- for _, param := range params {
- if !mod.Session.Env.Has(param) {
- t.Errorf("parameter %s not registered", param)
- }
- }
-}
-
-func TestProberConfigure(t *testing.T) {
- tests := []struct {
- name string
- params map[string]string
- expectErr bool
- expected struct {
- throttle int
- nbns bool
- mdns bool
- upnp bool
- wsd bool
- }
- }{
- {
- name: "default configuration",
- params: map[string]string{
- "net.probe.throttle": "10",
- "net.probe.nbns": "true",
- "net.probe.mdns": "true",
- "net.probe.upnp": "true",
- "net.probe.wsd": "true",
- },
- expectErr: false,
- expected: struct {
- throttle int
- nbns bool
- mdns bool
- upnp bool
- wsd bool
- }{10, true, true, true, true},
- },
- {
- name: "disabled probes",
- params: map[string]string{
- "net.probe.throttle": "5",
- "net.probe.nbns": "false",
- "net.probe.mdns": "false",
- "net.probe.upnp": "false",
- "net.probe.wsd": "false",
- },
- expectErr: false,
- expected: struct {
- throttle int
- nbns bool
- mdns bool
- upnp bool
- wsd bool
- }{5, false, false, false, false},
- },
- {
- name: "invalid throttle",
- params: map[string]string{
- "net.probe.throttle": "invalid",
- "net.probe.nbns": "true",
- "net.probe.mdns": "true",
- "net.probe.upnp": "true",
- "net.probe.wsd": "true",
- },
- expectErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- mockSess, _ := createMockSession()
- mod := NewProber(mockSess.Session)
-
- // Set parameters
- for k, v := range tt.params {
- mockSess.Env.Set(k, v)
- }
-
- err := mod.Configure()
-
- if tt.expectErr && err == nil {
- t.Error("expected error but got none")
- } else if !tt.expectErr && err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- if !tt.expectErr {
- if mod.throttle != tt.expected.throttle {
- t.Errorf("expected throttle %d, got %d", tt.expected.throttle, mod.throttle)
- }
- if mod.probes.NBNS != tt.expected.nbns {
- t.Errorf("expected NBNS %v, got %v", tt.expected.nbns, mod.probes.NBNS)
- }
- if mod.probes.MDNS != tt.expected.mdns {
- t.Errorf("expected MDNS %v, got %v", tt.expected.mdns, mod.probes.MDNS)
- }
- if mod.probes.UPNP != tt.expected.upnp {
- t.Errorf("expected UPNP %v, got %v", tt.expected.upnp, mod.probes.UPNP)
- }
- if mod.probes.WSD != tt.expected.wsd {
- t.Errorf("expected WSD %v, got %v", tt.expected.wsd, mod.probes.WSD)
- }
- }
- })
- }
-}
-
-// MockProber wraps Prober to allow mocking probe methods
-type MockProber struct {
- *Prober
- nbnsCount *int32
- upnpCount *int32
- wsdCount *int32
- mockQueue *MockQueue
-}
-
-func (m *MockProber) sendProbeNBNS(from net.IP, from_hw net.HardwareAddr, to net.IP) {
- atomic.AddInt32(m.nbnsCount, 1)
- m.mockQueue.Send([]byte(fmt.Sprintf("NBNS probe to %s", to)))
-}
-
-func (m *MockProber) sendProbeUPNP(from net.IP, from_hw net.HardwareAddr) {
- atomic.AddInt32(m.upnpCount, 1)
- m.mockQueue.Send([]byte("UPNP probe"))
-}
-
-func (m *MockProber) sendProbeWSD(from net.IP, from_hw net.HardwareAddr) {
- atomic.AddInt32(m.wsdCount, 1)
- m.mockQueue.Send([]byte("WSD probe"))
-}
-
-func TestProberStartStop(t *testing.T) {
- mockSess, _ := createMockSession()
- mod := NewProber(mockSess.Session)
-
- // Configure with fast throttle for testing
- mockSess.Env.Set("net.probe.throttle", "1")
- mockSess.Env.Set("net.probe.nbns", "true")
- mockSess.Env.Set("net.probe.mdns", "true")
- mockSess.Env.Set("net.probe.upnp", "true")
- mockSess.Env.Set("net.probe.wsd", "true")
-
- // Start the prober
- err := mod.Start()
- if err != nil {
- t.Fatalf("Failed to start prober: %v", err)
- }
-
- if !mod.Running() {
- t.Error("Prober should be running after Start()")
- }
-
- // Give it a moment to initialize
- time.Sleep(50 * time.Millisecond)
-
- // Stop the prober
- err = mod.Stop()
- if err != nil {
- t.Fatalf("Failed to stop prober: %v", err)
- }
-
- if mod.Running() {
- t.Error("Prober should not be running after Stop()")
- }
-
- // Since we can't easily mock the probe methods, we'll verify the module's state
- // and trust that the actual probe sending is tested in integration tests
-}
-
-func TestProberMonitorMode(t *testing.T) {
- mockSess, _ := createMockSession()
- mod := NewProber(mockSess.Session)
-
- // Set interface to monitor mode
- mockSess.Interface.IpAddress = network.MonitorModeAddress
-
- // Start the prober
- err := mod.Start()
- if err != nil {
- t.Fatalf("Failed to start prober: %v", err)
- }
-
- // Give it time to potentially start probing
- time.Sleep(50 * time.Millisecond)
-
- // Stop the prober
- mod.Stop()
-
- // In monitor mode, the prober should exit early without doing any work
- // We can't easily verify no probes were sent without mocking network calls,
- // but we can verify the module starts and stops correctly
-}
-
-func TestProberHandlers(t *testing.T) {
- mockSess, _ := createMockSession()
- mod := NewProber(mockSess.Session)
-
- // Test handlers
- handlers := mod.Handlers()
-
- expectedHandlers := []string{"net.probe on", "net.probe off"}
- handlerMap := make(map[string]bool)
-
- for _, h := range handlers {
- handlerMap[h.Name] = true
- }
-
- for _, expected := range expectedHandlers {
- if !handlerMap[expected] {
- t.Errorf("Expected handler '%s' not found", expected)
- }
- }
-
- // Test handler execution
- for _, h := range handlers {
- if h.Name == "net.probe on" {
- // Should start the module
- err := h.Exec([]string{})
- if err != nil {
- t.Errorf("Handler 'net.probe on' failed: %v", err)
- }
- if !mod.Running() {
- t.Error("Module should be running after 'net.probe on'")
- }
- mod.Stop()
- } else if h.Name == "net.probe off" {
- // Start first, then stop
- mod.Start()
- err := h.Exec([]string{})
- if err != nil {
- t.Errorf("Handler 'net.probe off' failed: %v", err)
- }
- if mod.Running() {
- t.Error("Module should not be running after 'net.probe off'")
- }
- }
- }
-}
-
-func TestProberSelectiveProbes(t *testing.T) {
- tests := []struct {
- name string
- enabledProbes map[string]bool
- }{
- {
- name: "only NBNS",
- enabledProbes: map[string]bool{
- "nbns": true,
- "mdns": false,
- "upnp": false,
- "wsd": false,
- },
- },
- {
- name: "only UPNP and WSD",
- enabledProbes: map[string]bool{
- "nbns": false,
- "mdns": false,
- "upnp": true,
- "wsd": true,
- },
- },
- {
- name: "all probes enabled",
- enabledProbes: map[string]bool{
- "nbns": true,
- "mdns": true,
- "upnp": true,
- "wsd": true,
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- mockSess, _ := createMockSession()
- mod := NewProber(mockSess.Session)
-
- // Configure probes
- mockSess.Env.Set("net.probe.throttle", "10")
- mockSess.Env.Set("net.probe.nbns", fmt.Sprintf("%v", tt.enabledProbes["nbns"]))
- mockSess.Env.Set("net.probe.mdns", fmt.Sprintf("%v", tt.enabledProbes["mdns"]))
- mockSess.Env.Set("net.probe.upnp", fmt.Sprintf("%v", tt.enabledProbes["upnp"]))
- mockSess.Env.Set("net.probe.wsd", fmt.Sprintf("%v", tt.enabledProbes["wsd"]))
-
- // Configure and verify the settings
- err := mod.Configure()
- if err != nil {
- t.Fatalf("Failed to configure: %v", err)
- }
-
- // Verify configuration
- if mod.probes.NBNS != tt.enabledProbes["nbns"] {
- t.Errorf("NBNS probe setting mismatch: expected %v, got %v",
- tt.enabledProbes["nbns"], mod.probes.NBNS)
- }
- if mod.probes.MDNS != tt.enabledProbes["mdns"] {
- t.Errorf("MDNS probe setting mismatch: expected %v, got %v",
- tt.enabledProbes["mdns"], mod.probes.MDNS)
- }
- if mod.probes.UPNP != tt.enabledProbes["upnp"] {
- t.Errorf("UPNP probe setting mismatch: expected %v, got %v",
- tt.enabledProbes["upnp"], mod.probes.UPNP)
- }
- if mod.probes.WSD != tt.enabledProbes["wsd"] {
- t.Errorf("WSD probe setting mismatch: expected %v, got %v",
- tt.enabledProbes["wsd"], mod.probes.WSD)
- }
- })
- }
-}
-
-func TestIPRangeExpansion(t *testing.T) {
- // Test that we correctly iterate through the subnet
- cidr := "192.168.1.0/30" // Small subnet for testing
- list, err := iprange.Parse(cidr)
- if err != nil {
- t.Fatalf("Failed to parse CIDR: %v", err)
- }
-
- addresses := list.Expand()
-
- // For /30, we should get 4 addresses
- expectedAddresses := []string{
- "192.168.1.0",
- "192.168.1.1",
- "192.168.1.2",
- "192.168.1.3",
- }
-
- if len(addresses) != len(expectedAddresses) {
- t.Errorf("Expected %d addresses, got %d", len(expectedAddresses), len(addresses))
- }
-
- for i, addr := range addresses {
- if addr.String() != expectedAddresses[i] {
- t.Errorf("Expected address %s at position %d, got %s", expectedAddresses[i], i, addr.String())
- }
- }
-}
-
-// Benchmarks
-func BenchmarkProberConfiguration(b *testing.B) {
- mockSess, _ := createMockSession()
-
- // Set up parameters
- mockSess.Env.Set("net.probe.throttle", "10")
- mockSess.Env.Set("net.probe.nbns", "true")
- mockSess.Env.Set("net.probe.mdns", "true")
- mockSess.Env.Set("net.probe.upnp", "true")
- mockSess.Env.Set("net.probe.wsd", "true")
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- mod := NewProber(mockSess.Session)
- mod.Configure()
- }
-}
-
-func BenchmarkIPRangeExpansion(b *testing.B) {
- cidr := "192.168.1.0/24"
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- list, _ := iprange.Parse(cidr)
- _ = list.Expand()
- }
-}
diff --git a/modules/net_recon/net_recon_test.go b/modules/net_recon/net_recon_test.go
deleted file mode 100644
index 93459666..00000000
--- a/modules/net_recon/net_recon_test.go
+++ /dev/null
@@ -1,644 +0,0 @@
-package net_recon
-
-import (
- "fmt"
- "sync"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/modules/utils"
- "github.com/bettercap/bettercap/v2/network"
- "github.com/bettercap/bettercap/v2/packets"
- "github.com/bettercap/bettercap/v2/session"
- "github.com/evilsocket/islazy/data"
-)
-
-// Mock ArpUpdate function
-var mockArpUpdateFunc func(string) (network.ArpTable, error)
-
-// Override the network.ArpUpdate function for testing
-func mockArpUpdate(iface string) (network.ArpTable, error) {
- if mockArpUpdateFunc != nil {
- return mockArpUpdateFunc(iface)
- }
- return make(network.ArpTable), nil
-}
-
-// MockLAN implements a mock version of the LAN interface
-type MockLAN struct {
- sync.RWMutex
- hosts map[string]*network.Endpoint
- wasMissed map[string]bool
- addedHosts []string
- removedHosts []string
-}
-
-func NewMockLAN() *MockLAN {
- return &MockLAN{
- hosts: make(map[string]*network.Endpoint),
- wasMissed: make(map[string]bool),
- addedHosts: []string{},
- removedHosts: []string{},
- }
-}
-
-func (m *MockLAN) AddIfNew(ip, mac string) {
- m.Lock()
- defer m.Unlock()
-
- if _, exists := m.hosts[mac]; !exists {
- m.hosts[mac] = &network.Endpoint{
- IpAddress: ip,
- HwAddress: mac,
- FirstSeen: time.Now(),
- LastSeen: time.Now(),
- }
- m.addedHosts = append(m.addedHosts, mac)
- }
-}
-
-func (m *MockLAN) Remove(ip, mac string) {
- m.Lock()
- defer m.Unlock()
-
- if _, exists := m.hosts[mac]; exists {
- delete(m.hosts, mac)
- m.removedHosts = append(m.removedHosts, mac)
- }
-}
-
-func (m *MockLAN) Clear() {
- m.Lock()
- defer m.Unlock()
-
- m.hosts = make(map[string]*network.Endpoint)
- m.wasMissed = make(map[string]bool)
- m.addedHosts = []string{}
- m.removedHosts = []string{}
-}
-
-func (m *MockLAN) EachHost(cb func(mac string, e *network.Endpoint)) {
- m.RLock()
- defer m.RUnlock()
-
- for mac, host := range m.hosts {
- cb(mac, host)
- }
-}
-
-func (m *MockLAN) List() []*network.Endpoint {
- m.RLock()
- defer m.RUnlock()
-
- list := make([]*network.Endpoint, 0, len(m.hosts))
- for _, host := range m.hosts {
- list = append(list, host)
- }
- return list
-}
-
-func (m *MockLAN) WasMissed(mac string) bool {
- m.RLock()
- defer m.RUnlock()
-
- return m.wasMissed[mac]
-}
-
-func (m *MockLAN) Get(mac string) *network.Endpoint {
- m.RLock()
- defer m.RUnlock()
-
- return m.hosts[mac]
-}
-
-// Create a mock session for testing
-func createMockSession() *session.Session {
- iface := &network.Endpoint{
- IpAddress: "192.168.1.100",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "eth0",
- }
- iface.SetIP("192.168.1.100")
- iface.SetBits(24)
-
- gateway := &network.Endpoint{
- IpAddress: "192.168.1.1",
- HwAddress: "11:22:33:44:55:66",
- }
-
- // Create environment
- env, _ := session.NewEnvironment("")
-
- sess := &session.Session{
- Interface: iface,
- Gateway: gateway,
- StartedAt: time.Now(),
- Active: true,
- Env: env,
- Queue: &packets.Queue{
- Traffic: sync.Map{},
- Stats: packets.Stats{},
- },
- Modules: make(session.ModuleList, 0),
- }
-
- // Initialize the Events field with a mock EventPool
- sess.Events = session.NewEventPool(false, false)
-
- return sess
-}
-
-func TestNewDiscovery(t *testing.T) {
- sess := createMockSession()
- mod := NewDiscovery(sess)
-
- if mod == nil {
- t.Fatal("NewDiscovery returned nil")
- }
-
- if mod.Name() != "net.recon" {
- t.Errorf("expected module name 'net.recon', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("unexpected author: %s", mod.Author())
- }
-
- if mod.selector == nil {
- t.Error("selector should be initialized")
- }
-}
-
-func TestRunDiff(t *testing.T) {
- // Test the basic diff functionality with a simpler approach
- tests := []struct {
- name string
- initialHosts map[string]string // IP -> MAC
- arpTable network.ArpTable
- expectedAdded []string
- expectedRemoved []string
- }{
- {
- name: "no changes",
- initialHosts: map[string]string{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- "192.168.1.20": "bb:bb:bb:bb:bb:bb",
- },
- arpTable: network.ArpTable{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- "192.168.1.20": "bb:bb:bb:bb:bb:bb",
- },
- expectedAdded: []string{},
- expectedRemoved: []string{},
- },
- {
- name: "new host discovered",
- initialHosts: map[string]string{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- },
- arpTable: network.ArpTable{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- "192.168.1.20": "bb:bb:bb:bb:bb:bb",
- },
- expectedAdded: []string{"bb:bb:bb:bb:bb:bb"},
- expectedRemoved: []string{},
- },
- {
- name: "host disappeared",
- initialHosts: map[string]string{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- "192.168.1.20": "bb:bb:bb:bb:bb:bb",
- },
- arpTable: network.ArpTable{
- "192.168.1.10": "aa:aa:aa:aa:aa:aa",
- },
- expectedAdded: []string{},
- expectedRemoved: []string{"bb:bb:bb:bb:bb:bb"},
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- sess := createMockSession()
-
- // Track callbacks
- addedHosts := []string{}
- removedHosts := []string{}
-
- newCb := func(e *network.Endpoint) {
- addedHosts = append(addedHosts, e.HwAddress)
- }
-
- lostCb := func(e *network.Endpoint) {
- removedHosts = append(removedHosts, e.HwAddress)
- }
-
- aliases, _ := data.NewUnsortedKV("", 0)
- sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, newCb, lostCb)
-
- mod := &Discovery{
- SessionModule: session.NewSessionModule("net.recon", sess),
- }
-
- // Add initial hosts
- for ip, mac := range tt.initialHosts {
- sess.Lan.AddIfNew(ip, mac)
- }
-
- // Reset tracking
- addedHosts = []string{}
- removedHosts = []string{}
-
- // Add interface and gateway to ARP table to avoid them being removed
- finalArpTable := make(network.ArpTable)
- for k, v := range tt.arpTable {
- finalArpTable[k] = v
- }
- finalArpTable[sess.Interface.IpAddress] = sess.Interface.HwAddress
- finalArpTable[sess.Gateway.IpAddress] = sess.Gateway.HwAddress
-
- // Run the diff multiple times to trigger actual removal (TTL countdown)
- for i := 0; i < network.LANDefaultttl+1; i++ {
- mod.runDiff(finalArpTable)
- }
-
- // Check results
- if len(addedHosts) != len(tt.expectedAdded) {
- t.Errorf("expected %d added hosts, got %d. Added: %v", len(tt.expectedAdded), len(addedHosts), addedHosts)
- }
-
- if len(removedHosts) != len(tt.expectedRemoved) {
- t.Errorf("expected %d removed hosts, got %d. Removed: %v", len(tt.expectedRemoved), len(removedHosts), removedHosts)
- }
- })
- }
-}
-
-func TestConfigure(t *testing.T) {
- sess := createMockSession()
- mod := NewDiscovery(sess)
-
- err := mod.Configure()
- if err != nil {
- t.Errorf("Configure() returned error: %v", err)
- }
-}
-
-func TestStartStop(t *testing.T) {
- sess := createMockSession()
- aliases, _ := data.NewUnsortedKV("", 0)
- sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- mod := NewDiscovery(sess)
-
- // Test starting the module
- err := mod.Start()
- if err != nil {
- t.Errorf("Start() returned error: %v", err)
- }
-
- if !mod.Running() {
- t.Error("module should be running after Start()")
- }
-
- // Let it run briefly
- time.Sleep(100 * time.Millisecond)
-
- // Test stopping the module
- err = mod.Stop()
- if err != nil {
- t.Errorf("Stop() returned error: %v", err)
- }
-
- if mod.Running() {
- t.Error("module should not be running after Stop()")
- }
-}
-
-func TestShowMethods(t *testing.T) {
- // Skip this test as it requires a full session with readline
- t.Skip("Skipping TestShowMethods as it requires readline initialization")
-}
-
-func TestDoSelection(t *testing.T) {
- sess := createMockSession()
- aliases, _ := data.NewUnsortedKV("", 0)
- sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- // Add test endpoints
- sess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
- sess.Lan.AddIfNew("192.168.1.20", "bb:bb:bb:bb:bb:bb")
- sess.Lan.AddIfNew("192.168.1.30", "cc:cc:cc:cc:cc:cc")
-
- // Get endpoints and set additional properties
- if e, found := sess.Lan.Get("aa:aa:aa:aa:aa:aa"); found {
- e.Hostname = "host1"
- e.Vendor = "Vendor1"
- }
-
- if e, found := sess.Lan.Get("bb:bb:bb:bb:bb:bb"); found {
- e.Alias = "mydevice"
- e.Vendor = "Vendor2"
- }
-
- mod := NewDiscovery(sess)
- mod.selector = utils.ViewSelectorFor(&mod.SessionModule, "net.show",
- []string{"ip", "mac", "seen", "sent", "rcvd"}, "ip asc")
-
- tests := []struct {
- name string
- arg string
- expectedCount int
- expectedIPs []string
- }{
- {
- name: "select all",
- arg: "",
- expectedCount: 3,
- },
- {
- name: "select by IP",
- arg: "192.168.1.10",
- expectedCount: 1,
- expectedIPs: []string{"192.168.1.10"},
- },
- {
- name: "select by MAC",
- arg: "aa:aa:aa:aa:aa:aa",
- expectedCount: 1,
- expectedIPs: []string{"192.168.1.10"},
- },
- {
- name: "select multiple by comma",
- arg: "192.168.1.10,192.168.1.20",
- expectedCount: 2,
- expectedIPs: []string{"192.168.1.10", "192.168.1.20"},
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err, targets := mod.doSelection(tt.arg)
- if err != nil {
- t.Errorf("doSelection returned error: %v", err)
- }
-
- if len(targets) != tt.expectedCount {
- t.Errorf("expected %d targets, got %d", tt.expectedCount, len(targets))
- }
-
- if tt.expectedIPs != nil {
- for _, expectedIP := range tt.expectedIPs {
- found := false
- for _, target := range targets {
- if target.IpAddress == expectedIP {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("expected to find IP %s in targets", expectedIP)
- }
- }
- }
- })
- }
-}
-
-func TestHandlers(t *testing.T) {
- sess := createMockSession()
- aliases, _ := data.NewUnsortedKV("", 0)
- sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- mod := NewDiscovery(sess)
-
- handlers := []struct {
- name string
- handler string
- args []string
- setup func()
- validate func() error
- }{
- {
- name: "net.clear",
- handler: "net.clear",
- args: []string{},
- setup: func() {
- sess.Lan.AddIfNew("192.168.1.10", "aa:aa:aa:aa:aa:aa")
- },
- validate: func() error {
- // Check if hosts were cleared
- hosts := sess.Lan.List()
- if len(hosts) != 0 {
- return fmt.Errorf("expected empty hosts after clear, got %d", len(hosts))
- }
- return nil
- },
- },
- }
-
- for _, tt := range handlers {
- t.Run(tt.name, func(t *testing.T) {
- if tt.setup != nil {
- tt.setup()
- }
-
- // Find and execute the handler
- found := false
- for _, h := range mod.Handlers() {
- if h.Name == tt.handler {
- found = true
- err := h.Exec(tt.args)
- if err != nil {
- t.Errorf("handler %s returned error: %v", tt.handler, err)
- }
- break
- }
- }
-
- if !found {
- t.Errorf("handler %s not found", tt.handler)
- }
-
- if tt.validate != nil {
- if err := tt.validate(); err != nil {
- t.Error(err)
- }
- }
- })
- }
-}
-
-func TestGetRow(t *testing.T) {
- sess := createMockSession()
- aliases, _ := data.NewUnsortedKV("", 0)
- sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- mod := NewDiscovery(sess)
-
- // Test endpoint with metadata
- endpoint := &network.Endpoint{
- IpAddress: "192.168.1.10",
- HwAddress: "aa:aa:aa:aa:aa:aa",
- Hostname: "testhost",
- Vendor: "Test Vendor",
- FirstSeen: time.Now().Add(-time.Hour),
- LastSeen: time.Now(),
- Meta: network.NewMeta(),
- }
- endpoint.Meta.Set("key1", "value1")
- endpoint.Meta.Set("key2", "value2")
-
- // Test without meta
- rows := mod.getRow(endpoint, false)
- if len(rows) != 1 {
- t.Errorf("expected 1 row without meta, got %d", len(rows))
- }
- if len(rows[0]) != 7 {
- t.Errorf("expected 7 columns, got %d", len(rows[0]))
- }
-
- // Test with meta
- rows = mod.getRow(endpoint, true)
- if len(rows) != 2 { // One main row + one meta row per metadata entry
- t.Errorf("expected 2 rows with meta, got %d", len(rows))
- }
-
- // Test interface endpoint
- ifaceEndpoint := sess.Interface
- rows = mod.getRow(ifaceEndpoint, false)
- if len(rows) != 1 {
- t.Errorf("expected 1 row for interface, got %d", len(rows))
- }
-
- // Test gateway endpoint
- gatewayEndpoint := sess.Gateway
- rows = mod.getRow(gatewayEndpoint, false)
- if len(rows) != 1 {
- t.Errorf("expected 1 row for gateway, got %d", len(rows))
- }
-}
-
-func TestDoFilter(t *testing.T) {
- sess := createMockSession()
- mod := NewDiscovery(sess)
- mod.selector = utils.ViewSelectorFor(&mod.SessionModule, "net.show",
- []string{"ip", "mac", "seen", "sent", "rcvd"}, "ip asc")
-
- // Test that doFilter behavior matches the actual implementation
- // When Expression is nil, it returns true (no filtering)
- // When Expression is set, it matches against any of the fields
-
- tests := []struct {
- name string
- filter string
- endpoint *network.Endpoint
- shouldMatch bool
- }{
- {
- name: "no filter",
- filter: "",
- endpoint: &network.Endpoint{
- IpAddress: "192.168.1.10",
- Meta: network.NewMeta(),
- },
- shouldMatch: true,
- },
- {
- name: "ip filter match",
- filter: "192.168",
- endpoint: &network.Endpoint{
- IpAddress: "192.168.1.10",
- Meta: network.NewMeta(),
- },
- shouldMatch: true,
- },
- {
- name: "mac filter match",
- filter: "aa:bb",
- endpoint: &network.Endpoint{
- IpAddress: "192.168.1.10",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Meta: network.NewMeta(),
- },
- shouldMatch: true,
- },
- {
- name: "hostname filter match",
- filter: "myhost",
- endpoint: &network.Endpoint{
- IpAddress: "192.168.1.10",
- Hostname: "myhost.local",
- Meta: network.NewMeta(),
- },
- shouldMatch: true,
- },
- {
- name: "no match - testing unique string",
- filter: "xyz123nomatch",
- endpoint: &network.Endpoint{
- IpAddress: "192.168.1.10",
- Ip6Address: "",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "host.local",
- Alias: "",
- Vendor: "",
- Meta: network.NewMeta(),
- },
- shouldMatch: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Reset selector for each test
- // Set the parameter value that Update() will read
- sess.Env.Set("net.show.filter", tt.filter)
- mod.selector.Expression = nil
-
- // Update will read from the parameter
- err := mod.selector.Update()
- if err != nil {
- t.Fatalf("selector.Update() failed: %v", err)
- }
-
- result := mod.doFilter(tt.endpoint)
- if result != tt.shouldMatch {
- if mod.selector.Expression != nil {
- t.Errorf("expected doFilter to return %v, got %v. Regex: %s", tt.shouldMatch, result, mod.selector.Expression.String())
- } else {
- t.Errorf("expected doFilter to return %v, got %v. Expression is nil", tt.shouldMatch, result)
- }
- }
- })
- }
-}
-
-// Benchmark the runDiff method
-func BenchmarkRunDiff(b *testing.B) {
- sess := createMockSession()
- aliases, _ := data.NewUnsortedKV("", 0)
- sess.Lan = network.NewLAN(sess.Interface, sess.Gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- mod := &Discovery{
- SessionModule: session.NewSessionModule("net.recon", sess),
- }
-
- // Create a large ARP table
- arpTable := make(network.ArpTable)
- for i := 0; i < 100; i++ {
- ip := fmt.Sprintf("192.168.1.%d", i)
- mac := fmt.Sprintf("aa:bb:cc:dd:%02x:%02x", i/256, i%256)
- arpTable[ip] = mac
-
- // Add half to the existing LAN
- if i < 50 {
- sess.Lan.AddIfNew(ip, mac)
- }
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- mod.runDiff(arpTable)
- }
-}
diff --git a/modules/net_sniff/net_sniff.go b/modules/net_sniff/net_sniff.go
index cb2c1b48..4daa9859 100644
--- a/modules/net_sniff/net_sniff.go
+++ b/modules/net_sniff/net_sniff.go
@@ -59,11 +59,6 @@ func NewSniffer(s *session.Session) *Sniffer {
"",
"If set, the sniffer will read from this pcap file instead of the current interface."))
- mod.AddParam(session.NewStringParameter("net.sniff.interface",
- "",
- "",
- "Interface to sniff on."))
-
mod.AddHandler(session.NewModuleHandler("net.sniff stats", "",
"Print sniffer session configuration and statistics.",
func(args []string) error {
diff --git a/modules/net_sniff/net_sniff_context.go b/modules/net_sniff/net_sniff_context.go
index 633238f1..e275ebf8 100644
--- a/modules/net_sniff/net_sniff_context.go
+++ b/modules/net_sniff/net_sniff_context.go
@@ -17,7 +17,6 @@ import (
type SnifferContext struct {
Handle *pcap.Handle
- Interface string
Source string
DumpLocal bool
Verbose bool
@@ -38,22 +37,13 @@ func (mod *Sniffer) GetContext() (error, *SnifferContext) {
return err, ctx
}
- if err, ctx.Interface = mod.StringParam("net.sniff.interface"); err != nil {
- return err, ctx
- }
-
- if ctx.Interface == "" {
- ctx.Interface = mod.Session.Interface.Name()
- }
-
if ctx.Source == "" {
/*
* We don't want to pcap.BlockForever otherwise pcap_close(handle)
* could hang waiting for a timeout to expire ...
*/
-
readTimeout := 500 * time.Millisecond
- if ctx.Handle, err = network.CaptureWithTimeout(ctx.Interface, readTimeout); err != nil {
+ if ctx.Handle, err = network.CaptureWithTimeout(mod.Session.Interface.Name(), readTimeout); err != nil {
return err, ctx
}
} else {
@@ -104,8 +94,6 @@ func (mod *Sniffer) GetContext() (error, *SnifferContext) {
func NewSnifferContext() *SnifferContext {
return &SnifferContext{
Handle: nil,
- Interface: "",
- Source: "",
DumpLocal: false,
Verbose: false,
Filter: "",
@@ -127,8 +115,7 @@ var (
)
func (c *SnifferContext) Log(sess *session.Session) {
- log.Info("Interface : %s", tui.Bold(c.Interface))
- log.Info("Skip local packets : %s", yn[!c.DumpLocal])
+ log.Info("Skip local packets : %s", yn[c.DumpLocal])
log.Info("Verbose : %s", yn[c.Verbose])
log.Info("BPF Filter : '%s'", tui.Yellow(c.Filter))
log.Info("Regular expression : '%s'", tui.Yellow(c.Expression))
diff --git a/modules/net_sniff/net_sniff_http.go b/modules/net_sniff/net_sniff_http.go
index 23e0375c..a111c08b 100644
--- a/modules/net_sniff/net_sniff_http.go
+++ b/modules/net_sniff/net_sniff_http.go
@@ -4,7 +4,7 @@ import (
"bufio"
"bytes"
"compress/gzip"
- "io"
+ "io/ioutil"
"net"
"net/http"
"strings"
@@ -50,7 +50,7 @@ func toSerializableRequest(req *http.Request) HTTPRequest {
body := []byte(nil)
ctype := "?"
if req.Body != nil {
- body, _ = io.ReadAll(req.Body)
+ body, _ = ioutil.ReadAll(req.Body)
}
for name, values := range req.Header {
@@ -90,7 +90,7 @@ func toSerializableResponse(res *http.Response) HTTPResponse {
}
if res.Body != nil {
- body, _ = io.ReadAll(res.Body)
+ body, _ = ioutil.ReadAll(res.Body)
}
// attempt decompression, but since this has been parsed by just
diff --git a/modules/packet_proxy/packet_proxy_linux.go b/modules/packet_proxy/packet_proxy_linux.go
index 9a40fcff..e124976c 100644
--- a/modules/packet_proxy/packet_proxy_linux.go
+++ b/modules/packet_proxy/packet_proxy_linux.go
@@ -22,7 +22,7 @@ type PacketProxy struct {
rule string
queue *nfqueue.Nfqueue
queueNum int
- queueCb func(q *nfqueue.Nfqueue, a nfqueue.Attribute) int
+ queueCb nfqueue.HookFunc
pluginPath string
plugin *plugin.Plugin
}
@@ -149,7 +149,7 @@ func (mod *PacketProxy) Configure() (err error) {
return
} else if sym, err = mod.plugin.Lookup("OnPacket"); err != nil {
return
- } else if mod.queueCb, ok = sym.(func(q *nfqueue.Nfqueue, a nfqueue.Attribute) int); !ok {
+ } else if mod.queueCb, ok = sym.(func(nfqueue.Attribute) int); !ok {
return fmt.Errorf("Symbol OnPacket is not a valid callback function.")
}
@@ -198,7 +198,7 @@ func (mod *PacketProxy) Configure() (err error) {
// CGO callback ... ¯\_(ツ)_/¯
func dummyCallback(attribute nfqueue.Attribute) int {
if mod.queueCb != nil {
- return mod.queueCb(mod.queue, attribute)
+ return mod.queueCb(attribute)
} else {
id := *attribute.PacketID
diff --git a/modules/tcp_proxy/tcp_proxy_script.go b/modules/tcp_proxy/tcp_proxy_script.go
index 50956ea0..fa801be5 100644
--- a/modules/tcp_proxy/tcp_proxy_script.go
+++ b/modules/tcp_proxy/tcp_proxy_script.go
@@ -1,7 +1,6 @@
package tcp_proxy
import (
- "encoding/json"
"net"
"strings"
@@ -56,36 +55,12 @@ func (s *TcpProxyScript) OnData(from, to net.Addr, data []byte, callback func(ca
log.Error("error while executing onData callback: %s", err)
return nil
} else if ret != nil {
- return toByteArray(ret)
- }
- }
- return nil
-}
-
-func toByteArray(ret interface{}) []byte {
- // this approach is a bit hacky but it handles all cases
-
- // serialize ret to JSON
- if jsonData, err := json.Marshal(ret); err == nil {
- // attempt to deserialize as []float64
- var back2Array []float64
- if err := json.Unmarshal(jsonData, &back2Array); err == nil {
- result := make([]byte, len(back2Array))
- for i, num := range back2Array {
- if num >= 0 && num <= 255 {
- result[i] = byte(num)
- } else {
- log.Error("array element at index %d is not a valid byte value %d", i, num)
- return nil
- }
+ array, ok := ret.([]byte)
+ if !ok {
+ log.Error("error while casting exported value to array of byte: value = %+v", ret)
}
- return result
- } else {
- log.Error("failed to deserialize %+v to []float64: %v", ret, err)
+ return array
}
- } else {
- log.Error("failed to serialize %+v to JSON: %v", ret, err)
}
-
return nil
}
diff --git a/modules/tcp_proxy/tcp_proxy_script_test.go b/modules/tcp_proxy/tcp_proxy_script_test.go
deleted file mode 100644
index 27bdc099..00000000
--- a/modules/tcp_proxy/tcp_proxy_script_test.go
+++ /dev/null
@@ -1,169 +0,0 @@
-package tcp_proxy
-
-import (
- "net"
- "testing"
-
- "github.com/evilsocket/islazy/plugin"
-)
-
-func TestOnData_NoReturn(t *testing.T) {
- jsCode := `
- function onData(from, to, data, callback) {
- // don't return anything
- }
- `
-
- plug, err := plugin.Parse(jsCode)
- if err != nil {
- t.Fatalf("Failed to parse plugin: %v", err)
- }
-
- script := &TcpProxyScript{
- Plugin: plug,
- doOnData: plug.HasFunc("onData"),
- }
-
- from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
- to := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 5678}
- data := []byte("test data")
-
- result := script.OnData(from, to, data, nil)
- if result != nil {
- t.Errorf("Expected nil result when callback returns nothing, got %v", result)
- }
-}
-
-func TestOnData_ReturnsArrayOfIntegers(t *testing.T) {
- jsCode := `
- function onData(from, to, data, callback) {
- // Return modified data as array of integers
- return [72, 101, 108, 108, 111]; // "Hello" in ASCII
- }
- `
-
- plug, err := plugin.Parse(jsCode)
- if err != nil {
- t.Fatalf("Failed to parse plugin: %v", err)
- }
-
- script := &TcpProxyScript{
- Plugin: plug,
- doOnData: plug.HasFunc("onData"),
- }
-
- from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
- to := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 5678}
- data := []byte("test data")
-
- result := script.OnData(from, to, data, nil)
- expected := []byte("Hello")
-
- if result == nil {
- t.Fatal("Expected non-nil result when callback returns array of integers")
- }
-
- if len(result) != len(expected) {
- t.Fatalf("Expected result length %d, got %d", len(expected), len(result))
- }
-
- for i, b := range result {
- if b != expected[i] {
- t.Errorf("Expected byte at index %d to be %d, got %d", i, expected[i], b)
- }
- }
-}
-
-func TestOnData_ReturnsDynamicArray(t *testing.T) {
- jsCode := `
- function onData(from, to, data, callback) {
- var result = [];
- for (var i = 0; i < data.length; i++) {
- result.push((data[i] + 1) % 256);
- }
- return result;
- }
- `
-
- plug, err := plugin.Parse(jsCode)
- if err != nil {
- t.Fatalf("Failed to parse plugin: %v", err)
- }
-
- script := &TcpProxyScript{
- Plugin: plug,
- doOnData: plug.HasFunc("onData"),
- }
-
- from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
- to := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 5678}
- data := []byte{10, 20, 30, 40, 255}
-
- result := script.OnData(from, to, data, nil)
- expected := []byte{11, 21, 31, 41, 0} // 255 + 1 = 256 % 256 = 0
-
- if result == nil {
- t.Fatal("Expected non-nil result when callback returns array of integers")
- }
-
- if len(result) != len(expected) {
- t.Fatalf("Expected result length %d, got %d", len(expected), len(result))
- }
-
- for i, b := range result {
- if b != expected[i] {
- t.Errorf("Expected byte at index %d to be %d, got %d", i, expected[i], b)
- }
- }
-}
-
-func TestOnData_ReturnsMixedArray(t *testing.T) {
- jsCode := `
- function charToInt(value) {
- return value.charCodeAt()
- }
-
- function onData(from, to, data) {
- st_data = String.fromCharCode.apply(null, data)
- if( st_data.indexOf("mysearch") != -1 ) {
- payload = "mypayload";
- st_data = st_data.replace("mysearch", payload);
- res_int_arr = st_data.split("").map(charToInt) // []uint16
- res_int_arr[0] = payload.length + 1; // first index is float64 and rest []uint16
- return res_int_arr;
- }
- return data;
- }
- `
-
- plug, err := plugin.Parse(jsCode)
- if err != nil {
- t.Fatalf("Failed to parse plugin: %v", err)
- }
-
- script := &TcpProxyScript{
- Plugin: plug,
- doOnData: plug.HasFunc("onData"),
- }
-
- from := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
- to := &net.TCPAddr{IP: net.ParseIP("192.168.1.6"), Port: 5678}
- data := []byte("Hello mysearch world")
-
- result := script.OnData(from, to, data, nil)
- expected := []byte("\x0aello mypayload world")
-
- if result == nil {
- t.Fatal("Expected non-nil result when callback returns array of integers")
- }
-
- if len(result) != len(expected) {
- t.Fatalf("Expected result length %d, got %d", len(expected), len(result))
- }
-
- for i, b := range result {
- if b != expected[i] {
- t.Errorf("Expected byte at index %d to be %d, got %d", i, expected[i], b)
- }
- }
-}
diff --git a/modules/ticker/ticker.go b/modules/ticker/ticker.go
index 34c4c02b..e629d2f0 100644
--- a/modules/ticker/ticker.go
+++ b/modules/ticker/ticker.go
@@ -43,7 +43,7 @@ func NewTicker(s *session.Session) *Ticker {
}))
mod.AddHandler(session.NewModuleHandler("ticker off", "",
- "Stop the main ticker.",
+ "Stop the maint icker.",
func(args []string) error {
return mod.Stop()
}))
diff --git a/modules/ticker/ticker_test.go b/modules/ticker/ticker_test.go
deleted file mode 100644
index 9b1b97a5..00000000
--- a/modules/ticker/ticker_test.go
+++ /dev/null
@@ -1,413 +0,0 @@
-package ticker
-
-import (
- "sync"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewTicker(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- if mod == nil {
- t.Fatal("NewTicker returned nil")
- }
-
- if mod.Name() != "ticker" {
- t.Errorf("Expected name 'ticker', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check parameters exist
- if err, _ := mod.StringParam("ticker.commands"); err != nil {
- t.Error("ticker.commands parameter not found")
- }
-
- if err, _ := mod.IntParam("ticker.period"); err != nil {
- t.Error("ticker.period parameter not found")
- }
-
- // Check handlers - only check the main ones since create/destroy have regex patterns
- handlers := []string{"ticker on", "ticker off"}
- for _, handler := range handlers {
- found := false
- for _, h := range mod.Handlers() {
- if h.Name == handler {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("Handler '%s' not found", handler)
- }
- }
-
- // Check that we have handlers for create and destroy (they have regex patterns)
- hasCreate := false
- hasDestroy := false
- for _, h := range mod.Handlers() {
- if h.Name == "ticker.create " {
- hasCreate = true
- } else if h.Name == "ticker.destroy " {
- hasDestroy = true
- }
- }
- if !hasCreate {
- t.Error("ticker.create handler not found")
- }
- if !hasDestroy {
- t.Error("ticker.destroy handler not found")
- }
-}
-
-func TestTickerConfigure(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Test configure before start
- if err := mod.Configure(); err != nil {
- t.Errorf("Configure failed: %v", err)
- }
-
- // Check main params were set
- if mod.main.Period == 0 {
- t.Error("Period not set")
- }
-
- if len(mod.main.Commands) == 0 {
- t.Error("Commands not set")
- }
-
- if !mod.main.Running {
- t.Error("Running flag not set")
- }
-}
-
-func TestTickerStartStop(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Set a short period for testing using session environment
- mod.Session.Env.Set("ticker.period", "1")
- mod.Session.Env.Set("ticker.commands", "help")
-
- // Start ticker
- if err := mod.Start(); err != nil {
- t.Fatalf("Failed to start ticker: %v", err)
- }
-
- if !mod.Running() {
- t.Error("Ticker should be running")
- }
-
- // Let it run briefly
- time.Sleep(100 * time.Millisecond)
-
- // Stop ticker
- if err := mod.Stop(); err != nil {
- t.Fatalf("Failed to stop ticker: %v", err)
- }
-
- if mod.Running() {
- t.Error("Ticker should not be running")
- }
-
- if mod.main.Running {
- t.Error("Main ticker should not be running")
- }
-}
-
-func TestTickerAlreadyStarted(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Start ticker
- if err := mod.Start(); err != nil {
- t.Fatalf("Failed to start ticker: %v", err)
- }
-
- // Try to configure while running
- if err := mod.Configure(); err == nil {
- t.Error("Configure should fail when already running")
- }
-
- // Stop ticker
- mod.Stop()
-}
-
-func TestTickerNamedOperations(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Create named ticker
- name := "test_ticker"
- if err := mod.createNamed(name, 1, "help"); err != nil {
- t.Fatalf("Failed to create named ticker: %v", err)
- }
-
- // Check it was created
- if _, found := mod.named[name]; !found {
- t.Error("Named ticker not found in map")
- }
-
- // Try to create duplicate
- if err := mod.createNamed(name, 1, "help"); err == nil {
- t.Error("Should not allow duplicate named ticker")
- }
-
- // Let it run briefly
- time.Sleep(100 * time.Millisecond)
-
- // Destroy named ticker
- if err := mod.destroyNamed(name); err != nil {
- t.Fatalf("Failed to destroy named ticker: %v", err)
- }
-
- // Check it was removed
- if _, found := mod.named[name]; found {
- t.Error("Named ticker still in map after destroy")
- }
-
- // Try to destroy non-existent
- if err := mod.destroyNamed("nonexistent"); err == nil {
- t.Error("Should fail when destroying non-existent ticker")
- }
-}
-
-func TestTickerHandlers(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- tests := []struct {
- name string
- handler string
- regex string
- args []string
- wantErr bool
- }{
- {
- name: "ticker on",
- handler: "ticker on",
- args: []string{},
- wantErr: false,
- },
- {
- name: "ticker off",
- handler: "ticker off",
- args: []string{},
- wantErr: true, // ticker off will fail if not running
- },
- {
- name: "ticker.create valid",
- handler: "ticker.create ",
- args: []string{"myticker", "2", "help; events.show"},
- wantErr: false,
- },
- {
- name: "ticker.create invalid period",
- handler: "ticker.create ",
- args: []string{"myticker", "notanumber", "help"},
- wantErr: true,
- },
- {
- name: "ticker.destroy",
- handler: "ticker.destroy ",
- args: []string{"myticker"},
- wantErr: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Find the handler
- var handler *session.ModuleHandler
- for _, h := range mod.Handlers() {
- if h.Name == tt.handler {
- handler = &h
- break
- }
- }
-
- if handler == nil {
- t.Fatalf("Handler '%s' not found", tt.handler)
- }
-
- // Create ticker if needed for destroy test
- if tt.handler == "ticker.destroy " && len(tt.args) > 0 && tt.args[0] == "myticker" {
- mod.createNamed("myticker", 1, "help")
- }
-
- // Execute handler
- err := handler.Exec(tt.args)
- if (err != nil) != tt.wantErr {
- t.Errorf("Handler execution error = %v, wantErr %v", err, tt.wantErr)
- }
-
- // Cleanup
- if tt.handler == "ticker on" || tt.handler == "ticker.create " {
- mod.Stop()
- }
- })
- }
-}
-
-func TestTickerWorker(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Create params for testing
- params := &Params{
- Commands: []string{"help"},
- Period: 100 * time.Millisecond,
- Running: true,
- }
-
- // Start worker in goroutine
- done := make(chan bool)
- go func() {
- mod.worker("test", params)
- done <- true
- }()
-
- // Let it tick at least once
- time.Sleep(150 * time.Millisecond)
-
- // Stop the worker
- params.Running = false
-
- // Wait for worker to finish
- select {
- case <-done:
- // Worker finished successfully
- case <-time.After(1 * time.Second):
- t.Error("Worker did not stop in time")
- }
-}
-
-func TestTickerParams(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Test setting invalid period
- mod.Session.Env.Set("ticker.period", "invalid")
- if err := mod.Configure(); err == nil {
- t.Error("Configure should fail with invalid period")
- }
-
- // Test empty commands
- mod.Session.Env.Set("ticker.period", "1")
- mod.Session.Env.Set("ticker.commands", "")
- if err := mod.Configure(); err != nil {
- t.Errorf("Configure should work with empty commands: %v", err)
- }
-}
-
-func TestTickerMultipleNamed(t *testing.T) {
- s := createMockSession(t)
- mod := NewTicker(s)
-
- // Start the ticker first
- if err := mod.Start(); err != nil {
- t.Fatalf("Failed to start ticker: %v", err)
- }
-
- // Create multiple named tickers
- names := []string{"ticker1", "ticker2", "ticker3"}
- for _, name := range names {
- if err := mod.createNamed(name, 1, "help"); err != nil {
- t.Errorf("Failed to create ticker '%s': %v", name, err)
- }
- }
-
- // Check all were created
- if len(mod.named) != len(names) {
- t.Errorf("Expected %d named tickers, got %d", len(names), len(mod.named))
- }
-
- // Stop all via Stop()
- if err := mod.Stop(); err != nil {
- t.Fatalf("Failed to stop: %v", err)
- }
-
- // Check all were stopped
- for name, params := range mod.named {
- if params.Running {
- t.Errorf("Ticker '%s' still running after Stop()", name)
- }
- }
-}
-
-func TestTickEvent(t *testing.T) {
- // Simple test for TickEvent struct
- event := TickEvent{}
- // TickEvent is empty, just ensure it can be created
- _ = event
-}
-
-// Benchmark tests
-func BenchmarkTickerCreate(b *testing.B) {
- // Use existing session to avoid flag redefinition
- s := testSession
- if s == nil {
- var err error
- s, err = session.New()
- if err != nil {
- b.Fatal(err)
- }
- testSession = s
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- mod := NewTicker(s)
- _ = mod
- }
-}
-
-func BenchmarkTickerStartStop(b *testing.B) {
- // Use existing session to avoid flag redefinition
- s := testSession
- if s == nil {
- var err error
- s, err = session.New()
- if err != nil {
- b.Fatal(err)
- }
- testSession = s
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- mod := NewTicker(s)
- // Set period parameter
- mod.Session.Env.Set("ticker.period", "1")
- mod.Start()
- mod.Stop()
- }
-}
diff --git a/modules/update/update_test.go b/modules/update/update_test.go
deleted file mode 100644
index f112fc14..00000000
--- a/modules/update/update_test.go
+++ /dev/null
@@ -1,348 +0,0 @@
-package update
-
-import (
- "sync"
- "testing"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-func TestNewUpdateModule(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- if mod == nil {
- t.Fatal("NewUpdateModule returned nil")
- }
-
- if mod.Name() != "update" {
- t.Errorf("Expected name 'update', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check handler
- handlers := mod.Handlers()
- if len(handlers) != 1 {
- t.Errorf("Expected 1 handler, got %d", len(handlers))
- }
-
- if len(handlers) > 0 && handlers[0].Name != "update.check on" {
- t.Errorf("Expected handler 'update.check on', got '%s'", handlers[0].Name)
- }
-}
-
-func TestVersionToNum(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- tests := []struct {
- name string
- version string
- want float64
- }{
- {
- name: "simple version",
- version: "1.2.3",
- want: 123, // 3*1 + 2*10 + 1*100
- },
- {
- name: "version with v prefix",
- version: "v1.2.3",
- want: 123,
- },
- {
- name: "major version only",
- version: "2",
- want: 2,
- },
- {
- name: "major.minor version",
- version: "2.1",
- want: 21, // 1*1 + 2*10
- },
- {
- name: "zero version",
- version: "0.0.0",
- want: 0,
- },
- {
- name: "large patch version",
- version: "1.0.10",
- want: 110, // 10*1 + 0*10 + 1*100
- },
- {
- name: "very large version",
- version: "10.20.30",
- want: 1230, // 30*1 + 20*10 + 10*100
- },
- {
- name: "version with leading v",
- version: "v2.2.0",
- want: 220, // 0*1 + 2*10 + 2*100
- },
- {
- name: "single digit versions",
- version: "1.1.1",
- want: 111, // 1*1 + 1*10 + 1*100
- },
- {
- name: "asymmetric version",
- version: "1.10.100",
- want: 300, // 100*1 + 10*10 + 1*100
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := mod.versionToNum(tt.version)
- if got != tt.want {
- t.Errorf("versionToNum(%q) = %v, want %v", tt.version, got, tt.want)
- }
- })
- }
-}
-
-func TestVersionComparison(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- tests := []struct {
- name string
- current string
- latest string
- isNewer bool
- }{
- {
- name: "newer patch version",
- current: "1.2.3",
- latest: "1.2.4",
- isNewer: true,
- },
- {
- name: "newer minor version",
- current: "1.2.3",
- latest: "1.3.0",
- isNewer: true,
- },
- {
- name: "newer major version",
- current: "1.2.3",
- latest: "2.0.0",
- isNewer: true,
- },
- {
- name: "same version",
- current: "1.2.3",
- latest: "1.2.3",
- isNewer: false,
- },
- {
- name: "older version",
- current: "2.0.0",
- latest: "1.9.9",
- isNewer: false,
- },
- {
- name: "v prefix handling",
- current: "v1.2.3",
- latest: "v1.2.4",
- isNewer: true,
- },
- {
- name: "mixed v prefix",
- current: "1.2.3",
- latest: "v1.2.4",
- isNewer: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- currentNum := mod.versionToNum(tt.current)
- latestNum := mod.versionToNum(tt.latest)
-
- isNewer := currentNum < latestNum
- if isNewer != tt.isNewer {
- t.Errorf("Expected %s < %s to be %v, but got %v (%.2f vs %.2f)",
- tt.current, tt.latest, tt.isNewer, isNewer, currentNum, latestNum)
- }
- })
- }
-}
-
-func TestConfigure(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- if err := mod.Configure(); err != nil {
- t.Errorf("Configure() error = %v", err)
- }
-}
-
-func TestStop(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- if err := mod.Stop(); err != nil {
- t.Errorf("Stop() error = %v", err)
- }
-}
-
-func TestModuleRunning(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-}
-
-func TestVersionEdgeCases(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- tests := []struct {
- name string
- version string
- want float64
- wantErr bool
- }{
- {
- name: "empty version",
- version: "",
- want: 0,
- wantErr: true, // Will panic on ver[0] access
- },
- {
- name: "only v",
- version: "v",
- want: 0,
- wantErr: true, // Will panic after stripping v
- },
- {
- name: "non-numeric version",
- version: "va.b.c",
- want: 0, // strconv.Atoi will return 0 for non-numeric
- },
- {
- name: "partial numeric",
- version: "1.a.3",
- want: 103, // 3*1 + 0*10 + 1*100 (a converts to 0)
- },
- {
- name: "extra dots",
- version: "1.2.3.4",
- want: 1234, // 4*1 + 3*10 + 2*100 + 1*1000
- },
- {
- name: "trailing dot",
- version: "1.2.",
- want: 120, // splits to ["1","2",""], reverses to ["","2","1"], = 0*1 + 2*10 + 1*100
- },
- {
- name: "leading dot",
- version: ".1.2",
- want: 12, // splits to ["","1","2"], reverses to ["2","1",""], = 2*1 + 1*10 + 0*100
- },
- {
- name: "single part",
- version: "42",
- want: 42,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Skip tests that would panic due to empty version
- if tt.wantErr {
- // These would panic, so skip them
- t.Skip("Skipping test that would panic")
- return
- }
-
- got := mod.versionToNum(tt.version)
- if got != tt.want {
- t.Errorf("versionToNum(%q) = %v, want %v", tt.version, got, tt.want)
- }
- })
- }
-}
-
-func TestHandlerExecution(t *testing.T) {
- s := createMockSession(t)
- mod := NewUpdateModule(s)
-
- // Find the handler
- var handler *session.ModuleHandler
- for _, h := range mod.Handlers() {
- if h.Name == "update.check on" {
- handler = &h
- break
- }
- }
-
- if handler == nil {
- t.Fatal("Handler 'update.check on' not found")
- }
-
- // Note: This will make a real API call to GitHub
- // In a production test suite, you'd want to mock the GitHub client
- // For now, we'll just check that the handler can be executed
- // The actual Start() method will be tested separately
-}
-
-// Benchmark tests
-func BenchmarkVersionToNum(b *testing.B) {
- s, _ := session.New()
- mod := NewUpdateModule(s)
-
- versions := []string{
- "1.2.3",
- "v2.4.6",
- "10.20.30",
- "v100.200.300",
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- for _, v := range versions {
- mod.versionToNum(v)
- }
- }
-}
-
-func BenchmarkVersionComparison(b *testing.B) {
- s, _ := session.New()
- mod := NewUpdateModule(s)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- current := mod.versionToNum("1.2.3")
- latest := mod.versionToNum("1.2.4")
- _ = current < latest
- }
-}
diff --git a/modules/utils/view_selector_test.go b/modules/utils/view_selector_test.go
deleted file mode 100644
index e2a9c609..00000000
--- a/modules/utils/view_selector_test.go
+++ /dev/null
@@ -1,455 +0,0 @@
-package utils
-
-import (
- "regexp"
- "sync"
- "testing"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- })
- return testSession
-}
-
-type mockModule struct {
- session.SessionModule
-}
-
-func newMockModule(s *session.Session) *mockModule {
- return &mockModule{
- SessionModule: session.NewSessionModule("test", s),
- }
-}
-
-func TestViewSelectorFor(t *testing.T) {
- s := createMockSession(t)
- m := newMockModule(s)
-
- sortFields := []string{"name", "mac", "seen"}
- defExpression := "seen desc"
- prefix := "test"
-
- vs := ViewSelectorFor(&m.SessionModule, prefix, sortFields, defExpression)
-
- if vs == nil {
- t.Fatal("ViewSelectorFor returned nil")
- }
-
- if vs.owner != &m.SessionModule {
- t.Error("ViewSelector owner not set correctly")
- }
-
- if vs.filterName != "test.filter" {
- t.Errorf("filterName = %s, want test.filter", vs.filterName)
- }
-
- if vs.sortName != "test.sort" {
- t.Errorf("sortName = %s, want test.sort", vs.sortName)
- }
-
- if vs.limitName != "test.limit" {
- t.Errorf("limitName = %s, want test.limit", vs.limitName)
- }
-
- // Check that parameters were added by trying to retrieve them
- if err, _ := m.SessionModule.StringParam("test.filter"); err != nil {
- t.Error("filter parameter not accessible")
- }
- if err, _ := m.SessionModule.StringParam("test.sort"); err != nil {
- t.Error("sort parameter not accessible")
- }
- if err, _ := m.SessionModule.IntParam("test.limit"); err != nil {
- t.Error("limit parameter not accessible")
- }
-
- // Check default sorting
- if vs.SortField != "seen" {
- t.Errorf("Default SortField = %s, want seen", vs.SortField)
- }
- if vs.Sort != "desc" {
- t.Errorf("Default Sort = %s, want desc", vs.Sort)
- }
-}
-
-func TestParseFilter(t *testing.T) {
- s := createMockSession(t)
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name"}, "name asc")
-
- tests := []struct {
- name string
- filter string
- wantErr bool
- wantExpr bool
- }{
- {
- name: "empty filter",
- filter: "",
- wantErr: false,
- wantExpr: false,
- },
- {
- name: "valid regex",
- filter: "^test.*",
- wantErr: false,
- wantExpr: true,
- },
- {
- name: "invalid regex",
- filter: "[invalid",
- wantErr: true,
- wantExpr: false,
- },
- {
- name: "simple string",
- filter: "test",
- wantErr: false,
- wantExpr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Set the filter parameter
- m.Session.Env.Set("test.filter", tt.filter)
-
- err := vs.parseFilter()
- if (err != nil) != tt.wantErr {
- t.Errorf("parseFilter() error = %v, wantErr %v", err, tt.wantErr)
- }
-
- if tt.wantExpr && vs.Expression == nil {
- t.Error("Expected Expression to be set, but it's nil")
- }
- if !tt.wantExpr && vs.Expression != nil {
- t.Error("Expected Expression to be nil, but it's set")
- }
-
- if tt.filter != "" && !tt.wantErr {
- if vs.Filter != tt.filter {
- t.Errorf("Filter = %s, want %s", vs.Filter, tt.filter)
- }
- }
- })
- }
-}
-
-func TestParseSorting(t *testing.T) {
- s := createMockSession(t)
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac", "seen"}, "name asc")
-
- tests := []struct {
- name string
- sortExpr string
- wantErr bool
- wantField string
- wantDirection string
- wantSymbol string
- }{
- {
- name: "name ascending",
- sortExpr: "name asc",
- wantErr: false,
- wantField: "name",
- wantDirection: "asc",
- wantSymbol: "▴", // Will be colored blue
- },
- {
- name: "mac descending",
- sortExpr: "mac desc",
- wantErr: false,
- wantField: "mac",
- wantDirection: "desc",
- wantSymbol: "▾", // Will be colored blue
- },
- {
- name: "seen descending",
- sortExpr: "seen desc",
- wantErr: false,
- wantField: "seen",
- wantDirection: "desc",
- wantSymbol: "▾",
- },
- {
- name: "invalid field",
- sortExpr: "invalid desc",
- wantErr: true,
- wantField: "",
- wantDirection: "",
- },
- {
- name: "invalid direction",
- sortExpr: "name invalid",
- wantErr: true,
- wantField: "",
- wantDirection: "",
- },
- {
- name: "malformed expression",
- sortExpr: "nameDesc",
- wantErr: true,
- wantField: "",
- wantDirection: "",
- },
- {
- name: "empty expression",
- sortExpr: "",
- wantErr: true,
- wantField: "",
- wantDirection: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Set the sort parameter
- m.Session.Env.Set("test.sort", tt.sortExpr)
-
- err := vs.parseSorting()
- if (err != nil) != tt.wantErr {
- t.Errorf("parseSorting() error = %v, wantErr %v", err, tt.wantErr)
- }
-
- if !tt.wantErr {
- if vs.SortField != tt.wantField {
- t.Errorf("SortField = %s, want %s", vs.SortField, tt.wantField)
- }
- if vs.Sort != tt.wantDirection {
- t.Errorf("Sort = %s, want %s", vs.Sort, tt.wantDirection)
- }
- // Check symbol contains expected character (stripping color codes)
- if !containsSymbol(vs.SortSymbol, tt.wantSymbol) {
- t.Errorf("SortSymbol doesn't contain %s", tt.wantSymbol)
- }
- }
- })
- }
-}
-
-func TestUpdate(t *testing.T) {
- s := createMockSession(t)
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac"}, "name asc")
-
- tests := []struct {
- name string
- filter string
- sort string
- limit string
- wantErr bool
- wantLimit int
- }{
- {
- name: "all valid",
- filter: "test.*",
- sort: "mac desc",
- limit: "10",
- wantErr: false,
- wantLimit: 10,
- },
- {
- name: "invalid filter",
- filter: "[invalid",
- sort: "name asc",
- limit: "5",
- wantErr: true,
- wantLimit: 0,
- },
- {
- name: "invalid sort",
- filter: "valid",
- sort: "invalid field",
- limit: "5",
- wantErr: true,
- wantLimit: 0,
- },
- {
- name: "invalid limit",
- filter: "valid",
- sort: "name asc",
- limit: "not a number",
- wantErr: true,
- wantLimit: 0,
- },
- {
- name: "zero limit",
- filter: "",
- sort: "name asc",
- limit: "0",
- wantErr: false,
- wantLimit: 0,
- },
- {
- name: "negative limit",
- filter: "",
- sort: "name asc",
- limit: "-1",
- wantErr: false,
- wantLimit: -1,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Set parameters
- m.Session.Env.Set("test.filter", tt.filter)
- m.Session.Env.Set("test.sort", tt.sort)
- m.Session.Env.Set("test.limit", tt.limit)
-
- err := vs.Update()
- if (err != nil) != tt.wantErr {
- t.Errorf("Update() error = %v, wantErr %v", err, tt.wantErr)
- }
-
- if !tt.wantErr {
- if vs.Limit != tt.wantLimit {
- t.Errorf("Limit = %d, want %d", vs.Limit, tt.wantLimit)
- }
- }
- })
- }
-}
-
-func TestFilterCaching(t *testing.T) {
- s := createMockSession(t)
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name"}, "name asc")
-
- // Set initial filter
- m.Session.Env.Set("test.filter", "test1")
- if err := vs.parseFilter(); err != nil {
- t.Fatalf("Failed to parse initial filter: %v", err)
- }
-
- firstExpr := vs.Expression
- if firstExpr == nil {
- t.Fatal("Expression should not be nil")
- }
-
- // Parse again with same filter - should use cached expression
- if err := vs.parseFilter(); err != nil {
- t.Fatalf("Failed to parse filter second time: %v", err)
- }
-
- // The filterPrev mechanism should prevent recompilation
- if vs.filterPrev != "test1" {
- t.Errorf("filterPrev = %s, want test1", vs.filterPrev)
- }
-
- // Change filter
- m.Session.Env.Set("test.filter", "test2")
- if err := vs.parseFilter(); err != nil {
- t.Fatalf("Failed to parse new filter: %v", err)
- }
-
- if vs.Filter != "test2" {
- t.Errorf("Filter = %s, want test2", vs.Filter)
- }
- if vs.filterPrev != "test2" {
- t.Errorf("filterPrev = %s, want test2", vs.filterPrev)
- }
-}
-
-func TestSortParserRegex(t *testing.T) {
- s := createMockSession(t)
- m := newMockModule(s)
-
- sortFields := []string{"field1", "field2", "complex_field"}
- vs := ViewSelectorFor(&m.SessionModule, "test", sortFields, "field1 asc")
-
- // Test the generated regex pattern
- expectedPattern := "(field1|field2|complex_field) (desc|asc)"
- if vs.sortParser != expectedPattern {
- t.Errorf("sortParser = %s, want %s", vs.sortParser, expectedPattern)
- }
-
- // Test regex compilation
- if vs.sortParse == nil {
- t.Fatal("sortParse regex is nil")
- }
-
- // Test regex matching
- testCases := []struct {
- expr string
- matches bool
- }{
- {"field1 asc", true},
- {"field2 desc", true},
- {"complex_field asc", true},
- {"invalid_field asc", false},
- {"field1 invalid", false},
- {"field1asc", false},
- {"", false},
- }
-
- for _, tc := range testCases {
- matches := vs.sortParse.MatchString(tc.expr)
- if matches != tc.matches {
- t.Errorf("sortParse.MatchString(%q) = %v, want %v", tc.expr, matches, tc.matches)
- }
- }
-}
-
-// Helper function to check if a string contains a symbol (ignoring ANSI color codes)
-func containsSymbol(s, symbol string) bool {
- // Remove ANSI color codes
- re := regexp.MustCompile(`\x1b\[[0-9;]*m`)
- cleaned := re.ReplaceAllString(s, "")
- return cleaned == symbol
-}
-
-// Benchmark tests
-func BenchmarkParseFilter(b *testing.B) {
- s, _ := session.New()
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name"}, "name asc")
-
- m.Session.Env.Set("test.filter", "test.*")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- vs.parseFilter()
- }
-}
-
-func BenchmarkParseSorting(b *testing.B) {
- s, _ := session.New()
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac", "seen"}, "name asc")
-
- m.Session.Env.Set("test.sort", "mac desc")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- vs.parseSorting()
- }
-}
-
-func BenchmarkUpdate(b *testing.B) {
- s, _ := session.New()
- m := newMockModule(s)
- vs := ViewSelectorFor(&m.SessionModule, "test", []string{"name", "mac"}, "name asc")
-
- m.Session.Env.Set("test.filter", "test")
- m.Session.Env.Set("test.sort", "mac desc")
- m.Session.Env.Set("test.limit", "10")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- vs.Update()
- }
-}
diff --git a/modules/wifi/wifi.go b/modules/wifi/wifi.go
index 2a000f4b..dea727b1 100644
--- a/modules/wifi/wifi.go
+++ b/modules/wifi/wifi.go
@@ -104,10 +104,7 @@ func NewWiFiModule(s *session.Session) *WiFiModule {
}
mod.InitState("channels")
- mod.InitState("channel")
-
mod.State.Store("channels", []int{})
- mod.State.Store("channel", 0)
mod.AddParam(session.NewStringParameter("wifi.interface",
"",
@@ -265,8 +262,8 @@ func NewWiFiModule(s *session.Session) *WiFiModule {
mod.AddHandler(probe)
- channelSwitchAnnounce := session.NewModuleHandler("wifi.channel_switch_announce BSSID CHANNEL ", `wifi\.channel_switch_announce ((?:[a-fA-F0-9:]{11,}))\s+((?:[0-9]+))`,
- "Start a 802.11 channel hop attack, all client will be forced to change the channel lead to connection down.",
+ channelSwitchAnnounce := session.NewModuleHandler("wifi.channel_switch_announce bssid channel ", `wifi\.channel_switch_announce ((?:[a-fA-F0-9:]{11,}))\s+((?:[0-9]+))`,
+ "Start a 802.11 channel hop attack, all client will be force to change the channel lead to connection down.",
func(args []string) error {
bssid, err := net.ParseMAC(args[0])
if err != nil {
@@ -651,22 +648,19 @@ func (mod *WiFiModule) Configure() error {
mod.hopPeriod = time.Duration(hopPeriod) * time.Millisecond
if mod.source == "" {
- if len(mod.frequencies) == 0 {
- if freqs, err := network.GetSupportedFrequencies(ifName); err != nil {
- return fmt.Errorf("error while getting supported frequencies of %s: %s", ifName, err)
- } else {
- mod.setFrequencies(freqs)
- }
-
- mod.Debug("wifi supported frequencies: %v", mod.frequencies)
+ if freqs, err := network.GetSupportedFrequencies(ifName); err != nil {
+ return fmt.Errorf("error while getting supported frequencies of %s: %s", ifName, err)
+ } else {
+ mod.setFrequencies(freqs)
}
+ mod.Debug("wifi supported frequencies: %v", mod.frequencies)
+
// we need to start somewhere, this is just to check if
// this OS supports switching channel programmatically.
if err = network.SetInterfaceChannel(ifName, 1); err != nil {
return fmt.Errorf("error while initializing %s to channel 1: %s", ifName, err)
}
- mod.State.Store("channel", 1)
mod.Info("started (min rssi: %d dBm)", mod.minRSSI)
}
diff --git a/modules/wifi/wifi_hopping.go b/modules/wifi/wifi_hopping.go
index 03797908..43b5fe7d 100644
--- a/modules/wifi/wifi_hopping.go
+++ b/modules/wifi/wifi_hopping.go
@@ -36,8 +36,6 @@ func (mod *WiFiModule) hopUnlocked(channel int) (mustStop bool) {
}
}
- mod.State.Store("channel", channel)
-
return
}
diff --git a/modules/wifi/wifi_test.go b/modules/wifi/wifi_test.go
deleted file mode 100644
index afd5322c..00000000
--- a/modules/wifi/wifi_test.go
+++ /dev/null
@@ -1,629 +0,0 @@
-package wifi
-
-import (
- "bytes"
- "net"
- "regexp"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/network"
- "github.com/bettercap/bettercap/v2/packets"
- "github.com/bettercap/bettercap/v2/session"
- "github.com/evilsocket/islazy/data"
-)
-
-// Create a mock session for testing
-func createMockSession() *session.Session {
- // Create interface
- iface := &network.Endpoint{
- IpAddress: "192.168.1.100",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "wlan0",
- }
- iface.SetIP("192.168.1.100")
- iface.SetBits(24)
-
- // Parse interface addresses
- ifaceIP := net.ParseIP("192.168.1.100")
- ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- iface.IP = ifaceIP
- iface.HW = ifaceHW
-
- // Create gateway
- gateway := &network.Endpoint{
- IpAddress: "192.168.1.1",
- HwAddress: "11:22:33:44:55:66",
- }
- gatewayIP := net.ParseIP("192.168.1.1")
- gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66")
- gateway.IP = gatewayIP
- gateway.HW = gatewayHW
-
- // Create environment
- env, _ := session.NewEnvironment("")
-
- // Create LAN
- aliases, _ := data.NewUnsortedKV("", 0)
- lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- // Create session
- sess := &session.Session{
- Interface: iface,
- Gateway: gateway,
- Lan: lan,
- StartedAt: time.Now(),
- Active: true,
- Env: env,
- Queue: &packets.Queue{},
- Modules: make(session.ModuleList, 0),
- }
-
- // Initialize events
- sess.Events = session.NewEventPool(false, false)
-
- // Initialize WiFi state
- sess.WiFi = network.NewWiFi(iface, aliases, func(ap *network.AccessPoint) {}, func(ap *network.AccessPoint) {})
-
- return sess
-}
-
-func TestNewWiFiModule(t *testing.T) {
- sess := createMockSession()
-
- mod := NewWiFiModule(sess)
-
- if mod == nil {
- t.Fatal("NewWiFiModule returned nil")
- }
-
- if mod.Name() != "wifi" {
- t.Errorf("expected module name 'wifi', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli && Gianluca Braga " {
- t.Errorf("unexpected author: %s", mod.Author())
- }
-
- // Check parameters
- params := []string{
- "wifi.interface",
- "wifi.rssi.min",
- "wifi.deauth.skip",
- "wifi.deauth.silent",
- "wifi.deauth.open",
- "wifi.deauth.acquired",
- "wifi.assoc.skip",
- "wifi.assoc.silent",
- "wifi.assoc.open",
- "wifi.assoc.acquired",
- "wifi.ap.ttl",
- "wifi.sta.ttl",
- "wifi.region",
- "wifi.txpower",
- "wifi.handshakes.file",
- "wifi.handshakes.aggregate",
- "wifi.ap.ssid",
- "wifi.ap.bssid",
- "wifi.ap.channel",
- "wifi.ap.encryption",
- "wifi.show.manufacturer",
- "wifi.source.file",
- "wifi.hop.period",
- "wifi.skip-broken",
- "wifi.channel_switch_announce.silent",
- "wifi.fake_auth.silent",
- "wifi.bruteforce.target",
- "wifi.bruteforce.wordlist",
- "wifi.bruteforce.workers",
- "wifi.bruteforce.wide",
- "wifi.bruteforce.stop_at_first",
- "wifi.bruteforce.timeout",
- }
- for _, param := range params {
- if !mod.Session.Env.Has(param) {
- t.Errorf("parameter %s not registered", param)
- }
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{
- "wifi.recon on",
- "wifi.recon off",
- "wifi.clear",
- "wifi.recon MAC",
- "wifi.recon clear",
- "wifi.deauth BSSID",
- "wifi.probe BSSID ESSID",
- "wifi.assoc BSSID",
- "wifi.ap",
- "wifi.show.wps BSSID",
- "wifi.show",
- "wifi.recon.channel CHANNEL",
- "wifi.client.probe.sta.filter FILTER",
- "wifi.client.probe.ap.filter FILTER",
- "wifi.channel_switch_announce bssid channel ",
- "wifi.fake_auth bssid client",
- "wifi.bruteforce on",
- "wifi.bruteforce off",
- }
-
- if len(handlers) != len(expectedHandlers) {
- t.Errorf("expected %d handlers, got %d", len(expectedHandlers), len(handlers))
- }
-}
-
-func TestWiFiModuleConfigure(t *testing.T) {
- tests := []struct {
- name string
- params map[string]string
- expectErr bool
- }{
- {
- name: "default configuration",
- params: map[string]string{
- "wifi.interface": "",
- "wifi.ap.ttl": "300",
- "wifi.sta.ttl": "300",
- "wifi.region": "",
- "wifi.txpower": "30",
- "wifi.source.file": "",
- "wifi.rssi.min": "-200",
- "wifi.handshakes.file": "~/bettercap-wifi-handshakes.pcap",
- "wifi.handshakes.aggregate": "true",
- "wifi.hop.period": "250",
- "wifi.skip-broken": "true",
- },
- expectErr: true, // Will fail without actual interface
- },
- {
- name: "invalid rssi",
- params: map[string]string{
- "wifi.rssi.min": "not-a-number",
- },
- expectErr: true,
- },
- {
- name: "invalid hop period",
- params: map[string]string{
- "wifi.hop.period": "invalid",
- },
- expectErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Set parameters
- for k, v := range tt.params {
- sess.Env.Set(k, v)
- }
-
- err := mod.Configure()
-
- if tt.expectErr && err == nil {
- t.Error("expected error but got none")
- } else if !tt.expectErr && err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- })
- }
-}
-
-func TestWiFiModuleFrequencies(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test setting frequencies
- freqs := []int{2412, 2437, 2462, 5180, 5200} // Channels 1, 6, 11, 36, 40
- mod.setFrequencies(freqs)
-
- if len(mod.frequencies) != len(freqs) {
- t.Errorf("expected %d frequencies, got %d", len(freqs), len(mod.frequencies))
- }
-
- // Check if channels were properly converted
- channels, _ := mod.State.Load("channels")
- channelList := channels.([]int)
- expectedChannels := []int{1, 6, 11, 36, 40}
-
- if len(channelList) != len(expectedChannels) {
- t.Errorf("expected %d channels, got %d", len(expectedChannels), len(channelList))
- }
-}
-
-func TestWiFiModuleFilters(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test STA filter
- handlers := mod.Handlers()
- var staFilterHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "wifi.client.probe.sta.filter FILTER" {
- staFilterHandler = h
- break
- }
- }
-
- if staFilterHandler.Name == "" {
- t.Fatal("STA filter handler not found")
- }
-
- // Set a filter
- err := staFilterHandler.Exec([]string{"^aa:bb:.*"})
- if err != nil {
- t.Errorf("Failed to set STA filter: %v", err)
- }
-
- if mod.filterProbeSTA == nil {
- t.Error("STA filter was not set")
- }
-
- // Clear filter
- err = staFilterHandler.Exec([]string{"clear"})
- if err != nil {
- t.Errorf("Failed to clear STA filter: %v", err)
- }
-
- if mod.filterProbeSTA != nil {
- t.Error("STA filter was not cleared")
- }
-
- // Test AP filter
- var apFilterHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "wifi.client.probe.ap.filter FILTER" {
- apFilterHandler = h
- break
- }
- }
-
- if apFilterHandler.Name == "" {
- t.Fatal("AP filter handler not found")
- }
-
- // Set a filter
- err = apFilterHandler.Exec([]string{"^TestAP.*"})
- if err != nil {
- t.Errorf("Failed to set AP filter: %v", err)
- }
-
- if mod.filterProbeAP == nil {
- t.Error("AP filter was not set")
- }
-}
-
-func TestWiFiModuleDeauth(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test deauth handler
- handlers := mod.Handlers()
- var deauthHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "wifi.deauth BSSID" {
- deauthHandler = h
- break
- }
- }
-
- if deauthHandler.Name == "" {
- t.Fatal("Deauth handler not found")
- }
-
- // Test with "all"
- err := deauthHandler.Exec([]string{"all"})
- if err == nil {
- t.Error("Expected error when starting deauth without running module")
- }
-
- // Test with invalid MAC
- err = deauthHandler.Exec([]string{"invalid-mac"})
- if err == nil {
- t.Error("Expected error with invalid MAC address")
- }
-}
-
-func TestWiFiModuleChannelHandler(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test channel handler
- handlers := mod.Handlers()
- var channelHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "wifi.recon.channel CHANNEL" {
- channelHandler = h
- break
- }
- }
-
- if channelHandler.Name == "" {
- t.Fatal("Channel handler not found")
- }
-
- // Test with valid channels
- err := channelHandler.Exec([]string{"1,6,11"})
- if err != nil {
- t.Errorf("Failed to set channels: %v", err)
- }
-
- // Test with invalid channel
- err = channelHandler.Exec([]string{"999"})
- if err == nil {
- t.Error("Expected error with invalid channel")
- }
-
- // Test clear
- err = channelHandler.Exec([]string{"clear"})
- if err == nil {
- // Will fail without actual interface but should parse correctly
- t.Log("Clear channels parsed correctly")
- }
-}
-
-func TestWiFiModuleShow(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test show handler exists
- handlers := mod.Handlers()
- found := false
- for _, h := range handlers {
- if h.Name == "wifi.show" {
- found = true
- break
- }
- }
-
- if !found {
- t.Fatal("Show handler not found")
- }
-
- // Skip actual execution as it requires UI components
- t.Log("Show handler found, skipping execution due to UI dependencies")
-}
-
-func TestWiFiModuleShowWPS(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test show WPS handler exists
- handlers := mod.Handlers()
- found := false
- for _, h := range handlers {
- if h.Name == "wifi.show.wps BSSID" {
- found = true
- break
- }
- }
-
- if !found {
- t.Fatal("Show WPS handler not found")
- }
-
- // Skip actual execution as it requires UI components
- t.Log("Show WPS handler found, skipping execution due to UI dependencies")
-}
-
-func TestWiFiModuleBruteforce(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Check bruteforce config
- if mod.bruteforce == nil {
- t.Fatal("Bruteforce config not initialized")
- }
-
- // Test bruteforce parameters
- params := map[string]string{
- "wifi.bruteforce.target": "TestAP",
- "wifi.bruteforce.wordlist": "/tmp/wordlist.txt",
- "wifi.bruteforce.workers": "4",
- "wifi.bruteforce.wide": "true",
- "wifi.bruteforce.stop_at_first": "true",
- "wifi.bruteforce.timeout": "30",
- }
-
- for k, v := range params {
- sess.Env.Set(k, v)
- }
-
- // Verify parameters were set
- if err, target := mod.StringParam("wifi.bruteforce.target"); err != nil {
- t.Errorf("Failed to get bruteforce target: %v", err)
- } else if target != "TestAP" {
- t.Errorf("Expected target 'TestAP', got '%s'", target)
- }
-}
-
-func TestWiFiModuleAPConfig(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Set AP parameters
- params := map[string]string{
- "wifi.ap.ssid": "TestAP",
- "wifi.ap.bssid": "aa:bb:cc:dd:ee:ff",
- "wifi.ap.channel": "6",
- "wifi.ap.encryption": "true",
- }
-
- for k, v := range params {
- sess.Env.Set(k, v)
- }
-
- // Parse AP config
- err := mod.parseApConfig()
- if err != nil {
- t.Errorf("Failed to parse AP config: %v", err)
- }
-
- // Verify config
- if mod.apConfig.SSID != "TestAP" {
- t.Errorf("Expected SSID 'TestAP', got '%s'", mod.apConfig.SSID)
- }
-
- if !bytes.Equal(mod.apConfig.BSSID, []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}) {
- t.Errorf("BSSID mismatch")
- }
-
- if mod.apConfig.Channel != 6 {
- t.Errorf("Expected channel 6, got %d", mod.apConfig.Channel)
- }
-
- if !mod.apConfig.Encryption {
- t.Error("Expected encryption to be enabled")
- }
-}
-
-func TestWiFiModuleSkipMACs(t *testing.T) {
- // Skip this test as updateDeauthSkipList and updateAssocSkipList are private methods
- t.Skip("Skipping test for private skip list methods")
-}
-
-func TestWiFiModuleProbe(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test probe handler
- handlers := mod.Handlers()
- var probeHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "wifi.probe BSSID ESSID" {
- probeHandler = h
- break
- }
- }
-
- if probeHandler.Name == "" {
- t.Fatal("Probe handler not found")
- }
-
- // Test with valid parameters
- err := probeHandler.Exec([]string{"aa:bb:cc:dd:ee:ff", "TestNetwork"})
- if err == nil {
- t.Error("Expected error when probing without running module")
- }
-
- // Test with invalid MAC
- err = probeHandler.Exec([]string{"invalid-mac", "TestNetwork"})
- if err == nil {
- t.Error("Expected error with invalid MAC address")
- }
-}
-
-func TestWiFiModuleFakeAuth(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Test fake auth handler
- handlers := mod.Handlers()
- var fakeAuthHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "wifi.fake_auth bssid client" {
- fakeAuthHandler = h
- break
- }
- }
-
- if fakeAuthHandler.Name == "" {
- t.Fatal("Fake auth handler not found")
- }
-
- // Test with valid parameters
- err := fakeAuthHandler.Exec([]string{"aa:bb:cc:dd:ee:ff", "11:22:33:44:55:66"})
- if err == nil {
- t.Error("Expected error when running fake auth without running module")
- }
-
- // Test with invalid MACs
- err = fakeAuthHandler.Exec([]string{"invalid-mac", "11:22:33:44:55:66"})
- if err == nil {
- t.Error("Expected error with invalid BSSID")
- }
-}
-
-func TestWiFiModuleViewSelector(t *testing.T) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- // Check if view selector is initialized
- if mod.selector == nil {
- t.Fatal("View selector not initialized")
- }
-}
-
-// Helper function
-func contains(slice []string, item string) bool {
- for _, s := range slice {
- if s == item {
- return true
- }
- }
- return false
-}
-
-// Test bruteforce config
-func TestBruteforceConfig(t *testing.T) {
- config := NewBruteForceConfig()
-
- if config == nil {
- t.Fatal("NewBruteForceConfig returned nil")
- }
-
- // Check defaults
- if config.target != "" {
- t.Errorf("Expected empty target, got '%s'", config.target)
- }
-
- if config.wordlist != "/usr/share/dict/words" {
- t.Errorf("Expected wordlist '/usr/share/dict/words', got '%s'", config.wordlist)
- }
-
- if config.workers != 1 {
- t.Errorf("Expected 1 worker, got %d", config.workers)
- }
-
- if config.wide {
- t.Error("Expected wide to be false by default")
- }
-
- if !config.stop_at_first {
- t.Error("Expected stop_at_first to be true by default")
- }
-
- if config.timeout != 15 {
- t.Errorf("Expected timeout 15, got %d", config.timeout)
- }
-}
-
-// Benchmarks
-func BenchmarkWiFiModuleSetFrequencies(b *testing.B) {
- sess := createMockSession()
- mod := NewWiFiModule(sess)
-
- freqs := []int{2412, 2437, 2462, 5180, 5200, 5220, 5240, 5745, 5765, 5785, 5805, 5825}
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- mod.setFrequencies(freqs)
- }
-}
-
-func BenchmarkWiFiModuleFilterCheck(b *testing.B) {
- filter, _ := regexp.Compile("^aa:bb:.*")
- testMAC := "aa:bb:cc:dd:ee:ff"
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- _ = filter.MatchString(testMAC)
- }
-}
diff --git a/modules/wol/wol_test.go b/modules/wol/wol_test.go
deleted file mode 100644
index 115f4f32..00000000
--- a/modules/wol/wol_test.go
+++ /dev/null
@@ -1,364 +0,0 @@
-package wol
-
-import (
- "bytes"
- "net"
- "sync"
- "testing"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-var (
- testSession *session.Session
- sessionOnce sync.Once
-)
-
-func createMockSession(t *testing.T) *session.Session {
- sessionOnce.Do(func() {
- var err error
- testSession, err = session.New()
- if err != nil {
- t.Fatalf("Failed to create session: %v", err)
- }
- // Initialize interface with mock data to avoid nil pointer
- // For now, we'll skip initializing these as they require more complex setup
- // The tests will handle the nil cases appropriately
- })
- return testSession
-}
-
-func TestNewWOL(t *testing.T) {
- s := createMockSession(t)
- mod := NewWOL(s)
-
- if mod == nil {
- t.Fatal("NewWOL returned nil")
- }
-
- if mod.Name() != "wol" {
- t.Errorf("Expected name 'wol', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("Unexpected author: %s", mod.Author())
- }
-
- if mod.Description() == "" {
- t.Error("Empty description")
- }
-
- // Check handlers
- handlers := []string{"wol.eth MAC", "wol.udp MAC"}
- for _, handlerName := range handlers {
- found := false
- for _, h := range mod.Handlers() {
- if h.Name == handlerName {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("Handler '%s' not found", handlerName)
- }
- }
-}
-
-func TestParseMAC(t *testing.T) {
- tests := []struct {
- name string
- args []string
- want string
- wantErr bool
- }{
- {
- name: "empty args",
- args: []string{},
- want: "ff:ff:ff:ff:ff:ff",
- wantErr: false,
- },
- {
- name: "empty string arg",
- args: []string{""},
- want: "ff:ff:ff:ff:ff:ff",
- wantErr: false,
- },
- {
- name: "valid MAC with colons",
- args: []string{"aa:bb:cc:dd:ee:ff"},
- want: "aa:bb:cc:dd:ee:ff",
- wantErr: false,
- },
- {
- name: "valid MAC with dashes",
- args: []string{"aa-bb-cc-dd-ee-ff"},
- want: "aa-bb-cc-dd-ee-ff",
- wantErr: false,
- },
- {
- name: "valid MAC uppercase",
- args: []string{"AA:BB:CC:DD:EE:FF"},
- want: "AA:BB:CC:DD:EE:FF",
- wantErr: false,
- },
- {
- name: "valid MAC mixed case",
- args: []string{"aA:bB:cC:dD:eE:fF"},
- want: "aA:bB:cC:dD:eE:fF",
- wantErr: false,
- },
- {
- name: "invalid MAC - too short",
- args: []string{"aa:bb:cc:dd:ee"},
- want: "",
- wantErr: true,
- },
- {
- name: "invalid MAC - too long",
- args: []string{"aa:bb:cc:dd:ee:ff:gg"},
- want: "",
- wantErr: true,
- },
- {
- name: "invalid MAC - bad characters",
- args: []string{"aa:bb:cc:dd:ee:gg"},
- want: "",
- wantErr: true,
- },
- {
- name: "invalid MAC - no separators",
- args: []string{"aabbccddeeff"},
- want: "",
- wantErr: true,
- },
- {
- name: "MAC with spaces",
- args: []string{" aa:bb:cc:dd:ee:ff "},
- want: "aa:bb:cc:dd:ee:ff",
- wantErr: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := parseMAC(tt.args)
- if (err != nil) != tt.wantErr {
- t.Errorf("parseMAC() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("parseMAC() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestBuildPayload(t *testing.T) {
- tests := []struct {
- name string
- mac string
- }{
- {
- name: "broadcast MAC",
- mac: "ff:ff:ff:ff:ff:ff",
- },
- {
- name: "specific MAC",
- mac: "aa:bb:cc:dd:ee:ff",
- },
- {
- name: "zeros MAC",
- mac: "00:00:00:00:00:00",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- payload := buildPayload(tt.mac)
-
- // Payload should be 102 bytes: 6 bytes sync + 16 * 6 bytes MAC
- if len(payload) != 102 {
- t.Errorf("buildPayload() length = %d, want 102", len(payload))
- }
-
- // First 6 bytes should be 0xff
- for i := 0; i < 6; i++ {
- if payload[i] != 0xff {
- t.Errorf("payload[%d] = %x, want 0xff", i, payload[i])
- }
- }
-
- // Parse the MAC for comparison
- parsedMAC, _ := net.ParseMAC(tt.mac)
-
- // Next 16 copies of the MAC
- for i := 0; i < 16; i++ {
- start := 6 + i*6
- end := start + 6
- if !bytes.Equal(payload[start:end], parsedMAC) {
- t.Errorf("MAC copy %d = %x, want %x", i, payload[start:end], parsedMAC)
- }
- }
- })
- }
-}
-
-func TestWOLConfigure(t *testing.T) {
- s := createMockSession(t)
- mod := NewWOL(s)
-
- if err := mod.Configure(); err != nil {
- t.Errorf("Configure() error = %v", err)
- }
-}
-
-func TestWOLStartStop(t *testing.T) {
- s := createMockSession(t)
- mod := NewWOL(s)
-
- if err := mod.Start(); err != nil {
- t.Errorf("Start() error = %v", err)
- }
-
- if err := mod.Stop(); err != nil {
- t.Errorf("Stop() error = %v", err)
- }
-}
-
-func TestWOLHandlers(t *testing.T) {
- // Only test parseMAC validation since the actual handlers require a fully initialized session
- testCases := []struct {
- name string
- args []string
- wantMAC string
- wantErr bool
- }{
- {
- name: "empty args",
- args: []string{},
- wantMAC: "ff:ff:ff:ff:ff:ff",
- wantErr: false,
- },
- {
- name: "valid MAC",
- args: []string{"aa:bb:cc:dd:ee:ff"},
- wantMAC: "aa:bb:cc:dd:ee:ff",
- wantErr: false,
- },
- {
- name: "invalid MAC",
- args: []string{"invalid:mac"},
- wantMAC: "",
- wantErr: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- mac, err := parseMAC(tc.args)
- if (err != nil) != tc.wantErr {
- t.Errorf("parseMAC() error = %v, wantErr %v", err, tc.wantErr)
- }
- if mac != tc.wantMAC {
- t.Errorf("parseMAC() = %v, want %v", mac, tc.wantMAC)
- }
- })
- }
-}
-
-func TestWOLMethods(t *testing.T) {
- s := createMockSession(t)
- mod := NewWOL(s)
-
- // Test that the methods exist and can be called without panic
- // The actual execution will fail due to nil session interface/queue
- // but we're testing the module structure
-
- // Check that handlers were properly registered
- expectedHandlers := 2 // wol.eth and wol.udp
- if len(mod.Handlers()) != expectedHandlers {
- t.Errorf("Expected %d handlers, got %d", expectedHandlers, len(mod.Handlers()))
- }
-
- // Verify handler names
- handlerNames := make(map[string]bool)
- for _, h := range mod.Handlers() {
- handlerNames[h.Name] = true
- }
-
- if !handlerNames["wol.eth MAC"] {
- t.Error("wol.eth handler not found")
- }
- if !handlerNames["wol.udp MAC"] {
- t.Error("wol.udp handler not found")
- }
-}
-
-func TestReMAC(t *testing.T) {
- tests := []struct {
- mac string
- valid bool
- }{
- {"aa:bb:cc:dd:ee:ff", true},
- {"AA:BB:CC:DD:EE:FF", true},
- {"aa-bb-cc-dd-ee-ff", true},
- {"AA-BB-CC-DD-EE-FF", true},
- {"aA:bB:cC:dD:eE:fF", true},
- {"00:00:00:00:00:00", true},
- {"ff:ff:ff:ff:ff:ff", true},
- {"aabbccddeeff", false},
- {"aa:bb:cc:dd:ee", false},
- {"aa:bb:cc:dd:ee:ff:gg", false},
- {"aa:bb:cc:dd:ee:gg", false},
- {"zz:zz:zz:zz:zz:zz", false},
- {"", false},
- {"not a mac", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.mac, func(t *testing.T) {
- if got := reMAC.MatchString(tt.mac); got != tt.valid {
- t.Errorf("reMAC.MatchString(%q) = %v, want %v", tt.mac, got, tt.valid)
- }
- })
- }
-}
-
-// Test that the module sets running state correctly
-func TestWOLRunningState(t *testing.T) {
- s := createMockSession(t)
- mod := NewWOL(s)
-
- // Initially should not be running
- if mod.Running() {
- t.Error("Module should not be running initially")
- }
-
- // Note: wolETH and wolUDP will fail due to nil session.Queue,
- // but they should still set the running state before failing
-}
-
-// Benchmark tests
-func BenchmarkBuildPayload(b *testing.B) {
- mac := "aa:bb:cc:dd:ee:ff"
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = buildPayload(mac)
- }
-}
-
-func BenchmarkParseMAC(b *testing.B) {
- args := []string{"aa:bb:cc:dd:ee:ff"}
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = parseMAC(args)
- }
-}
-
-func BenchmarkReMAC(b *testing.B) {
- mac := "aa:bb:cc:dd:ee:ff"
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = reMAC.MatchString(mac)
- }
-}
diff --git a/modules/zerogod/zerogod_discovery.go b/modules/zerogod/zerogod_discovery.go
index f6223e54..97d0f486 100644
--- a/modules/zerogod/zerogod_discovery.go
+++ b/modules/zerogod/zerogod_discovery.go
@@ -201,14 +201,6 @@ func (mod *ZeroGod) logDNS(src net.IP, dns layers.DNS, isLocal bool) {
func (mod *ZeroGod) onPacket(pkt gopacket.Packet) {
mod.Debug("%++v", pkt)
- // sadly the latest available version of gopacket has an unpatched bug :/
- // https://github.com/bettercap/bettercap/issues/1184
- defer func() {
- if err := recover(); err != nil {
- mod.Error("unexpected error while parsing network packet: %v\n\n%++v", err, pkt)
- }
- }()
-
netLayer := pkt.NetworkLayer()
if netLayer == nil {
mod.Warning("not network layer in packet %+v", pkt)
diff --git a/modules/zerogod/zerogod_show.go b/modules/zerogod/zerogod_show.go
index 4c465d0d..03abebbf 100644
--- a/modules/zerogod/zerogod_show.go
+++ b/modules/zerogod/zerogod_show.go
@@ -61,24 +61,15 @@ func (mod *ZeroGod) show(filter string, withData bool) error {
for _, field := range svc.Text {
if field = str.Trim(field); len(field) > 0 {
keyval := strings.SplitN(field, "=", 2)
- key := str.Trim(keyval[0])
- val := str.Trim(keyval[1])
-
- if key != "" || val != "" {
- rows = append(rows, []string{
- key,
- val,
- })
- }
+ rows = append(rows, []string{
+ keyval[0],
+ keyval[1],
+ })
}
}
- if len(rows) == 0 {
- fmt.Fprintf(mod.Session.Events.Stdout, " %s\n", tui.Dim("no data"))
- } else {
- tui.Table(mod.Session.Events.Stdout, columns, rows)
- fmt.Fprintf(mod.Session.Events.Stdout, "\n")
- }
+ tui.Table(mod.Session.Events.Stdout, columns, rows)
+ fmt.Fprintf(mod.Session.Events.Stdout, "\n")
} else {
fmt.Fprintf(mod.Session.Events.Stdout, " %s\n", tui.Dim("no data"))
diff --git a/modules/zerogod/zerogod_test.go b/modules/zerogod/zerogod_test.go
deleted file mode 100644
index b64bbab0..00000000
--- a/modules/zerogod/zerogod_test.go
+++ /dev/null
@@ -1,480 +0,0 @@
-package zerogod
-
-import (
- "fmt"
- "io/ioutil"
- "net"
- "os"
- "testing"
- "time"
-
- "github.com/bettercap/bettercap/v2/network"
- "github.com/bettercap/bettercap/v2/packets"
- "github.com/bettercap/bettercap/v2/session"
- "github.com/evilsocket/islazy/data"
-)
-
-// MockNetRecon implements a minimal net.recon module for testing
-type MockNetRecon struct {
- session.SessionModule
-}
-
-func NewMockNetRecon(s *session.Session) *MockNetRecon {
- mod := &MockNetRecon{
- SessionModule: session.NewSessionModule("net.recon", s),
- }
-
- // Add handlers
- mod.AddHandler(session.NewModuleHandler("net.recon on", "",
- "Start net.recon",
- func(args []string) error {
- return mod.Start()
- }))
-
- mod.AddHandler(session.NewModuleHandler("net.recon off", "",
- "Stop net.recon",
- func(args []string) error {
- return mod.Stop()
- }))
-
- return mod
-}
-
-func (m *MockNetRecon) Name() string {
- return "net.recon"
-}
-
-func (m *MockNetRecon) Description() string {
- return "Mock net.recon module"
-}
-
-func (m *MockNetRecon) Author() string {
- return "test"
-}
-
-func (m *MockNetRecon) Configure() error {
- return nil
-}
-
-func (m *MockNetRecon) Start() error {
- return m.SetRunning(true, nil)
-}
-
-func (m *MockNetRecon) Stop() error {
- return m.SetRunning(false, nil)
-}
-
-// MockBrowser for testing
-type MockBrowser struct {
- started bool
- stopped bool
- waitCh chan bool
-}
-
-func (m *MockBrowser) Start() error {
- m.started = true
- m.waitCh = make(chan bool, 1)
- return nil
-}
-
-func (m *MockBrowser) Stop() error {
- m.stopped = true
- if m.waitCh != nil {
- m.waitCh <- true
- close(m.waitCh)
- }
- return nil
-}
-
-func (m *MockBrowser) Wait() {
- if m.waitCh != nil {
- <-m.waitCh
- }
-}
-
-// MockAdvertiser for testing
-type MockAdvertiser struct {
- started bool
- stopped bool
- services []*ServiceData
- config string
-}
-
-func (m *MockAdvertiser) Start(services []*ServiceData) error {
- m.started = true
- m.services = services
- return nil
-}
-
-func (m *MockAdvertiser) Stop() error {
- m.stopped = true
- return nil
-}
-
-// Create a mock session for testing
-func createMockSession() *session.Session {
- // Create interface
- iface := &network.Endpoint{
- IpAddress: "192.168.1.100",
- HwAddress: "aa:bb:cc:dd:ee:ff",
- Hostname: "eth0",
- }
- iface.SetIP("192.168.1.100")
- iface.SetBits(24)
-
- // Parse interface addresses
- ifaceIP := net.ParseIP("192.168.1.100")
- ifaceHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- iface.IP = ifaceIP
- iface.HW = ifaceHW
-
- // Create gateway
- gateway := &network.Endpoint{
- IpAddress: "192.168.1.1",
- HwAddress: "11:22:33:44:55:66",
- }
- gatewayIP := net.ParseIP("192.168.1.1")
- gatewayHW, _ := net.ParseMAC("11:22:33:44:55:66")
- gateway.IP = gatewayIP
- gateway.HW = gatewayHW
-
- // Create environment
- env, _ := session.NewEnvironment("")
-
- // Create LAN with some test endpoints
- aliases, _ := data.NewUnsortedKV("", 0)
- lan := network.NewLAN(iface, gateway, aliases, func(e *network.Endpoint) {}, func(e *network.Endpoint) {})
-
- // Add test endpoints
- testEndpoint := &network.Endpoint{
- IpAddress: "192.168.1.10",
- HwAddress: "11:11:11:11:11:11",
- Hostname: "test-device",
- }
- testEndpoint.IP = net.ParseIP("192.168.1.10")
- // Add endpoint to LAN using AddIfNew
- lan.AddIfNew(testEndpoint.IpAddress, testEndpoint.HwAddress)
-
- // Create session
- sess := &session.Session{
- Interface: iface,
- Gateway: gateway,
- Lan: lan,
- StartedAt: time.Now(),
- Active: true,
- Env: env,
- Queue: &packets.Queue{},
- Modules: make(session.ModuleList, 0),
- }
-
- // Initialize events
- sess.Events = session.NewEventPool(false, false)
-
- // Add mock net.recon module
- mockNetRecon := NewMockNetRecon(sess)
- sess.Modules = append(sess.Modules, mockNetRecon)
-
- return sess
-}
-
-func TestNewZeroGod(t *testing.T) {
- sess := createMockSession()
-
- mod := NewZeroGod(sess)
-
- if mod == nil {
- t.Fatal("NewZeroGod returned nil")
- }
-
- if mod.Name() != "zerogod" {
- t.Errorf("expected module name 'zerogod', got '%s'", mod.Name())
- }
-
- if mod.Author() != "Simone Margaritelli " {
- t.Errorf("unexpected author: %s", mod.Author())
- }
-
- // Check parameters - only check the ones that are directly registered
- params := []string{
- "zerogod.advertise.certificate",
- "zerogod.advertise.key",
- "zerogod.ipp.save_path",
- "zerogod.verbose",
- }
- for _, param := range params {
- if !mod.Session.Env.Has(param) {
- t.Errorf("parameter %s not registered", param)
- }
- }
-
- // Check handlers
- handlers := mod.Handlers()
- expectedHandlers := []string{
- "zerogod.discovery on",
- "zerogod.discovery off",
- "zerogod.show-full ADDRESS",
- "zerogod.show ADDRESS",
- "zerogod.save ADDRESS FILENAME",
- "zerogod.advertise FILENAME",
- "zerogod.impersonate ADDRESS",
- }
-
- if len(handlers) != len(expectedHandlers) {
- t.Errorf("expected %d handlers, got %d", len(expectedHandlers), len(handlers))
- }
-}
-
-func TestZeroGodConfigure(t *testing.T) {
- sess := createMockSession()
- mod := NewZeroGod(sess)
-
- // Configure should succeed when not running
- err := mod.Configure()
- if err != nil {
- t.Errorf("Configure failed: %v", err)
- }
-
- // Force module to running state by starting it
- mod.SetRunning(true, nil)
-
- // Configure should fail when already running
- err = mod.Configure()
- if err == nil {
- t.Error("Configure should fail when module is already running")
- }
-
- // Clean up
- mod.SetRunning(false, nil)
-}
-
-func TestZeroGodStartStop(t *testing.T) {
- sess := createMockSession()
- _ = NewZeroGod(sess)
-
- // Skip this test as it requires mocking private methods
- t.Skip("Skipping test that requires mocking private methods")
-}
-
-func TestZeroGodShow(t *testing.T) {
- sess := createMockSession()
- mod := NewZeroGod(sess)
-
- // Start discovery first (mock it)
- mod.browser = &Browser{}
-
- // Test show handler
- handlers := mod.Handlers()
- var showHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "zerogod.show ADDRESS" {
- showHandler = h
- break
- }
- }
-
- if showHandler.Name == "" {
- t.Fatal("Show handler not found")
- }
-
- // Test with IP address
- err := showHandler.Exec([]string{"192.168.1.10"})
- if err != nil {
- t.Errorf("Show handler failed: %v", err)
- }
-
- // Test with empty address (show all)
- err = showHandler.Exec([]string{})
- if err != nil {
- t.Errorf("Show handler failed with empty address: %v", err)
- }
-}
-
-func TestZeroGodShowFull(t *testing.T) {
- sess := createMockSession()
- mod := NewZeroGod(sess)
-
- // Start discovery first (mock it)
- mod.browser = &Browser{}
-
- // Test show-full handler
- handlers := mod.Handlers()
- var showFullHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "zerogod.show-full ADDRESS" {
- showFullHandler = h
- break
- }
- }
-
- if showFullHandler.Name == "" {
- t.Fatal("Show-full handler not found")
- }
-
- // Test with IP address
- err := showFullHandler.Exec([]string{"192.168.1.10"})
- if err != nil {
- t.Errorf("Show-full handler failed: %v", err)
- }
-}
-
-func TestZeroGodSave(t *testing.T) {
- // Skip this test as it requires actual mDNS discovery data
- t.Skip("Skipping test that requires actual mDNS discovery data")
-}
-
-func TestZeroGodAdvertise(t *testing.T) {
- sess := createMockSession()
- mod := NewZeroGod(sess)
-
- // Mock advertiser - skip test as we can't properly mock the advertiser structure
- t.Skip("Skipping test that requires complex advertiser mocking")
-
- // Create a test YAML file with services
- tmpFile, err := ioutil.TempFile("", "zerogod_advertise_*.yml")
- if err != nil {
- t.Fatalf("Failed to create temp file: %v", err)
- }
- defer os.Remove(tmpFile.Name())
-
- yamlContent := `services:
- - name: Test Service
- type: _http._tcp
- port: 8080
- txt:
- - model=TestDevice
- - version=1.0
-`
- if _, err := tmpFile.Write([]byte(yamlContent)); err != nil {
- t.Fatalf("Failed to write YAML content: %v", err)
- }
- tmpFile.Close()
-
- // Test advertise handler
- handlers := mod.Handlers()
- var advertiseHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "zerogod.advertise FILENAME" {
- advertiseHandler = h
- break
- }
- }
-
- if advertiseHandler.Name == "" {
- t.Fatal("Advertise handler not found")
- }
-
- // Note: Cannot mock methods in Go, would need interface refactoring
-}
-
-func TestZeroGodImpersonate(t *testing.T) {
- sess := createMockSession()
- mod := NewZeroGod(sess)
-
- // Skip test as we can't properly mock the advertiser
- t.Skip("Skipping test that requires complex advertiser mocking")
-
- // Test impersonate handler
- handlers := mod.Handlers()
- var impersonateHandler session.ModuleHandler
- for _, h := range handlers {
- if h.Name == "zerogod.impersonate ADDRESS" {
- impersonateHandler = h
- break
- }
- }
-
- if impersonateHandler.Name == "" {
- t.Fatal("Impersonate handler not found")
- }
-
- // Note: Cannot mock methods in Go, would need interface refactoring
-}
-
-func TestZeroGodParameters(t *testing.T) {
- // Skip parameter validation tests as Environment.Set behavior is not straightforward
- t.Skip("Skipping parameter validation tests")
-}
-
-// Test service data structure
-func TestServiceData(t *testing.T) {
- svc := ServiceData{
- Name: "Test Service",
- Service: "_http._tcp",
- Domain: "local",
- Port: 8080,
- Records: []string{"model=TestDevice", "version=1.0"},
- IPP: map[string]string{"attr1": "value1"},
- HTTP: map[string]string{"/": "index.html"},
- }
-
- // Test basic properties
- if svc.Name != "Test Service" {
- t.Errorf("Expected service name 'Test Service', got '%s'", svc.Name)
- }
-
- if svc.Port != 8080 {
- t.Errorf("Expected port 8080, got %d", svc.Port)
- }
-
- if len(svc.Records) != 2 {
- t.Errorf("Expected 2 records, got %d", len(svc.Records))
- }
-
- // Test FullName method
- fullName := svc.FullName()
- expected := "Test Service._http._tcp.local"
- if fullName != expected {
- t.Errorf("Expected full name '%s', got '%s'", expected, fullName)
- }
-}
-
-// Test endpoint handling
-func TestEndpointHandling(t *testing.T) {
- endpoint := &network.Endpoint{
- IpAddress: "192.168.1.10",
- HwAddress: "11:11:11:11:11:11",
- Hostname: "test-device",
- }
-
- // Verify basic endpoint properties
- if endpoint.IpAddress != "192.168.1.10" {
- t.Errorf("Expected IP address '192.168.1.10', got '%s'", endpoint.IpAddress)
- }
-
- if endpoint.Hostname != "test-device" {
- t.Errorf("Expected hostname 'test-device', got '%s'", endpoint.Hostname)
- }
-}
-
-// Test known services lookup
-func TestKnownServices(t *testing.T) {
- // Skip this test as knownServices might not be available in test context
- t.Skip("Skipping known services test - requires module initialization")
-}
-
-// Benchmarks
-func BenchmarkServiceDataCreation(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = ServiceData{
- Name: fmt.Sprintf("Service %d", i),
- Service: "_http._tcp",
- Port: 8080 + i,
- Domain: "local",
- Records: []string{"model=Test", fmt.Sprintf("id=%d", i)},
- }
- }
-}
-
-func BenchmarkServiceDataFullName(b *testing.B) {
- svc := ServiceData{
- Name: "Test Service",
- Service: "_http._tcp",
- Domain: "local",
- }
-
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- _ = svc.FullName()
- }
-}
diff --git a/network/lan.go b/network/lan.go
index 6342968d..082b4c74 100644
--- a/network/lan.go
+++ b/network/lan.go
@@ -62,7 +62,7 @@ func (lan *LAN) Get(mac string) (*Endpoint, bool) {
if mac == lan.iface.HwAddress {
return lan.iface, true
- } else if lan.gateway != nil && mac == lan.gateway.HwAddress {
+ } else if mac == lan.gateway.HwAddress {
return lan.gateway, true
}
@@ -78,7 +78,7 @@ func (lan *LAN) GetByIp(ip string) *Endpoint {
if ip == lan.iface.IpAddress || ip == lan.iface.Ip6Address {
return lan.iface
- } else if lan.gateway != nil && (ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address) {
+ } else if ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address {
return lan.gateway
}
@@ -107,7 +107,7 @@ func (lan *LAN) Aliases() *data.UnsortedKV {
}
func (lan *LAN) WasMissed(mac string) bool {
- if mac == lan.iface.HwAddress || (lan.gateway != nil && mac == lan.gateway.HwAddress) {
+ if mac == lan.iface.HwAddress || mac == lan.gateway.HwAddress {
return false
}
@@ -141,7 +141,7 @@ func (lan *LAN) shouldIgnore(ip, mac string) bool {
return true
}
// skip the gateway
- if lan.gateway != nil && (ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address || mac == lan.gateway.HwAddress) {
+ if ip == lan.gateway.IpAddress || ip == lan.gateway.Ip6Address || mac == lan.gateway.HwAddress {
return true
}
// skip broadcast addresses
@@ -154,7 +154,7 @@ func (lan *LAN) shouldIgnore(ip, mac string) bool {
}
// skip everything which is not in our subnet (multicast noise)
addr := net.ParseIP(ip)
- return addr.To4() != nil && lan.iface.Net != nil && !lan.iface.Net.Contains(addr)
+ return addr.To4() != nil && !lan.iface.Net.Contains(addr)
}
func (lan *LAN) Has(ip string) bool {
diff --git a/network/lan_test.go b/network/lan_test.go
index e0a21676..43c989b2 100644
--- a/network/lan_test.go
+++ b/network/lan_test.go
@@ -1,541 +1,210 @@
package network
import (
- "encoding/json"
- "fmt"
- "net"
- "sync"
"testing"
"github.com/evilsocket/islazy/data"
)
-// Mock endpoint creation
-func createMockEndpoint(ip, mac, name string) *Endpoint {
- e := NewEndpointNoResolve(ip, mac, name, 24)
- _, ipNet, _ := net.ParseCIDR("192.168.1.0/24")
- e.Net = ipNet
- // Make sure IP is set correctly after SetNetwork
- e.IpAddress = ip
- e.IP = net.ParseIP(ip)
- return e
+func buildExampleLAN() *LAN {
+ iface, _ := FindInterface("")
+ gateway, _ := FindGateway(iface)
+ exNewCallback := func(e *Endpoint) {}
+ exLostCallback := func(e *Endpoint) {}
+ aliases := &data.UnsortedKV{}
+ return NewLAN(iface, gateway, aliases, exNewCallback, exLostCallback)
}
-// Mock LAN creation with controlled endpoints
-func createMockLAN() (*LAN, *Endpoint, *Endpoint) {
- iface := createMockEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff", "eth0")
- gateway := createMockEndpoint("192.168.1.1", "11:22:33:44:55:66", "gateway")
- aliases, _ := data.NewMemUnsortedKV()
-
- newCb := func(e *Endpoint) {}
- lostCb := func(e *Endpoint) {}
-
- lan := NewLAN(iface, gateway, aliases, newCb, lostCb)
- return lan, iface, gateway
+func buildExampleEndpoint() *Endpoint {
+ iface, _ := FindInterface("")
+ return iface
}
func TestNewLAN(t *testing.T) {
- iface := createMockEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff", "eth0")
- gateway := createMockEndpoint("192.168.1.1", "11:22:33:44:55:66", "gateway")
- aliases, _ := data.NewMemUnsortedKV()
-
- newCb := func(e *Endpoint) {}
- lostCb := func(e *Endpoint) {}
-
- lan := NewLAN(iface, gateway, aliases, newCb, lostCb)
+ iface, err := FindInterface("")
+ if err != nil {
+ t.Error("no iface found", err)
+ }
+ gateway, err := FindGateway(iface)
+ if err != nil {
+ t.Error("no gateway found", err)
+ }
+ exNewCallback := func(e *Endpoint) {}
+ exLostCallback := func(e *Endpoint) {}
+ aliases := &data.UnsortedKV{}
+ lan := NewLAN(iface, gateway, aliases, exNewCallback, exLostCallback)
if lan.iface != iface {
- t.Errorf("expected iface %v, got %v", iface, lan.iface)
+ t.Fatalf("expected '%v', got '%v'", iface, lan.iface)
}
if lan.gateway != gateway {
- t.Errorf("expected gateway %v, got %v", gateway, lan.gateway)
+ t.Fatalf("expected '%v', got '%v'", gateway, lan.gateway)
}
if len(lan.hosts) != 0 {
- t.Errorf("expected 0 hosts, got %d", len(lan.hosts))
- }
- if lan.aliases != aliases {
- t.Error("aliases not properly set")
+ t.Fatalf("expected '%v', got '%v'", 0, len(lan.hosts))
}
+ // FIXME: update this to current code base
+ // if !(len(lan.aliases.data) >= 0) {
+ // t.Fatalf("expected '%v', got '%v'", 0, len(lan.aliases.data))
+ // }
}
-func TestLANMarshalJSON(t *testing.T) {
- lan, _, _ := createMockLAN()
-
- // Add some hosts
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
-
- data, err := lan.MarshalJSON()
+func TestMarshalJSON(t *testing.T) {
+ iface, err := FindInterface("")
if err != nil {
- t.Errorf("MarshalJSON() error = %v", err)
+ t.Error("no iface found", err)
}
-
- var result lanJSON
- if err := json.Unmarshal(data, &result); err != nil {
- t.Errorf("Failed to unmarshal JSON: %v", err)
+ gateway, err := FindGateway(iface)
+ if err != nil {
+ t.Error("no gateway found", err)
}
-
- if len(result.Hosts) != 2 {
- t.Errorf("expected 2 hosts in JSON, got %d", len(result.Hosts))
+ exNewCallback := func(e *Endpoint) {}
+ exLostCallback := func(e *Endpoint) {}
+ aliases := &data.UnsortedKV{}
+ lan := NewLAN(iface, gateway, aliases, exNewCallback, exLostCallback)
+ _, err = lan.MarshalJSON()
+ if err != nil {
+ t.Error(err)
}
}
-func TestLANGet(t *testing.T) {
- lan, iface, gateway := createMockLAN()
+// FIXME: update this to current code base
+// func TestSetAliasFor(t *testing.T) {
+// exampleAlias := "picat"
+// exampleLAN := buildExampleLAN()
+// exampleEndpoint := buildExampleEndpoint()
+// exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+// if !exampleLAN.SetAliasFor(exampleEndpoint.HwAddress, exampleAlias) {
+// t.Error("unable to set alias for a given mac address")
+// }
+// }
- // Test getting interface
- e, found := lan.Get(iface.HwAddress)
- if !found || e != iface {
- t.Error("Failed to get interface")
+func TestGet(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ exampleEndpoint := buildExampleEndpoint()
+ exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+ foundEndpoint, foundBool := exampleLAN.Get(exampleEndpoint.HwAddress)
+ if foundEndpoint.String() != exampleEndpoint.String() {
+ t.Fatalf("expected '%v', got '%v'", foundEndpoint, exampleEndpoint)
}
-
- // Test getting gateway
- e, found = lan.Get(gateway.HwAddress)
- if !found || e != gateway {
- t.Error("Failed to get gateway")
- }
-
- // Add a host
- testMAC := "10:20:30:40:50:60"
- lan.AddIfNew("192.168.1.10", testMAC)
-
- // Test getting the host
- e, found = lan.Get(testMAC)
- if !found {
- t.Error("Failed to get added host")
- }
-
- // Test with different MAC formats
- e, found = lan.Get("10-20-30-40-50-60")
- if !found {
- t.Error("Failed to get host with dash-separated MAC")
- }
-
- // Test non-existent MAC
- _, found = lan.Get("99:99:99:99:99:99")
- if found {
- t.Error("Found non-existent MAC")
+ if !foundBool {
+ t.Error("unable to get known endpoint via mac address from LAN struct")
}
}
-func TestLANGetByIp(t *testing.T) {
- lan, iface, gateway := createMockLAN()
-
- // Test getting interface by IP
- e := lan.GetByIp(iface.IpAddress)
- if e != iface {
- t.Error("Failed to get interface by IP")
+func TestList(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ exampleEndpoint := buildExampleEndpoint()
+ exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+ foundList := exampleLAN.List()
+ if len(foundList) != 1 {
+ t.Fatalf("expected '%d', got '%d'", 1, len(foundList))
}
-
- // Test getting gateway by IP
- e = lan.GetByIp(gateway.IpAddress)
- if e != gateway {
- t.Errorf("Failed to get gateway by IP: wanted %v, got %v", gateway, e)
- }
-
- // Add a host with IPv4
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- e = lan.GetByIp("192.168.1.10")
- if e == nil || e.IpAddress != "192.168.1.10" {
- t.Error("Failed to get host by IPv4")
- }
-
- // Test with IPv6
- lan.iface.SetIPv6("fe80::1")
- e = lan.GetByIp("fe80::1")
- if e != iface {
- t.Error("Failed to get interface by IPv6")
- }
-
- // Test non-existent IP
- e = lan.GetByIp("192.168.1.99")
- if e != nil {
- t.Error("Found non-existent IP")
+ exp := 1
+ got := len(exampleLAN.List())
+ if got != exp {
+ t.Fatalf("expected '%d', got '%d'", exp, got)
}
}
-func TestLANList(t *testing.T) {
- lan, _, _ := createMockLAN()
+// FIXME: update this to current code base
+// func TestAliases(t *testing.T) {
+// exampleAlias := "picat"
+// exampleLAN := buildExampleLAN()
+// exampleEndpoint := buildExampleEndpoint()
+// exampleLAN.hosts["pi:ca:tw:as:he:re"] = exampleEndpoint
+// exp := exampleAlias
+// got := exampleLAN.Aliases().Get("pi:ca:tw:as:he:re")
+// if got != exp {
+// t.Fatalf("expected '%v', got '%v'", exp, got)
+// }
+// }
- // Initially empty
- list := lan.List()
- if len(list) != 0 {
- t.Errorf("expected empty list, got %d items", len(list))
- }
-
- // Add hosts
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
-
- list = lan.List()
- if len(list) != 2 {
- t.Errorf("expected 2 items, got %d", len(list))
+func TestWasMissed(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ exampleEndpoint := buildExampleEndpoint()
+ exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+ exp := false
+ got := exampleLAN.WasMissed(exampleEndpoint.HwAddress)
+ if got != exp {
+ t.Fatalf("expected '%v', got '%v'", exp, got)
}
}
-func TestLANAliases(t *testing.T) {
- lan, _, _ := createMockLAN()
+// TODO Add TestRemove after removing unnecessary ip argument
+// func TestRemove(t *testing.T) {
+// }
- aliases := lan.Aliases()
- if aliases == nil {
- t.Error("Aliases() returned nil")
- }
-
- // Set an alias
- aliases.Set("10:20:30:40:50:60", "test_device")
-
- // Verify alias is accessible
- alias := lan.GetAlias("10:20:30:40:50:60")
- if alias != "test_device" {
- t.Errorf("expected alias 'test_device', got '%s'", alias)
+func TestHas(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ exampleEndpoint := buildExampleEndpoint()
+ exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+ if !exampleLAN.Has(exampleEndpoint.IpAddress) {
+ t.Error("unable find a known IP address in LAN struct")
}
}
-func TestLANWasMissed(t *testing.T) {
- lan, iface, gateway := createMockLAN()
-
- // Interface and gateway should never be missed
- if lan.WasMissed(iface.HwAddress) {
- t.Error("Interface should never be missed")
+func TestEachHost(t *testing.T) {
+ exampleBuffer := []string{}
+ exampleLAN := buildExampleLAN()
+ exampleEndpoint := buildExampleEndpoint()
+ exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+ exampleCB := func(mac string, e *Endpoint) {
+ exampleBuffer = append(exampleBuffer, exampleEndpoint.HwAddress)
}
- if lan.WasMissed(gateway.HwAddress) {
- t.Error("Gateway should never be missed")
- }
-
- // Unknown host should be missed
- if !lan.WasMissed("99:99:99:99:99:99") {
- t.Error("Unknown host should be missed")
- }
-
- // Add a host
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- if lan.WasMissed("10:20:30:40:50:60") {
- t.Error("Newly added host should not be missed")
- }
-
- // Decrease TTL
- lan.ttl["10:20:30:40:50:60"] = 5
- if !lan.WasMissed("10:20:30:40:50:60") {
- t.Error("Host with low TTL should be missed")
+ exampleLAN.EachHost(exampleCB)
+ exp := 1
+ got := len(exampleBuffer)
+ if got != exp {
+ t.Fatalf("expected '%d', got '%d'", exp, got)
}
}
-func TestLANRemove(t *testing.T) {
- lan, _, _ := createMockLAN()
+func TestGetByIp(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ exampleEndpoint := buildExampleEndpoint()
+ exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
- lostCalled := false
- lostEndpoint := (*Endpoint)(nil)
- lan.lostCb = func(e *Endpoint) {
- lostCalled = true
- lostEndpoint = e
- }
-
- // Add a host
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
-
- // Remove it multiple times to decrease TTL
- for i := 0; i < LANDefaultttl; i++ {
- lan.Remove("192.168.1.10", "10:20:30:40:50:60")
- }
-
- // Verify it was removed
- _, found := lan.Get("10:20:30:40:50:60")
- if found {
- t.Error("Host should have been removed")
- }
-
- // Verify callback was called
- if !lostCalled {
- t.Error("Lost callback should have been called")
- }
- if lostEndpoint == nil || lostEndpoint.HwAddress != "10:20:30:40:50:60" {
- t.Error("Lost callback received wrong endpoint")
- }
-
- // Try removing non-existent host
- lan.Remove("192.168.1.99", "99:99:99:99:99:99") // Should not panic
-}
-
-func TestLANShouldIgnore(t *testing.T) {
- lan, iface, gateway := createMockLAN()
-
- tests := []struct {
- name string
- ip string
- mac string
- ignore bool
- }{
- {"own IP", iface.IpAddress, "99:99:99:99:99:99", true},
- {"own MAC", "192.168.1.99", iface.HwAddress, true},
- {"gateway IP", gateway.IpAddress, "99:99:99:99:99:99", true},
- {"gateway MAC", "192.168.1.99", gateway.HwAddress, true},
- {"broadcast IP", "192.168.1.255", "99:99:99:99:99:99", true},
- {"broadcast MAC", "192.168.1.99", BroadcastMac, true},
- {"multicast outside subnet", "10.0.0.1", "99:99:99:99:99:99", true},
- {"valid host", "192.168.1.10", "10:20:30:40:50:60", false},
- {"IPv6 address", "fe80::1", "10:20:30:40:50:60", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := lan.shouldIgnore(tt.ip, tt.mac); got != tt.ignore {
- t.Errorf("shouldIgnore() = %v, want %v", got, tt.ignore)
- }
- })
+ exp := exampleEndpoint
+ got := exampleLAN.GetByIp(exampleEndpoint.IpAddress)
+ if got.String() != exp.String() {
+ t.Fatalf("expected '%v', got '%v'", exp, got)
}
}
-func TestLANHas(t *testing.T) {
- lan, _, _ := createMockLAN()
-
- // Add hosts
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
-
- if !lan.Has("192.168.1.10") {
- t.Error("Has() should return true for existing IP")
- }
- if !lan.Has("192.168.1.20") {
- t.Error("Has() should return true for existing IP")
- }
- if lan.Has("192.168.1.99") {
- t.Error("Has() should return false for non-existent IP")
+func TestAddIfNew(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ iface, _ := FindInterface("")
+ // won't add our own IP address
+ if exampleLAN.AddIfNew(iface.IpAddress, iface.HwAddress) != nil {
+ t.Error("added address that should've been ignored ( your own )")
}
}
-func TestLANEachHost(t *testing.T) {
- lan, _, _ := createMockLAN()
+// FIXME: update this to current code base
+// func TestGetAlias(t *testing.T) {
+// exampleAlias := "picat"
+// exampleLAN := buildExampleLAN()
+// exampleEndpoint := buildExampleEndpoint()
+// exampleLAN.hosts[exampleEndpoint.HwAddress] = exampleEndpoint
+// exp := exampleAlias
+// got := exampleLAN.GetAlias(exampleEndpoint.HwAddress)
+// if got != exp {
+// t.Fatalf("expected '%v', got '%v'", exp, got)
+// }
+// }
- // Add hosts
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
-
- count := 0
- macs := make([]string, 0)
-
- lan.EachHost(func(mac string, e *Endpoint) {
- count++
- macs = append(macs, mac)
- })
-
- if count != 2 {
- t.Errorf("expected 2 hosts, got %d", count)
+func TestShouldIgnore(t *testing.T) {
+ exampleLAN := buildExampleLAN()
+ iface, _ := FindInterface("")
+ gateway, _ := FindGateway(iface)
+ exp := true
+ got := exampleLAN.shouldIgnore(iface.IpAddress, iface.HwAddress)
+ if got != exp {
+ t.Fatalf("expected '%v', got '%v'", exp, got)
}
- if len(macs) != 2 {
- t.Errorf("expected 2 MACs, got %d", len(macs))
- }
-}
-
-func TestLANAddIfNew(t *testing.T) {
- lan, _, _ := createMockLAN()
-
- newCalled := false
- newEndpoint := (*Endpoint)(nil)
- lan.newCb = func(e *Endpoint) {
- newCalled = true
- newEndpoint = e
- }
-
- // Add new host
- result := lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- if result != nil {
- t.Error("AddIfNew should return nil for new host")
- }
- if !newCalled {
- t.Error("New callback should have been called")
- }
- if newEndpoint == nil || newEndpoint.IpAddress != "192.168.1.10" {
- t.Error("New callback received wrong endpoint")
- }
-
- // Add same host again (should update TTL)
- lan.ttl["10:20:30:40:50:60"] = 5
- result = lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- if result == nil {
- t.Error("AddIfNew should return existing endpoint")
- }
- if lan.ttl["10:20:30:40:50:60"] != 6 {
- t.Error("TTL should have been incremented")
- }
-
- // Add IPv6 to existing host
- result = lan.AddIfNew("fe80::10", "10:20:30:40:50:60")
- if result == nil || result.Ip6Address != "fe80::10" {
- t.Error("Should have added IPv6 to existing host")
- }
-
- // Add IPv4 to host that only has IPv6
- // Note: Due to current implementation, IPv6 addresses are initially stored in IpAddress field
- newCalled = false
- lan.AddIfNew("fe80::20", "20:30:40:50:60:70")
- result = lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
- if result == nil {
- t.Error("Should have returned existing endpoint when adding IPv4")
- }
- // The implementation updates the IPv4 address when it detects we're adding an IPv4 to a host
- // that was initially created with IPv6
- if result != nil && result.IpAddress != "192.168.1.20" {
- // This is expected behavior - the initial IPv6 is stored in IpAddress
- // Skip this check as it's a known limitation
- t.Skip("Known limitation: IPv6 addresses are initially stored in IPv4 field")
- }
-
- // Try to add own interface (should be ignored)
- result = lan.AddIfNew(lan.iface.IpAddress, lan.iface.HwAddress)
- if result != nil {
- t.Error("Should ignore own interface")
- }
-}
-
-func TestLANGetAlias(t *testing.T) {
- lan, _, _ := createMockLAN()
-
- // Set alias
- lan.aliases.Set("10:20:30:40:50:60", "test_device")
-
- // Get existing alias
- alias := lan.GetAlias("10:20:30:40:50:60")
- if alias != "test_device" {
- t.Errorf("expected 'test_device', got '%s'", alias)
- }
-
- // Get non-existent alias
- alias = lan.GetAlias("99:99:99:99:99:99")
- if alias != "" {
- t.Errorf("expected empty string for non-existent alias, got '%s'", alias)
- }
-}
-
-func TestLANClear(t *testing.T) {
- lan, _, _ := createMockLAN()
-
- // Add hosts
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
-
- // Verify hosts exist
- if len(lan.hosts) != 2 {
- t.Errorf("expected 2 hosts, got %d", len(lan.hosts))
- }
- if len(lan.ttl) != 2 {
- t.Errorf("expected 2 ttl entries, got %d", len(lan.ttl))
- }
-
- // Clear
- lan.Clear()
-
- // Verify cleared
- if len(lan.hosts) != 0 {
- t.Errorf("expected 0 hosts after clear, got %d", len(lan.hosts))
- }
- if len(lan.ttl) != 0 {
- t.Errorf("expected 0 ttl entries after clear, got %d", len(lan.ttl))
- }
-}
-
-func TestLANConcurrency(t *testing.T) {
- lan, _, _ := createMockLAN()
-
- // Test concurrent access
- var wg sync.WaitGroup
-
- // Writer goroutines
- for i := 0; i < 10; i++ {
- wg.Add(1)
- go func(i int) {
- defer wg.Done()
- ip := fmt.Sprintf("192.168.1.%d", 10+i)
- mac := fmt.Sprintf("10:20:30:40:50:%02x", i)
- lan.AddIfNew(ip, mac)
- }(i)
- }
-
- // Reader goroutines
- for i := 0; i < 10; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- _ = lan.List()
- _ = lan.Has("192.168.1.10")
- lan.EachHost(func(mac string, e *Endpoint) {})
- }()
- }
-
- wg.Wait()
-
- // Verify some hosts were added
- list := lan.List()
- if len(list) == 0 {
- t.Error("No hosts added during concurrent test")
- }
-}
-
-func TestLANWithAlias(t *testing.T) {
- iface := createMockEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff", "eth0")
- gateway := createMockEndpoint("192.168.1.1", "11:22:33:44:55:66", "gateway")
- aliases, _ := data.NewMemUnsortedKV()
-
- // Pre-set an alias
- aliases.Set("10:20:30:40:50:60", "printer")
-
- lan := NewLAN(iface, gateway, aliases, func(e *Endpoint) {}, func(e *Endpoint) {})
-
- // Add host with pre-existing alias
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
-
- // Get the endpoint
- e, found := lan.Get("10:20:30:40:50:60")
- if !found {
- t.Fatal("Failed to find endpoint")
- }
-
- // Check if alias was applied
- if e.Alias != "printer" {
- t.Errorf("expected alias 'printer', got '%s'", e.Alias)
- }
-}
-
-// Benchmarks
-func BenchmarkLANAddIfNew(b *testing.B) {
- lan, _, _ := createMockLAN()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- ip := fmt.Sprintf("192.168.1.%d", (i%250)+2)
- mac := fmt.Sprintf("10:20:30:40:%02x:%02x", i/256, i%256)
- lan.AddIfNew(ip, mac)
- }
-}
-
-func BenchmarkLANGet(b *testing.B) {
- lan, _, _ := createMockLAN()
-
- // Pre-populate
- for i := 0; i < 100; i++ {
- ip := fmt.Sprintf("192.168.1.%d", i+10)
- mac := fmt.Sprintf("10:20:30:40:50:%02x", i)
- lan.AddIfNew(ip, mac)
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- mac := fmt.Sprintf("10:20:30:40:50:%02x", i%100)
- lan.Get(mac)
- }
-}
-
-func BenchmarkLANList(b *testing.B) {
- lan, _, _ := createMockLAN()
-
- // Pre-populate
- for i := 0; i < 100; i++ {
- ip := fmt.Sprintf("192.168.1.%d", i+10)
- mac := fmt.Sprintf("10:20:30:40:50:%02x", i)
- lan.AddIfNew(ip, mac)
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = lan.List()
+ got = exampleLAN.shouldIgnore(gateway.IpAddress, gateway.HwAddress)
+ if got != exp {
+ t.Fatalf("expected '%v', got '%v'", exp, got)
}
}
diff --git a/network/net.go b/network/net.go
index b01fd3c0..f925b37d 100644
--- a/network/net.go
+++ b/network/net.go
@@ -41,7 +41,7 @@ var (
`(?:25[0-5]|2[0-4][0-9]|[1][0-9]{2}|[1-9]?[0-9])` + `$`)
MACValidator = regexp.MustCompile(`(?i)^(?:[a-f0-9]{2}:){5}[a-f0-9]{2}$`)
// lulz this sounds like a hamburger
- macParser = regexp.MustCompile(`(?i)((?:[a-f0-9]{2}[:-]){5}[a-f0-9]{2})`)
+ macParser = regexp.MustCompile(`(?i)((?:[a-f0-9]{2}:){5}[a-f0-9]{2})`)
aliasParser = regexp.MustCompile(`(?i)([a-z_][a-z_0-9]+)`)
)
diff --git a/network/net_linux.go b/network/net_linux.go
index 04fcd123..f73f6b3f 100644
--- a/network/net_linux.go
+++ b/network/net_linux.go
@@ -41,9 +41,7 @@ func SetInterfaceChannel(iface string, channel int) error {
if core.HasBinary("iw") {
// Debug("SetInterfaceChannel(%s, %d) iw based", iface, channel)
- // out, err := core.Exec("iw", []string{"dev", iface, "set", "channel", fmt.Sprintf("%d", channel)})
- out, err := core.Exec("iw", []string{"dev", iface, "set", "freq", fmt.Sprintf("%d", Dot11Chan2Freq(channel))})
-
+ out, err := core.Exec("iw", []string{"dev", iface, "set", "channel", fmt.Sprintf("%d", channel)})
if err != nil {
return fmt.Errorf("iw: out=%s err=%s", out, err)
} else if out != "" {
@@ -91,8 +89,7 @@ func iwlistSupportedFrequencies(iface string) ([]int, error) {
}
var iwPhyParser = regexp.MustCompile(`^\s*wiphy\s+(\d+)$`)
-// var iwFreqParser = regexp.MustCompile(`^\s+\*\s+(\d+)\s+MHz.+dBm.+$`)
-var iwFreqParser = regexp.MustCompile(`^\s+\*\s+(\d+)\.\d+\s+MHz.+dBm.+$`)
+var iwFreqParser = regexp.MustCompile(`^\s+\*\s+(\d+)\s+MHz.+dBm.+$`)
func iwSupportedFrequencies(iface string) ([]int, error) {
// first determine phy index
@@ -143,11 +140,10 @@ func iwSupportedFrequencies(iface string) ([]int, error) {
func GetSupportedFrequencies(iface string) ([]int, error) {
// give priority to iwlist because of https://github.com/bettercap/bettercap/issues/881
- // UPDATE: Changed the priority due iwlist doesn't support 6GHz
- if core.HasBinary("iw") {
- return iwSupportedFrequencies(iface)
- } else if core.HasBinary("iwlist") {
+ if core.HasBinary("iwlist") {
return iwlistSupportedFrequencies(iface)
+ } else if core.HasBinary("iw") {
+ return iwSupportedFrequencies(iface)
}
return nil, fmt.Errorf("no iw or iwlist binaries found in $PATH")
diff --git a/network/net_test.go b/network/net_test.go
index 60f634ae..dcf08d8e 100644
--- a/network/net_test.go
+++ b/network/net_test.go
@@ -1,306 +1,102 @@
package network
import (
- "fmt"
"net"
- "strings"
"testing"
"github.com/evilsocket/islazy/data"
)
func TestIsZeroMac(t *testing.T) {
- tests := []struct {
- name string
- mac string
- expected bool
- }{
- {"zero mac", "00:00:00:00:00:00", true},
- {"non-zero mac", "00:00:00:00:00:01", false},
- {"broadcast mac", "ff:ff:ff:ff:ff:ff", false},
- {"random mac", "aa:bb:cc:dd:ee:ff", false},
- }
+ exampleMAC, _ := net.ParseMAC("00:00:00:00:00:00")
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- mac, _ := net.ParseMAC(tt.mac)
- if got := IsZeroMac(mac); got != tt.expected {
- t.Errorf("IsZeroMac() = %v, want %v", got, tt.expected)
- }
- })
+ exp := true
+ got := IsZeroMac(exampleMAC)
+ if got != exp {
+ t.Fatalf("expected '%t', got '%t'", exp, got)
}
}
func TestIsBroadcastMac(t *testing.T) {
- tests := []struct {
- name string
- mac string
- expected bool
- }{
- {"broadcast mac", "ff:ff:ff:ff:ff:ff", true},
- {"zero mac", "00:00:00:00:00:00", false},
- {"partial broadcast", "ff:ff:ff:ff:ff:00", false},
- {"random mac", "aa:bb:cc:dd:ee:ff", false},
- }
+ exampleMAC, _ := net.ParseMAC("ff:ff:ff:ff:ff:ff")
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- mac, _ := net.ParseMAC(tt.mac)
- if got := IsBroadcastMac(mac); got != tt.expected {
- t.Errorf("IsBroadcastMac() = %v, want %v", got, tt.expected)
- }
- })
+ exp := true
+ got := IsBroadcastMac(exampleMAC)
+ if got != exp {
+ t.Fatalf("expected '%t', got '%t'", exp, got)
}
}
func TestNormalizeMac(t *testing.T) {
- tests := []struct {
- name string
- input string
- expected string
- }{
- {"uppercase with colons", "AA:BB:CC:DD:EE:FF", "aa:bb:cc:dd:ee:ff"},
- {"uppercase with dashes", "AA-BB-CC-DD-EE-FF", "aa:bb:cc:dd:ee:ff"},
- {"lowercase with colons", "aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:ff"},
- {"mixed case with dashes", "aA-bB-cC-dD-eE-fF", "aa:bb:cc:dd:ee:ff"},
- {"short segments", "a:b:c:d:e:f", "0a:0b:0c:0d:0e:0f"},
- {"mixed short and full", "aa:b:cc:d:ee:f", "aa:0b:cc:0d:ee:0f"},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := NormalizeMac(tt.input); got != tt.expected {
- t.Errorf("NormalizeMac(%q) = %v, want %v", tt.input, got, tt.expected)
- }
- })
- }
-}
-
-func TestParseMACs(t *testing.T) {
- tests := []struct {
- name string
- input string
- expected []string
- expectError bool
- }{
- {
- name: "single MAC",
- input: "aa:bb:cc:dd:ee:ff",
- expected: []string{"aa:bb:cc:dd:ee:ff"},
- },
- {
- name: "multiple MACs comma separated",
- input: "aa:bb:cc:dd:ee:ff, 11:22:33:44:55:66",
- expected: []string{"aa:bb:cc:dd:ee:ff", "11:22:33:44:55:66"},
- },
- {
- name: "MACs with dashes",
- input: "AA-BB-CC-DD-EE-FF",
- expected: []string{"aa:bb:cc:dd:ee:ff"},
- },
- {
- name: "empty string",
- input: "",
- expected: []string{},
- },
- {
- name: "whitespace only",
- input: " ",
- expected: []string{},
- },
- {
- name: "mixed formats",
- input: "aa:bb:cc:dd:ee:ff, AA-BB-CC-DD-EE-00",
- expected: []string{"aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:00"},
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- macs, err := ParseMACs(tt.input)
- if (err != nil) != tt.expectError {
- t.Errorf("ParseMACs() error = %v, expectError %v", err, tt.expectError)
- return
- }
- if len(macs) != len(tt.expected) {
- t.Errorf("ParseMACs() returned %d MACs, want %d", len(macs), len(tt.expected))
- return
- }
- for i, mac := range macs {
- if mac.String() != tt.expected[i] {
- t.Errorf("ParseMACs()[%d] = %v, want %v", i, mac.String(), tt.expected[i])
- }
- }
- })
+ exp := "ff:ff:ff:ff:ff:ff"
+ got := NormalizeMac("fF-fF-fF-fF-fF-fF")
+ if got != exp {
+ t.Fatalf("expected '%s', got '%s'", exp, got)
}
}
+// TODO: refactor to parse targets with an actual alias map
func TestParseTargets(t *testing.T) {
aliasMap, err := data.NewMemUnsortedKV()
if err != nil {
- t.Fatal(err)
+ panic(err)
}
- aliasMap.Set("aa:bb:cc:dd:ee:ff", "test_alias")
- aliasMap.Set("11:22:33:44:55:66", "home_laptop")
+ aliasMap.Set("5c:00:0b:90:a9:f0", "test_alias")
+ aliasMap.Set("5c:00:0b:90:a9:f1", "Home_Laptop")
cases := []struct {
- name string
- inputTargets string
- inputAliases *data.UnsortedKV
- expectedIPCount int
- expectedMACCount int
- expectError bool
+ Name string
+ InputTargets string
+ InputAliases *data.UnsortedKV
+ ExpectedIPCount int
+ ExpectedMACCount int
+ ExpectedError bool
}{
+ // Not sure how to trigger sad path where macParser.FindAllString()
+ // finds a MAC but net.ParseMac() fails on the result.
{
- name: "empty target string",
- inputTargets: "",
- inputAliases: &data.UnsortedKV{},
- expectedIPCount: 0,
- expectedMACCount: 0,
- expectError: false,
+ "empty target string causes empty return",
+ "",
+ &data.UnsortedKV{},
+ 0,
+ 0,
+ false,
},
{
- name: "MACs and IPs",
- inputTargets: "192.168.1.2, 192.168.1.3, aa:bb:cc:dd:ee:ff, 11:22:33:44:55:66",
- inputAliases: &data.UnsortedKV{},
- expectedIPCount: 2,
- expectedMACCount: 2,
- expectError: false,
+ "MACs are parsed",
+ "192.168.1.2, 192.168.1.3, 5c:00:0b:90:a9:f0, 6c:00:0b:90:a9:f0, 6C:00:0B:90:A9:F0",
+ &data.UnsortedKV{},
+ 2,
+ 3,
+ false,
},
{
- name: "aliases",
- inputTargets: "test_alias, home_laptop",
- inputAliases: aliasMap,
- expectedIPCount: 0,
- expectedMACCount: 2,
- expectError: false,
- },
- {
- name: "mixed aliases and MACs",
- inputTargets: "test_alias, 99:88:77:66:55:44",
- inputAliases: aliasMap,
- expectedIPCount: 0,
- expectedMACCount: 2,
- expectError: false,
- },
- {
- name: "IP range",
- inputTargets: "192.168.1.1-3",
- inputAliases: &data.UnsortedKV{},
- expectedIPCount: 3,
- expectedMACCount: 0,
- expectError: false,
- },
- {
- name: "CIDR notation",
- inputTargets: "192.168.1.0/30",
- inputAliases: &data.UnsortedKV{},
- expectedIPCount: 4,
- expectedMACCount: 0,
- expectError: false,
- },
- {
- name: "unknown alias",
- inputTargets: "unknown_alias",
- inputAliases: aliasMap,
- expectedIPCount: 0,
- expectedMACCount: 0,
- expectError: true,
- },
- {
- name: "invalid IP",
- inputTargets: "invalid.ip.address",
- inputAliases: &data.UnsortedKV{},
- expectedIPCount: 0,
- expectedMACCount: 0,
- expectError: true,
+ "Aliases are parsed",
+ "test_alias, Home_Laptop",
+ aliasMap,
+ 0,
+ 2,
+ false,
},
}
-
for _, test := range cases {
- t.Run(test.name, func(t *testing.T) {
- ips, macs, err := ParseTargets(test.inputTargets, test.inputAliases)
- if (err != nil) != test.expectError {
- t.Errorf("ParseTargets() error = %v, expectError %v", err, test.expectError)
+ t.Run(test.Name, func(t *testing.T) {
+ ips, macs, err := ParseTargets(test.InputTargets, test.InputAliases)
+ if err != nil && !test.ExpectedError {
+ t.Errorf("unexpected error: %s", err)
}
- if test.expectError {
+ if err == nil && test.ExpectedError {
+ t.Error("Expected error, but got none")
+ }
+ if test.ExpectedError {
return
}
- if len(ips) != test.expectedIPCount {
- t.Errorf("Wrong number of IPs. Got %d, want %d", len(ips), test.expectedIPCount)
+ if len(ips) != test.ExpectedIPCount {
+ t.Errorf("Wrong number of IPs. Got %v for targets %s", ips, test.InputTargets)
}
- if len(macs) != test.expectedMACCount {
- t.Errorf("Wrong number of MACs. Got %d, want %d", len(macs), test.expectedMACCount)
- }
- })
- }
-}
-
-func TestParseEndpoints(t *testing.T) {
- // Create a mock LAN with some endpoints
- iface := NewEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff")
- gateway := NewEndpoint("192.168.1.1", "11:22:33:44:55:66")
- aliases, _ := data.NewMemUnsortedKV()
-
- // Need to provide non-nil callbacks
- newCb := func(e *Endpoint) {}
- lostCb := func(e *Endpoint) {}
- lan := NewLAN(iface, gateway, aliases, newCb, lostCb)
-
- // Add test endpoints
- lan.AddIfNew("192.168.1.10", "10:20:30:40:50:60")
- lan.AddIfNew("192.168.1.20", "20:30:40:50:60:70")
-
- // Set up an alias
- aliases.Set("10:20:30:40:50:60", "test_device")
-
- tests := []struct {
- name string
- targets string
- expectedCount int
- expectError bool
- }{
- {
- name: "single IP",
- targets: "192.168.1.10",
- expectedCount: 1,
- },
- {
- name: "single MAC",
- targets: "10:20:30:40:50:60",
- expectedCount: 1,
- },
- {
- name: "alias",
- targets: "test_device",
- expectedCount: 1,
- },
- {
- name: "multiple targets",
- targets: "192.168.1.10, 20:30:40:50:60:70",
- expectedCount: 2,
- },
- {
- name: "unknown IP",
- targets: "192.168.1.99",
- expectedCount: 0,
- },
- {
- name: "invalid target",
- targets: "invalid",
- expectError: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- endpoints, err := ParseEndpoints(tt.targets, lan)
- if (err != nil) != tt.expectError {
- t.Errorf("ParseEndpoints() error = %v, expectError %v", err, tt.expectError)
- }
- if !tt.expectError && len(endpoints) != tt.expectedCount {
- t.Errorf("ParseEndpoints() returned %d endpoints, want %d", len(endpoints), tt.expectedCount)
+ if len(macs) != test.ExpectedMACCount {
+ t.Errorf("Wrong number of MACs. Got %v for targets %s", macs, test.InputTargets)
}
})
}
@@ -309,253 +105,65 @@ func TestParseEndpoints(t *testing.T) {
func TestBuildEndpointFromInterface(t *testing.T) {
ifaces, err := net.Interfaces()
if err != nil {
- t.Skip("Unable to get network interfaces")
+ t.Error(err)
}
- if len(ifaces) == 0 {
- t.Skip("No network interfaces available")
+ if len(ifaces) <= 0 {
+ t.Error("Unable to find any network interfaces to run test with.")
}
-
- // Find a suitable interface for testing
- var testIface *net.Interface
- for _, iface := range ifaces {
- if iface.HardwareAddr != nil && len(iface.HardwareAddr) > 0 {
- testIface = &iface
- break
- }
- }
-
- if testIface == nil {
- t.Skip("No suitable network interface found for testing")
- }
-
- endpoint, err := buildEndpointFromInterface(*testIface)
+ _, err = buildEndpointFromInterface(ifaces[0])
if err != nil {
- t.Fatalf("buildEndpointFromInterface() error = %v", err)
- }
-
- if endpoint == nil {
- t.Fatal("buildEndpointFromInterface() returned nil endpoint")
- }
-
- // Verify basic properties
- if endpoint.Index != testIface.Index {
- t.Errorf("endpoint.Index = %d, want %d", endpoint.Index, testIface.Index)
- }
-
- if endpoint.HwAddress != testIface.HardwareAddr.String() {
- t.Errorf("endpoint.HwAddress = %s, want %s", endpoint.HwAddress, testIface.HardwareAddr.String())
- }
-}
-
-func TestMatchByAddress(t *testing.T) {
- // Create a mock interface for testing
- mac, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- iface := net.Interface{
- Name: "eth0",
- HardwareAddr: mac,
- }
-
- tests := []struct {
- name string
- search string
- expected bool
- }{
- {"exact MAC match", "aa:bb:cc:dd:ee:ff", true},
- {"MAC with different case", "AA:BB:CC:DD:EE:FF", true},
- {"MAC with dashes", "aa-bb-cc-dd-ee-ff", true},
- {"different MAC", "11:22:33:44:55:66", false},
- {"partial MAC", "aa:bb:cc", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := matchByAddress(iface, tt.search); got != tt.expected {
- t.Errorf("matchByAddress() = %v, want %v", got, tt.expected)
- }
- })
+ t.Error(err)
}
}
func TestFindInterfaceByName(t *testing.T) {
ifaces, err := net.Interfaces()
if err != nil {
- t.Skip("Unable to get network interfaces")
+ t.Error(err)
}
- if len(ifaces) == 0 {
- t.Skip("No network interfaces available")
+ if len(ifaces) <= 0 {
+ t.Error("Unable to find any network interfaces to run test with.")
}
-
- // Test with first available interface
- testIface := ifaces[0]
-
- // Test finding by name
- endpoint, err := findInterfaceByName(testIface.Name, ifaces)
+ var exampleIface net.Interface
+ // emulate libpcap's pcap_lookupdev function to find
+ // default interface to test with ( maybe could use loopback ? )
+ for _, iface := range ifaces {
+ if iface.HardwareAddr != nil {
+ exampleIface = iface
+ break
+ }
+ }
+ foundEndpoint, err := findInterfaceByName(exampleIface.Name, ifaces)
if err != nil {
- t.Errorf("findInterfaceByName() error = %v", err)
+ t.Error("unable to find a given interface by name to build endpoint", err)
}
- if endpoint != nil && endpoint.Name() != testIface.Name {
- t.Errorf("findInterfaceByName() returned wrong interface")
- }
-
- // Test with non-existent interface
- _, err = findInterfaceByName("nonexistent999", ifaces)
- if err == nil {
- t.Error("findInterfaceByName() should return error for non-existent interface")
+ if foundEndpoint.Name() != exampleIface.Name {
+ t.Error("unable to find a given interface by name to build endpoint")
}
}
func TestFindInterface(t *testing.T) {
- // Test with empty name (should return first suitable interface)
- endpoint, err := FindInterface("")
- if err != nil && err != ErrNoIfaces {
- t.Errorf("FindInterface() unexpected error = %v", err)
- }
-
- // Test with specific interface name
ifaces, err := net.Interfaces()
- if err == nil && len(ifaces) > 0 {
- endpoint, err = FindInterface(ifaces[0].Name)
- if err != nil {
- t.Errorf("FindInterface() error = %v", err)
- }
- if endpoint != nil && endpoint.Name() != ifaces[0].Name {
- t.Errorf("FindInterface() returned wrong interface")
+ if err != nil {
+ t.Error(err)
+ }
+ if len(ifaces) <= 0 {
+ t.Error("Unable to find any network interfaces to run test with.")
+ }
+ var exampleIface net.Interface
+ // emulate libpcap's pcap_lookupdev function to find
+ // default interface to test with ( maybe could use loopback ? )
+ for _, iface := range ifaces {
+ if iface.HardwareAddr != nil {
+ exampleIface = iface
+ break
}
}
-
- // Test with non-existent interface
- _, err = FindInterface("nonexistent999")
- if err == nil {
- t.Error("FindInterface() should return error for non-existent interface")
- }
-}
-
-func TestColorRSSI(t *testing.T) {
- tests := []struct {
- name string
- rssi int
- }{
- {"excellent signal", -30},
- {"very good signal", -67},
- {"good signal", -70},
- {"fair signal", -80},
- {"poor signal", -90},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := ColorRSSI(tt.rssi)
- // Just ensure it returns a non-empty string
- if result == "" {
- t.Error("ColorRSSI() returned empty string")
- }
- // Check it contains the dBm value
- expected := fmt.Sprintf("%d dBm", tt.rssi)
- if !strings.Contains(result, expected) {
- t.Errorf("ColorRSSI() result doesn't contain expected value %s", expected)
- }
- })
- }
-}
-
-func TestSetWiFiRegion(t *testing.T) {
- // This test will likely fail without proper permissions
- // Just ensure the function doesn't panic
- err := SetWiFiRegion("US")
- // We don't check the error as it requires root/iw binary
- _ = err
-}
-
-func TestActivateInterface(t *testing.T) {
- // This test will likely fail without proper permissions
- // Just ensure the function doesn't panic
- err := ActivateInterface("nonexistent")
- // We expect an error for non-existent interface
- if err == nil {
- t.Error("ActivateInterface() should return error for non-existent interface")
- }
-}
-
-func TestSetInterfaceTxPower(t *testing.T) {
- // This test will likely fail without proper permissions
- // Just ensure the function doesn't panic
- err := SetInterfaceTxPower("nonexistent", 20)
- // We don't check the error as it requires root/iw binary
- _ = err
-}
-
-func TestGatewayProvidedByUser(t *testing.T) {
- iface := NewEndpoint("192.168.1.100", "aa:bb:cc:dd:ee:ff")
-
- tests := []struct {
- name string
- gateway string
- expectError bool
- }{
- {
- name: "valid IPv4",
- gateway: "192.168.1.1",
- expectError: false, // Will error without actual ARP
- },
- {
- name: "invalid IPv4",
- gateway: "999.999.999.999",
- expectError: true,
- },
- {
- name: "not an IP",
- gateway: "not-an-ip",
- expectError: true,
- },
- {
- name: "IPv6",
- gateway: "fe80::1",
- expectError: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- _, err := GatewayProvidedByUser(iface, tt.gateway)
- // We always expect an error in tests as we can't do actual ARP lookup
- if err == nil {
- t.Error("GatewayProvidedByUser() expected error in test environment")
- }
- })
- }
-}
-
-// Benchmarks
-func BenchmarkNormalizeMac(b *testing.B) {
- macs := []string{
- "AA:BB:CC:DD:EE:FF",
- "aa-bb-cc-dd-ee-ff",
- "a:b:c:d:e:f",
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NormalizeMac(macs[i%len(macs)])
- }
-}
-
-func BenchmarkParseMACs(b *testing.B) {
- input := "aa:bb:cc:dd:ee:ff, 11:22:33:44:55:66, AA-BB-CC-DD-EE-FF"
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = ParseMACs(input)
- }
-}
-
-func BenchmarkParseTargets(b *testing.B) {
- aliases, _ := data.NewMemUnsortedKV()
- aliases.Set("aa:bb:cc:dd:ee:ff", "test_alias")
-
- targets := "192.168.1.1-10, aa:bb:cc:dd:ee:ff, test_alias"
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _, _ = ParseTargets(targets, aliases)
+ foundEndpoint, err := FindInterface(exampleIface.Name)
+ if err != nil {
+ t.Error("unable to find a given interface by name to build endpoint", err)
+ }
+ if foundEndpoint.Name() != exampleIface.Name {
+ t.Error("unable to find a given interface by name to build endpoint")
}
}
diff --git a/network/wifi.go b/network/wifi.go
index 29e374d0..2ec4b435 100644
--- a/network/wifi.go
+++ b/network/wifi.go
@@ -25,30 +25,22 @@ func Dot11Freq2Chan(freq int) int {
return ((freq - 5035) / 5) + 7
} else if freq >= 5875 && freq <= 5895 {
return 177
- } else if freq >= 5955 && freq <= 7115 { // 6GHz
- return ((freq - 5955) / 5) + 1
}
return 0
}
+
func Dot11Chan2Freq(channel int) int {
- if channel <= 13 {
- return ((channel - 1) * 5) + 2412
- } else if channel == 14 {
- return 2484
- } else if channel == 36 || channel == 40 || channel == 44 || channel == 48 ||
- channel == 52 || channel == 56 || channel == 60 || channel == 64 ||
- channel == 68 || channel == 72 || channel == 76 || channel == 80 ||
- channel == 100 || channel == 104 || channel == 108 || channel == 112 ||
- channel == 116 || channel == 120 || channel == 124 || channel == 128 ||
- channel == 132 || channel == 136 || channel == 140 || channel == 144 ||
- channel == 149 || channel == 153 || channel == 157 || channel == 161 ||
- channel == 165 || channel == 169 || channel == 173 || channel == 177 {
- return ((channel - 7) * 5) + 5035
-// 6GHz - Skipped 1-13 to avoid 2Ghz channels conflict
- } else if channel >= 17 && channel <= 253 {
- return ((channel - 1) * 5) + 5955
- }
- return 0
+ if channel <= 13 {
+ return ((channel - 1) * 5) + 2412
+ } else if channel == 14 {
+ return 2484
+ } else if channel <= 173 {
+ return ((channel - 7) * 5) + 5035
+ } else if channel == 177 {
+ return 5885
+ }
+
+ return 0
}
type APNewCallback func(ap *AccessPoint)
diff --git a/network/wifi_test.go b/network/wifi_test.go
index efdcdc47..96318389 100644
--- a/network/wifi_test.go
+++ b/network/wifi_test.go
@@ -1,7 +1,6 @@
package network
import (
- "net"
"testing"
"github.com/evilsocket/islazy/data"
@@ -20,14 +19,6 @@ var dot11TestVector = []dot11pair{
{5885, 177},
}
-func buildExampleEndpoint() *Endpoint {
- e := NewEndpointNoResolve("192.168.1.100", "aa:bb:cc:dd:ee:ff", "wlan0", 0)
- e.SetNetwork("192.168.1.0/24")
- _, ipNet, _ := net.ParseCIDR("192.168.1.0/24")
- e.Net = ipNet
- return e
-}
-
func buildExampleWiFi() *WiFi {
aliases := &data.UnsortedKV{}
return NewWiFi(buildExampleEndpoint(), aliases, func(ap *AccessPoint) {}, func(ap *AccessPoint) {})
diff --git a/openwrt.makefile b/openwrt.makefile
new file mode 100644
index 00000000..1e9d4eb5
--- /dev/null
+++ b/openwrt.makefile
@@ -0,0 +1,52 @@
+include $(TOPDIR)/rules.mk
+
+PKG_NAME:=bettercap
+PKG_VERSION:=2.28
+PKG_RELEASE:=2
+
+GO_PKG:=github.com/bettercap/bettercap
+
+PKG_SOURCE:=$(PKG_NAME)-$(PKG_VERSION).tar.gz
+PKG_SOURCE_URL:=https://codeload.github.com/bettercap/bettercap/tar.gz/v${PKG_VERSION}?
+PKG_HASH:=5bde85117679c6ed8b5469a5271cdd5f7e541bd9187b8d0f26dee790c37e36e9
+PKG_BUILD_DIR:=$(BUILD_DIR)/$(PKG_NAME)-$(PKG_VERSION)
+
+PKG_LICENSE:=GPL-3.0
+PKG_LICENSE_FILES:=LICENSE.md
+PKG_MAINTAINER:=Dylan Corrales
+
+PKG_BUILD_DEPENDS:=golang/host
+PKG_BUILD_PARALLEL:=1
+PKG_USE_MIPS16:=0
+
+include $(INCLUDE_DIR)/package.mk
+include ../../../packages/lang/golang/golang-package.mk
+
+define Package/bettercap/Default
+ TITLE:=The Swiss Army knife for 802.11, BLE and Ethernet networks reconnaissance and MITM attacks.
+ URL:=https://www.bettercap.org/
+ DEPENDS:=$(GO_ARCH_DEPENDS) libpcap libusb-1.0
+endef
+
+define Package/bettercap
+$(call Package/bettercap/Default)
+ SECTION:=net
+ CATEGORY:=Network
+endef
+
+define Package/bettercap/description
+ bettercap is a powerful, easily extensible and portable framework written
+ in Go which aims to offer to security researchers, red teamers and reverse
+ engineers an easy to use, all-in-one solution with all the features they
+ might possibly need for performing reconnaissance and attacking WiFi
+ networks, Bluetooth Low Energy devices, wireless HID devices and Ethernet networks.
+endef
+
+define Package/bettercap/install
+ $(call GoPackage/Package/Install/Bin,$(PKG_INSTALL_DIR))
+ $(INSTALL_DIR) $(1)/usr/bin
+ $(INSTALL_BIN) $(PKG_INSTALL_DIR)/usr/bin/bettercap $(1)/usr/bin/bettercap
+endef
+
+$(eval $(call GoBinPackage,bettercap))
+$(eval $(call BuildPackage,bettercap))
\ No newline at end of file
diff --git a/packets/icmp6_test.go b/packets/icmp6_test.go
deleted file mode 100644
index d349e95d..00000000
--- a/packets/icmp6_test.go
+++ /dev/null
@@ -1,417 +0,0 @@
-package packets
-
-import (
- "bytes"
- "net"
- "testing"
-
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
-)
-
-func TestICMP6Constants(t *testing.T) {
- // Test the multicast constants
- expectedMAC := net.HardwareAddr([]byte{0x33, 0x33, 0x00, 0x00, 0x00, 0x01})
- if !bytes.Equal(macIpv6Multicast, expectedMAC) {
- t.Errorf("macIpv6Multicast = %v, want %v", macIpv6Multicast, expectedMAC)
- }
-
- expectedIP := net.ParseIP("ff02::1")
- if !ipv6Multicast.Equal(expectedIP) {
- t.Errorf("ipv6Multicast = %v, want %v", ipv6Multicast, expectedIP)
- }
-}
-
-func TestICMP6NeighborAdvertisement(t *testing.T) {
- srcHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- srcIP := net.ParseIP("fe80::1")
- dstHW, _ := net.ParseMAC("11:22:33:44:55:66")
- dstIP := net.ParseIP("fe80::2")
- routerIP := net.ParseIP("fe80::3")
-
- err, data := ICMP6NeighborAdvertisement(srcHW, srcIP, dstHW, dstIP, routerIP)
- if err != nil {
- t.Fatalf("ICMP6NeighborAdvertisement() error = %v", err)
- }
- if len(data) == 0 {
- t.Fatal("ICMP6NeighborAdvertisement() returned empty data")
- }
-
- // Parse the packet to verify structure
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // Check Ethernet layer
- if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil {
- eth := ethLayer.(*layers.Ethernet)
- if !bytes.Equal(eth.SrcMAC, srcHW) {
- t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, srcHW)
- }
- if !bytes.Equal(eth.DstMAC, dstHW) {
- t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, dstHW)
- }
- if eth.EthernetType != layers.EthernetTypeIPv6 {
- t.Errorf("EthernetType = %v, want %v", eth.EthernetType, layers.EthernetTypeIPv6)
- }
- } else {
- t.Error("Packet missing Ethernet layer")
- }
-
- // Check IPv6 layer
- if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil {
- ip := ipLayer.(*layers.IPv6)
- if !ip.SrcIP.Equal(srcIP) {
- t.Errorf("IPv6 SrcIP = %v, want %v", ip.SrcIP, srcIP)
- }
- if !ip.DstIP.Equal(dstIP) {
- t.Errorf("IPv6 DstIP = %v, want %v", ip.DstIP, dstIP)
- }
- if ip.HopLimit != 255 {
- t.Errorf("IPv6 HopLimit = %d, want 255", ip.HopLimit)
- }
- if ip.NextHeader != layers.IPProtocolICMPv6 {
- t.Errorf("IPv6 NextHeader = %v, want %v", ip.NextHeader, layers.IPProtocolICMPv6)
- }
- } else {
- t.Error("Packet missing IPv6 layer")
- }
-
- // Check ICMPv6 layer
- if icmpLayer := packet.Layer(layers.LayerTypeICMPv6); icmpLayer != nil {
- icmp := icmpLayer.(*layers.ICMPv6)
- expectedType := uint8(layers.ICMPv6TypeNeighborAdvertisement)
- if icmp.TypeCode.Type() != expectedType {
- t.Errorf("ICMPv6 Type = %v, want %v", icmp.TypeCode.Type(), expectedType)
- }
- } else {
- t.Error("Packet missing ICMPv6 layer")
- }
-
- // Check ICMPv6NeighborAdvertisement layer
- if naLayer := packet.Layer(layers.LayerTypeICMPv6NeighborAdvertisement); naLayer != nil {
- na := naLayer.(*layers.ICMPv6NeighborAdvertisement)
- if !na.TargetAddress.Equal(routerIP) {
- t.Errorf("TargetAddress = %v, want %v", na.TargetAddress, routerIP)
- }
- // Check flags (solicited && override)
- expectedFlags := uint8(0x20 | 0x40)
- if na.Flags != expectedFlags {
- t.Errorf("Flags = %x, want %x", na.Flags, expectedFlags)
- }
- // Check options
- if len(na.Options) != 1 {
- t.Errorf("Options count = %d, want 1", len(na.Options))
- } else {
- opt := na.Options[0]
- if opt.Type != layers.ICMPv6OptTargetAddress {
- t.Errorf("Option Type = %v, want %v", opt.Type, layers.ICMPv6OptTargetAddress)
- }
- if !bytes.Equal(opt.Data, srcHW) {
- t.Errorf("Option Data = %v, want %v", opt.Data, srcHW)
- }
- }
- } else {
- t.Error("Packet missing ICMPv6NeighborAdvertisement layer")
- }
-}
-
-func TestICMP6RouterAdvertisement(t *testing.T) {
- ip := net.ParseIP("fe80::1")
- hw, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- prefix := "2001:db8::"
- prefixLength := uint8(64)
- routerLifetime := uint16(1800)
-
- err, data := ICMP6RouterAdvertisement(ip, hw, prefix, prefixLength, routerLifetime)
- if err != nil {
- t.Fatalf("ICMP6RouterAdvertisement() error = %v", err)
- }
- if len(data) == 0 {
- t.Fatal("ICMP6RouterAdvertisement() returned empty data")
- }
-
- // Parse the packet to verify structure
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // Check Ethernet layer
- if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil {
- eth := ethLayer.(*layers.Ethernet)
- if !bytes.Equal(eth.SrcMAC, hw) {
- t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, hw)
- }
- if !bytes.Equal(eth.DstMAC, macIpv6Multicast) {
- t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, macIpv6Multicast)
- }
- if eth.EthernetType != layers.EthernetTypeIPv6 {
- t.Errorf("EthernetType = %v, want %v", eth.EthernetType, layers.EthernetTypeIPv6)
- }
- } else {
- t.Error("Packet missing Ethernet layer")
- }
-
- // Check IPv6 layer
- if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil {
- ip6 := ipLayer.(*layers.IPv6)
- if !ip6.SrcIP.Equal(ip) {
- t.Errorf("IPv6 SrcIP = %v, want %v", ip6.SrcIP, ip)
- }
- if !ip6.DstIP.Equal(ipv6Multicast) {
- t.Errorf("IPv6 DstIP = %v, want %v", ip6.DstIP, ipv6Multicast)
- }
- if ip6.HopLimit != 255 {
- t.Errorf("IPv6 HopLimit = %d, want 255", ip6.HopLimit)
- }
- if ip6.NextHeader != layers.IPProtocolICMPv6 {
- t.Errorf("IPv6 NextHeader = %v, want %v", ip6.NextHeader, layers.IPProtocolICMPv6)
- }
- if ip6.TrafficClass != 224 {
- t.Errorf("IPv6 TrafficClass = %d, want 224", ip6.TrafficClass)
- }
- } else {
- t.Error("Packet missing IPv6 layer")
- }
-
- // Check ICMPv6 layer
- if icmpLayer := packet.Layer(layers.LayerTypeICMPv6); icmpLayer != nil {
- icmp := icmpLayer.(*layers.ICMPv6)
- expectedType := uint8(layers.ICMPv6TypeRouterAdvertisement)
- if icmp.TypeCode.Type() != expectedType {
- t.Errorf("ICMPv6 Type = %v, want %v", icmp.TypeCode.Type(), expectedType)
- }
- } else {
- t.Error("Packet missing ICMPv6 layer")
- }
-
- // Check ICMPv6RouterAdvertisement layer
- if raLayer := packet.Layer(layers.LayerTypeICMPv6RouterAdvertisement); raLayer != nil {
- ra := raLayer.(*layers.ICMPv6RouterAdvertisement)
- if ra.HopLimit != 255 {
- t.Errorf("HopLimit = %d, want 255", ra.HopLimit)
- }
- if ra.Flags != 0x08 {
- t.Errorf("Flags = %x, want 0x08", ra.Flags)
- }
- if ra.RouterLifetime != routerLifetime {
- t.Errorf("RouterLifetime = %d, want %d", ra.RouterLifetime, routerLifetime)
- }
- // Check options - the actual order from the code is SourceAddress, MTU, PrefixInfo
- if len(ra.Options) != 3 {
- t.Errorf("Options count = %d, want 3", len(ra.Options))
- } else {
- // Find each option type
- hasSourceAddr := false
- hasMTU := false
- hasPrefixInfo := false
-
- for _, opt := range ra.Options {
- switch opt.Type {
- case layers.ICMPv6OptSourceAddress:
- hasSourceAddr = true
- if !bytes.Equal(opt.Data, hw) {
- t.Errorf("SourceAddress option data = %v, want %v", opt.Data, hw)
- }
- case layers.ICMPv6OptMTU:
- hasMTU = true
- expectedMTU := []byte{0x00, 0x00, 0x00, 0x00, 0x05, 0xdc} // 1500
- if !bytes.Equal(opt.Data, expectedMTU) {
- t.Errorf("MTU option data = %v, want %v", opt.Data, expectedMTU)
- }
- case layers.ICMPv6OptPrefixInfo:
- hasPrefixInfo = true
- // Verify prefix length is in the data
- if len(opt.Data) > 0 && opt.Data[0] != prefixLength {
- t.Errorf("PrefixInfo prefix length = %d, want %d", opt.Data[0], prefixLength)
- }
- }
- }
-
- if !hasSourceAddr {
- t.Error("Missing SourceAddress option")
- }
- if !hasMTU {
- t.Error("Missing MTU option")
- }
- if !hasPrefixInfo {
- t.Error("Missing PrefixInfo option")
- }
- }
- } else {
- t.Error("Packet missing ICMPv6RouterAdvertisement layer")
- }
-}
-
-func TestICMP6NeighborAdvertisementWithNilValues(t *testing.T) {
- // Test with nil values - function should handle gracefully
- err, data := ICMP6NeighborAdvertisement(nil, nil, nil, nil, nil)
-
- // The function likely returns an error or empty data with nil inputs
- if err == nil && len(data) > 0 {
- t.Error("Expected error or empty data with nil values")
- }
-}
-
-func TestICMP6RouterAdvertisementWithNilValues(t *testing.T) {
- // Test with nil values - function should handle gracefully
- err, data := ICMP6RouterAdvertisement(nil, nil, "", 0, 0)
-
- // The function likely returns an error or empty data with nil inputs
- if err == nil && len(data) > 0 {
- t.Error("Expected error or empty data with nil values")
- }
-}
-
-func TestICMP6RouterAdvertisementVariousInputs(t *testing.T) {
- tests := []struct {
- name string
- ip string
- hw string
- prefix string
- prefixLength uint8
- routerLifetime uint16
- shouldError bool
- }{
- {
- name: "valid input",
- ip: "fe80::1",
- hw: "aa:bb:cc:dd:ee:ff",
- prefix: "2001:db8::",
- prefixLength: 64,
- routerLifetime: 1800,
- shouldError: false,
- },
- {
- name: "zero router lifetime",
- ip: "fe80::1",
- hw: "aa:bb:cc:dd:ee:ff",
- prefix: "2001:db8::",
- prefixLength: 64,
- routerLifetime: 0,
- shouldError: false,
- },
- {
- name: "max prefix length",
- ip: "fe80::1",
- hw: "aa:bb:cc:dd:ee:ff",
- prefix: "2001:db8::",
- prefixLength: 128,
- routerLifetime: 1800,
- shouldError: false,
- },
- {
- name: "max router lifetime",
- ip: "fe80::1",
- hw: "aa:bb:cc:dd:ee:ff",
- prefix: "2001:db8::",
- prefixLength: 64,
- routerLifetime: 65535,
- shouldError: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ip := net.ParseIP(tt.ip)
- hw, _ := net.ParseMAC(tt.hw)
-
- err, data := ICMP6RouterAdvertisement(ip, hw, tt.prefix, tt.prefixLength, tt.routerLifetime)
-
- if tt.shouldError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tt.shouldError && err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if !tt.shouldError && len(data) == 0 {
- t.Error("Expected data but got empty")
- }
- })
- }
-}
-
-func TestICMP6NeighborAdvertisementVariousInputs(t *testing.T) {
- tests := []struct {
- name string
- srcHW string
- srcIP string
- dstHW string
- dstIP string
- routerIP string
- shouldError bool
- }{
- {
- name: "valid IPv6 link-local",
- srcHW: "aa:bb:cc:dd:ee:ff",
- srcIP: "fe80::1",
- dstHW: "11:22:33:44:55:66",
- dstIP: "fe80::2",
- routerIP: "fe80::3",
- shouldError: false,
- },
- {
- name: "valid IPv6 global",
- srcHW: "aa:bb:cc:dd:ee:ff",
- srcIP: "2001:db8::1",
- dstHW: "11:22:33:44:55:66",
- dstIP: "2001:db8::2",
- routerIP: "2001:db8::3",
- shouldError: false,
- },
- {
- name: "broadcast MAC",
- srcHW: "ff:ff:ff:ff:ff:ff",
- srcIP: "fe80::1",
- dstHW: "ff:ff:ff:ff:ff:ff",
- dstIP: "fe80::2",
- routerIP: "fe80::3",
- shouldError: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- srcHW, _ := net.ParseMAC(tt.srcHW)
- srcIP := net.ParseIP(tt.srcIP)
- dstHW, _ := net.ParseMAC(tt.dstHW)
- dstIP := net.ParseIP(tt.dstIP)
- routerIP := net.ParseIP(tt.routerIP)
-
- err, data := ICMP6NeighborAdvertisement(srcHW, srcIP, dstHW, dstIP, routerIP)
-
- if tt.shouldError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tt.shouldError && err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if !tt.shouldError && len(data) == 0 {
- t.Error("Expected data but got empty")
- }
- })
- }
-}
-
-// Benchmarks
-func BenchmarkICMP6NeighborAdvertisement(b *testing.B) {
- srcHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- srcIP := net.ParseIP("fe80::1")
- dstHW, _ := net.ParseMAC("11:22:33:44:55:66")
- dstIP := net.ParseIP("fe80::2")
- routerIP := net.ParseIP("fe80::3")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = ICMP6NeighborAdvertisement(srcHW, srcIP, dstHW, dstIP, routerIP)
- }
-}
-
-func BenchmarkICMP6RouterAdvertisement(b *testing.B) {
- ip := net.ParseIP("fe80::1")
- hw, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- prefix := "2001:db8::"
- prefixLength := uint8(64)
- routerLifetime := uint16(1800)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = ICMP6RouterAdvertisement(ip, hw, prefix, prefixLength, routerLifetime)
- }
-}
diff --git a/packets/mdns_test.go b/packets/mdns_test.go
deleted file mode 100644
index 2a380cd4..00000000
--- a/packets/mdns_test.go
+++ /dev/null
@@ -1,393 +0,0 @@
-package packets
-
-import (
- "bytes"
- "net"
- "testing"
-
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
-)
-
-func TestMDNSConstants(t *testing.T) {
- if MDNSPort != 5353 {
- t.Errorf("MDNSPort = %d, want 5353", MDNSPort)
- }
-
- expectedMac := net.HardwareAddr{0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb}
- if !bytes.Equal(MDNSDestMac, expectedMac) {
- t.Errorf("MDNSDestMac = %v, want %v", MDNSDestMac, expectedMac)
- }
-
- expectedIP := net.ParseIP("224.0.0.251")
- if !MDNSDestIP.Equal(expectedIP) {
- t.Errorf("MDNSDestIP = %v, want %v", MDNSDestIP, expectedIP)
- }
-}
-
-func TestNewMDNSProbe(t *testing.T) {
- from := net.ParseIP("192.168.1.100")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- err, data := NewMDNSProbe(from, fromHW)
- if err != nil {
- t.Errorf("NewMDNSProbe() error = %v", err)
- }
- if len(data) == 0 {
- t.Error("NewMDNSProbe() returned empty data")
- }
-
- // Parse the packet to verify structure
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // Check Ethernet layer
- if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil {
- eth := ethLayer.(*layers.Ethernet)
- if !bytes.Equal(eth.SrcMAC, fromHW) {
- t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, fromHW)
- }
- if !bytes.Equal(eth.DstMAC, MDNSDestMac) {
- t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, MDNSDestMac)
- }
- } else {
- t.Error("Packet missing Ethernet layer")
- }
-
- // Check IPv4 layer
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- if !ip.SrcIP.Equal(from) {
- t.Errorf("IPv4 SrcIP = %v, want %v", ip.SrcIP, from)
- }
- if !ip.DstIP.Equal(MDNSDestIP) {
- t.Errorf("IPv4 DstIP = %v, want %v", ip.DstIP, MDNSDestIP)
- }
- } else {
- t.Error("Packet missing IPv4 layer")
- }
-
- // Check UDP layer
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- udp := udpLayer.(*layers.UDP)
- if udp.DstPort != MDNSPort {
- t.Errorf("UDP DstPort = %d, want %d", udp.DstPort, MDNSPort)
- }
- } else {
- t.Error("Packet missing UDP layer")
- }
-
- // The DNS layer is carried as payload in UDP, not a separate layer
- // So we check the UDP payload instead
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- udp := udpLayer.(*layers.UDP)
- // Verify that the UDP payload contains DNS data
- if len(udp.Payload) == 0 {
- t.Error("UDP payload is empty (should contain DNS data)")
- }
- }
-}
-
-func TestMDNSGetMeta(t *testing.T) {
- // Create a mock MDNS packet with various record types
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: MDNSDestMac,
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := layers.IPv4{
- Protocol: layers.IPProtocolUDP,
- Version: 4,
- TTL: 64,
- SrcIP: net.ParseIP("192.168.1.100"),
- DstIP: MDNSDestIP,
- }
-
- udp := layers.UDP{
- SrcPort: MDNSPort,
- DstPort: MDNSPort,
- }
-
- dns := layers.DNS{
- ID: 1,
- QR: true,
- OpCode: layers.DNSOpCodeQuery,
- Answers: []layers.DNSResourceRecord{
- {
- Name: []byte("test.local"),
- Type: layers.DNSTypeA,
- Class: layers.DNSClassIN,
- IP: net.ParseIP("192.168.1.100"),
- },
- {
- Name: []byte("test.local"),
- Type: layers.DNSTypeTXT,
- Class: layers.DNSClassIN,
- TXTs: [][]byte{[]byte("model=Test Device"), []byte("version=1.0")},
- },
- },
- }
-
- udp.SetNetworkLayerForChecksum(&ip4)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp, &dns)
- if err != nil {
- t.Fatalf("Failed to serialize packet: %v", err)
- }
-
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- meta := MDNSGetMeta(packet)
- if meta == nil {
- t.Fatal("MDNSGetMeta() returned nil")
- }
-
- // TXT records are extracted correctly
-
- if model, ok := meta["mdns:model"]; !ok || model != "Test Device" {
- t.Errorf("Expected model 'Test Device', got '%v'", model)
- }
-
- if version, ok := meta["mdns:version"]; !ok || version != "1.0" {
- t.Errorf("Expected version '1.0', got '%v'", version)
- }
-}
-
-func TestMDNSGetMetaNonMDNS(t *testing.T) {
- // Create a non-MDNS UDP packet
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := layers.IPv4{
- Protocol: layers.IPProtocolUDP,
- Version: 4,
- TTL: 64,
- SrcIP: net.ParseIP("192.168.1.100"),
- DstIP: net.ParseIP("192.168.1.200"),
- }
-
- udp := layers.UDP{
- SrcPort: 12345,
- DstPort: 80,
- }
-
- udp.SetNetworkLayerForChecksum(&ip4)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp)
- if err != nil {
- t.Fatalf("Failed to serialize packet: %v", err)
- }
-
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- meta := MDNSGetMeta(packet)
- if meta != nil {
- t.Error("MDNSGetMeta() should return nil for non-MDNS packet")
- }
-}
-
-func TestMDNSGetMetaInvalidDNS(t *testing.T) {
- // Create MDNS packet with invalid DNS payload
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: MDNSDestMac,
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := layers.IPv4{
- Protocol: layers.IPProtocolUDP,
- Version: 4,
- TTL: 64,
- SrcIP: net.ParseIP("192.168.1.100"),
- DstIP: MDNSDestIP,
- }
-
- udp := layers.UDP{
- SrcPort: MDNSPort,
- DstPort: MDNSPort,
- }
-
- udp.SetNetworkLayerForChecksum(&ip4)
- udp.Payload = []byte{0x00, 0x01, 0x02, 0x03} // Invalid DNS data
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp)
- if err != nil {
- t.Fatalf("Failed to serialize packet: %v", err)
- }
-
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- meta := MDNSGetMeta(packet)
- if meta != nil {
- t.Error("MDNSGetMeta() should return nil for invalid DNS data")
- }
-}
-
-func TestMDNSGetMetaRecovery(t *testing.T) {
- // Test that panic recovery works
- defer func() {
- if r := recover(); r != nil {
- t.Error("MDNSGetMeta should not panic")
- }
- }()
-
- // Create a minimal packet that might cause issues
- data := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05}
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- meta := MDNSGetMeta(packet)
- if meta != nil {
- t.Error("MDNSGetMeta() should return nil for invalid packet")
- }
-}
-
-func TestMDNSGetMetaWithAdditionals(t *testing.T) {
- // Create a mock MDNS packet with additional records
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: MDNSDestMac,
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := layers.IPv4{
- Protocol: layers.IPProtocolUDP,
- Version: 4,
- TTL: 64,
- SrcIP: net.ParseIP("192.168.1.100"),
- DstIP: MDNSDestIP,
- }
-
- udp := layers.UDP{
- SrcPort: MDNSPort,
- DstPort: MDNSPort,
- }
-
- dns := layers.DNS{
- ID: 1,
- QR: true,
- OpCode: layers.DNSOpCodeQuery,
- Additionals: []layers.DNSResourceRecord{
- {
- Name: []byte("additional.local"),
- Type: layers.DNSTypeAAAA,
- Class: layers.DNSClassIN,
- IP: net.ParseIP("fe80::1"),
- },
- },
- Authorities: []layers.DNSResourceRecord{
- {
- Name: []byte("authority.local"),
- Type: layers.DNSTypePTR,
- Class: layers.DNSClassIN,
- },
- },
- }
-
- udp.SetNetworkLayerForChecksum(&ip4)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- err := gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp, &dns)
- if err != nil {
- t.Fatalf("Failed to serialize packet: %v", err)
- }
-
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- meta := MDNSGetMeta(packet)
- if meta == nil {
- t.Fatal("MDNSGetMeta() returned nil")
- }
-
- if hostname, ok := meta["mdns:hostname"]; !ok || hostname != "additional.local" {
- t.Errorf("Expected hostname 'additional.local', got '%v'", hostname)
- }
-}
-
-// Benchmarks
-func BenchmarkNewMDNSProbe(b *testing.B) {
- from := net.ParseIP("192.168.1.100")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = NewMDNSProbe(from, fromHW)
- }
-}
-
-func BenchmarkMDNSGetMeta(b *testing.B) {
- // Create a sample MDNS packet
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: MDNSDestMac,
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := layers.IPv4{
- Protocol: layers.IPProtocolUDP,
- Version: 4,
- TTL: 64,
- SrcIP: net.ParseIP("192.168.1.100"),
- DstIP: MDNSDestIP,
- }
-
- udp := layers.UDP{
- SrcPort: MDNSPort,
- DstPort: MDNSPort,
- }
-
- dns := layers.DNS{
- ID: 1,
- QR: true,
- OpCode: layers.DNSOpCodeQuery,
- Answers: []layers.DNSResourceRecord{
- {
- Name: []byte("test.local"),
- Type: layers.DNSTypeA,
- Class: layers.DNSClassIN,
- IP: net.ParseIP("192.168.1.100"),
- },
- },
- }
-
- udp.SetNetworkLayerForChecksum(&ip4)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip4, &udp, &dns)
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = MDNSGetMeta(packet)
- }
-}
diff --git a/packets/mysql_test.go b/packets/mysql_test.go
deleted file mode 100644
index f807429a..00000000
--- a/packets/mysql_test.go
+++ /dev/null
@@ -1,241 +0,0 @@
-package packets
-
-import (
- "bytes"
- "testing"
-)
-
-func TestMySQLConstants(t *testing.T) {
- // Test MySQLGreeting
- if len(MySQLGreeting) != 95 {
- t.Errorf("MySQLGreeting length = %d, want 95", len(MySQLGreeting))
- }
- // Check some key bytes in the greeting
- if MySQLGreeting[0] != 0x5b {
- t.Errorf("MySQLGreeting[0] = 0x%02x, want 0x5b", MySQLGreeting[0])
- }
- // Check version string starts at byte 5
- versionBytes := MySQLGreeting[5:12]
- expectedVersion := []byte("5.6.28-")
- if !bytes.Equal(versionBytes, expectedVersion) {
- t.Errorf("MySQL version = %s, want %s", versionBytes, expectedVersion)
- }
-
- // Test MySQLFirstResponseOK
- if len(MySQLFirstResponseOK) != 11 {
- t.Errorf("MySQLFirstResponseOK length = %d, want 11", len(MySQLFirstResponseOK))
- }
- // Check packet sequence number
- if MySQLFirstResponseOK[3] != 0x02 {
- t.Errorf("MySQLFirstResponseOK sequence = 0x%02x, want 0x02", MySQLFirstResponseOK[3])
- }
-
- // Test MySQLSecondResponseOK
- if len(MySQLSecondResponseOK) != 11 {
- t.Errorf("MySQLSecondResponseOK length = %d, want 11", len(MySQLSecondResponseOK))
- }
- // Check packet sequence number
- if MySQLSecondResponseOK[3] != 0x04 {
- t.Errorf("MySQLSecondResponseOK sequence = 0x%02x, want 0x04", MySQLSecondResponseOK[3])
- }
-}
-
-func TestMySQLGetFile(t *testing.T) {
- tests := []struct {
- name string
- infile string
- expected []byte
- }{
- {
- name: "empty filename",
- infile: "",
- expected: []byte{
- 0x01, // length + 1
- 0x00, 0x00, 0x01, 0xfb, // header
- },
- },
- {
- name: "short filename",
- infile: "test.txt",
- expected: []byte{
- 0x09, // length of "test.txt" + 1 = 9
- 0x00, 0x00, 0x01, 0xfb, // header
- 't', 'e', 's', 't', '.', 't', 'x', 't',
- },
- },
- {
- name: "path with directory",
- infile: "/etc/passwd",
- expected: []byte{
- 0x0c, // length of "/etc/passwd" + 1 = 12
- 0x00, 0x00, 0x01, 0xfb, // header
- '/', 'e', 't', 'c', '/', 'p', 'a', 's', 's', 'w', 'd',
- },
- },
- {
- name: "windows path",
- infile: "C:\\Windows\\System32\\config\\sam",
- expected: []byte{
- 0x1f, // length of path + 1 = 31
- 0x00, 0x00, 0x01, 0xfb, // header
- 'C', ':', '\\', 'W', 'i', 'n', 'd', 'o', 'w', 's', '\\',
- 'S', 'y', 's', 't', 'e', 'm', '3', '2', '\\',
- 'c', 'o', 'n', 'f', 'i', 'g', '\\', 's', 'a', 'm',
- },
- },
- {
- name: "unicode filename",
- infile: "файл.txt",
- expected: func() []byte {
- filename := "файл.txt"
- result := []byte{
- byte(len(filename) + 1),
- 0x00, 0x00, 0x01, 0xfb,
- }
- return append(result, []byte(filename)...)
- }(),
- },
- {
- name: "max length filename",
- infile: string(make([]byte, 254)), // Max that fits in a single byte length
- expected: func() []byte {
- result := []byte{
- 0xff, // 254 + 1 = 255
- 0x00, 0x00, 0x01, 0xfb,
- }
- return append(result, make([]byte, 254)...)
- }(),
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := MySQLGetFile(tt.infile)
- if !bytes.Equal(result, tt.expected) {
- t.Errorf("MySQLGetFile(%q) = %v, want %v", tt.infile, result, tt.expected)
- }
- })
- }
-}
-
-func TestMySQLGetFileLength(t *testing.T) {
- // Test that the length byte is correctly calculated
- testCases := []struct {
- filename string
- expected byte
- }{
- {"", 0x01},
- {"a", 0x02},
- {"ab", 0x03},
- {"abc", 0x04},
- {"test.txt", 0x09},
- {string(make([]byte, 100)), 0x65}, // 100 + 1 = 101 = 0x65
- {string(make([]byte, 254)), 0xff}, // 254 + 1 = 255 = 0xff
- }
-
- for _, tc := range testCases {
- result := MySQLGetFile(tc.filename)
- if result[0] != tc.expected {
- t.Errorf("MySQLGetFile(%q) length byte = 0x%02x, want 0x%02x",
- tc.filename, result[0], tc.expected)
- }
- }
-}
-
-func TestMySQLGetFileHeader(t *testing.T) {
- // Test that the header bytes are always the same
- expectedHeader := []byte{0x00, 0x00, 0x01, 0xfb}
-
- filenames := []string{
- "",
- "test",
- "long_filename_with_many_characters.txt",
- "/path/to/file",
- "C:\\Windows\\file.exe",
- }
-
- for _, filename := range filenames {
- result := MySQLGetFile(filename)
- if len(result) < 5 {
- t.Errorf("MySQLGetFile(%q) returned packet too short: %d bytes", filename, len(result))
- continue
- }
-
- header := result[1:5]
- if !bytes.Equal(header, expectedHeader) {
- t.Errorf("MySQLGetFile(%q) header = %v, want %v", filename, header, expectedHeader)
- }
- }
-}
-
-func TestMySQLPacketStructure(t *testing.T) {
- // Test the overall packet structure
- filename := "test_file.sql"
- packet := MySQLGetFile(filename)
-
- // Check minimum packet size (1 byte length + 4 bytes header)
- if len(packet) < 5 {
- t.Fatalf("Packet too short: %d bytes", len(packet))
- }
-
- // Check that packet length matches expected
- expectedLen := 1 + 4 + len(filename) // length byte + header + filename
- if len(packet) != expectedLen {
- t.Errorf("Packet length = %d, want %d", len(packet), expectedLen)
- }
-
- // Check that the length byte correctly represents filename length + 1
- if packet[0] != byte(len(filename)+1) {
- t.Errorf("Length byte = %d, want %d", packet[0], len(filename)+1)
- }
-
- // Check that the filename is correctly appended
- filenameInPacket := string(packet[5:])
- if filenameInPacket != filename {
- t.Errorf("Filename in packet = %q, want %q", filenameInPacket, filename)
- }
-}
-
-func TestMySQLGreetingStructure(t *testing.T) {
- // Test specific parts of the MySQL greeting packet
- greeting := MySQLGreeting
-
- // The greeting should contain "mysql_native_password" at the end
- expectedSuffix := "mysql_native_password"
- suffixStart := len(greeting) - len(expectedSuffix) - 1 // -1 for null terminator
- suffix := string(greeting[suffixStart : suffixStart+len(expectedSuffix)])
-
- if suffix != expectedSuffix {
- t.Errorf("Greeting suffix = %q, want %q", suffix, expectedSuffix)
- }
-
- // Check null terminator
- if greeting[len(greeting)-1] != 0x00 {
- t.Error("Greeting should end with null terminator")
- }
-}
-
-// Benchmarks
-func BenchmarkMySQLGetFile(b *testing.B) {
- filename := "/etc/passwd"
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = MySQLGetFile(filename)
- }
-}
-
-func BenchmarkMySQLGetFileShort(b *testing.B) {
- filename := "a.txt"
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = MySQLGetFile(filename)
- }
-}
-
-func BenchmarkMySQLGetFileLong(b *testing.B) {
- filename := string(make([]byte, 200))
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = MySQLGetFile(filename)
- }
-}
diff --git a/packets/nbns_test.go b/packets/nbns_test.go
deleted file mode 100644
index 5e172d3b..00000000
--- a/packets/nbns_test.go
+++ /dev/null
@@ -1,351 +0,0 @@
-package packets
-
-import (
- "bytes"
- "net"
- "testing"
-
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
-)
-
-func TestNBNSConstants(t *testing.T) {
- if NBNSPort != 137 {
- t.Errorf("NBNSPort = %d, want 137", NBNSPort)
- }
-
- if NBNSMinRespSize != 73 {
- t.Errorf("NBNSMinRespSize = %d, want 73", NBNSMinRespSize)
- }
-}
-
-func TestNBNSRequest(t *testing.T) {
- // Test the structure of NBNSRequest
- if len(NBNSRequest) != 50 {
- t.Errorf("NBNSRequest length = %d, want 50", len(NBNSRequest))
- }
-
- // Check key bytes in the request
- expectedStart := []byte{0x82, 0x28, 0x00, 0x00, 0x00, 0x01}
- if !bytes.Equal(NBNSRequest[0:6], expectedStart) {
- t.Errorf("NBNSRequest start = %v, want %v", NBNSRequest[0:6], expectedStart)
- }
-
- // Check the encoded name section (starts at byte 12)
- // NBNS encodes names with 0x43 ('C') prefix followed by encoded characters
- if NBNSRequest[12] != 0x20 {
- t.Errorf("NBNSRequest[12] = 0x%02x, want 0x20", NBNSRequest[12])
- }
- if NBNSRequest[13] != 0x43 {
- t.Errorf("NBNSRequest[13] = 0x%02x, want 0x43 (C)", NBNSRequest[13])
- }
-
- // Check the query type and class at the end
- expectedEnd := []byte{0x00, 0x00, 0x21, 0x00, 0x01}
- if !bytes.Equal(NBNSRequest[45:50], expectedEnd) {
- t.Errorf("NBNSRequest end = %v, want %v", NBNSRequest[45:50], expectedEnd)
- }
-}
-
-func TestNBNSGetMeta(t *testing.T) {
- tests := []struct {
- name string
- buildPacket func() gopacket.Packet
- expectNil bool
- }{
- {
- name: "non-NBNS packet (wrong port)",
- buildPacket: func() gopacket.Packet {
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
-
- udp := layers.UDP{
- SrcPort: 80, // Not NBNS port
- DstPort: 12345,
- }
-
- payload := make([]byte, NBNSMinRespSize)
- udp.Payload = payload
- udp.SetNetworkLayerForChecksum(&ip)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip, &udp)
- return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
- },
- expectNil: true,
- },
- {
- name: "NBNS packet with insufficient payload",
- buildPacket: func() gopacket.Packet {
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
-
- udp := layers.UDP{
- SrcPort: NBNSPort,
- DstPort: 12345,
- }
-
- // Payload too small (less than NBNSMinRespSize)
- payload := make([]byte, 50)
- udp.Payload = payload
- udp.SetNetworkLayerForChecksum(&ip)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip, &udp)
- return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
- },
- expectNil: true,
- },
- {
- name: "NBNS packet with non-printable hostname",
- buildPacket: func() gopacket.Packet {
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
-
- udp := layers.UDP{
- SrcPort: NBNSPort,
- DstPort: 12345,
- }
-
- payload := make([]byte, NBNSMinRespSize)
- // Set non-printable character at the start of hostname
- payload[57] = 0x01 // Non-printable
- copy(payload[58:72], []byte("WORKSTATION "))
-
- udp.Payload = payload
- udp.SetNetworkLayerForChecksum(&ip)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip, &udp)
- return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
- },
- expectNil: true,
- },
- {
- name: "packet without UDP layer",
- buildPacket: func() gopacket.Packet {
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolTCP, // TCP instead of UDP
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip)
- return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
- },
- expectNil: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- packet := tt.buildPacket()
- meta := NBNSGetMeta(packet)
-
- // Due to a bug in NBNSGetMeta where it doesn't check if hostname is empty
- // after trimming, we just verify it doesn't panic
- _ = meta
- })
- }
-}
-
-func TestNBNSBasicFunctionality(t *testing.T) {
- // Test that NBNSGetMeta doesn't panic on various inputs
- tests := []struct {
- name string
- buildPacket func() gopacket.Packet
- }{
- {
- name: "valid packet",
- buildPacket: func() gopacket.Packet {
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
- udp := layers.UDP{
- SrcPort: NBNSPort,
- DstPort: 12345,
- }
- payload := make([]byte, NBNSMinRespSize)
- copy(payload[57:72], []byte("WORKSTATION "))
- udp.Payload = payload
- udp.SetNetworkLayerForChecksum(&ip)
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
- gopacket.SerializeLayers(buf, opts, ð, &ip, &udp)
- return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
- },
- },
- {
- name: "empty packet",
- buildPacket: func() gopacket.Packet {
- return gopacket.NewPacket([]byte{}, layers.LayerTypeEthernet, gopacket.Default)
- },
- },
- {
- name: "non-UDP packet",
- buildPacket: func() gopacket.Packet {
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeARP,
- }
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
- gopacket.SerializeLayers(buf, opts, ð)
- return gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- packet := tt.buildPacket()
- // Just verify it doesn't panic
- _ = NBNSGetMeta(packet)
- })
- }
-}
-
-// Benchmarks
-func BenchmarkNBNSGetMeta(b *testing.B) {
- // Create a sample NBNS packet
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
-
- udp := layers.UDP{
- SrcPort: NBNSPort,
- DstPort: 12345,
- }
-
- payload := make([]byte, NBNSMinRespSize)
- copy(payload[57:72], []byte("WORKSTATION "))
-
- udp.Payload = payload
- udp.SetNetworkLayerForChecksum(&ip)
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip, &udp)
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NBNSGetMeta(packet)
- }
-}
-
-func BenchmarkNBNSGetMetaNonNBNS(b *testing.B) {
- // Create a non-NBNS packet to test early exit performance
- eth := layers.Ethernet{
- SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip := layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolTCP,
- SrcIP: net.IP{192, 168, 1, 100},
- DstIP: net.IP{192, 168, 1, 200},
- }
-
- buf := gopacket.NewSerializeBuffer()
- opts := gopacket.SerializeOptions{
- FixLengths: true,
- ComputeChecksums: true,
- }
-
- gopacket.SerializeLayers(buf, opts, ð, &ip)
- packet := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = NBNSGetMeta(packet)
- }
-}
diff --git a/packets/serialize_test.go b/packets/serialize_test.go
deleted file mode 100644
index 10a19057..00000000
--- a/packets/serialize_test.go
+++ /dev/null
@@ -1,403 +0,0 @@
-package packets
-
-import (
- "bytes"
- "testing"
-
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
-)
-
-func TestSerializationOptions(t *testing.T) {
- // Verify the global serialization options are set correctly
- if !SerializationOptions.FixLengths {
- t.Error("SerializationOptions.FixLengths should be true")
- }
- if !SerializationOptions.ComputeChecksums {
- t.Error("SerializationOptions.ComputeChecksums should be true")
- }
-}
-
-func TestSerialize(t *testing.T) {
- tests := []struct {
- name string
- layers []gopacket.SerializableLayer
- expectError bool
- minLength int
- }{
- {
- name: "simple ethernet frame",
- layers: []gopacket.SerializableLayer{
- &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- },
- },
- expectError: false,
- minLength: 14, // Ethernet header
- },
- {
- name: "ethernet with IPv4",
- layers: []gopacket.SerializableLayer{
- &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- },
- &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolTCP,
- TTL: 64,
- SrcIP: []byte{192, 168, 1, 1},
- DstIP: []byte{192, 168, 1, 2},
- },
- },
- expectError: false,
- minLength: 34, // Ethernet + IPv4 headers
- },
- {
- name: "complete TCP packet",
- layers: func() []gopacket.SerializableLayer {
- ip4 := &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolTCP,
- TTL: 64,
- SrcIP: []byte{192, 168, 1, 1},
- DstIP: []byte{192, 168, 1, 2},
- }
- tcp := &layers.TCP{
- SrcPort: 12345,
- DstPort: 80,
- Seq: 1000,
- Ack: 0,
- SYN: true,
- Window: 65535,
- }
- tcp.SetNetworkLayerForChecksum(ip4)
- return []gopacket.SerializableLayer{
- &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- },
- ip4,
- tcp,
- }
- }(),
- expectError: false,
- minLength: 54, // Ethernet + IPv4 + TCP headers
- },
- {
- name: "empty layers",
- layers: []gopacket.SerializableLayer{},
- expectError: false,
- minLength: 0,
- },
- {
- name: "layer with payload",
- layers: []gopacket.SerializableLayer{
- &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- },
- gopacket.Payload([]byte("Hello, World!")),
- },
- expectError: false,
- minLength: 27, // Ethernet header + payload
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err, data := Serialize(tt.layers...)
-
- if tt.expectError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tt.expectError && err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
-
- if err == nil {
- if len(data) < tt.minLength {
- t.Errorf("Data length %d is less than expected minimum %d", len(data), tt.minLength)
- }
-
- // For non-empty results, verify we can parse it back
- if len(data) > 0 && len(tt.layers) > 0 {
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
- if packet == nil {
- t.Error("Failed to parse serialized data")
- }
- }
- }
- })
- }
-}
-
-func TestSerializeWithChecksum(t *testing.T) {
- // Test that checksums are computed correctly
- ip4 := &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- TTL: 64,
- SrcIP: []byte{192, 168, 1, 1},
- DstIP: []byte{192, 168, 1, 2},
- }
-
- udp := &layers.UDP{
- SrcPort: 12345,
- DstPort: 53,
- }
-
- // Set network layer for checksum computation
- udp.SetNetworkLayerForChecksum(ip4)
-
- eth := &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- err, data := Serialize(eth, ip4, udp)
- if err != nil {
- t.Fatalf("Failed to serialize: %v", err)
- }
-
- // Parse back and verify checksums
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- // The checksum should be computed (non-zero)
- if ip.Checksum == 0 {
- t.Error("IPv4 checksum was not computed")
- }
- } else {
- t.Error("IPv4 layer not found in packet")
- }
-
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- udp := udpLayer.(*layers.UDP)
- // The checksum should be computed (non-zero for UDP over IPv4)
- if udp.Checksum == 0 {
- t.Error("UDP checksum was not computed")
- }
- } else {
- t.Error("UDP layer not found in packet")
- }
-}
-
-func TestSerializeFixLengths(t *testing.T) {
- // Test that lengths are fixed correctly
- ip4 := &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolTCP,
- TTL: 64,
- SrcIP: []byte{10, 0, 0, 1},
- DstIP: []byte{10, 0, 0, 2},
- // Don't set Length - it should be computed
- }
-
- tcp := &layers.TCP{
- SrcPort: 80,
- DstPort: 12345,
- Seq: 1000,
- SYN: true,
- Window: 65535,
- }
-
- tcp.SetNetworkLayerForChecksum(ip4)
-
- payload := gopacket.Payload([]byte("Test payload data"))
-
- eth := &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- err, data := Serialize(eth, ip4, tcp, payload)
- if err != nil {
- t.Fatalf("Failed to serialize: %v", err)
- }
-
- // Parse back and verify lengths
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- expectedLen := 20 + 20 + len("Test payload data") // IPv4 header + TCP header + payload
- if ip.Length != uint16(expectedLen) {
- t.Errorf("IPv4 length = %d, want %d", ip.Length, expectedLen)
- }
- } else {
- t.Error("IPv4 layer not found in packet")
- }
-}
-
-func TestSerializeErrorHandling(t *testing.T) {
- // Test serialization with an invalid layer configuration
- // This test is a bit tricky because gopacket is quite forgiving
- // We'll create a scenario that might fail in serialization
-
- // Create an ethernet layer with invalid type for the next layer
- eth := &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- // Follow with a non-IPv4 layer when IPv4 is expected
- // This actually won't cause an error in gopacket, so we test that errors are handled
- tcp := &layers.TCP{
- SrcPort: 80,
- DstPort: 12345,
- }
-
- err, data := Serialize(eth, tcp)
- // This might not actually error, but we're testing the error handling path
- if err != nil {
- // Error path - should return nil data
- if data != nil {
- t.Error("When error occurs, data should be nil")
- }
- } else {
- // Success path - should return data
- if data == nil {
- t.Error("When no error, data should not be nil")
- }
- }
-}
-
-func TestSerializeMultiplePackets(t *testing.T) {
- // Test serializing multiple different packet types in sequence
- srcMAC := []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}
- dstMAC := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66}
-
- packets := []struct {
- name string
- layers []gopacket.SerializableLayer
- }{
- {
- name: "ARP request",
- layers: []gopacket.SerializableLayer{
- &layers.Ethernet{
- SrcMAC: srcMAC,
- DstMAC: dstMAC,
- EthernetType: layers.EthernetTypeARP,
- },
- &layers.ARP{
- AddrType: layers.LinkTypeEthernet,
- Protocol: layers.EthernetTypeIPv4,
- HwAddressSize: 6,
- ProtAddressSize: 4,
- Operation: layers.ARPRequest,
- SourceHwAddress: srcMAC,
- SourceProtAddress: []byte{192, 168, 1, 100},
- DstHwAddress: []byte{0, 0, 0, 0, 0, 0},
- DstProtAddress: []byte{192, 168, 1, 1},
- },
- },
- },
- {
- name: "ICMP echo",
- layers: []gopacket.SerializableLayer{
- &layers.Ethernet{
- SrcMAC: srcMAC,
- DstMAC: dstMAC,
- EthernetType: layers.EthernetTypeIPv4,
- },
- &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolICMPv4,
- TTL: 64,
- SrcIP: []byte{192, 168, 1, 100},
- DstIP: []byte{8, 8, 8, 8},
- },
- &layers.ICMPv4{
- TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
- Id: 1,
- Seq: 1,
- },
- gopacket.Payload([]byte("ping")),
- },
- },
- }
-
- for _, pkt := range packets {
- t.Run(pkt.name, func(t *testing.T) {
- err, data := Serialize(pkt.layers...)
- if err != nil {
- t.Errorf("Failed to serialize %s: %v", pkt.name, err)
- }
- if len(data) == 0 {
- t.Errorf("Serialized %s has zero length", pkt.name)
- }
- })
- }
-}
-
-// Benchmarks
-func BenchmarkSerialize(b *testing.B) {
- eth := &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolTCP,
- TTL: 64,
- SrcIP: []byte{192, 168, 1, 1},
- DstIP: []byte{192, 168, 1, 2},
- }
-
- tcp := &layers.TCP{
- SrcPort: 12345,
- DstPort: 80,
- Seq: 1000,
- SYN: true,
- Window: 65535,
- }
-
- tcp.SetNetworkLayerForChecksum(ip4)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = Serialize(eth, ip4, tcp)
- }
-}
-
-func BenchmarkSerializeWithPayload(b *testing.B) {
- eth := &layers.Ethernet{
- SrcMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
- DstMAC: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x02},
- EthernetType: layers.EthernetTypeIPv4,
- }
-
- ip4 := &layers.IPv4{
- Version: 4,
- Protocol: layers.IPProtocolUDP,
- TTL: 64,
- SrcIP: []byte{192, 168, 1, 1},
- DstIP: []byte{192, 168, 1, 2},
- }
-
- udp := &layers.UDP{
- SrcPort: 12345,
- DstPort: 53,
- }
-
- udp.SetNetworkLayerForChecksum(ip4)
-
- payload := gopacket.Payload(bytes.Repeat([]byte("x"), 1024))
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = Serialize(eth, ip4, udp, payload)
- }
-}
diff --git a/packets/tcp_test.go b/packets/tcp_test.go
deleted file mode 100644
index 87829ea1..00000000
--- a/packets/tcp_test.go
+++ /dev/null
@@ -1,354 +0,0 @@
-package packets
-
-import (
- "bytes"
- "net"
- "testing"
-
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
-)
-
-func TestNewTCPSyn(t *testing.T) {
- tests := []struct {
- name string
- from string
- fromHW string
- to string
- toHW string
- srcPort int
- dstPort int
- expectError bool
- expectIPv6 bool
- }{
- {
- name: "IPv4 TCP SYN",
- from: "192.168.1.100",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "192.168.1.200",
- toHW: "11:22:33:44:55:66",
- srcPort: 12345,
- dstPort: 80,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "IPv6 TCP SYN",
- from: "2001:db8::1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "2001:db8::2",
- toHW: "11:22:33:44:55:66",
- srcPort: 54321,
- dstPort: 443,
- expectError: false,
- expectIPv6: true,
- },
- {
- name: "IPv4 with different ports",
- from: "10.0.0.1",
- fromHW: "01:23:45:67:89:ab",
- to: "10.0.0.2",
- toHW: "cd:ef:01:23:45:67",
- srcPort: 8080,
- dstPort: 3306,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "IPv6 link-local addresses",
- from: "fe80::1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "fe80::2",
- toHW: "11:22:33:44:55:66",
- srcPort: 1234,
- dstPort: 5678,
- expectError: false,
- expectIPv6: true,
- },
- {
- name: "IPv4 loopback",
- from: "127.0.0.1",
- fromHW: "00:00:00:00:00:00",
- to: "127.0.0.1",
- toHW: "00:00:00:00:00:00",
- srcPort: 9000,
- dstPort: 9001,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "IPv6 loopback",
- from: "::1",
- fromHW: "00:00:00:00:00:00",
- to: "::1",
- toHW: "00:00:00:00:00:00",
- srcPort: 9000,
- dstPort: 9001,
- expectError: false,
- expectIPv6: true,
- },
- {
- name: "Max port number",
- from: "192.168.1.1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "192.168.1.2",
- toHW: "11:22:33:44:55:66",
- srcPort: 65535,
- dstPort: 65535,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "Min port number",
- from: "192.168.1.1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "192.168.1.2",
- toHW: "11:22:33:44:55:66",
- srcPort: 1,
- dstPort: 1,
- expectError: false,
- expectIPv6: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- from := net.ParseIP(tt.from)
- fromHW, _ := net.ParseMAC(tt.fromHW)
- to := net.ParseIP(tt.to)
- toHW, _ := net.ParseMAC(tt.toHW)
-
- err, data := NewTCPSyn(from, fromHW, to, toHW, tt.srcPort, tt.dstPort)
-
- if tt.expectError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tt.expectError && err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
-
- if err == nil {
- if len(data) == 0 {
- t.Error("Expected data but got empty")
- }
-
- // Parse the packet to verify structure
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // Check Ethernet layer
- if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil {
- eth := ethLayer.(*layers.Ethernet)
- if !bytes.Equal(eth.SrcMAC, fromHW) {
- t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, fromHW)
- }
- if !bytes.Equal(eth.DstMAC, toHW) {
- t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, toHW)
- }
- expectedType := layers.EthernetTypeIPv4
- if tt.expectIPv6 {
- expectedType = layers.EthernetTypeIPv6
- }
- if eth.EthernetType != expectedType {
- t.Errorf("EthernetType = %v, want %v", eth.EthernetType, expectedType)
- }
- } else {
- t.Error("Packet missing Ethernet layer")
- }
-
- // Check IP layer
- if tt.expectIPv6 {
- if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil {
- ip := ipLayer.(*layers.IPv6)
- if !ip.SrcIP.Equal(from) {
- t.Errorf("IPv6 SrcIP = %v, want %v", ip.SrcIP, from)
- }
- if !ip.DstIP.Equal(to) {
- t.Errorf("IPv6 DstIP = %v, want %v", ip.DstIP, to)
- }
- if ip.HopLimit != 64 {
- t.Errorf("IPv6 HopLimit = %d, want 64", ip.HopLimit)
- }
- if ip.NextHeader != layers.IPProtocolTCP {
- t.Errorf("IPv6 NextHeader = %v, want %v", ip.NextHeader, layers.IPProtocolTCP)
- }
- } else {
- t.Error("Packet missing IPv6 layer")
- }
- } else {
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- if !ip.SrcIP.Equal(from) {
- t.Errorf("IPv4 SrcIP = %v, want %v", ip.SrcIP, from)
- }
- if !ip.DstIP.Equal(to) {
- t.Errorf("IPv4 DstIP = %v, want %v", ip.DstIP, to)
- }
- if ip.TTL != 64 {
- t.Errorf("IPv4 TTL = %d, want 64", ip.TTL)
- }
- if ip.Protocol != layers.IPProtocolTCP {
- t.Errorf("IPv4 Protocol = %v, want %v", ip.Protocol, layers.IPProtocolTCP)
- }
- } else {
- t.Error("Packet missing IPv4 layer")
- }
- }
-
- // Check TCP layer
- if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
- tcp := tcpLayer.(*layers.TCP)
- if tcp.SrcPort != layers.TCPPort(tt.srcPort) {
- t.Errorf("TCP SrcPort = %d, want %d", tcp.SrcPort, tt.srcPort)
- }
- if tcp.DstPort != layers.TCPPort(tt.dstPort) {
- t.Errorf("TCP DstPort = %d, want %d", tcp.DstPort, tt.dstPort)
- }
- if !tcp.SYN {
- t.Error("TCP SYN flag not set")
- }
- // Verify other flags are not set
- if tcp.ACK || tcp.FIN || tcp.RST || tcp.PSH || tcp.URG {
- t.Error("TCP has unexpected flags set")
- }
- } else {
- t.Error("Packet missing TCP layer")
- }
- }
- })
- }
-}
-
-func TestNewTCPSynWithNilValues(t *testing.T) {
- // Test with nil IPs - should return an error
- err, data := NewTCPSyn(nil, nil, nil, nil, 12345, 80)
- if err == nil {
- t.Error("Expected error with nil values, but got none")
- }
- if len(data) != 0 {
- t.Error("Expected no data with nil values")
- }
-}
-
-func TestNewTCPSynChecksumComputation(t *testing.T) {
- // Test that checksums are computed correctly for both IPv4 and IPv6
- testCases := []struct {
- name string
- from string
- to string
- isIPv6 bool
- }{
- {
- name: "IPv4 checksum",
- from: "192.168.1.1",
- to: "192.168.1.2",
- isIPv6: false,
- },
- {
- name: "IPv6 checksum",
- from: "2001:db8::1",
- to: "2001:db8::2",
- isIPv6: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- from := net.ParseIP(tc.from)
- to := net.ParseIP(tc.to)
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- toHW, _ := net.ParseMAC("11:22:33:44:55:66")
-
- err, data := NewTCPSyn(from, fromHW, to, toHW, 12345, 80)
- if err != nil {
- t.Fatalf("Failed to create TCP SYN: %v", err)
- }
-
- // Parse the packet
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // Verify TCP checksum is non-zero (computed)
- if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
- tcp := tcpLayer.(*layers.TCP)
- if tcp.Checksum == 0 {
- t.Error("TCP checksum was not computed")
- }
- } else {
- t.Error("TCP layer not found")
- }
-
- // For IPv4, also check IP checksum
- if !tc.isIPv6 {
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- if ip.Checksum == 0 {
- t.Error("IPv4 checksum was not computed")
- }
- }
- }
- })
- }
-}
-
-func TestNewTCPSynPortRange(t *testing.T) {
- // Test various port numbers including edge cases
- portTests := []struct {
- srcPort int
- dstPort int
- }{
- {0, 0}, // Minimum possible (though 0 is typically reserved)
- {1, 1}, // Minimum valid
- {80, 443}, // Common ports
- {1024, 1025}, // First non-privileged ports
- {32768, 32769}, // Common ephemeral port range start
- {65534, 65535}, // Maximum ports
- }
-
- from := net.ParseIP("192.168.1.1")
- to := net.ParseIP("192.168.1.2")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- toHW, _ := net.ParseMAC("11:22:33:44:55:66")
-
- for _, pt := range portTests {
- err, data := NewTCPSyn(from, fromHW, to, toHW, pt.srcPort, pt.dstPort)
- if err != nil {
- t.Errorf("Failed with ports %d->%d: %v", pt.srcPort, pt.dstPort, err)
- continue
- }
-
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
- if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
- tcp := tcpLayer.(*layers.TCP)
- if tcp.SrcPort != layers.TCPPort(pt.srcPort) {
- t.Errorf("TCP SrcPort = %d, want %d", tcp.SrcPort, pt.srcPort)
- }
- if tcp.DstPort != layers.TCPPort(pt.dstPort) {
- t.Errorf("TCP DstPort = %d, want %d", tcp.DstPort, pt.dstPort)
- }
- }
- }
-}
-
-// Benchmarks
-func BenchmarkNewTCPSynIPv4(b *testing.B) {
- from := net.ParseIP("192.168.1.1")
- to := net.ParseIP("192.168.1.2")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- toHW, _ := net.ParseMAC("11:22:33:44:55:66")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = NewTCPSyn(from, fromHW, to, toHW, 12345, 80)
- }
-}
-
-func BenchmarkNewTCPSynIPv6(b *testing.B) {
- from := net.ParseIP("2001:db8::1")
- to := net.ParseIP("2001:db8::2")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- toHW, _ := net.ParseMAC("11:22:33:44:55:66")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = NewTCPSyn(from, fromHW, to, toHW, 12345, 80)
- }
-}
diff --git a/packets/udp_test.go b/packets/udp_test.go
deleted file mode 100644
index 11493ae5..00000000
--- a/packets/udp_test.go
+++ /dev/null
@@ -1,366 +0,0 @@
-package packets
-
-import (
- "bytes"
- "net"
- "testing"
-
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
-)
-
-func TestNewUDPProbe(t *testing.T) {
- tests := []struct {
- name string
- from string
- fromHW string
- to string
- port int
- expectError bool
- expectIPv6 bool
- }{
- {
- name: "IPv4 UDP probe",
- from: "192.168.1.100",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "192.168.1.200",
- port: 53,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "IPv6 UDP probe",
- from: "2001:db8::1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "2001:db8::2",
- port: 53,
- expectError: false,
- expectIPv6: true,
- },
- {
- name: "IPv4 with high port",
- from: "10.0.0.1",
- fromHW: "01:23:45:67:89:ab",
- to: "10.0.0.2",
- port: 65535,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "IPv6 link-local",
- from: "fe80::1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "fe80::2",
- port: 123,
- expectError: false,
- expectIPv6: true,
- },
- {
- name: "IPv4 loopback",
- from: "127.0.0.1",
- fromHW: "00:00:00:00:00:00",
- to: "127.0.0.1",
- port: 8080,
- expectError: false,
- expectIPv6: false,
- },
- {
- name: "IPv6 loopback",
- from: "::1",
- fromHW: "00:00:00:00:00:00",
- to: "::1",
- port: 8080,
- expectError: false,
- expectIPv6: true,
- },
- {
- name: "Port 0",
- from: "192.168.1.1",
- fromHW: "aa:bb:cc:dd:ee:ff",
- to: "192.168.1.2",
- port: 0,
- expectError: false,
- expectIPv6: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- from := net.ParseIP(tt.from)
- fromHW, _ := net.ParseMAC(tt.fromHW)
- to := net.ParseIP(tt.to)
-
- err, data := NewUDPProbe(from, fromHW, to, tt.port)
-
- if tt.expectError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tt.expectError && err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
-
- if err == nil {
- if len(data) == 0 {
- t.Error("Expected data but got empty")
- }
-
- // Parse the packet to verify structure
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // Check Ethernet layer
- if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil {
- eth := ethLayer.(*layers.Ethernet)
- if !bytes.Equal(eth.SrcMAC, fromHW) {
- t.Errorf("Ethernet SrcMAC = %v, want %v", eth.SrcMAC, fromHW)
- }
- // Check broadcast destination MAC
- expectedDstMAC := net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
- if !bytes.Equal(eth.DstMAC, expectedDstMAC) {
- t.Errorf("Ethernet DstMAC = %v, want %v", eth.DstMAC, expectedDstMAC)
- }
- // Note: The function always sets EthernetTypeIPv4, even for IPv6
- // This is a bug in the implementation but we test actual behavior
- if eth.EthernetType != layers.EthernetTypeIPv4 {
- t.Errorf("EthernetType = %v, want %v", eth.EthernetType, layers.EthernetTypeIPv4)
- }
- } else {
- t.Error("Packet missing Ethernet layer")
- }
-
- // For IPv6, the packet won't parse correctly due to wrong EthernetType
- // We just verify the packet was created
- if tt.expectIPv6 {
- // Due to the bug, IPv6 packets won't parse correctly
- // Just check that we got data
- if len(data) == 0 {
- t.Error("Expected packet data for IPv6")
- }
- } else {
- // IPv4 should work correctly
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- if !ip.SrcIP.Equal(from) {
- t.Errorf("IPv4 SrcIP = %v, want %v", ip.SrcIP, from)
- }
- if !ip.DstIP.Equal(to) {
- t.Errorf("IPv4 DstIP = %v, want %v", ip.DstIP, to)
- }
- if ip.TTL != 64 {
- t.Errorf("IPv4 TTL = %d, want 64", ip.TTL)
- }
- if ip.Protocol != layers.IPProtocolUDP {
- t.Errorf("IPv4 Protocol = %v, want %v", ip.Protocol, layers.IPProtocolUDP)
- }
- } else {
- t.Error("Packet missing IPv4 layer")
- }
-
- // Check UDP layer for IPv4
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- udp := udpLayer.(*layers.UDP)
- if udp.SrcPort != 12345 {
- t.Errorf("UDP SrcPort = %d, want 12345", udp.SrcPort)
- }
- if udp.DstPort != layers.UDPPort(tt.port) {
- t.Errorf("UDP DstPort = %d, want %d", udp.DstPort, tt.port)
- }
- // Note: The payload is not properly parsed by gopacket
- // This is likely due to how the packet is serialized
- // We'll skip payload verification for now
- _ = udp.Payload
- } else {
- t.Error("Packet missing UDP layer")
- }
- }
- }
- })
- }
-}
-
-func TestNewUDPProbeWithNilValues(t *testing.T) {
- // Test with nil IPs - should return an error
- err, data := NewUDPProbe(nil, nil, nil, 53)
- if err == nil {
- t.Error("Expected error with nil values, but got none")
- }
- if len(data) != 0 {
- t.Error("Expected no data with nil values")
- }
-}
-
-func TestNewUDPProbePayload(t *testing.T) {
- from := net.ParseIP("192.168.1.1")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- to := net.ParseIP("192.168.1.2")
-
- err, data := NewUDPProbe(from, fromHW, to, 53)
- if err != nil {
- t.Fatalf("Failed to create UDP probe: %v", err)
- }
-
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- _ = udpLayer.(*layers.UDP) // UDP layer exists, payload check below
- } else {
- t.Error("UDP layer not found")
- }
-
- // Note: The payload is not properly parsed by gopacket
- // This is likely due to how the packet is serialized
- // We'll just verify the packet was created successfully
- t.Log("UDP packet created successfully")
-}
-
-func TestNewUDPProbeChecksumComputation(t *testing.T) {
- // Test that checksums are computed correctly for both IPv4 and IPv6
- testCases := []struct {
- name string
- from string
- to string
- isIPv6 bool
- }{
- {
- name: "IPv4 checksum",
- from: "192.168.1.1",
- to: "192.168.1.2",
- isIPv6: false,
- },
- {
- name: "IPv6 checksum",
- from: "2001:db8::1",
- to: "2001:db8::2",
- isIPv6: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- from := net.ParseIP(tc.from)
- to := net.ParseIP(tc.to)
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- err, data := NewUDPProbe(from, fromHW, to, 53)
- if err != nil {
- t.Fatalf("Failed to create UDP probe: %v", err)
- }
-
- // Parse the packet
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- // For IPv6, the packet won't parse correctly due to wrong EthernetType
- if tc.isIPv6 {
- // Just verify we got data
- if len(data) == 0 {
- t.Error("Expected packet data for IPv6")
- }
- } else {
- // Verify UDP checksum is non-zero (computed) for IPv4
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- udp := udpLayer.(*layers.UDP)
- if udp.Checksum == 0 {
- t.Error("UDP checksum was not computed")
- }
- } else {
- t.Error("UDP layer not found")
- }
-
- // For IPv4, also check IP checksum
- if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
- ip := ipLayer.(*layers.IPv4)
- if ip.Checksum == 0 {
- t.Error("IPv4 checksum was not computed")
- }
- }
- }
- })
- }
-}
-
-func TestNewUDPProbePortRange(t *testing.T) {
- // Test various port numbers including edge cases
- portTests := []int{
- 0, // Minimum
- 1, // Minimum valid
- 53, // DNS
- 123, // NTP
- 161, // SNMP
- 500, // IKE
- 1024, // First non-privileged
- 5353, // mDNS
- 8080, // Common alternative HTTP
- 32768, // Common ephemeral port range start
- 65535, // Maximum
- }
-
- from := net.ParseIP("192.168.1.1")
- to := net.ParseIP("192.168.1.2")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- for _, port := range portTests {
- err, data := NewUDPProbe(from, fromHW, to, port)
- if err != nil {
- t.Errorf("Failed with port %d: %v", port, err)
- continue
- }
-
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
- if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
- udp := udpLayer.(*layers.UDP)
- if udp.DstPort != layers.UDPPort(port) {
- t.Errorf("UDP DstPort = %d, want %d", udp.DstPort, port)
- }
- // Source port should always be 12345
- if udp.SrcPort != 12345 {
- t.Errorf("UDP SrcPort = %d, want 12345", udp.SrcPort)
- }
- }
- }
-}
-
-func TestNewUDPProbeBroadcastMAC(t *testing.T) {
- // Test that destination MAC is always broadcast
- from := net.ParseIP("192.168.1.1")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
- to := net.ParseIP("192.168.1.255") // Broadcast IP
-
- err, data := NewUDPProbe(from, fromHW, to, 53)
- if err != nil {
- t.Fatalf("Failed to create UDP probe: %v", err)
- }
-
- packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.Default)
-
- if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil {
- eth := ethLayer.(*layers.Ethernet)
- expectedMAC := net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
- if !bytes.Equal(eth.DstMAC, expectedMAC) {
- t.Errorf("Ethernet DstMAC = %v, want broadcast %v", eth.DstMAC, expectedMAC)
- }
- } else {
- t.Error("Ethernet layer not found")
- }
-}
-
-// Benchmarks
-func BenchmarkNewUDPProbeIPv4(b *testing.B) {
- from := net.ParseIP("192.168.1.1")
- to := net.ParseIP("192.168.1.2")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = NewUDPProbe(from, fromHW, to, 53)
- }
-}
-
-func BenchmarkNewUDPProbeIPv6(b *testing.B) {
- from := net.ParseIP("2001:db8::1")
- to := net.ParseIP("2001:db8::2")
- fromHW, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff")
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = NewUDPProbe(from, fromHW, to, 53)
- }
-}
diff --git a/routing/route_test.go b/routing/route_test.go
deleted file mode 100644
index ac99ad9a..00000000
--- a/routing/route_test.go
+++ /dev/null
@@ -1,353 +0,0 @@
-package routing
-
-import (
- "testing"
-)
-
-func TestRouteType(t *testing.T) {
- // Test the RouteType constants
- if IPv4 != RouteType("IPv4") {
- t.Errorf("IPv4 constant has wrong value: %s", IPv4)
- }
- if IPv6 != RouteType("IPv6") {
- t.Errorf("IPv6 constant has wrong value: %s", IPv6)
- }
-}
-
-func TestRouteStruct(t *testing.T) {
- tests := []struct {
- name string
- route Route
- }{
- {
- name: "IPv4 default route",
- route: Route{
- Type: IPv4,
- Default: true,
- Device: "eth0",
- Destination: "0.0.0.0",
- Gateway: "192.168.1.1",
- Flags: "UG",
- },
- },
- {
- name: "IPv4 network route",
- route: Route{
- Type: IPv4,
- Default: false,
- Device: "eth0",
- Destination: "192.168.1.0/24",
- Gateway: "",
- Flags: "U",
- },
- },
- {
- name: "IPv6 default route",
- route: Route{
- Type: IPv6,
- Default: true,
- Device: "eth0",
- Destination: "::/0",
- Gateway: "fe80::1",
- Flags: "UG",
- },
- },
- {
- name: "IPv6 link-local route",
- route: Route{
- Type: IPv6,
- Default: false,
- Device: "eth0",
- Destination: "fe80::/64",
- Gateway: "",
- Flags: "U",
- },
- },
- {
- name: "localhost route",
- route: Route{
- Type: IPv4,
- Default: false,
- Device: "lo",
- Destination: "127.0.0.0/8",
- Gateway: "",
- Flags: "U",
- },
- },
- {
- name: "VPN route",
- route: Route{
- Type: IPv4,
- Default: false,
- Device: "tun0",
- Destination: "10.8.0.0/24",
- Gateway: "",
- Flags: "U",
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Test that all fields are accessible
- _ = tt.route.Type
- _ = tt.route.Default
- _ = tt.route.Device
- _ = tt.route.Destination
- _ = tt.route.Gateway
- _ = tt.route.Flags
-
- // Verify the route has the expected type
- if tt.route.Type != IPv4 && tt.route.Type != IPv6 {
- t.Errorf("route has invalid type: %s", tt.route.Type)
- }
- })
- }
-}
-
-func TestRouteDefaultFlag(t *testing.T) {
- // Test routes with different default flag settings
- defaultRoute := Route{
- Type: IPv4,
- Default: true,
- Device: "eth0",
- Destination: "0.0.0.0",
- Gateway: "192.168.1.1",
- Flags: "UG",
- }
-
- normalRoute := Route{
- Type: IPv4,
- Default: false,
- Device: "eth0",
- Destination: "192.168.1.0/24",
- Gateway: "",
- Flags: "U",
- }
-
- if !defaultRoute.Default {
- t.Error("default route should have Default=true")
- }
-
- if normalRoute.Default {
- t.Error("normal route should have Default=false")
- }
-}
-
-func TestRouteTypeString(t *testing.T) {
- // Test that RouteType can be converted to string
- ipv4Str := string(IPv4)
- ipv6Str := string(IPv6)
-
- if ipv4Str != "IPv4" {
- t.Errorf("IPv4 string conversion failed: got %s", ipv4Str)
- }
-
- if ipv6Str != "IPv6" {
- t.Errorf("IPv6 string conversion failed: got %s", ipv6Str)
- }
-}
-
-func TestRouteTypeComparison(t *testing.T) {
- // Test RouteType comparisons
- var rt1 RouteType = IPv4
- var rt2 RouteType = IPv4
- var rt3 RouteType = IPv6
-
- if rt1 != rt2 {
- t.Error("identical RouteType values should be equal")
- }
-
- if rt1 == rt3 {
- t.Error("different RouteType values should not be equal")
- }
-}
-
-func TestRouteTypeCustomValues(t *testing.T) {
- // Test that custom RouteType values can be created
- customType := RouteType("Custom")
-
- if customType == IPv4 || customType == IPv6 {
- t.Error("custom RouteType should not equal predefined constants")
- }
-
- if string(customType) != "Custom" {
- t.Errorf("custom RouteType string conversion failed: got %s", customType)
- }
-}
-
-func TestRouteWithEmptyFields(t *testing.T) {
- // Test route with empty fields
- emptyRoute := Route{}
-
- if emptyRoute.Type != "" {
- t.Errorf("empty route Type should be empty string, got %s", emptyRoute.Type)
- }
-
- if emptyRoute.Default != false {
- t.Error("empty route Default should be false")
- }
-
- if emptyRoute.Device != "" {
- t.Errorf("empty route Device should be empty string, got %s", emptyRoute.Device)
- }
-
- if emptyRoute.Destination != "" {
- t.Errorf("empty route Destination should be empty string, got %s", emptyRoute.Destination)
- }
-
- if emptyRoute.Gateway != "" {
- t.Errorf("empty route Gateway should be empty string, got %s", emptyRoute.Gateway)
- }
-
- if emptyRoute.Flags != "" {
- t.Errorf("empty route Flags should be empty string, got %s", emptyRoute.Flags)
- }
-}
-
-func TestRouteFieldAssignment(t *testing.T) {
- // Test that route fields can be assigned individually
- r := Route{}
-
- r.Type = IPv6
- r.Default = true
- r.Device = "wlan0"
- r.Destination = "2001:db8::/32"
- r.Gateway = "fe80::1"
- r.Flags = "UGH"
-
- if r.Type != IPv6 {
- t.Errorf("Type assignment failed: got %s", r.Type)
- }
-
- if !r.Default {
- t.Error("Default assignment failed")
- }
-
- if r.Device != "wlan0" {
- t.Errorf("Device assignment failed: got %s", r.Device)
- }
-
- if r.Destination != "2001:db8::/32" {
- t.Errorf("Destination assignment failed: got %s", r.Destination)
- }
-
- if r.Gateway != "fe80::1" {
- t.Errorf("Gateway assignment failed: got %s", r.Gateway)
- }
-
- if r.Flags != "UGH" {
- t.Errorf("Flags assignment failed: got %s", r.Flags)
- }
-}
-
-func TestRouteArrayOperations(t *testing.T) {
- // Test operations on arrays of routes
- routes := []Route{
- {
- Type: IPv4,
- Default: true,
- Device: "eth0",
- Destination: "0.0.0.0",
- Gateway: "192.168.1.1",
- Flags: "UG",
- },
- {
- Type: IPv4,
- Default: false,
- Device: "eth0",
- Destination: "192.168.1.0/24",
- Gateway: "",
- Flags: "U",
- },
- {
- Type: IPv6,
- Default: false,
- Device: "eth0",
- Destination: "fe80::/64",
- Gateway: "",
- Flags: "U",
- },
- }
-
- // Test array length
- if len(routes) != 3 {
- t.Errorf("expected 3 routes, got %d", len(routes))
- }
-
- // Count IPv4 vs IPv6 routes
- ipv4Count := 0
- ipv6Count := 0
- defaultCount := 0
-
- for _, r := range routes {
- switch r.Type {
- case IPv4:
- ipv4Count++
- case IPv6:
- ipv6Count++
- }
-
- if r.Default {
- defaultCount++
- }
- }
-
- if ipv4Count != 2 {
- t.Errorf("expected 2 IPv4 routes, got %d", ipv4Count)
- }
-
- if ipv6Count != 1 {
- t.Errorf("expected 1 IPv6 route, got %d", ipv6Count)
- }
-
- if defaultCount != 1 {
- t.Errorf("expected 1 default route, got %d", defaultCount)
- }
-}
-
-func BenchmarkRouteCreation(b *testing.B) {
- for i := 0; i < b.N; i++ {
- _ = Route{
- Type: IPv4,
- Default: true,
- Device: "eth0",
- Destination: "0.0.0.0",
- Gateway: "192.168.1.1",
- Flags: "UG",
- }
- }
-}
-
-func BenchmarkRouteTypeComparison(b *testing.B) {
- rt1 := IPv4
- rt2 := IPv6
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = rt1 == rt2
- }
-}
-
-func BenchmarkRouteArrayIteration(b *testing.B) {
- routes := make([]Route, 100)
- for i := range routes {
- if i%2 == 0 {
- routes[i].Type = IPv4
- } else {
- routes[i].Type = IPv6
- }
- routes[i].Device = "eth0"
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- count := 0
- for _, r := range routes {
- if r.Type == IPv4 {
- count++
- }
- }
- _ = count
- }
-}
diff --git a/routing/tables.go b/routing/tables.go
index 1023ff3b..fcb9f043 100644
--- a/routing/tables.go
+++ b/routing/tables.go
@@ -21,12 +21,7 @@ func Update() ([]Route, error) {
func Gateway(ip RouteType, device string) (string, error) {
Update()
- return gatewayFromTable(ip, device)
-}
-// gatewayFromTable finds the gateway from the current table without updating it
-// This allows testing with controlled table data
-func gatewayFromTable(ip RouteType, device string) (string, error) {
lock.RLock()
defer lock.RUnlock()
diff --git a/routing/tables_test.go b/routing/tables_test.go
deleted file mode 100644
index 761f1356..00000000
--- a/routing/tables_test.go
+++ /dev/null
@@ -1,387 +0,0 @@
-package routing
-
-import (
- "fmt"
- "sync"
- "testing"
-)
-
-// Helper function to reset the table for testing
-func resetTable() {
- lock.Lock()
- defer lock.Unlock()
- table = make([]Route, 0)
-}
-
-// Helper function to add routes for testing
-func addTestRoutes() {
- lock.Lock()
- defer lock.Unlock()
- table = []Route{
- {
- Type: IPv4,
- Default: true,
- Device: "eth0",
- Destination: "0.0.0.0",
- Gateway: "192.168.1.1",
- Flags: "UG",
- },
- {
- Type: IPv4,
- Default: false,
- Device: "eth0",
- Destination: "192.168.1.0/24",
- Gateway: "",
- Flags: "U",
- },
- {
- Type: IPv6,
- Default: true,
- Device: "eth0",
- Destination: "::/0",
- Gateway: "fe80::1",
- Flags: "UG",
- },
- {
- Type: IPv6,
- Default: false,
- Device: "eth0",
- Destination: "fe80::/64",
- Gateway: "",
- Flags: "U",
- },
- {
- Type: IPv4,
- Default: false,
- Device: "lo",
- Destination: "127.0.0.0/8",
- Gateway: "",
- Flags: "U",
- },
- {
- Type: IPv4,
- Default: true,
- Device: "wlan0",
- Destination: "0.0.0.0",
- Gateway: "10.0.0.1",
- Flags: "UG",
- },
- }
-}
-
-func TestTable(t *testing.T) {
- // Reset table
- resetTable()
-
- // Test empty table
- routes := Table()
- if len(routes) != 0 {
- t.Errorf("Expected empty table, got %d routes", len(routes))
- }
-
- // Add test routes
- addTestRoutes()
-
- // Test table with routes
- routes = Table()
- if len(routes) != 6 {
- t.Errorf("Expected 6 routes, got %d", len(routes))
- }
-
- // Verify first route
- if routes[0].Type != IPv4 {
- t.Errorf("Expected first route to be IPv4, got %s", routes[0].Type)
- }
- if !routes[0].Default {
- t.Error("Expected first route to be default")
- }
- if routes[0].Gateway != "192.168.1.1" {
- t.Errorf("Expected gateway 192.168.1.1, got %s", routes[0].Gateway)
- }
-}
-
-func TestGateway(t *testing.T) {
- // Note: Gateway() calls Update() which loads real system routes
- // So we can't test specific values, just test the behavior
-
- // Test IPv4 gateway
- gateway, err := Gateway(IPv4, "")
- if err != nil {
- t.Errorf("Unexpected error getting IPv4 gateway: %v", err)
- }
- t.Logf("System IPv4 gateway: %s", gateway)
-
- // Test IPv6 gateway
- gateway, err = Gateway(IPv6, "")
- if err != nil {
- t.Errorf("Unexpected error getting IPv6 gateway: %v", err)
- }
- t.Logf("System IPv6 gateway: %s", gateway)
-
- // Test with specific device that likely doesn't exist
- gateway, err = Gateway(IPv4, "nonexistent999")
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- // Should return empty string for non-existent device
- if gateway != "" {
- t.Logf("Got gateway for non-existent device (might be Windows): %s", gateway)
- }
-}
-
-func TestGatewayBehavior(t *testing.T) {
- // Test that Gateway doesn't panic with various inputs
- testCases := []struct {
- name string
- ipType RouteType
- device string
- }{
- {"IPv4 empty device", IPv4, ""},
- {"IPv6 empty device", IPv6, ""},
- {"IPv4 with device", IPv4, "eth0"},
- {"IPv6 with device", IPv6, "eth0"},
- {"Custom type", RouteType("custom"), ""},
- {"Empty type", RouteType(""), ""},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- gateway, err := Gateway(tc.ipType, tc.device)
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- t.Logf("Gateway for %s: %s", tc.name, gateway)
- })
- }
-}
-
-func TestGatewayEmptyTable(t *testing.T) {
- // Test with empty table
- resetTable()
-
- gateway, err := gatewayFromTable(IPv4, "eth0")
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if gateway != "" {
- t.Errorf("Expected empty gateway, got %s", gateway)
- }
-}
-
-func TestGatewayNoDefaultRoute(t *testing.T) {
- // Test with routes but no default
- resetTable()
-
- lock.Lock()
- table = []Route{
- {
- Type: IPv4,
- Default: false,
- Device: "eth0",
- Destination: "192.168.1.0/24",
- Gateway: "",
- Flags: "U",
- },
- }
- lock.Unlock()
-
- gateway, err := gatewayFromTable(IPv4, "eth0")
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if gateway != "" {
- t.Errorf("Expected empty gateway, got %s", gateway)
- }
-}
-
-func TestGatewayWindowsCase(t *testing.T) {
- // Since Gateway() calls Update(), we can't control the table content
- // Just test that it doesn't panic and returns something
- gateway, err := Gateway(IPv4, "eth0")
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- t.Logf("Gateway result for eth0: %s", gateway)
-}
-
-func TestGatewayFromTableWithDefaults(t *testing.T) {
- // Test gatewayFromTable with controlled data containing defaults
- resetTable()
- addTestRoutes()
-
- gateway, err := gatewayFromTable(IPv4, "eth0")
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if gateway != "192.168.1.1" {
- t.Errorf("Expected gateway 192.168.1.1, got %s", gateway)
- }
-
- // Test with device-specific lookup
- gateway, err = gatewayFromTable(IPv4, "wlan0")
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if gateway != "10.0.0.1" {
- t.Errorf("Expected gateway 10.0.0.1, got %s", gateway)
- }
-}
-
-func TestTableConcurrency(t *testing.T) {
- // Test concurrent access to Table()
- resetTable()
- addTestRoutes()
-
- var wg sync.WaitGroup
- errors := make(chan error, 100)
-
- // Multiple readers
- for i := 0; i < 10; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for j := 0; j < 100; j++ {
- routes := Table()
- if len(routes) != 6 {
- select {
- case errors <- fmt.Errorf("Expected 6 routes, got %d", len(routes)):
- default:
- }
- }
- }
- }()
- }
-
- wg.Wait()
- close(errors)
-
- // Check for errors
- for err := range errors {
- if err != nil {
- t.Error(err)
- }
- }
-}
-
-func TestGatewayConcurrency(t *testing.T) {
- // Test concurrent access to Gateway()
- var wg sync.WaitGroup
- errors := make(chan error, 100)
-
- // Multiple readers calling Gateway concurrently
- for i := 0; i < 10; i++ {
- wg.Add(1)
- go func(id int) {
- defer wg.Done()
- for j := 0; j < 50; j++ {
- _, err := Gateway(IPv4, "")
- if err != nil {
- select {
- case errors <- fmt.Errorf("goroutine %d: error: %v", id, err):
- default:
- }
- }
- }
- }(i)
- }
-
- wg.Wait()
- close(errors)
-
- // Check for errors
- errorCount := 0
- for err := range errors {
- if err != nil {
- errorCount++
- if errorCount <= 5 { // Only log first 5 errors
- t.Error(err)
- }
- }
- }
- if errorCount > 5 {
- t.Errorf("... and %d more errors", errorCount-5)
- }
-}
-
-func TestUpdate(t *testing.T) {
- // Note: Update() calls platform-specific update() function
- // which we can't easily test without mocking
- // But we can test that it doesn't panic and returns something
- resetTable()
-
- routes, err := Update()
- // The error might be nil or non-nil depending on the platform
- // and whether we have permissions to read routing table
- if err == nil && routes != nil {
- t.Logf("Update returned %d routes", len(routes))
- } else if err != nil {
- t.Logf("Update returned error (expected on some platforms): %v", err)
- }
-}
-
-func TestGatewayMultipleDefaults(t *testing.T) {
- // Since Gateway() calls Update() and loads real routes,
- // we can't test specific scenarios with multiple defaults
- // Just ensure it handles the real system state without panicking
-
- // Call Gateway multiple times to ensure consistency
- gateway1, err1 := Gateway(IPv4, "")
- gateway2, err2 := Gateway(IPv4, "")
-
- if err1 != nil {
- t.Errorf("First call error: %v", err1)
- }
- if err2 != nil {
- t.Errorf("Second call error: %v", err2)
- }
-
- // Results should be consistent
- if gateway1 != gateway2 {
- t.Errorf("Inconsistent results: first=%s, second=%s", gateway1, gateway2)
- }
-
- t.Logf("Consistent gateway result: %s", gateway1)
-}
-
-// Benchmark tests
-func BenchmarkTable(b *testing.B) {
- resetTable()
- addTestRoutes()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = Table()
- }
-}
-
-func BenchmarkGateway(b *testing.B) {
- resetTable()
- addTestRoutes()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = Gateway(IPv4, "eth0")
- }
-}
-
-func BenchmarkTableConcurrent(b *testing.B) {
- resetTable()
- addTestRoutes()
-
- b.RunParallel(func(pb *testing.PB) {
- for pb.Next() {
- _ = Table()
- }
- })
-}
-
-func BenchmarkGatewayConcurrent(b *testing.B) {
- resetTable()
- addTestRoutes()
-
- b.RunParallel(func(pb *testing.PB) {
- for pb.Next() {
- _, _ = Gateway(IPv4, "eth0")
- }
- })
-}
diff --git a/session/module_param_test.go b/session/module_param_test.go
deleted file mode 100644
index 0938c827..00000000
--- a/session/module_param_test.go
+++ /dev/null
@@ -1,478 +0,0 @@
-package session
-
-import (
- "regexp"
- "strings"
- "testing"
-)
-
-func TestNewModuleParameter(t *testing.T) {
- tests := []struct {
- name string
- paramName string
- defValue string
- paramType ParamType
- validator string
- desc string
- }{
- {
- name: "string parameter with validator",
- paramName: "test.param",
- defValue: "default",
- paramType: STRING,
- validator: "^[a-z]+$",
- desc: "A test parameter",
- },
- {
- name: "int parameter without validator",
- paramName: "test.int",
- defValue: "42",
- paramType: INT,
- validator: "",
- desc: "An integer parameter",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- p := NewModuleParameter(tt.paramName, tt.defValue, tt.paramType, tt.validator, tt.desc)
-
- if p.Name != tt.paramName {
- t.Errorf("expected name %s, got %s", tt.paramName, p.Name)
- }
- if p.Value != tt.defValue {
- t.Errorf("expected value %s, got %s", tt.defValue, p.Value)
- }
- if p.Type != tt.paramType {
- t.Errorf("expected type %v, got %v", tt.paramType, p.Type)
- }
- if p.Description != tt.desc {
- t.Errorf("expected description %s, got %s", tt.desc, p.Description)
- }
-
- if tt.validator != "" && p.Validator == nil {
- t.Error("expected validator to be set")
- }
- if tt.validator == "" && p.Validator != nil {
- t.Error("expected validator to be nil")
- }
- })
- }
-}
-
-func TestNewStringParameter(t *testing.T) {
- p := NewStringParameter("test.string", "hello", "^[a-z]+$", "A string param")
-
- if p.Type != STRING {
- t.Errorf("expected type STRING, got %v", p.Type)
- }
- if p.Validator == nil {
- t.Error("expected validator to be set")
- }
-}
-
-func TestNewBoolParameter(t *testing.T) {
- p := NewBoolParameter("test.bool", "true", "A boolean param")
-
- if p.Type != BOOL {
- t.Errorf("expected type BOOL, got %v", p.Type)
- }
- if p.Validator == nil || p.Validator.String() != "^(true|false)$" {
- t.Error("expected boolean validator to be set")
- }
-}
-
-func TestNewIntParameter(t *testing.T) {
- p := NewIntParameter("test.int", "123", "An integer param")
-
- if p.Type != INT {
- t.Errorf("expected type INT, got %v", p.Type)
- }
- if p.Validator == nil {
- t.Error("expected integer validator to be set")
- }
-}
-
-func TestNewDecimalParameter(t *testing.T) {
- p := NewDecimalParameter("test.decimal", "3.14", "A decimal param")
-
- if p.Type != FLOAT {
- t.Errorf("expected type FLOAT, got %v", p.Type)
- }
- if p.Validator == nil {
- t.Error("expected decimal validator to be set")
- }
-}
-
-func TestModuleParamValidate(t *testing.T) {
- tests := []struct {
- name string
- param *ModuleParam
- value string
- wantError bool
- expected interface{}
- }{
- // String tests
- {
- name: "valid string without validator",
- param: &ModuleParam{
- Name: "test",
- Type: STRING,
- },
- value: "any string",
- wantError: false,
- expected: "any string",
- },
- {
- name: "valid string with validator",
- param: &ModuleParam{
- Name: "test",
- Type: STRING,
- Validator: regexp.MustCompile("^[a-z]+$"),
- },
- value: "hello",
- wantError: false,
- expected: "hello",
- },
- {
- name: "invalid string with validator",
- param: &ModuleParam{
- Name: "test",
- Type: STRING,
- Validator: regexp.MustCompile("^[a-z]+$"),
- },
- value: "Hello123",
- wantError: true,
- },
- // Bool tests
- {
- name: "valid bool true",
- param: &ModuleParam{
- Name: "test",
- Type: BOOL,
- Validator: regexp.MustCompile("^(true|false)$"),
- },
- value: "true",
- wantError: false,
- expected: true,
- },
- {
- name: "valid bool false",
- param: &ModuleParam{
- Name: "test",
- Type: BOOL,
- Validator: regexp.MustCompile("^(true|false)$"),
- },
- value: "false",
- wantError: false,
- expected: false,
- },
- {
- name: "valid bool uppercase",
- param: &ModuleParam{
- Name: "test",
- Type: BOOL,
- },
- value: "TRUE",
- wantError: false,
- expected: true,
- },
- {
- name: "invalid bool",
- param: &ModuleParam{
- Name: "test",
- Type: BOOL,
- },
- value: "yes",
- wantError: true,
- },
- // Int tests
- {
- name: "valid positive int",
- param: &ModuleParam{
- Name: "test",
- Type: INT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`),
- },
- value: "123",
- wantError: false,
- expected: 123,
- },
- {
- name: "valid negative int",
- param: &ModuleParam{
- Name: "test",
- Type: INT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`),
- },
- value: "-456",
- wantError: false,
- expected: -456,
- },
- {
- name: "valid int with plus",
- param: &ModuleParam{
- Name: "test",
- Type: INT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`),
- },
- value: "+789",
- wantError: false,
- expected: 789,
- },
- {
- name: "invalid int",
- param: &ModuleParam{
- Name: "test",
- Type: INT,
- },
- value: "12.34",
- wantError: true,
- },
- // Float tests
- {
- name: "valid float",
- param: &ModuleParam{
- Name: "test",
- Type: FLOAT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+(\.\d+)?$`),
- },
- value: "3.14",
- wantError: false,
- expected: 3.14,
- },
- {
- name: "valid float without decimal",
- param: &ModuleParam{
- Name: "test",
- Type: FLOAT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+(\.\d+)?$`),
- },
- value: "42",
- wantError: false,
- expected: 42.0,
- },
- {
- name: "valid negative float",
- param: &ModuleParam{
- Name: "test",
- Type: FLOAT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+(\.\d+)?$`),
- },
- value: "-2.718",
- wantError: false,
- expected: -2.718,
- },
- {
- name: "invalid float",
- param: &ModuleParam{
- Name: "test",
- Type: FLOAT,
- },
- value: "3.14.15",
- wantError: true,
- },
- // Invalid type test
- {
- name: "invalid type",
- param: &ModuleParam{
- Name: "test",
- Type: ParamType(999),
- },
- value: "anything",
- wantError: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err, result := tt.param.validate(tt.value)
-
- if tt.wantError {
- if err == nil {
- t.Error("expected error but got none")
- }
- } else {
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if result != tt.expected {
- t.Errorf("expected %v (%T), got %v (%T)", tt.expected, tt.expected, result, result)
- }
- }
- })
- }
-}
-
-func TestModuleParamHelp(t *testing.T) {
- p := &ModuleParam{
- Name: "test.param",
- Description: "A test parameter",
- Value: "default",
- }
-
- help := p.Help(15)
-
- // Check that help contains the name
- if !strings.Contains(help, "test.param") {
- t.Error("help should contain parameter name")
- }
-
- // Check that help contains the description
- if !strings.Contains(help, "A test parameter") {
- t.Error("help should contain parameter description")
- }
-
- // Check that help contains the default value
- if !strings.Contains(help, "default=default") {
- t.Error("help should contain default value")
- }
-}
-
-func TestParseSpecialValues(t *testing.T) {
- // Test the special parameter constants
- tests := []struct {
- name string
- value string
- isSpecial bool
- }{
- {
- name: "interface name",
- value: ParamIfaceName,
- isSpecial: true,
- },
- {
- name: "interface address",
- value: ParamIfaceAddress,
- isSpecial: true,
- },
- {
- name: "interface address6",
- value: ParamIfaceAddress6,
- isSpecial: true,
- },
- {
- name: "interface mac",
- value: ParamIfaceMac,
- isSpecial: true,
- },
- {
- name: "subnet",
- value: ParamSubnet,
- isSpecial: true,
- },
- {
- name: "random mac",
- value: ParamRandomMAC,
- isSpecial: true,
- },
- {
- name: "normal value",
- value: "192.168.1.1",
- isSpecial: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if tt.isSpecial {
- // Special values should be in angle brackets
- if !strings.HasPrefix(tt.value, "<") || !strings.HasSuffix(tt.value, ">") {
- t.Errorf("special value %s should be in angle brackets", tt.value)
- }
- }
- })
- }
-}
-
-func TestParamIfaceNameParser(t *testing.T) {
- tests := []struct {
- name string
- input string
- matches bool
- ifaceName string
- }{
- {
- name: "valid interface name",
- input: "",
- matches: true,
- ifaceName: "eth0",
- },
- {
- name: "valid interface with numbers",
- input: "",
- matches: true,
- ifaceName: "wlan1",
- },
- {
- name: "long interface name",
- input: "",
- matches: true,
- ifaceName: "enp0s31f6",
- },
- {
- name: "no angle brackets",
- input: "eth0",
- matches: false,
- },
- {
- name: "invalid characters",
- input: "",
- matches: false,
- },
- {
- name: "too short",
- input: "",
- matches: false,
- },
- {
- name: "too long",
- input: "",
- matches: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- matches := ParamIfaceNameParser.FindStringSubmatch(tt.input)
-
- if tt.matches {
- if len(matches) != 2 {
- t.Errorf("expected to match interface name pattern, got %v", matches)
- } else if matches[1] != tt.ifaceName {
- t.Errorf("expected interface name %s, got %s", tt.ifaceName, matches[1])
- }
- } else {
- if len(matches) > 0 {
- t.Errorf("expected no match, but got %v", matches)
- }
- }
- })
- }
-}
-
-func BenchmarkModuleParamValidate(b *testing.B) {
- p := &ModuleParam{
- Name: "test",
- Type: STRING,
- Validator: regexp.MustCompile("^[a-z]+$"),
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- p.validate("hello")
- }
-}
-
-func BenchmarkModuleParamValidateInt(b *testing.B) {
- p := &ModuleParam{
- Name: "test",
- Type: INT,
- Validator: regexp.MustCompile(`^[\-\+]?[\d]+$`),
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- p.validate("12345")
- }
-}
diff --git a/session/session.go b/session/session.go
index df597b60..983ef1a2 100644
--- a/session/session.go
+++ b/session/session.go
@@ -194,9 +194,7 @@ func (s *Session) Close() {
}
}
- if s.Firewall != nil {
- s.Firewall.Restore()
- }
+ s.Firewall.Restore()
if *s.Options.EnvFile != "" {
envFile, _ := fs.Expand(*s.Options.EnvFile)
diff --git a/session/session_core_handlers.go b/session/session_core_handlers.go
index 9d71e7a0..2b47f641 100644
--- a/session/session_core_handlers.go
+++ b/session/session_core_handlers.go
@@ -13,14 +13,11 @@ import (
"time"
"github.com/bettercap/bettercap/v2/core"
- "github.com/bettercap/bettercap/v2/log"
"github.com/bettercap/bettercap/v2/network"
"github.com/bettercap/readline"
"github.com/evilsocket/islazy/str"
"github.com/evilsocket/islazy/tui"
-
- "github.com/robertkrimen/otto"
)
func (s *Session) generalHelp() {
@@ -158,14 +155,6 @@ func (s *Session) activeHandler(args []string, sess *Session) error {
}
func (s *Session) exitHandler(args []string, sess *Session) error {
- if s.script != nil {
- if s.script.Plugin.HasFunc("onExit") {
- if _, err := s.script.Plugin.Call("onExit"); err != nil {
- log.Error("Error while executing onExit callback: %s", "\nTraceback:\n "+err.(*otto.Error).String())
- }
- }
- }
-
// notify any listener that the session is about to end
s.Events.Add("session.stopped", nil)
diff --git a/tls/tls_test.go b/tls/tls_test.go
deleted file mode 100644
index 556b0b1c..00000000
--- a/tls/tls_test.go
+++ /dev/null
@@ -1,136 +0,0 @@
-package tls
-
-import (
- "crypto/x509"
- "encoding/pem"
- "io/ioutil"
- "os"
- "path/filepath"
- "testing"
-
- "github.com/bettercap/bettercap/v2/session"
-)
-
-func TestCertConfigToModule(t *testing.T) {
- prefix := "test"
- defaults := DefaultLegitConfig
-
- dummyEnv, err := session.NewEnvironment("")
- if err != nil {
- t.Fatal(err)
- }
- dummySession := &session.Session{Env: dummyEnv}
- m := session.NewSessionModule(prefix, dummySession)
-
- CertConfigToModule(prefix, &m, defaults)
-
- // Check if parameters were added
- if len(m.Parameters()) != 6 {
- t.Errorf("expected 6 parameters, got %d", len(m.Parameters()))
- }
-}
-
-func TestCertConfigFromModule(t *testing.T) {
- dummyEnv, err := session.NewEnvironment("")
- if err != nil {
- t.Fatal(err)
- }
- dummySession := &session.Session{Env: dummyEnv}
- m := session.NewSessionModule("test", dummySession)
- prefix := "test"
-
- // Set some parameters
- m.AddParam(session.NewIntParameter(prefix+".certificate.bits", "2048", "dummy desc"))
- m.AddParam(session.NewStringParameter(prefix+".certificate.country", "TestCountry", ".*", "dummy desc"))
- m.AddParam(session.NewStringParameter(prefix+".certificate.locality", "TestLocality", ".*", "dummy desc"))
- m.AddParam(session.NewStringParameter(prefix+".certificate.organization", "TestOrg", ".*", "dummy desc"))
- m.AddParam(session.NewStringParameter(prefix+".certificate.organizationalunit", "TestUnit", ".*", "dummy desc"))
- m.AddParam(session.NewStringParameter(prefix+".certificate.commonname", "TestCN", ".*", "dummy desc"))
-
- cfg, err := CertConfigFromModule(prefix, m)
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- if cfg.Bits != 2048 || cfg.Country != "TestCountry" || cfg.Locality != "TestLocality" ||
- cfg.Organization != "TestOrg" || cfg.OrganizationalUnit != "TestUnit" || cfg.CommonName != "TestCN" {
- t.Error("config not parsed correctly")
- }
-}
-
-func TestCreateCertificate(t *testing.T) {
- cfg := DefaultLegitConfig
- cfg.Bits = 1024 // smaller for test
-
- priv, certBytes, err := CreateCertificate(cfg, true)
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if priv == nil {
- t.Error("private key is nil")
- }
- if len(certBytes) == 0 {
- t.Error("cert bytes empty")
- }
-
- // Parse to verify
- cert, err := x509.ParseCertificate(certBytes)
- if err != nil {
- t.Errorf("could not parse cert: %v", err)
- }
- if cert.Subject.CommonName != cfg.CommonName {
- t.Errorf("common name mismatch: %s != %s", cert.Subject.CommonName, cfg.CommonName)
- }
- if !cert.IsCA {
- t.Error("not CA")
- }
-}
-
-func TestGenerate(t *testing.T) {
- tempDir, err := ioutil.TempDir("", "tlstest")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- certPath := filepath.Join(tempDir, "test.cert")
- keyPath := filepath.Join(tempDir, "test.key")
-
- cfg := DefaultLegitConfig
- cfg.Bits = 1024
-
- err = Generate(cfg, certPath, keyPath, false)
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- // Check files exist
- if _, err := os.Stat(certPath); os.IsNotExist(err) {
- t.Error("cert file not created")
- }
- if _, err := os.Stat(keyPath); os.IsNotExist(err) {
- t.Error("key file not created")
- }
-
- // Load and verify
- certBytes, _ := ioutil.ReadFile(certPath)
- keyBytes, _ := ioutil.ReadFile(keyPath)
-
- certBlock, _ := pem.Decode(certBytes)
- if certBlock == nil || certBlock.Type != "CERTIFICATE" {
- t.Error("invalid cert PEM")
- }
-
- keyBlock, _ := pem.Decode(keyBytes)
- if keyBlock == nil || keyBlock.Type != "RSA PRIVATE KEY" {
- t.Error("invalid key PEM")
- }
-
- priv, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
- if err != nil {
- t.Errorf("invalid private key: %v", err)
- }
- if priv.N.BitLen() != 1024 {
- t.Errorf("key bits mismatch: %d", priv.N.BitLen())
- }
-}