mirror of
https://github.com/bettercap/bettercap
synced 2025-07-15 01:23:42 -07:00
420 lines
7.3 KiB
Go
420 lines
7.3 KiB
Go
package test
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/chzyer/logex"
|
|
)
|
|
|
|
var (
|
|
mainRoot = ""
|
|
RootPath = os.TempDir()
|
|
ErrNotExcept = logex.Define("result not expect")
|
|
ErrNotEqual = logex.Define("result not equals")
|
|
ErrRequireNotEqual = logex.Define("result require not equals")
|
|
StrNotSuchFile = "no such file or directory"
|
|
)
|
|
|
|
func init() {
|
|
println("tmpdir:", RootPath)
|
|
}
|
|
|
|
type testException struct {
|
|
depth int
|
|
info string
|
|
}
|
|
|
|
func getMainRoot() string {
|
|
if mainRoot != "" {
|
|
return mainRoot
|
|
}
|
|
|
|
cwd, err := os.Getwd()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
for len(cwd) > 1 {
|
|
_, err := os.Stat(filepath.Join(cwd, ".git"))
|
|
if err == nil {
|
|
mainRoot = cwd + string([]rune{filepath.Separator})
|
|
break
|
|
}
|
|
cwd = filepath.Dir(cwd)
|
|
}
|
|
return mainRoot
|
|
}
|
|
|
|
func Skip() {
|
|
panic(nil)
|
|
}
|
|
|
|
type Failer interface {
|
|
FailNow()
|
|
}
|
|
|
|
func New(t Failer) {
|
|
err := recover()
|
|
if err == nil {
|
|
return
|
|
}
|
|
te, ok := err.(*testException)
|
|
if !ok {
|
|
panic(err)
|
|
}
|
|
|
|
_, file, line, _ := runtime.Caller(5 + te.depth)
|
|
if strings.HasPrefix(file, getMainRoot()) {
|
|
file = file[len(getMainRoot()):]
|
|
}
|
|
println(fmt.Sprintf("%s:%d: %s", file, line, te.info))
|
|
t.FailNow()
|
|
}
|
|
|
|
func getErr(def error, e []error) error {
|
|
if len(e) == 0 {
|
|
return def
|
|
}
|
|
return e[0]
|
|
}
|
|
|
|
func ReadAt(r io.ReaderAt, b []byte, at int64) {
|
|
n, err := r.ReadAt(b, at)
|
|
if err != nil {
|
|
Panic(0, fmt.Errorf("ReadAt error: %v", err))
|
|
}
|
|
if n != len(b) {
|
|
Panic(0, fmt.Errorf("ReadAt short read: %v, want: %v", n, len(b)))
|
|
}
|
|
}
|
|
|
|
func ReadAndCheck(r io.Reader, b []byte) {
|
|
buf := make([]byte, len(b))
|
|
Read(r, buf)
|
|
equalBytes(1, buf, b)
|
|
}
|
|
|
|
func Read(r io.Reader, b []byte) {
|
|
n, err := r.Read(b)
|
|
if err != nil && !logex.Equal(err, io.EOF) {
|
|
Panic(0, fmt.Errorf("Read error: %v", err))
|
|
}
|
|
if n != len(b) {
|
|
Panic(0, fmt.Errorf("Read: %v, want: %v", n, len(b)))
|
|
}
|
|
}
|
|
|
|
func ReadStringAt(r io.ReaderAt, off int64, s string) {
|
|
buf := make([]byte, len(s))
|
|
n, err := r.ReadAt(buf, off)
|
|
buf = buf[:n]
|
|
if err != nil {
|
|
Panic(0, fmt.Errorf("ReadStringAt: %v", err))
|
|
}
|
|
if string(buf) != s {
|
|
Panic(0, fmt.Errorf(
|
|
"ReadStringAt not match: %v, got: %v",
|
|
strconv.Quote(s),
|
|
strconv.Quote(string(buf)),
|
|
))
|
|
}
|
|
}
|
|
|
|
func ReadString(r io.Reader, s string) {
|
|
buf := make([]byte, len(s))
|
|
n, err := r.Read(buf)
|
|
if err != nil && !logex.Equal(err, io.EOF) {
|
|
Panic(0, fmt.Errorf("ReadString: %v, got: %v", strconv.Quote(s), err))
|
|
}
|
|
if n != len(buf) {
|
|
Panic(0, fmt.Errorf("ReadString: %v, got: %v", strconv.Quote(s), n))
|
|
}
|
|
if string(buf) != s {
|
|
Panic(0, fmt.Errorf(
|
|
"ReadString not match: %v, got: %v",
|
|
strconv.Quote(s),
|
|
strconv.Quote(string(buf)),
|
|
))
|
|
}
|
|
}
|
|
|
|
func WriteAt(w io.WriterAt, b []byte, at int64) {
|
|
n, err := w.WriteAt(b, at)
|
|
if err != nil {
|
|
Panic(0, err)
|
|
}
|
|
if n != len(b) {
|
|
Panic(0, "short write")
|
|
}
|
|
}
|
|
|
|
func Write(w io.Writer, b []byte) {
|
|
n, err := w.Write(b)
|
|
if err != nil {
|
|
Panic(0, err)
|
|
}
|
|
if n != len(b) {
|
|
Panic(0, "short write")
|
|
}
|
|
}
|
|
|
|
func WriteString(w io.Writer, s string) {
|
|
n, err := w.Write([]byte(s))
|
|
if err != nil {
|
|
Panic(0, err)
|
|
}
|
|
if n != len(s) {
|
|
Panic(0, "short write")
|
|
}
|
|
}
|
|
|
|
func Equals(o ...interface{}) {
|
|
if len(o)%2 != 0 {
|
|
Panic(0, "invalid Equals arguments")
|
|
}
|
|
for i := 0; i < len(o); i += 2 {
|
|
equal(1, o[i], o[i+1], nil)
|
|
}
|
|
}
|
|
|
|
func NotEqual(a, b interface{}, e ...error) {
|
|
notEqual(1, a, b, e)
|
|
}
|
|
|
|
func toInt(a interface{}) (int64, bool) {
|
|
switch n := a.(type) {
|
|
case int:
|
|
return int64(n), true
|
|
case int8:
|
|
return int64(n), true
|
|
case int16:
|
|
return int64(n), true
|
|
case int32:
|
|
return int64(n), true
|
|
case int64:
|
|
return int64(n), true
|
|
case uintptr:
|
|
return int64(n), true
|
|
default:
|
|
return -1, false
|
|
}
|
|
}
|
|
|
|
func MarkLine() {
|
|
r := strings.Repeat("-", 20)
|
|
println(r)
|
|
}
|
|
|
|
var globalMarkInfo string
|
|
|
|
func Mark(obj ...interface{}) {
|
|
globalMarkInfo = fmt.Sprint(obj...)
|
|
}
|
|
|
|
func EqualBytes(got, want []byte) {
|
|
equalBytes(0, got, want)
|
|
}
|
|
|
|
func equalBytes(n int, got, want []byte) {
|
|
a := got
|
|
b := want
|
|
size := 16
|
|
if len(a) != len(b) {
|
|
Panic(n, fmt.Sprintf("equal bytes, %v != %v", len(a), len(b)))
|
|
}
|
|
if bytes.Equal(a, b) {
|
|
return
|
|
}
|
|
|
|
for off := 0; off < len(a); off += size {
|
|
end := off + size
|
|
if end > len(a) {
|
|
end = len(a)
|
|
}
|
|
if !bytes.Equal(a[off:end], b[off:end]) {
|
|
Panic(n, fmt.Sprintf(
|
|
"equal [%v]byte in [%v, %v]:\n\tgot: %v\n\twant: %v",
|
|
len(a),
|
|
off, off+size,
|
|
a[off:end], b[off:end],
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
func Equal(a, b interface{}, e ...error) {
|
|
if ai, ok := toInt(a); ok {
|
|
if bi, ok := toInt(b); ok {
|
|
equal(1, ai, bi, e)
|
|
return
|
|
}
|
|
}
|
|
equal(1, a, b, e)
|
|
}
|
|
|
|
func CheckError(e error, s string) {
|
|
if e == nil {
|
|
Panic(0, ErrNotExcept)
|
|
}
|
|
if !strings.Contains(e.Error(), s) {
|
|
Panic(0, fmt.Errorf(
|
|
"want: %s, got %s",
|
|
strconv.Quote(s),
|
|
strconv.Quote(e.Error()),
|
|
))
|
|
}
|
|
}
|
|
|
|
func formatMax(o interface{}, max int) string {
|
|
aStr := fmt.Sprint(o)
|
|
if len(aStr) > max {
|
|
aStr = aStr[:max] + " ..."
|
|
}
|
|
return aStr
|
|
}
|
|
|
|
func notEqual(d int, a, b interface{}, e []error) {
|
|
_, oka := a.(error)
|
|
_, okb := b.(error)
|
|
if oka && okb {
|
|
if logex.Equal(a.(error), b.(error)) {
|
|
Panic(d, fmt.Sprintf("%v: %v",
|
|
getErr(ErrRequireNotEqual, e),
|
|
a,
|
|
))
|
|
}
|
|
return
|
|
}
|
|
if reflect.DeepEqual(a, b) {
|
|
Panic(d, fmt.Sprintf("%v: (%v, %v)",
|
|
getErr(ErrRequireNotEqual, e),
|
|
formatMax(a, 100),
|
|
formatMax(b, 100),
|
|
))
|
|
}
|
|
}
|
|
|
|
func equal(d int, a, b interface{}, e []error) {
|
|
_, oka := a.(error)
|
|
_, okb := b.(error)
|
|
if oka && okb {
|
|
if !logex.Equal(a.(error), b.(error)) {
|
|
Panic(d, fmt.Sprintf("%v: (%v, %v)",
|
|
getErr(ErrNotEqual, e),
|
|
formatMax(a, 100), formatMax(b, 100),
|
|
))
|
|
}
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(a, b) {
|
|
Panic(d, fmt.Sprintf("%v: (%+v, %+v)", getErr(ErrNotEqual, e), a, b))
|
|
}
|
|
}
|
|
|
|
func Should(b bool, e ...error) {
|
|
if !b {
|
|
Panic(0, getErr(ErrNotExcept, e))
|
|
}
|
|
}
|
|
|
|
func NotNil(obj interface{}) {
|
|
if obj == nil {
|
|
Panic(0, "should not nil")
|
|
}
|
|
}
|
|
|
|
func False(obj bool) {
|
|
if obj {
|
|
Panic(0, "should false")
|
|
}
|
|
}
|
|
|
|
func True(obj bool) {
|
|
if !obj {
|
|
Panic(0, "should true")
|
|
}
|
|
}
|
|
|
|
func Nil(obj interface{}) {
|
|
if obj != nil {
|
|
// double check, incase different type with nil value
|
|
if !reflect.ValueOf(obj).IsNil() {
|
|
str := fmt.Sprint(obj)
|
|
if err, ok := obj.(error); ok {
|
|
str = logex.DecodeError(err)
|
|
}
|
|
Panic(0, fmt.Sprintf("should nil: %v", str))
|
|
}
|
|
}
|
|
}
|
|
|
|
func Panic(depth int, obj interface{}) {
|
|
t := &testException{
|
|
depth: depth,
|
|
}
|
|
if err, ok := obj.(error); ok {
|
|
t.info = logex.DecodeError(err)
|
|
} else {
|
|
t.info = fmt.Sprint(obj)
|
|
}
|
|
if globalMarkInfo != "" {
|
|
t.info = "[info:" + globalMarkInfo + "] " + t.info
|
|
}
|
|
panic(t)
|
|
}
|
|
|
|
func CleanTmp() {
|
|
os.RemoveAll(root(2))
|
|
}
|
|
|
|
func TmpFile() (*os.File, error) {
|
|
dir := root(2)
|
|
if err := os.MkdirAll(dir, 0744); err != nil {
|
|
return nil, err
|
|
}
|
|
return ioutil.TempFile(dir, "")
|
|
}
|
|
|
|
func Root() string {
|
|
p := root(2)
|
|
os.RemoveAll(root(2))
|
|
return p
|
|
}
|
|
|
|
func root(n int) string {
|
|
pc, _, _, _ := runtime.Caller(n)
|
|
name := runtime.FuncForPC(pc).Name()
|
|
if idx := strings.LastIndex(name, "."); idx > 0 {
|
|
name = name[:idx] + "/" + name[idx+1:]
|
|
}
|
|
|
|
root := os.Getenv("TEST_ROOT")
|
|
if root == "" {
|
|
root = RootPath
|
|
}
|
|
return filepath.Join(root, name)
|
|
}
|
|
|
|
func RandBytes(n int) []byte {
|
|
buf := make([]byte, n)
|
|
rand.Read(buf)
|
|
return buf
|
|
}
|
|
|
|
func SeqBytes(n int) []byte {
|
|
buf := make([]byte, n)
|
|
for idx := range buf {
|
|
buf[idx] = byte(idx)
|
|
}
|
|
return buf
|
|
}
|