mirror of
https://github.com/bettercap/bettercap
synced 2025-08-21 05:53:20 -07:00
new: implemented api.rest.record and api.rest.replay
This commit is contained in:
parent
4713d25ea7
commit
0a31ac8167
76 changed files with 7610 additions and 48 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bettercap/bettercap/session"
|
||||
|
@ -26,6 +27,13 @@ type RestAPI struct {
|
|||
useWebsocket bool
|
||||
upgrader websocket.Upgrader
|
||||
quit chan bool
|
||||
|
||||
recording bool
|
||||
recTime int
|
||||
replaying bool
|
||||
recordFileName string
|
||||
recordWait *sync.WaitGroup
|
||||
record *Record
|
||||
}
|
||||
|
||||
func NewRestAPI(s *session.Session) *RestAPI {
|
||||
|
@ -39,8 +47,21 @@ func NewRestAPI(s *session.Session) *RestAPI {
|
|||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
recording: false,
|
||||
recTime: 0,
|
||||
replaying: false,
|
||||
recordFileName: "",
|
||||
recordWait: &sync.WaitGroup{},
|
||||
record: nil,
|
||||
}
|
||||
|
||||
mod.State.Store("recording", &mod.recording)
|
||||
mod.State.Store("replaying", &mod.replaying)
|
||||
mod.State.Store("rec_time", &mod.recTime)
|
||||
mod.State.Store("rec_filename", &mod.recordFileName)
|
||||
mod.State.Store("rec_frames", 0)
|
||||
mod.State.Store("rec_cur_frame", 0)
|
||||
|
||||
mod.AddParam(session.NewStringParameter("api.rest.address",
|
||||
"127.0.0.1",
|
||||
session.IPv4Validator,
|
||||
|
@ -93,6 +114,30 @@ func NewRestAPI(s *session.Session) *RestAPI {
|
|||
return mod.Stop()
|
||||
}))
|
||||
|
||||
mod.AddHandler(session.NewModuleHandler("api.rest.record off", "",
|
||||
"Stop recording the session.",
|
||||
func(args []string) error {
|
||||
return mod.stopRecording()
|
||||
}))
|
||||
|
||||
mod.AddHandler(session.NewModuleHandler("api.rest.record FILENAME", `api\.rest\.record (.+)`,
|
||||
"Start polling the rest API every second recording each sample as a session file that can be replayed.",
|
||||
func(args []string) error {
|
||||
return mod.startRecording(args[0])
|
||||
}))
|
||||
|
||||
mod.AddHandler(session.NewModuleHandler("api.rest.replay off", "",
|
||||
"Stop replaying the recorded session.",
|
||||
func(args []string) error {
|
||||
return mod.stopReplay()
|
||||
}))
|
||||
|
||||
mod.AddHandler(session.NewModuleHandler("api.rest.replay FILENAME", `api\.rest\.replay (.+)`,
|
||||
"Start the rest API module in replay mode using FILENAME as the recorded session file.",
|
||||
func(args []string) error {
|
||||
return mod.startReplay(args[0])
|
||||
}))
|
||||
|
||||
return mod
|
||||
}
|
||||
|
||||
|
@ -205,7 +250,9 @@ func (mod *RestAPI) Configure() error {
|
|||
}
|
||||
|
||||
func (mod *RestAPI) Start() error {
|
||||
if err := mod.Configure(); err != nil {
|
||||
if mod.replaying {
|
||||
return fmt.Errorf("the api is currently in replay mode, run api.rest.replay off before starting it")
|
||||
} else if err := mod.Configure(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -229,6 +276,12 @@ func (mod *RestAPI) Start() error {
|
|||
}
|
||||
|
||||
func (mod *RestAPI) Stop() error {
|
||||
if mod.recording {
|
||||
mod.stopRecording()
|
||||
} else if mod.replaying {
|
||||
mod.stopReplay()
|
||||
}
|
||||
|
||||
return mod.SetRunning(false, func() {
|
||||
go func() {
|
||||
mod.quit <- true
|
||||
|
|
|
@ -36,7 +36,7 @@ func (mod *RestAPI) setAuthFailed(w http.ResponseWriter, r *http.Request) {
|
|||
func (mod *RestAPI) toJSON(w http.ResponseWriter, o interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(o); err != nil {
|
||||
mod.Error("error while encoding object to JSON: %v", err)
|
||||
fmt.Printf("error while encoding object to JSON: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,8 +64,68 @@ func (mod *RestAPI) checkAuth(r *http.Request) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (mod *RestAPI) patchFrame(buf []byte) (frame map[string]interface{}, err error) {
|
||||
// this is ugly but necessary: since we're replaying, the
|
||||
// api.rest state object is filled with *old* values (the
|
||||
// recorded ones), but the UI needs updated values at least
|
||||
// of that in order to understand that a replay is going on
|
||||
// and where we are at it. So we need to parse the record
|
||||
// back into a session object and update only the api.rest.state
|
||||
frame = make(map[string]interface{})
|
||||
|
||||
if err = json.Unmarshal(buf, &frame); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, i := range frame["modules"].([]interface{}) {
|
||||
m := i.(map[string]interface{})
|
||||
if m["name"] == "api.rest" {
|
||||
state := m["state"].(map[string]interface{})
|
||||
mod.State.Range(func(key interface{}, value interface{}) bool {
|
||||
state[key.(string)] = value
|
||||
return true
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showSession(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I)
|
||||
if mod.replaying {
|
||||
if !mod.record.Session.Over() {
|
||||
from := mod.record.Session.CurFrame() - 1
|
||||
q := r.URL.Query()
|
||||
vals := q["from"]
|
||||
if len(vals) > 0 {
|
||||
if n, err := strconv.Atoi(vals[0]); err == nil {
|
||||
from = n
|
||||
}
|
||||
}
|
||||
mod.record.Session.SetFrom(from)
|
||||
|
||||
mod.Debug("replaying session %d of %d from %s",
|
||||
mod.record.Session.CurFrame(),
|
||||
mod.record.Session.Frames(),
|
||||
mod.recordFileName)
|
||||
|
||||
mod.State.Store("rec_frames", mod.record.Session.Frames())
|
||||
mod.State.Store("rec_cur_frame", mod.record.Session.CurFrame())
|
||||
|
||||
buf := mod.record.Session.Next()
|
||||
if frame, err := mod.patchFrame(buf); err != nil {
|
||||
mod.Error("%v", err)
|
||||
} else {
|
||||
mod.toJSON(w, frame)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
mod.stopReplay()
|
||||
}
|
||||
}
|
||||
|
||||
mod.toJSON(w, mod.Session)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showBLE(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -73,8 +133,8 @@ func (mod *RestAPI) showBLE(w http.ResponseWriter, r *http.Request) {
|
|||
mac := strings.ToLower(params["mac"])
|
||||
|
||||
if mac == "" {
|
||||
mod.toJSON(w, session.I.BLE)
|
||||
} else if dev, found := session.I.BLE.Get(mac); found {
|
||||
mod.toJSON(w, mod.Session.BLE)
|
||||
} else if dev, found := mod.Session.BLE.Get(mac); found {
|
||||
mod.toJSON(w, dev)
|
||||
} else {
|
||||
http.Error(w, "Not Found", 404)
|
||||
|
@ -86,8 +146,8 @@ func (mod *RestAPI) showHID(w http.ResponseWriter, r *http.Request) {
|
|||
mac := strings.ToLower(params["mac"])
|
||||
|
||||
if mac == "" {
|
||||
mod.toJSON(w, session.I.HID)
|
||||
} else if dev, found := session.I.HID.Get(mac); found {
|
||||
mod.toJSON(w, mod.Session.HID)
|
||||
} else if dev, found := mod.Session.HID.Get(mac); found {
|
||||
mod.toJSON(w, dev)
|
||||
} else {
|
||||
http.Error(w, "Not Found", 404)
|
||||
|
@ -95,19 +155,19 @@ func (mod *RestAPI) showHID(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (mod *RestAPI) showEnv(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.Env)
|
||||
mod.toJSON(w, mod.Session.Env)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showGateway(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.Gateway)
|
||||
mod.toJSON(w, mod.Session.Gateway)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showInterface(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.Interface)
|
||||
mod.toJSON(w, mod.Session.Interface)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showModules(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.Modules)
|
||||
mod.toJSON(w, mod.Session.Modules)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showLAN(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -115,8 +175,8 @@ func (mod *RestAPI) showLAN(w http.ResponseWriter, r *http.Request) {
|
|||
mac := strings.ToLower(params["mac"])
|
||||
|
||||
if mac == "" {
|
||||
mod.toJSON(w, session.I.Lan)
|
||||
} else if host, found := session.I.Lan.Get(mac); found {
|
||||
mod.toJSON(w, mod.Session.Lan)
|
||||
} else if host, found := mod.Session.Lan.Get(mac); found {
|
||||
mod.toJSON(w, host)
|
||||
} else {
|
||||
http.Error(w, "Not Found", 404)
|
||||
|
@ -124,15 +184,15 @@ func (mod *RestAPI) showLAN(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (mod *RestAPI) showOptions(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.Options)
|
||||
mod.toJSON(w, mod.Session.Options)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showPackets(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.Queue)
|
||||
mod.toJSON(w, mod.Session.Queue)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showStartedAt(w http.ResponseWriter, r *http.Request) {
|
||||
mod.toJSON(w, session.I.StartedAt)
|
||||
mod.toJSON(w, mod.Session.StartedAt)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showWiFi(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -140,10 +200,10 @@ func (mod *RestAPI) showWiFi(w http.ResponseWriter, r *http.Request) {
|
|||
mac := strings.ToLower(params["mac"])
|
||||
|
||||
if mac == "" {
|
||||
mod.toJSON(w, session.I.WiFi)
|
||||
} else if station, found := session.I.WiFi.Get(mac); found {
|
||||
mod.toJSON(w, mod.Session.WiFi)
|
||||
} else if station, found := mod.Session.WiFi.Get(mac); found {
|
||||
mod.toJSON(w, station)
|
||||
} else if client, found := session.I.WiFi.GetClient(mac); found {
|
||||
} else if client, found := mod.Session.WiFi.GetClient(mac); found {
|
||||
mod.toJSON(w, client)
|
||||
} else {
|
||||
http.Error(w, "Not Found", 404)
|
||||
|
@ -170,42 +230,72 @@ func (mod *RestAPI) runSessionCommand(w http.ResponseWriter, r *http.Request) {
|
|||
mod.toJSON(w, APIResponse{Success: true})
|
||||
}
|
||||
|
||||
func (mod *RestAPI) getEvents(limit int) []session.Event {
|
||||
events := make([]session.Event, 0)
|
||||
for _, e := range mod.Session.Events.Sorted() {
|
||||
if mod.Session.EventsIgnoreList.Ignored(e) == false {
|
||||
events = append(events, e)
|
||||
}
|
||||
}
|
||||
|
||||
nevents := len(events)
|
||||
nmax := nevents
|
||||
n := nmax
|
||||
|
||||
if limit > 0 && limit < nmax {
|
||||
n = limit
|
||||
}
|
||||
|
||||
return events[nevents-n:]
|
||||
}
|
||||
|
||||
func (mod *RestAPI) showEvents(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
q := r.URL.Query()
|
||||
|
||||
if mod.replaying {
|
||||
if !mod.record.Events.Over() {
|
||||
from := mod.record.Events.CurFrame() - 1
|
||||
vals := q["from"]
|
||||
if len(vals) > 0 {
|
||||
if n, err := strconv.Atoi(vals[0]); err == nil {
|
||||
from = n
|
||||
}
|
||||
}
|
||||
mod.record.Events.SetFrom(from)
|
||||
|
||||
mod.Debug("replaying events %d of %d from %s",
|
||||
mod.record.Events.CurFrame(),
|
||||
mod.record.Events.Frames(),
|
||||
mod.recordFileName)
|
||||
|
||||
buf := mod.record.Events.Next()
|
||||
if _, err := w.Write(buf); err != nil {
|
||||
mod.Error("%v", err)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
mod.stopReplay()
|
||||
}
|
||||
}
|
||||
|
||||
if mod.useWebsocket {
|
||||
mod.startStreamingEvents(w, r)
|
||||
} else {
|
||||
events := make([]session.Event, 0)
|
||||
for _, e := range session.I.Events.Sorted() {
|
||||
if mod.Session.EventsIgnoreList.Ignored(e) == false {
|
||||
events = append(events, e)
|
||||
}
|
||||
}
|
||||
|
||||
nevents := len(events)
|
||||
nmax := nevents
|
||||
n := nmax
|
||||
|
||||
q := r.URL.Query()
|
||||
vals := q["n"]
|
||||
limit := 0
|
||||
if len(vals) > 0 {
|
||||
n, err = strconv.Atoi(q["n"][0])
|
||||
if err == nil {
|
||||
if n > nmax {
|
||||
n = nmax
|
||||
}
|
||||
} else {
|
||||
n = nmax
|
||||
if n, err := strconv.Atoi(q["n"][0]); err == nil {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
|
||||
mod.toJSON(w, events[nevents-n:])
|
||||
mod.toJSON(w, mod.getEvents(limit))
|
||||
}
|
||||
}
|
||||
|
||||
func (mod *RestAPI) clearEvents(w http.ResponseWriter, r *http.Request) {
|
||||
session.I.Events.Clear()
|
||||
mod.Session.Events.Clear()
|
||||
}
|
||||
|
||||
func (mod *RestAPI) corsRoute(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -227,10 +317,10 @@ func (mod *RestAPI) sessionRoute(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
session.I.Lock()
|
||||
defer session.I.Unlock()
|
||||
mod.Session.Lock()
|
||||
defer mod.Session.Unlock()
|
||||
|
||||
path := r.URL.String()
|
||||
path := r.URL.Path
|
||||
switch {
|
||||
case path == "/api/session":
|
||||
mod.showSession(w, r)
|
||||
|
|
107
modules/api_rest/api_rest_record.go
Normal file
107
modules/api_rest/api_rest_record.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package api_rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/evilsocket/islazy/fs"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotRecording = errors.New("not recording")
|
||||
)
|
||||
|
||||
func (mod *RestAPI) errAlreadyRecording() error {
|
||||
return fmt.Errorf("the module is already recording to %s", mod.recordFileName)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) recordState() error {
|
||||
mod.Session.Lock()
|
||||
defer mod.Session.Unlock()
|
||||
|
||||
session := new(bytes.Buffer)
|
||||
encoder := json.NewEncoder(session)
|
||||
|
||||
if err := encoder.Encode(mod.Session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := new(bytes.Buffer)
|
||||
encoder = json.NewEncoder(events)
|
||||
|
||||
if err := encoder.Encode(mod.getEvents(0)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mod.record.NewState(session.Bytes(), events.Bytes())
|
||||
}
|
||||
|
||||
func (mod *RestAPI) recorder() {
|
||||
mod.recTime = 0
|
||||
mod.recording = true
|
||||
mod.replaying = false
|
||||
mod.record = NewRecord(mod.recordFileName)
|
||||
|
||||
mod.Info("started recording to %s ...", mod.recordFileName)
|
||||
|
||||
mod.recordWait.Add(1)
|
||||
defer mod.recordWait.Done()
|
||||
|
||||
tick := time.NewTicker(1 * time.Second)
|
||||
for range tick.C {
|
||||
if !mod.recording {
|
||||
break
|
||||
}
|
||||
|
||||
mod.recTime++
|
||||
|
||||
if err := mod.recordState(); err != nil {
|
||||
mod.Error("error while recording: %s", err)
|
||||
mod.recording = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mod.Info("stopped recording to %s ...", mod.recordFileName)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) startRecording(filename string) (err error) {
|
||||
if mod.recording {
|
||||
return mod.errAlreadyRecording()
|
||||
} else if mod.replaying {
|
||||
return mod.errAlreadyReplaying()
|
||||
} else if mod.recordFileName, err = fs.Expand(filename); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we need the api itself up and running
|
||||
if !mod.Running() {
|
||||
if err = mod.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
go mod.recorder()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mod *RestAPI) stopRecording() error {
|
||||
if !mod.recording {
|
||||
return errNotRecording
|
||||
}
|
||||
|
||||
mod.recording = false
|
||||
|
||||
mod.recordWait.Wait()
|
||||
|
||||
err := mod.record.Flush()
|
||||
|
||||
mod.recordFileName = ""
|
||||
mod.record = nil
|
||||
|
||||
return err
|
||||
}
|
63
modules/api_rest/api_rest_replay.go
Normal file
63
modules/api_rest/api_rest_replay.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package api_rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/evilsocket/islazy/fs"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotReplaying = errors.New("not replaying")
|
||||
)
|
||||
|
||||
func (mod *RestAPI) errAlreadyReplaying() error {
|
||||
return fmt.Errorf("the module is already replaying a session from %s", mod.recordFileName)
|
||||
}
|
||||
|
||||
func (mod *RestAPI) startReplay(filename string) (err error) {
|
||||
if mod.replaying {
|
||||
return mod.errAlreadyReplaying()
|
||||
} else if mod.recording {
|
||||
return mod.errAlreadyRecording()
|
||||
} else if mod.recordFileName, err = fs.Expand(filename); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mod.Info("loading %s ...", mod.recordFileName)
|
||||
|
||||
start := time.Now()
|
||||
if mod.record, err = LoadRecord(mod.recordFileName); err != nil {
|
||||
return err
|
||||
}
|
||||
loadedIn := time.Since(start)
|
||||
|
||||
// we need the api itself up and running
|
||||
if !mod.Running() {
|
||||
if err := mod.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
mod.replaying = true
|
||||
mod.recording = false
|
||||
|
||||
mod.Info("loaded %d frames in %s, started replaying ...", mod.record.Session.Frames(), loadedIn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mod *RestAPI) stopReplay() error {
|
||||
if !mod.replaying {
|
||||
return errNotReplaying
|
||||
}
|
||||
|
||||
mod.replaying = false
|
||||
|
||||
mod.Info("stopped replaying from %s ...", mod.recordFileName)
|
||||
|
||||
mod.recordFileName = ""
|
||||
|
||||
return nil
|
||||
}
|
233
modules/api_rest/record.go
Normal file
233
modules/api_rest/record.go
Normal file
|
@ -0,0 +1,233 @@
|
|||
package api_rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/evilsocket/islazy/fs"
|
||||
"github.com/kr/binarydist"
|
||||
)
|
||||
|
||||
type patch []byte
|
||||
type frame []byte
|
||||
|
||||
type RecordEntry struct {
|
||||
sync.Mutex
|
||||
|
||||
Data []byte `json:"data"`
|
||||
Cur []byte `json:"-"`
|
||||
States []patch `json:"states"`
|
||||
NumStates int `json:"-"`
|
||||
CurState int `json:"-"`
|
||||
|
||||
frames []frame
|
||||
}
|
||||
|
||||
func NewRecordEntry() *RecordEntry {
|
||||
return &RecordEntry{
|
||||
Data: nil,
|
||||
Cur: nil,
|
||||
States: make([]patch, 0),
|
||||
NumStates: 0,
|
||||
CurState: 0,
|
||||
frames: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *RecordEntry) AddState(state []byte) error {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
|
||||
// set reference state
|
||||
if e.Data == nil {
|
||||
e.Data = state
|
||||
} else {
|
||||
// create a patch
|
||||
oldReader := bytes.NewReader(e.Cur)
|
||||
newReader := bytes.NewReader(state)
|
||||
writer := new(bytes.Buffer)
|
||||
|
||||
if err := binarydist.Diff(oldReader, newReader, writer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.States = append(e.States, patch(writer.Bytes()))
|
||||
e.NumStates++
|
||||
e.CurState = 0
|
||||
}
|
||||
e.Cur = state
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *RecordEntry) Reset() {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
e.Cur = e.Data
|
||||
e.NumStates = len(e.States)
|
||||
e.CurState = 0
|
||||
}
|
||||
|
||||
func (e *RecordEntry) Compile() error {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
|
||||
// reset the state
|
||||
e.Cur = e.Data
|
||||
e.NumStates = len(e.States)
|
||||
e.CurState = 0
|
||||
e.frames = make([]frame, e.NumStates+1)
|
||||
|
||||
// first is the master frame
|
||||
e.frames[0] = frame(e.Data)
|
||||
// precompute frames so they can be accessed by index
|
||||
for i := 0; i < e.NumStates; i++ {
|
||||
patch := e.States[i]
|
||||
oldReader := bytes.NewReader(e.Cur)
|
||||
patchReader := bytes.NewReader(patch)
|
||||
newWriter := new(bytes.Buffer)
|
||||
|
||||
if err := binarydist.Patch(oldReader, newWriter, patchReader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Cur = newWriter.Bytes()
|
||||
e.frames[i+1] = e.Cur
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *RecordEntry) Frames() int {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
// master + sub states
|
||||
return e.NumStates + 1
|
||||
}
|
||||
|
||||
func (e *RecordEntry) CurFrame() int {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
return e.CurState + 1
|
||||
}
|
||||
|
||||
func (e *RecordEntry) SetFrom(from int) {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
e.CurState = from
|
||||
}
|
||||
|
||||
func (e *RecordEntry) Over() bool {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
return e.CurState > e.NumStates
|
||||
}
|
||||
|
||||
func (e *RecordEntry) Next() []byte {
|
||||
e.Lock()
|
||||
defer e.Unlock()
|
||||
cur := e.CurState
|
||||
e.CurState++
|
||||
return e.frames[cur]
|
||||
}
|
||||
|
||||
// the Record object represents a recorded session
|
||||
type Record struct {
|
||||
sync.Mutex
|
||||
|
||||
fileName string `json:"-"`
|
||||
Session *RecordEntry `json:"session"`
|
||||
Events *RecordEntry `json:"events"`
|
||||
}
|
||||
|
||||
func NewRecord(fileName string) *Record {
|
||||
return &Record{
|
||||
fileName: fileName,
|
||||
Session: NewRecordEntry(),
|
||||
Events: NewRecordEntry(),
|
||||
}
|
||||
}
|
||||
|
||||
func LoadRecord(fileName string) (*Record, error) {
|
||||
if !fs.Exists(fileName) {
|
||||
return nil, fmt.Errorf("%s does not exist", fileName)
|
||||
}
|
||||
|
||||
compressed, err := ioutil.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while reading %s: %s", fileName, err)
|
||||
}
|
||||
|
||||
decompress, err := gzip.NewReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while reading gzip file %s: %s", fileName, err)
|
||||
}
|
||||
defer decompress.Close()
|
||||
|
||||
raw, err := ioutil.ReadAll(decompress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while decompressing %s: %s", fileName, err)
|
||||
}
|
||||
|
||||
rec := &Record{}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(raw))
|
||||
if err = decoder.Decode(rec); err != nil {
|
||||
return nil, fmt.Errorf("error while parsing %s: %s", fileName, err)
|
||||
}
|
||||
|
||||
rec.fileName = fileName
|
||||
|
||||
// reset state and precompute frames
|
||||
if err = rec.Session.Compile(); err != nil {
|
||||
return nil, err
|
||||
} else if err = rec.Events.Compile(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
func (r *Record) NewState(session []byte, events []byte) error {
|
||||
if err := r.Session.AddState(session); err != nil {
|
||||
return err
|
||||
} else if err := r.Events.AddState(events); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.Flush()
|
||||
}
|
||||
|
||||
func (r *Record) save() error {
|
||||
buf := new(bytes.Buffer)
|
||||
encoder := json.NewEncoder(buf)
|
||||
|
||||
if err := encoder.Encode(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := buf.Bytes()
|
||||
|
||||
compressed := new(bytes.Buffer)
|
||||
compress := gzip.NewWriter(compressed)
|
||||
|
||||
if _, err := compress.Write(data); err != nil {
|
||||
return err
|
||||
} else if err = compress.Flush(); err != nil {
|
||||
return err
|
||||
} else if err = compress.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(r.fileName, compressed.Bytes(), os.ModePerm)
|
||||
}
|
||||
|
||||
func (r *Record) Flush() error {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return r.save()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue