diff --git a/modules/dns_proxy/dns_proxy.go b/modules/dns_proxy/dns_proxy.go index f27cde2c..76de6db9 100644 --- a/modules/dns_proxy/dns_proxy.go +++ b/modules/dns_proxy/dns_proxy.go @@ -113,10 +113,10 @@ func NewDnsProxy(s *session.Session) *DnsProxy { "Address to bind the DNS proxy to.")) mod.AddParam(session.NewStringParameter("dns.proxy.blacklist", "", "", - "Comma separated list of hostnames to skip while proxying (wildcard expressions can be used).")) + "Comma separated list of client IPs to skip while proxying.")) mod.AddParam(session.NewStringParameter("dns.proxy.whitelist", "", "", - "Comma separated list of hostnames to proxy if the blacklist is used (wildcard expressions can be used).")) + "Comma separated list of client IPs to proxy if the blacklist is used.")) mod.AddParam(session.NewStringParameter("dns.proxy.nameserver", "1.1.1.1", diff --git a/modules/dns_proxy/dns_proxy_base.go b/modules/dns_proxy/dns_proxy_base.go index 8eb4e0f6..ac637bf3 100644 --- a/modules/dns_proxy/dns_proxy_base.go +++ b/modules/dns_proxy/dns_proxy_base.go @@ -41,6 +41,24 @@ type DNSProxy struct { tag string } +func (p *DNSProxy) shouldProxy(clientIP string) bool { + // check if this client is in the whitelist + for _, ip := range p.Whitelist { + if clientIP == ip { + return true + } + } + + // check if this client is in the blacklist + for _, ip := range p.Blacklist { + if clientIP == ip { + return false + } + } + + return true +} + func (p *DNSProxy) Configure(address string, dnsPort int, doRedirect bool, nameserver string, netProtocol string, proxyPort int, scriptPath string, certFile string, keyFile string) error { var err error diff --git a/modules/dns_proxy/dns_proxy_base_filters.go b/modules/dns_proxy/dns_proxy_base_filters.go index d66b0ad1..60cdd5dd 100644 --- a/modules/dns_proxy/dns_proxy_base_filters.go +++ b/modules/dns_proxy/dns_proxy_base_filters.go @@ -59,53 +59,57 @@ func (p *DNSProxy) logResponseAction(m *dns.Msg, clientIP string) { } func (p *DNSProxy) onRequestFilter(query *dns.Msg, clientIP string) (req, res *dns.Msg) { - p.Debug("< %s q[%s]", - clientIP, - strings.Join(questionsToStrings(query.Question), ",")) + if p.shouldProxy(clientIP) { + p.Debug("< %s q[%s]", + clientIP, + strings.Join(questionsToStrings(query.Question), ",")) - // do we have a proxy script? - if p.Script == nil { - return query, nil - } + // do we have a proxy script? + if p.Script == nil { + return query, nil + } - // run the module OnRequest callback if defined - jsreq, jsres := p.Script.OnRequest(query, clientIP) - if jsreq != nil { - // the request has been changed by the script - req := jsreq.ToQuery() - p.logRequestAction(req, clientIP) - return req, nil - } else if jsres != nil { - // a fake response has been returned by the script - res := jsres.ToQuery() - p.logResponseAction(res, clientIP) - return query, res + // run the module OnRequest callback if defined + jsreq, jsres := p.Script.OnRequest(query, clientIP) + if jsreq != nil { + // the request has been changed by the script + req := jsreq.ToQuery() + p.logRequestAction(req, clientIP) + return req, nil + } else if jsres != nil { + // a fake response has been returned by the script + res := jsres.ToQuery() + p.logResponseAction(res, clientIP) + return query, res + } } return query, nil } func (p *DNSProxy) onResponseFilter(req, res *dns.Msg, clientIP string) *dns.Msg { - // sometimes it happens ¯\_(ツ)_/¯ - if res == nil { - return nil - } + if p.shouldProxy(clientIP) { + // sometimes it happens ¯\_(ツ)_/¯ + if res == nil { + return nil + } - p.Debug("> %s q[%s] a[%s] e[%s] n[%s]", - clientIP, - strings.Join(questionsToStrings(res.Question), ","), - strings.Join(recordsToStrings(res.Answer), ","), - strings.Join(recordsToStrings(res.Extra), ","), - strings.Join(recordsToStrings(res.Ns), ",")) + p.Debug("> %s q[%s] a[%s] e[%s] n[%s]", + clientIP, + strings.Join(questionsToStrings(res.Question), ","), + strings.Join(recordsToStrings(res.Answer), ","), + strings.Join(recordsToStrings(res.Extra), ","), + strings.Join(recordsToStrings(res.Ns), ",")) - // do we have a proxy script? - if p.Script != nil { - _, jsres := p.Script.OnResponse(req, res, clientIP) - if jsres != nil { - // the response has been changed by the script - res := jsres.ToQuery() - p.logResponseAction(res, clientIP) - return res + // do we have a proxy script? + if p.Script != nil { + _, jsres := p.Script.OnResponse(req, res, clientIP) + if jsres != nil { + // the response has been changed by the script + res := jsres.ToQuery() + p.logResponseAction(res, clientIP) + return res + } } }