diff --git a/firewall/firewall_linux.go b/firewall/firewall_linux.go index 16218969..8a817ead 100644 --- a/firewall/firewall_linux.go +++ b/firewall/firewall_linux.go @@ -65,9 +65,40 @@ func (f LinuxFirewall) EnableForwarding(enabled bool) error { return f.enableFeature(IPV4ForwardingFile, enabled) } -func (f *LinuxFirewall) EnableRedirection(r *Redirection, enabled bool) error { - var opts []string +func (f *LinuxFirewall) getCommandLine(r *Redirection, enabled bool) (cmdLine []string) { + action := "-A" + if enabled == false { + action = "-D" + } + if r.SrcAddress == "" { + cmdLine = []string{ + "-t", "nat", + action, "PREROUTING", + "-i", r.Interface, + "-p", r.Protocol, + "--dport", fmt.Sprintf("%d", r.SrcPort), + "-j", "DNAT", + "--to", fmt.Sprintf("%s:%d", r.DstAddress, r.DstPort), + } + } else { + cmdLine = []string{ + "-t", "nat", + action, "PREROUTING", + "-i", r.Interface, + "-p", r.Protocol, + "-d", r.SrcAddress, + "--dport", fmt.Sprintf("%d", r.SrcPort), + "-j", "DNAT", + "--to", fmt.Sprintf("%s:%d", r.DstAddress, r.DstPort), + } + } + + return +} + +func (f *LinuxFirewall) EnableRedirection(r *Redirection, enabled bool) error { + cmdLine := f.getCommandLine(r, enabled) rkey := r.String() _, found := f.redirections[rkey] @@ -81,31 +112,7 @@ func (f *LinuxFirewall) EnableRedirection(r *Redirection, enabled bool) error { // accept all if _, err := core.Exec("iptables", []string{"-P", "FORWARD", "ACCEPT"}); err != nil { return err - } - - if r.SrcAddress == "" { - opts = []string{ - "-t", "nat", - "-A", "PREROUTING", - "-i", r.Interface, - "-p", r.Protocol, - "--dport", fmt.Sprintf("%d", r.SrcPort), - "-j", "DNAT", - "--to", fmt.Sprintf("%s:%d", r.DstAddress, r.DstPort), - } - } else { - opts = []string{ - "-t", "nat", - "-A", "PREROUTING", - "-i", r.Interface, - "-p", r.Protocol, - "-d", r.SrcAddress, - "--dport", fmt.Sprintf("%d", r.SrcPort), - "-j", "DNAT", - "--to", fmt.Sprintf("%s:%d", r.DstAddress, r.DstPort), - } - } - if _, err := core.Exec("iptables", opts); err != nil { + } else if _, err := core.Exec("iptables", cmdLine); err != nil { return err } } else { @@ -115,29 +122,7 @@ func (f *LinuxFirewall) EnableRedirection(r *Redirection, enabled bool) error { delete(f.redirections, r.String()) - if r.SrcAddress == "" { - opts = []string{ - "-t", "nat", - "-D", "PREROUTING", - "-i", r.Interface, - "-p", r.Protocol, - "--dport", fmt.Sprintf("%d", r.SrcPort), - "-j", "DNAT", - "--to", fmt.Sprintf("%s:%d", r.DstAddress, r.DstPort), - } - } else { - opts = []string{ - "-t", "nat", - "-D", "PREROUTING", - "-i", r.Interface, - "-p", r.Protocol, - "-d", r.SrcAddress, - "--dport", fmt.Sprintf("%d", r.SrcPort), - "-j", "DNAT", - "--to", fmt.Sprintf("%s:%d", r.DstAddress, r.DstPort), - } - } - if _, err := core.Exec("iptables", opts); err != nil { + if _, err := core.Exec("iptables", cmdLine); err != nil { return err } }