@@ -15,7 +15,7 @@ import (
1515
1616var (
1717blocking = false
18- defaultDomains = []string {"github.com" , "api.github.com" , "*.actions.githubusercontent.com" , "results-receiver.actions.githubusercontent.com" , "*.blob.core.windows.net" }
18+ defaultDomains = []string {"github.com" , "api.github.com" , "*.actions.githubusercontent.com" , "results-receiver.actions.githubusercontent.com" , "*.blob.core.windows.net" , "*.githubapp.com" }
1919defaultIps = []string {"168.63.129.16" , "169.254.169.254" , "127.0.0.1" }
2020defaultDNSServers = []string {"127.0.0.53" }
2121)
@@ -27,6 +27,7 @@ const (
2727EGRESS_POLICY_AUDIT = "audit"
2828DNS_POLICY_ALLOWED_DOMAINS_ONLY = "allowed-domains-only"
2929DNS_POLICY_ANY = "any"
30+ DNS_PORT = layers .TCPPort (53 )
3031)
3132
3233type AgentConfig struct {
@@ -63,6 +64,7 @@ func NewAgent(config AgentConfig) *Agent{
6364netInfoProvider : config .NetInfoProvider ,
6465filesystem : config .FileSystem ,
6566 }
67+
6668agent .init (config )
6769return agent
6870}
@@ -153,7 +155,7 @@ func (a *Agent) loadAllowedIp(ips []string){
153155a .addIpToLogs ("allowed" , "unknown" , ip )
154156continue
155157 }
156- fmt .Printf ("Failed to parse IP : %s. Skipping .\n " , ip )
158+ fmt .Printf ("failed to parse ip : %s. skipping .\n " , ip )
157159 }
158160}
159161
@@ -164,13 +166,13 @@ func (a *Agent) addToFirewall(ips map[string]bool, cidr []*net.IPNet) error{
164166for ip := range ips {
165167err := a .firewall .AddIp (ip )
166168if err != nil {
167- return fmt .Errorf ("Error adding %s to firewall: %v\n " , ip , err )
169+ return fmt .Errorf ("error adding %s to firewall: %v" , ip , err )
168170 }
169171 }
170172for _ , c := range cidr {
171173err := a .firewall .AddIp (c .String ())
172174if err != nil {
173- return fmt .Errorf ("Error adding %s to firewall: %v\n " , c .String (), err )
175+ return fmt .Errorf ("error adding %s to firewall: %v" , c .String (), err )
174176 }
175177 }
176178return nil
@@ -228,22 +230,19 @@ func (a *Agent) loadAllowedDNSServers() error{
228230}
229231
230232func getDestinationIP (packet gopacket.Packet ) (string , error ){
231- ipLayer := packet .Layer (layers .LayerTypeIPv4 )
232- if ipLayer == nil {
233- ipLayer = packet .Layer (layers .LayerTypeIPv6 )
234- }
235- if ipLayer == nil {
236- return "" , fmt .Errorf ("Failed to get IP layer" )
233+ netLayer := packet .NetworkLayer ()
234+ if netLayer == nil {
235+ return "" , fmt .Errorf ("failed to get network layer" )
237236 }
238- ip , _ := ipLayer .(* layers.IPv4 )
239- if ip == nil {
240- ip6 , _ := ipLayer .(* layers.IPv6 )
241- if ip6 == nil {
242- return "" , fmt .Errorf ("Failed to get IP layer" )
243- }
244- return ip6 .DstIP .String (), nil
237+
238+ switch v := netLayer .(type ){
239+ case * layers.IPv4 :
240+ return v .DstIP .String (), nil
241+ case * layers.IPv6 :
242+ return v .DstIP .String (), nil
243+ default :
244+ return "" , fmt .Errorf ("unknown network layer type" )
245245 }
246- return ip .DstIP .String (), nil
247246}
248247
249248func extractDomainFromSRV (domain string ) string {
@@ -254,25 +253,18 @@ func extractDomainFromSRV(domain string) string{
254253return re .ReplaceAllString (domain , "" )
255254}
256255
257- func (a * Agent ) processDNSQuery (packet gopacket.Packet ) uint8 {
258- dnsLayer := packet .Layer (layers .LayerTypeDNS )
259- dns , _ := dnsLayer .(* layers.DNS )
256+ func (a * Agent ) processDNSLayer (dns * layers.DNS ) uint8 {
257+ if ! dns .QR {
258+ return a .processDNSQuery (dns )
259+ }
260+ return a .processDNSResponse (dns )
261+ }
262+
263+ func (a * Agent ) processDNSQuery (dns * layers.DNS ) uint8 {
260264for _ , q := range dns .Questions {
261265domain := string (q .Name )
262266fmt .Printf ("DNS Question: %s %s\n " , q .Name , q .Type )
263267
264- // making sure the DNS query is using a trusted DNS server
265- destinationIP , err := getDestinationIP (packet )
266- if err != nil {
267- fmt .Println ("Failed to get destination IP" )
268- a .addIpToLogs ("blocked" , domain , "unknown" )
269- return DROP_REQUEST
270- }
271- if ! a .allowedDNSServers [destinationIP ]{
272- fmt .Printf ("%s -> Blocked DNS Query. Untrusted DNS server %s\n " , domain , destinationIP )
273- a .addIpToLogs ("blocked" , domain , "unknown" )
274- return DROP_REQUEST
275- }
276268if q .Type == layers .DNSTypeSRV {
277269originalDomain := domain
278270domain = extractDomainFromSRV (domain )
@@ -345,11 +337,10 @@ func (a *Agent) processDNSTypeSRVResponse(domain string, answer *layers.DNSResou
345337 }
346338}
347339
348- func (a * Agent ) processDNSResponse (packet gopacket.Packet ) uint8 {
349- dnsLayer := packet .Layer (layers .LayerTypeDNS )
350- dns , _ := dnsLayer .(* layers.DNS )
340+ func (a * Agent ) processDNSResponse (dns * layers.DNS ) uint8 {
351341domain := string (dns .Questions [0 ].Name )
352342for _ , answer := range dns .Answers {
343+ fmt .Printf ("DNS Answer: %s %s %s\n " , answer .Name , answer .Type , answer .IP )
353344if answer .Type == layers .DNSTypeA {
354345a .processDNSTypeAResponse (domain , & answer )
355346 } else if answer .Type == layers .DNSTypeCNAME {
@@ -365,21 +356,108 @@ func (a *Agent) processDNSResponse(packet gopacket.Packet) uint8{
365356return ACCEPT_REQUEST
366357}
367358
368- func (a * Agent ) ProcessPacket (packet gopacket.Packet ) uint8 {
369- if dnsLayer := packet .Layer (layers .LayerTypeDNS ); dnsLayer != nil {
359+ func (a * Agent ) processDNSPacket (packet gopacket.Packet ) uint8 {
360+ dnsLayer := packet .Layer (layers .LayerTypeDNS )
361+ dns , _ := dnsLayer .(* layers.DNS )
362+ for _ , q := range dns .Questions {
363+ fmt .Printf ("DNS Question: %s %s\n " , q .Name , q .Type )
364+ }
370365
371- dns , _ := dnsLayer .(* layers.DNS )
372- for _ , q := range dns .Questions {
373- fmt .Printf ("DNS Question: %s %s\n " , q .Name , q .Type )
366+ domain := string (dns .Questions [0 ].Name )
367+ // if we are blocking DNS queries, intercept the DNS queries and decide whether to block or allow them
368+ if ! dns .QR {
369+ // making sure the DNS query is using a trusted DNS server
370+ destinationIP , err := getDestinationIP (packet )
371+ if err != nil {
372+ fmt .Printf ("Failed to get destination IP: %v\n " , err )
373+ a .addIpToLogs ("blocked" , domain , "unknown" )
374+ return DROP_REQUEST
374375 }
375- // if we are blocking DNS queries, intercept the DNS queries and decide whether to block or allow them
376- if a .blockDNS && ! dns .QR {
377- return a .processDNSQuery (packet )
378- } else if dns .QR {
379- return a .processDNSResponse (packet )
376+ if ! a .allowedDNSServers [destinationIP ]{
377+ fmt .Printf ("%s -> Blocked DNS Query. Untrusted DNS server %s\n " , domain , destinationIP )
378+ a .addIpToLogs ("blocked" , domain , destinationIP )
379+ return DROP_REQUEST
380380 }
381381 }
382- return ACCEPT_REQUEST
382+
383+ // if we are not blocking DNS queries, just accept the query request
384+ if ! a .blockDNS && ! dns .QR {
385+ return ACCEPT_REQUEST
386+ }
387+ return a .processDNSLayer (dns )
388+ }
389+
390+ func (a * Agent ) processDNSOverTCPPayload (payload []byte ) uint8 {
391+ // Extract message length from first 2 bytes
392+ // - First byte shifted left 8 bits + second byte
393+ // - Creates 16-bit length prefix
394+ messageLen := int (payload [0 ])<< 8 | int (payload [1 ])
395+ if messageLen == 0 || len (payload ) < messageLen + 2 {
396+ fmt .Println ("Invalid DNS over TCP payload" )
397+ return DROP_REQUEST
398+ }
399+
400+ // We attempt to decode the DNS over TCP payload
401+ // The only way we can accept the request is if the DNS query is contained within a single TCP packet payload
402+ dns := & layers.DNS {}
403+ err := dns .DecodeFromBytes (payload [2 :messageLen + 2 ], gopacket .NilDecodeFeedback )
404+ if err != nil {
405+ fmt .Println ("Failed to decode DNS over TCP payload" , err )
406+ return DROP_REQUEST
407+ }
408+ return a .processDNSLayer (dns )
409+ }
410+
411+ func (a * Agent ) processTCPPacket (packet gopacket.Packet ) uint8 {
412+ tcpLayer := packet .Layer (layers .LayerTypeTCP )
413+ tcp , _ := tcpLayer .(* layers.TCP )
414+ dstPort , srcPort , payload := tcp .DstPort , tcp .SrcPort , tcp .Payload
415+
416+ // Validate DNS server IP
417+ if dstPort == DNS_PORT {
418+ destinationIP , err := getDestinationIP (packet )
419+ if err != nil {
420+ fmt .Printf ("Failed to get destination IP: %v\n " , err )
421+ a .addIpToLogs ("blocked" , "unknown" , "unknown" )
422+ return DROP_REQUEST
423+ }
424+ if ! a .allowedDNSServers [destinationIP ]{
425+ fmt .Printf ("%s -> Blocked DNS Query. Untrusted DNS server %s\n " , "unknown" , destinationIP )
426+ a .addIpToLogs ("blocked" , "unknown" , destinationIP )
427+ return DROP_REQUEST
428+ }
429+ }
430+
431+ if dstPort != DNS_PORT && srcPort != DNS_PORT {
432+ fmt .Println ("Warning: Destination and source port are not DNS ports. Dropping request" )
433+ return DROP_REQUEST
434+ }
435+
436+ // if we are not blocking DNS queries, just accept the query request
437+ if ! a .blockDNS && dstPort == DNS_PORT {
438+ return ACCEPT_REQUEST
439+ }
440+
441+ if len (payload ) == 0 {
442+ // We only accept DNS over TCP packets with no payload since they are only used for initiating a connection
443+ return ACCEPT_REQUEST
444+ }
445+
446+ // Now we have a payload in the TCP packet, we need to make sure it is a valid DNS over TCP payload and the DNS query is for a known domain. We don't want to exfiltrate data over DNS over TCP
447+ return a .processDNSOverTCPPayload (payload )
448+
449+ }
450+
451+ func (a * Agent ) ProcessPacket (packet gopacket.Packet ) uint8 {
452+ if dnsLayer := packet .Layer (layers .LayerTypeDNS ); dnsLayer != nil {
453+ return a .processDNSPacket (packet )
454+ }
455+ // check dns over tcp
456+ if tcpLayer := packet .Layer (layers .LayerTypeTCP ); tcpLayer != nil {
457+ return a .processTCPPacket (packet )
458+ }
459+ fmt .Println ("Warning: Packet is not DNS or TCP. Dropping request, this shouldn't be happening." )
460+ return DROP_REQUEST
383461}
384462
385463func (a * Agent ) disableSudo () error {
0 commit comments