diff --git a/session/session.go b/session/session.go index f512da4d..372a5443 100644 --- a/session/session.go +++ b/session/session.go @@ -110,7 +110,7 @@ func (s *Session) Module(name string) (err error, mod Module) { return fmt.Errorf("Module %s not found", name), mod } -func (s *Session) setupInput() error { +func (s *Session) setupReadline() error { var err error pcompleters := make([]readline.PrefixCompleterInterface, 0) @@ -211,86 +211,23 @@ func (s *Session) Close() { func (s *Session) Register(mod Module) error { s.Modules = append(s.Modules, mod) - return nil -} -func (s *Session) Start() error { - var err error - - // make sure modules are always sorted by name - sort.Slice(s.Modules, func(i, j int) bool { - return s.Modules[i].Name() < s.Modules[j].Name() - }) - - net.OuiInit() - - if s.Interface, err = net.FindInterface(*s.Options.InterfaceName); err != nil { - return err - } - - s.Env.Set(PromptVariable, DefaultPrompt) - - s.Env.Set("iface.index", fmt.Sprintf("%d", s.Interface.Index)) - s.Env.Set("iface.name", s.Interface.Name()) - s.Env.Set("iface.ipv4", s.Interface.IpAddress) - s.Env.Set("iface.ipv6", s.Interface.Ip6Address) - s.Env.Set("iface.mac", s.Interface.HwAddress) - - if s.Queue, err = packets.NewQueue(s.Interface); err != nil { - return err - } - - if s.Gateway, err = net.FindGateway(s.Interface); err != nil { - s.Events.Log(core.WARNING, "%s", err.Error()) - } - - if s.Gateway == nil || s.Gateway.IpAddress == s.Interface.IpAddress { - s.Gateway = s.Interface - } - - s.Env.Set("gateway.address", s.Gateway.IpAddress) - s.Env.Set("gateway.mac", s.Gateway.HwAddress) - - s.Targets = NewTargets(s, s.Interface, s.Gateway) - s.Firewall = firewall.Make(s.Interface) - - if err := s.setupInput(); err != nil { - return err - } - - for _, h := range s.CoreHandlers { + for _, h := range mod.Handlers() { if len(h.Name) > s.HelpPadding { s.HelpPadding = len(h.Name) } } - for _, m := range s.Modules { - for _, h := range m.Handlers() { - if len(h.Name) > s.HelpPadding { - s.HelpPadding = len(h.Name) - } - } - for _, p := range m.Parameters() { - if len(p.Name) > s.HelpPadding { - s.HelpPadding = len(p.Name) - } + for _, p := range mod.Parameters() { + if len(p.Name) > s.HelpPadding { + s.HelpPadding = len(p.Name) } } - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - signal.Notify(c, syscall.SIGTERM) - go func() { - <-c - fmt.Println() - s.Events.Log(core.WARNING, "Got SIGTERM") - s.Close() - os.Exit(0) - }() - - s.StartedAt = time.Now() - s.Active = true + return nil +} +func (s *Session) startNetMon() { // keep reading network events in order to add / update endpoints go func() { for event := range s.Queue.Activities { @@ -309,6 +246,73 @@ func (s *Session) Start() error { } } }() +} + +func (s *Session) setupSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + signal.Notify(c, syscall.SIGTERM) + go func() { + <-c + fmt.Println() + s.Events.Log(core.WARNING, "Got SIGTERM") + s.Close() + os.Exit(0) + }() +} + +func (s *Session) setupEnv() { + s.Env.Set(PromptVariable, DefaultPrompt) + s.Env.Set("iface.index", fmt.Sprintf("%d", s.Interface.Index)) + s.Env.Set("iface.name", s.Interface.Name()) + s.Env.Set("iface.ipv4", s.Interface.IpAddress) + s.Env.Set("iface.ipv6", s.Interface.Ip6Address) + s.Env.Set("iface.mac", s.Interface.HwAddress) + s.Env.Set("gateway.address", s.Gateway.IpAddress) + s.Env.Set("gateway.mac", s.Gateway.HwAddress) +} + +func (s *Session) Start() error { + var err error + + // make sure modules are always sorted by name + sort.Slice(s.Modules, func(i, j int) bool { + return s.Modules[i].Name() < s.Modules[j].Name() + }) + + net.OuiInit() + + if s.Interface, err = net.FindInterface(*s.Options.InterfaceName); err != nil { + return err + } + + if s.Queue, err = packets.NewQueue(s.Interface); err != nil { + return err + } + + if s.Gateway, err = net.FindGateway(s.Interface); err != nil { + s.Events.Log(core.WARNING, "%s", err.Error()) + } + + if s.Gateway == nil || s.Gateway.IpAddress == s.Interface.IpAddress { + s.Gateway = s.Interface + } + + s.Targets = NewTargets(s, s.Interface, s.Gateway) + s.Firewall = firewall.Make(s.Interface) + + s.setupEnv() + + if err := s.setupReadline(); err != nil { + return err + } + + s.setupSignals() + + s.StartedAt = time.Now() + s.Active = true + + s.startNetMon() if *s.Options.Debug { s.Events.Add("session.started", nil) diff --git a/session/session_core_handlers.go b/session/session_core_handlers.go index 811c10d7..9796dc33 100644 --- a/session/session_core_handlers.go +++ b/session/session_core_handlers.go @@ -195,6 +195,9 @@ func (s *Session) aliasHandler(args []string, sess *Session) error { func (s *Session) addHandler(h CommandHandler, c *readline.PrefixCompleter) { h.Completer = c s.CoreHandlers = append(s.CoreHandlers, h) + if len(h.Name) > s.HelpPadding { + s.HelpPadding = len(h.Name) + } } func (s *Session) registerCoreHandlers() {