misc: several improvements to the graph module

This commit is contained in:
Simone Margaritelli 2021-04-08 18:41:30 +02:00
commit 71634058a7
6 changed files with 468 additions and 263 deletions

View file

@ -1,36 +1,33 @@
package graph
import (
"encoding/json"
"fmt"
"strings"
"github.com/bettercap/bettercap/session"
"github.com/evilsocket/islazy/fs"
"io/ioutil"
"os"
"path"
"regexp"
"sort"
"sync"
"time"
)
var edgesParser = regexp.MustCompile(`^edges_(.+_[a-fA-F0-9:]{17})_(.+_.+)\.json$`)
type NodeCallback func(*Node)
type EdgeCallback func(*Node, *Edge, *Node)
type EdgeCallback func(*Node, []Edge, *Node)
type Graph struct {
sync.Mutex
path string
path string
edges *Edges
}
func NewGraph(path string) (*Graph, error) {
g := &Graph{
path: path,
if edges, err := LoadEdges(path); err != nil {
return nil, err
} else {
return &Graph{
path: path,
edges: edges,
}, nil
}
return g, nil
}
func (g *Graph) EachNode(cb NodeCallback) error {
@ -39,13 +36,10 @@ func (g *Graph) EachNode(cb NodeCallback) error {
for _, nodeType := range NodeTypes {
err := fs.Glob(g.path, fmt.Sprintf("%s_*.json", nodeType), func(fileName string) error {
var node Node
if raw, err := ioutil.ReadFile(fileName); err != nil {
return fmt.Errorf("error while reading %s: %v", fileName, err)
} else if err = json.Unmarshal(raw, &node); err != nil {
return fmt.Errorf("error while decoding %s: %v", fileName, err)
if node, err := ReadNode(fileName); err != nil {
return err
} else {
cb(&node)
cb(node)
}
return nil
})
@ -60,36 +54,21 @@ func (g *Graph) EachEdge(cb EdgeCallback) error {
g.Lock()
defer g.Unlock()
return fs.Glob(g.path, "edges_*.json", func(fileName string) error {
matches := edgesParser.FindAllStringSubmatch(path.Base(fileName), -1)
if len(matches) > 0 && len(matches[0]) == 3 {
var left, right Node
leftFileName := path.Join(g.path, matches[0][1]+".json")
rightFileName := path.Join(g.path, matches[0][2]+".json")
return g.edges.ForEachEdge(func(fromID string, edges []Edge, toID string) error {
var left, right *Node
var err error
if raw, err := ioutil.ReadFile(leftFileName); err != nil {
return fmt.Errorf("error while reading %s: %v", leftFileName, err)
} else if err = json.Unmarshal(raw, &left); err != nil {
return fmt.Errorf("error while decoding %s: %v", leftFileName, err)
} else if raw, err = ioutil.ReadFile(rightFileName); err != nil {
return fmt.Errorf("error while reading %s: %v", rightFileName, err)
} else if err = json.Unmarshal(raw, &right); err != nil {
return fmt.Errorf("error while decoding %s: %v", rightFileName, err)
}
leftFileName := path.Join(g.path, fromID+".json")
rightFileName := path.Join(g.path, toID+".json")
var edges []*Edge
if raw, err := ioutil.ReadFile(fileName); err != nil {
return fmt.Errorf("error while reading %s: %v", fileName, err)
} else if err = json.Unmarshal(raw, &edges); err != nil {
return fmt.Errorf("error while decoding %s: %v", fileName, err)
}
for _, edge := range edges {
cb(&left, edge, &right)
}
} else {
return fmt.Errorf("filename %s doesn't match edges parser", fileName)
if left, err = ReadNode(leftFileName); err != nil {
return err
} else if right, err = ReadNode(rightFileName); err != nil {
return err
}
cb(left, edges, right)
return nil
})
}
@ -110,14 +89,13 @@ func (g *Graph) Traverse(root string, onNode NodeCallback, onEdge EdgeCallback)
}
stack := NewStack()
for _, root := range roots {
stack.Push(root)
}
type edgeBucket struct {
left *Node
edge *Edge
left *Node
edges []Edge
right *Node
}
@ -138,74 +116,122 @@ func (g *Graph) Traverse(root string, onNode NodeCallback, onEdge EdgeCallback)
onNode(node)
// find all edges starting from this node
edgesFilter := fmt.Sprintf("edges_%s_*.json", nodeID)
err = fs.Glob(g.path, edgesFilter, func(edgeFileName string) error {
right := new(Node)
base := path.Base(edgeFileName)
base = strings.ReplaceAll(base, "edges_", "")
base = strings.ReplaceAll(base, nodeID + "_", "")
// read right node
rightFileName := path.Join(g.path, base)
if raw, err := ioutil.ReadFile(rightFileName); err != nil {
return fmt.Errorf("error while reading %s: %v", rightFileName, err)
} else if err = json.Unmarshal(raw, right); err != nil {
return fmt.Errorf("error while decoding %s: %v", rightFileName, err)
}
stack.Push(right)
// read edges
var edges []*Edge
if raw, err := ioutil.ReadFile(edgeFileName); err != nil {
return fmt.Errorf("error while reading %s: %v", edgeFileName, err)
} else if err = json.Unmarshal(raw, &edges); err != nil {
return fmt.Errorf("error while decoding %s: %v", edgeFileName, err)
}
for _, edge := range edges {
allEdges = append(allEdges, edgeBucket {
left: node,
edge: edge,
// collect all edges starting from this node
err = g.edges.ForEachEdgeFrom(nodeID, func(_ string, edges []Edge, toID string) error {
rightFileName := path.Join(g.path, toID+".json")
if right, err := ReadNode(rightFileName); err != nil {
return err
} else {
// collect new node
if _, found := visited[toID]; !found {
stack.Push(right)
}
// collect all edges, we'll emit this later
allEdges = append(allEdges, edgeBucket{
left: node,
edges: edges,
right: right,
})
}
return nil
})
if err != nil {
return err
}
}
}
for _, edge := range allEdges {
onEdge(edge.left, edge.edge, edge.right)
onEdge(edge.left, edge.edges, edge.right)
}
}
return nil
}
func (g *Graph) Dot(filter, layout, name string) (string, error) {
func (g *Graph) Dot(filter, layout, name string, disconnected bool) (string, int, int, error) {
size := 0
discarded := 0
data := fmt.Sprintf("digraph %s {\n", name)
data += fmt.Sprintf(" layout=%s\n", layout)
if err := g.Traverse(filter, func(node *Node) {
data += fmt.Sprintf(" %s\n", node.Dot(filter == node.ID))
}, func(left *Node, edge *Edge, right *Node) {
data += fmt.Sprintf(" %s\n", edge.Dot(left, right))
}); err != nil {
return "", err
typeMap := make(map[NodeType]bool)
type typeCount struct {
edge Edge
count int
}
if err := g.Traverse(filter, func(node *Node) {
include := false
if disconnected || node.Type == SSID { // we don't create backwards edges for SSID
include = true
} else {
include = g.edges.IsConnected(node.String())
}
if include {
size++
typeMap[node.Type] = true
data += fmt.Sprintf(" %s\n", node.Dot(filter == node.ID))
} else {
discarded++
}
}, func(left *Node, edges []Edge, right *Node) {
// collect counters by edge type in order to calculate proportional widths
byType := make(map[string]typeCount)
tot := len(edges)
for _, edge := range edges {
if c, found := byType[string(edge.Type)]; found {
c.count++
} else {
byType[string(edge.Type)] = typeCount{
edge: edge,
count: 1,
}
}
}
max := 2.0
for _, c := range byType {
w := max * float64(c.count/tot)
if w < 0.5 {
w = 0.5
}
data += fmt.Sprintf(" %s\n", c.edge.Dot(left, right, w))
}
}); err != nil {
return "", 0, 0, err
}
data += "\n"
data += "node [style=filled height=0.55 fontname=\"Verdana\" fontsize=10];\n"
data += "subgraph legend {\n" +
"graph[style=dotted];\n" +
"label = \"Legend\";\n"
var types []NodeType
for nodeType, _ := range typeMap {
types = append(types, nodeType)
node := Node{
Type: nodeType,
Annotations: nodeTypeDescs[nodeType],
Dummy: true,
}
data += fmt.Sprintf(" %s\n", node.Dot(false))
}
ntypes := len(types)
for i := 0; i < ntypes - 1; i++ {
data += fmt.Sprintf(" \"%s\" -> \"%s\" [style=invis];\n", types[i], types[i + 1])
}
data += "}\n"
data += "\n"
data += " overlap=false\n"
data += "}"
return data, nil
return data, size, discarded, nil
}
func (g *Graph) FindNode(t NodeType, id string) (*Node, error) {
@ -214,13 +240,7 @@ func (g *Graph) FindNode(t NodeType, id string) (*Node, error) {
nodeFileName := path.Join(g.path, fmt.Sprintf("%s_%s.json", t, id))
if fs.Exists(nodeFileName) {
var node Node
if raw, err := ioutil.ReadFile(nodeFileName); err != nil {
return nil, fmt.Errorf("error while reading %s: %v", nodeFileName, err)
} else if err = json.Unmarshal(raw, &node); err != nil {
return nil, fmt.Errorf("error while decoding %s: %v", nodeFileName, err)
}
return &node, nil
return ReadNode(nodeFileName)
}
return nil, nil
@ -235,13 +255,10 @@ func (g *Graph) FindOtherTypes(t NodeType, id string) ([]*Node, error) {
for _, otherType := range NodeTypes {
if otherType != t {
if nodeFileName := path.Join(g.path, fmt.Sprintf("%s_%s.json", otherType, id)); fs.Exists(nodeFileName) {
var node Node
if raw, err := ioutil.ReadFile(nodeFileName); err != nil {
return nil, fmt.Errorf("error while reading %s: %v", nodeFileName, err)
} else if err = json.Unmarshal(raw, &node); err != nil {
return nil, fmt.Errorf("error while decoding %s: %v", nodeFileName, err)
if node, err := ReadNode(nodeFileName); err != nil {
return nil, err
} else {
otherNodes = append(otherNodes, &node)
otherNodes = append(otherNodes, node)
}
}
}
@ -257,16 +274,13 @@ func (g *Graph) CreateNode(t NodeType, id string, entity interface{}, annotation
node := &Node{
Type: t,
ID: id,
CreatedAt: time.Now(),
Entity: entity,
Annotations: annotations,
}
nodeFileName := path.Join(g.path, fmt.Sprintf("%s.json", node.String()))
if raw, err := json.Marshal(node); err != nil {
return nil, fmt.Errorf("error creating data for %s: %v", nodeFileName, err)
} else if err = ioutil.WriteFile(nodeFileName, raw, os.ModePerm); err != nil {
return nil, fmt.Errorf("error creating %s: %v", nodeFileName, err)
if err := CreateNode(nodeFileName, node); err != nil {
return nil, err
}
session.I.Events.Add("graph.node.new", node)
@ -278,86 +292,40 @@ func (g *Graph) UpdateNode(node *Node) error {
g.Lock()
defer g.Unlock()
node.UpdatedAt = time.Now()
nodeFileName := path.Join(g.path, fmt.Sprintf("%s.json", node.String()))
if raw, err := json.Marshal(node); err != nil {
return fmt.Errorf("error creating new data for %s: %v", nodeFileName, err)
} else if err = ioutil.WriteFile(nodeFileName, raw, os.ModePerm); err != nil {
return fmt.Errorf("error updating %s: %v", nodeFileName, err)
if err := UpdateNode(nodeFileName, node); err != nil {
return err
}
return nil
}
func (g *Graph) findEdgesUnlocked(from, to *Node) (string, []*Edge, error) {
edgesFileName := path.Join(g.path, fmt.Sprintf("edges_%s_%s.json", from.String(), to.String()))
if fs.Exists(edgesFileName) {
var edges []*Edge
if raw, err := ioutil.ReadFile(edgesFileName); err != nil {
return edgesFileName, nil, fmt.Errorf("error while reading %s: %v", edgesFileName, err)
} else if err = json.Unmarshal(raw, &edges); err != nil {
return edgesFileName, nil, fmt.Errorf("error while decoding %s: %v", edgesFileName, err)
}
// sort edges from oldest to newer
sort.Slice(edges, func(i, j int) bool {
return edges[i].CreatedAt.Before(edges[j].CreatedAt)
})
return edgesFileName, edges, nil
}
return edgesFileName, nil, nil
}
func (g *Graph) FindEdges(from, to *Node) ([]*Edge, error) {
g.Lock()
defer g.Unlock()
_, edges, err := g.findEdgesUnlocked(from, to)
return edges, err
}
func (g *Graph) FindLastEdgeOfType(from, to *Node, edgeType EdgeType) (*Edge, error) {
g.Lock()
defer g.Unlock()
if _, edges, err := g.findEdgesUnlocked(from, to); err != nil {
return nil, err
} else {
num := len(edges)
for i := range edges {
// loop backwards
idx := num - 1 - i
edge := edges[idx]
if edge.Type == edgeType {
return edge, nil
}
edges := g.edges.FindEdges(from.String(), to.String(), true)
num := len(edges)
for i := range edges {
// loop backwards
idx := num - 1 - i
edge := edges[idx]
if edge.Type == edgeType {
return &edge, nil
}
}
return nil, nil
}
func (g *Graph) FindLastRecentEdgeOfType(from, to *Node, edgeType EdgeType, staleTime time.Duration) (*Edge, error) {
g.Lock()
defer g.Unlock()
if _, edges, err := g.findEdgesUnlocked(from, to); err != nil {
return nil, err
} else {
num := len(edges)
for i := range edges {
// loop backwards
idx := num - 1 - i
edge := edges[idx]
if edge.Type == edgeType {
if time.Since(edge.CreatedAt) >= staleTime {
return nil, nil
}
// edge is still fresh
return edge, nil
edges := g.edges.FindEdges(from.String(), to.String(), true)
num := len(edges)
for i := range edges {
// loop backwards
idx := num - 1 - i
edge := edges[idx]
if edge.Type == edgeType {
if time.Since(edge.CreatedAt) >= staleTime {
return nil, nil
}
return &edge, nil
}
}
@ -365,35 +333,24 @@ func (g *Graph) FindLastRecentEdgeOfType(from, to *Node, edgeType EdgeType, stal
}
func (g *Graph) CreateEdge(from, to *Node, edgeType EdgeType) (*Edge, error) {
g.Lock()
defer g.Unlock()
var edgesFileName string
var edges []*Edge
edge := &Edge{
edge := Edge{
Type: edgeType,
CreatedAt: time.Now(),
Position: session.I.GPS,
}
if edgesFileName, edges, _ = g.findEdgesUnlocked(from, to); edges != nil {
edges = append(edges, edge)
} else {
edges = []*Edge{edge}
if session.I.GPS.Updated.IsZero() == false {
edge.Position = &session.I.GPS
}
if raw, err := json.Marshal(edges); err != nil {
return nil, fmt.Errorf("error creating data for %s: %v", edgesFileName, err)
} else if err = ioutil.WriteFile(edgesFileName, raw, os.ModePerm); err != nil {
return nil, fmt.Errorf("error writing %s: %v", edgesFileName, err)
if err := g.edges.Connect(from.String(), to.String(), edge); err != nil {
return nil, err
}
session.I.Events.Add("graph.edge.new", EdgeEvent{
Left: from,
Edge: edge,
Left: from,
Edge: &edge,
Right: to,
})
return edge, nil
return &edge, nil
}