diff --git a/.github/workflows/gradle-tests.yml b/.github/workflows/gradle-tests.yml index 613e24605..bd90c398a 100644 --- a/.github/workflows/gradle-tests.yml +++ b/.github/workflows/gradle-tests.yml @@ -59,6 +59,7 @@ jobs: - name: Add local.aikido.io to /etc/hosts run: | echo "127.0.0.1 local.aikido.io" | sudo tee -a /etc/hosts + echo "127.0.0.1 app.local.aikido.io" | sudo tee -a /etc/hosts - name: Start databases working-directory: ./sample-apps/databases diff --git a/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/APIResponse.java b/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/APIResponse.java index 052e15892..07dc9f694 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/APIResponse.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/APIResponse.java @@ -1,6 +1,7 @@ package dev.aikido.agent_api.background.cloud.api; import dev.aikido.agent_api.background.Endpoint; +import dev.aikido.agent_api.storage.service_configuration.Domain; import java.util.List; @@ -11,6 +12,8 @@ public record APIResponse( List endpoints, List blockedUserIds, List allowedIPAddresses, + boolean blockNewOutgoingRequests, + List domains, boolean receivedAnyStats, boolean block ) { diff --git a/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/ReportingApiHTTP.java b/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/ReportingApiHTTP.java index 5f658fb1d..190da623d 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/ReportingApiHTTP.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/background/cloud/api/ReportingApiHTTP.java @@ -152,7 +152,7 @@ private static APIResponse getUnsuccessfulAPIResponse(String error) { return new APIResponse( false, // Success error, - 0, null, null, null, false, false // Unimportant values. + 0, null, null, null, false, null, false, false // Unimportant values. ); } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java index d92b8e924..d33c165c9 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java @@ -1,11 +1,14 @@ package dev.aikido.agent_api.collectors; import dev.aikido.agent_api.context.Context; -import dev.aikido.agent_api.storage.Hostnames; +import dev.aikido.agent_api.storage.HostnamesStore; +import dev.aikido.agent_api.storage.PendingHostnamesStore; +import dev.aikido.agent_api.storage.ServiceConfigStore; import dev.aikido.agent_api.storage.statistics.OperationKind; import dev.aikido.agent_api.storage.statistics.StatisticsStore; import dev.aikido.agent_api.vulnerabilities.Attack; import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFDetector; +import dev.aikido.agent_api.vulnerabilities.outbound_blocking.BlockedOutboundException; import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException; import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; @@ -15,6 +18,7 @@ import java.net.InetAddress; import java.util.ArrayList; import java.util.List; +import java.util.Set; import static dev.aikido.agent_api.helpers.ShouldBlockHelper.shouldBlock; import static dev.aikido.agent_api.storage.AttackQueue.attackDetected; @@ -30,38 +34,49 @@ public static void report(String hostname, InetAddress[] inetAddresses) { // store stats StatisticsStore.registerCall("java.net.InetAddress.getAllByName", OperationKind.OUTGOING_HTTP_OP); + // Consume pending ports recorded by URLCollector for this hostname. + // Removing them here ensures each (hostname, port) pair is counted exactly once. + Set ports = PendingHostnamesStore.getAndRemove(hostname); + if (!ports.isEmpty()) { + for (int port : ports) { + HostnamesStore.incrementHits(hostname, port); + } + } else { + // We still need to report a hit to the hostname for outbound domain blocking + HostnamesStore.incrementHits(hostname, 0); + } + + // Block if the hostname is in the blocked domains list + if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname)) { + logger.debug("Blocking DNS lookup for domain: %s", hostname); + throw BlockedOutboundException.get(); + } + // Convert inetAddresses array to a List of IP strings : List ipAddresses = new ArrayList<>(); for (InetAddress inetAddress : inetAddresses) { ipAddresses.add(inetAddress.getHostAddress()); } - // Fetch hostnames from Context (this is to get port number e.g.) - if (Context.get() != null && Context.get().getHostnames() != null) { - for (Hostnames.HostnameEntry hostnameEntry : Context.get().getHostnames().asArray()) { - if (!hostnameEntry.getHostname().equals(hostname)) { - continue; - } - logger.debug("Hostname: %s, Port: %s, IPs: %s", hostnameEntry.getHostname(), hostnameEntry.getPort(), ipAddresses); - - Attack attack = SSRFDetector.run( - hostname, hostnameEntry.getPort(), ipAddresses, OPERATION_NAME - ); - if (attack == null) { - continue; - } - - logger.debug("SSRF Attack detected due to: %s:%s", hostname, hostnameEntry.getPort()); - attackDetected(attack, Context.get()); - - if (shouldBlock()) { - logger.debug("Blocking SSRF attack..."); - throw SSRFException.get(); - } - - // We don't want to test for a stored SSRF attack. - return; + // Run SSRF check for all ports found in the pending store (empty = no SSRF check) + for (int port : ports) { + logger.debug("Hostname: %s, Port: %s, IPs: %s", hostname, port, ipAddresses); + + Attack attack = SSRFDetector.run(hostname, port, ipAddresses, OPERATION_NAME); + if (attack == null) { + continue; } + + logger.debug("SSRF Attack detected due to: %s:%s", hostname, port); + attackDetected(attack, Context.get()); + + if (shouldBlock()) { + logger.debug("Blocking SSRF attack..."); + throw SSRFException.get(); + } + + // We don't want to test for a stored SSRF attack. + return; } // We don't need the context object to check for stored ssrf, but we do want to run this after our other @@ -76,7 +91,7 @@ public static void report(String hostname, InetAddress[] inetAddresses) { } } - } catch (SSRFException | StoredSSRFException e) { + } catch (BlockedOutboundException | SSRFException | StoredSSRFException e) { throw e; } catch (Throwable e) { logger.trace(e); diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java index d612d92d5..e1c1244b1 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java @@ -1,10 +1,8 @@ package dev.aikido.agent_api.collectors; -import dev.aikido.agent_api.context.Context; -import dev.aikido.agent_api.context.ContextObject; -import dev.aikido.agent_api.storage.HostnamesStore; import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; +import dev.aikido.agent_api.storage.PendingHostnamesStore; import java.net.URL; @@ -15,25 +13,14 @@ public final class URLCollector { private URLCollector() {} public static void report(URL url) { - if(url != null) { + if (url != null) { if (!url.getProtocol().startsWith("http")) { - return; // Non-HTTP(S) URL + return; // Non-HTTP(S) URL } logger.trace("Adding a new URL to the cache: %s", url); - int port = getPortFromURL(url); - - // We store hostname and port in two places, HostnamesStore and Context. HostnamesStore is for reporting - // outbound domains. Context is to have a map of hostnames with used port numbers to detect SSRF attacks. - - // Store (new) hostname hits - HostnamesStore.incrementHits(url.getHost(), port); - - // Add to context : - ContextObject context = Context.get(); - if (context != null) { - context.getHostnames().add(url.getHost(), port); - Context.set(context); - } + // Store hostname+port in the pending store so DNSRecordCollector can pick it + // up during the DNS lookup that follows, for SSRF detection and outbound hostnames + PendingHostnamesStore.add(url.getHost(), getPortFromURL(url)); } } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java index d331965f2..1d66a2126 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java @@ -8,6 +8,7 @@ import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; import dev.aikido.agent_api.storage.AttackQueue; +import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; import dev.aikido.agent_api.storage.ServiceConfiguration; import dev.aikido.agent_api.storage.attack_wave_detector.AttackWaveDetectorStore; @@ -36,7 +37,14 @@ private WebRequestCollector() { */ public static Res report(ContextObject newContext) { ServiceConfiguration config = getConfig(); - Context.reset(); // clear context + + // clear context + Context.reset(); + + // Flush pending hostnames on every context change to prevent the store from + // growing unboundedly when a thread is reused across multiple requests. + PendingHostnamesStore.clear(); + if (config.isIpBypassed(newContext.getRemoteAddress())) { return null; // do not set context when the IP address is bypassed (zen = off) } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/context/Context.java b/agent_api/src/main/java/dev/aikido/agent_api/context/Context.java index 7f5dbc7ce..0389cc903 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/context/Context.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/context/Context.java @@ -1,5 +1,7 @@ package dev.aikido.agent_api.context; +import dev.aikido.agent_api.storage.PendingHostnamesStore; + public final class Context { private Context() {} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/context/ContextObject.java b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextObject.java index 7f84d5959..339358384 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/context/ContextObject.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextObject.java @@ -2,7 +2,6 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; -import dev.aikido.agent_api.storage.Hostnames; import dev.aikido.agent_api.storage.RedirectNode; import java.util.*; @@ -26,10 +25,6 @@ public class ContextObject { protected transient Map> cache = new HashMap<>(); protected transient Optional forcedProtectionOff = Optional.empty(); - // We store hostnames in the context object so we can match a given hostname (by DNS request) - // with its port number (which we know by instrumenting the URLs that get requested). - protected transient Hostnames hostnames = new Hostnames(1000); // max 1000 entries - public boolean middlewareExecuted() {return executedMiddleware; } public void setExecutedMiddleware(boolean value) { executedMiddleware = value; } @@ -97,7 +92,6 @@ public HashMap> getCookies() { return cookies; } public Map> getCache() { return cache; } - public Hostnames getHostnames() { return hostnames; } public void setForcedProtectionOff(boolean forcedProtectionOff) { this.forcedProtectionOff = Optional.of(forcedProtectionOff); diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java new file mode 100644 index 000000000..2efd5ecf1 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java @@ -0,0 +1,43 @@ +package dev.aikido.agent_api.storage; + +import java.util.*; + +/** + * Thread-local bridge between URLCollector and DNSRecordCollector. + * URLCollector records hostname+port here; DNSRecordCollector reads and removes the entry + * so each (hostname, port) pair is processed exactly once per DNS lookup. + */ +public final class PendingHostnamesStore { + private PendingHostnamesStore() {} + + private static final ThreadLocal>> store = + ThreadLocal.withInitial(LinkedHashMap::new); + + public static void add(String hostname, int port) { + Map> map = store.get(); + if (!map.containsKey(hostname)) { + map.put(hostname, new LinkedHashSet<>()); + } + map.get(hostname).add(port); + } + + public static Set getAndRemove(String hostname) { + Set ports = store.get().remove(hostname); + if (ports == null) { + return Collections.emptySet(); + } + return ports; + } + + public static Set getPorts(String hostname) { + Set ports = store.get().get(hostname); + if (ports == null) { + return Collections.emptySet(); + } + return Collections.unmodifiableSet(ports); + } + + public static void clear() { + store.get().clear(); + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfigStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfigStore.java index 41986919e..472a4f722 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfigStore.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfigStore.java @@ -90,4 +90,13 @@ public static void setMiddlewareInstalled(boolean middlewareInstalled) { mutex.writeLock().unlock(); } } + + public static boolean shouldBlockOutgoingRequest(String hostname) { + mutex.readLock().lock(); + try { + return config.shouldBlockOutgoingRequest(hostname); + } finally { + mutex.readLock().unlock(); + } + } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java index 3a7a7c3b3..d6f4597d8 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java @@ -5,11 +5,10 @@ import dev.aikido.agent_api.background.cloud.api.ReportingApi; import dev.aikido.agent_api.helpers.net.IPList; import dev.aikido.agent_api.storage.service_configuration.ParsedFirewallLists; +import dev.aikido.agent_api.vulnerabilities.outbound_blocking.OutboundDomains; import dev.aikido.agent_api.storage.statistics.StatisticsStore; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; +import java.util.*; import static dev.aikido.agent_api.helpers.IPListBuilder.createIPList; import static dev.aikido.agent_api.vulnerabilities.ssrf.IsPrivateIP.isPrivateIp; @@ -26,6 +25,7 @@ public class ServiceConfiguration { private IPList bypassedIPs = new IPList(); private HashSet blockedUserIDs = new HashSet<>(); private List endpoints = new ArrayList<>(); + private OutboundDomains outboundDomains = new OutboundDomains(); public ServiceConfiguration() { this.receivedAnyStats = true; // true by default, waiting for the startup event @@ -46,6 +46,7 @@ public void updateConfig(APIResponse apiResponse) { if (apiResponse.endpoints() != null) { this.endpoints = apiResponse.endpoints(); } + this.outboundDomains.update(apiResponse.domains(), apiResponse.blockNewOutgoingRequests()); this.receivedAnyStats = apiResponse.receivedAnyStats(); } @@ -127,4 +128,8 @@ public boolean isBlockedUserAgent(String userAgent) { public record BlockedResult(boolean blocked, String description) { } + + public boolean shouldBlockOutgoingRequest(String hostname) { + return this.outboundDomains.shouldBlockOutgoingRequest(hostname); + } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/service_configuration/Domain.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/service_configuration/Domain.java new file mode 100644 index 000000000..bbc8c4ad8 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/service_configuration/Domain.java @@ -0,0 +1,4 @@ +package dev.aikido.agent_api.storage.service_configuration; + +public record Domain(String hostname, String mode) { +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java new file mode 100644 index 000000000..da55cb32f --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java @@ -0,0 +1,14 @@ +package dev.aikido.agent_api.vulnerabilities.outbound_blocking; + +import dev.aikido.agent_api.vulnerabilities.AikidoException; + +public class BlockedOutboundException extends AikidoException { + public BlockedOutboundException(String msg) { + super(msg); + } + + public static BlockedOutboundException get() { + String defaultMsg = generateDefaultMessage("an outbound request"); + return new BlockedOutboundException(defaultMsg); + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java new file mode 100644 index 000000000..84e8bc281 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java @@ -0,0 +1,35 @@ +package dev.aikido.agent_api.vulnerabilities.outbound_blocking; + +import dev.aikido.agent_api.storage.service_configuration.Domain; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class OutboundDomains { + private Map domains = new HashMap<>(); + private boolean blockNewOutgoingRequests = false; + + public void update(List newDomains, boolean blockNewOutgoingRequests) { + if (newDomains != null) { + this.domains = new HashMap<>(); + for (Domain domain : newDomains) { + this.domains.put(domain.hostname(), domain.mode()); + } + } + this.blockNewOutgoingRequests = blockNewOutgoingRequests; + } + + public boolean shouldBlockOutgoingRequest(String hostname) { + String mode = this.domains.get(hostname); + + if (this.blockNewOutgoingRequests) { + // Only allow outgoing requests if the mode is "allow" + // null means unknown hostname, so they get blocked + return !"allow".equals(mode); + } + + // Only block outgoing requests if the mode is "block" + return "block".equals(mode); + } +} diff --git a/agent_api/src/test/java/ShouldBlockRequestTest.java b/agent_api/src/test/java/ShouldBlockRequestTest.java index a03f6294e..92b3cb89b 100644 --- a/agent_api/src/test/java/ShouldBlockRequestTest.java +++ b/agent_api/src/test/java/ShouldBlockRequestTest.java @@ -87,7 +87,7 @@ public void testUserSet() throws SQLException { ServiceConfigStore.updateFromAPIResponse(new APIResponse( true, "", getUnixTimeMS(), List.of(), /* blockedUserIds */ List.of("ID1", "ID2", "ID3"), List.of(), - false, true + false, null, false, true )); var res2 = ShouldBlockRequest.shouldBlockRequest(); assertTrue(res2.block()); @@ -227,7 +227,7 @@ public void testBlockedUserWithMultipleEndpoints() throws SQLException { ); List blockedUserIds = List.of("ID1"); ServiceConfigStore.updateFromAPIResponse(new APIResponse( - true, "", getUnixTimeMS(), endpoints, blockedUserIds, List.of(), true, false + true, "", getUnixTimeMS(), endpoints, blockedUserIds, List.of(), false, null,true, false )); // Call the method diff --git a/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java b/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java index 95814a3df..e33676dd3 100644 --- a/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java +++ b/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java @@ -1,5 +1,6 @@ package collectors; +import dev.aikido.agent_api.background.cloud.api.APIResponse; import dev.aikido.agent_api.background.cloud.api.events.DetectedAttack; import dev.aikido.agent_api.collectors.DNSRecordCollector; import dev.aikido.agent_api.context.Context; @@ -7,8 +8,10 @@ import dev.aikido.agent_api.storage.AttackQueue; import dev.aikido.agent_api.storage.Hostnames; import dev.aikido.agent_api.storage.HostnamesStore; +import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; -import dev.aikido.agent_api.vulnerabilities.Attack; +import dev.aikido.agent_api.storage.service_configuration.Domain; +import dev.aikido.agent_api.vulnerabilities.outbound_blocking.BlockedOutboundException; import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException; import dev.aikido.agent_api.vulnerabilities.ssrf.StoredSSRFException; import org.junit.jupiter.api.*; @@ -28,62 +31,24 @@ public class DNSRecordCollectorTest { @BeforeEach void setup() throws UnknownHostException { - // We want to define InetAddresses here so it does not interfere with counts of getHostname() inetAddress1 = InetAddress.getByName("1.1.1.1"); inetAddress2 = InetAddress.getByName("127.0.0.1"); imdsAddress1 = InetAddress.getByName("169.254.169.254"); AttackQueue.clear(); + HostnamesStore.clear(); + PendingHostnamesStore.clear(); } @AfterEach public void cleanup() { HostnamesStore.clear(); + PendingHostnamesStore.clear(); Context.set(null); AttackQueue.clear(); - } - - @Test - public void testContextNull() { - // Early return because of Context being null : - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - inetAddress1, inetAddress2 - }); - } - - @Test - public void testThreadCacheHostnames() { - ContextObject myContextObject = mock(ContextObject.class); - Context.set(myContextObject); - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - inetAddress1, inetAddress2 - }); - verify(myContextObject).getHostnames(); - - myContextObject = mock(ContextObject.class); - Hostnames hostnames = new Hostnames(20); - when(myContextObject.getHostnames()).thenReturn(hostnames); - - Context.set(myContextObject); - - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - inetAddress1, inetAddress2 - }); - verify(myContextObject, times(2)).getHostnames(); - } - - @Test - public void testHostnameSame() { - ContextObject myContextObject = mock(ContextObject.class); - Hostnames hostnames = new Hostnames(20); - hostnames.add("dev.aikido.not", 80); - hostnames.add("dev.aikido", 80); - when(myContextObject.getHostnames()).thenReturn(hostnames); - - Context.set(myContextObject); - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - inetAddress1, inetAddress2 - }); - verify(myContextObject, times(2)).getHostnames(); + // Reset domain config + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, false, List.of(), true, false + )); } public static class SampleContextObject extends EmptySampleContextObject { @@ -95,18 +60,37 @@ public SampleContextObject() { } @Test - public void testHostnameSameWithContextAsAttack() { + public void testNoPendingHostnames() { + // No pending hostnames → port 0 recorded, no SSRF check + Context.set(new EmptySampleContextObject()); + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1, inetAddress2}); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("dev.aikido", entries[0].getHostname()); + assertEquals(0, entries[0].getPort()); + } + + @Test + public void testPendingHostnameOtherThanLookedUp() { + // A pending entry for a different hostname should not affect the looked-up hostname + PendingHostnamesStore.add("dev.aikido.not", 80); + Context.set(new EmptySampleContextObject()); + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1, inetAddress2}); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("dev.aikido", entries[0].getHostname()); + assertEquals(0, entries[0].getPort()); + } + + @Test + public void testSSRFWithPendingHostname() { ServiceConfigStore.updateBlocking(true); - ContextObject myContextObject = new SampleContextObject(); - myContextObject.getHostnames().add("dev.aikido.not", 80); - myContextObject.getHostnames().add("dev.aikido", 80); - Context.set(myContextObject); + PendingHostnamesStore.add("dev.aikido", 80); + Context.set(new SampleContextObject()); Exception exception = assertThrows(SSRFException.class, () -> { - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - inetAddress1, inetAddress2 - }); + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1, inetAddress2}); }); assertEquals("Aikido Zen has blocked a server-side request forgery", exception.getMessage()); } @@ -115,26 +99,110 @@ public void testHostnameSameWithContextAsAttack() { public void testHostnameSameWithContextAsAStoredSSRFAttack() { ServiceConfigStore.updateBlocking(true); - ContextObject myContextObject = new SampleContextObject(); - Context.set(myContextObject); + Context.set(new SampleContextObject()); Exception exception = assertThrows(StoredSSRFException.class, () -> { - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - imdsAddress1, inetAddress2 - }); + DNSRecordCollector.report("dev.aikido", new InetAddress[]{imdsAddress1, inetAddress2}); }); assertEquals("Aikido Zen has blocked a stored server-side request forgery", exception.getMessage()); assertDoesNotThrow(() -> { - DNSRecordCollector.report("metadata.goog", new InetAddress[]{ - imdsAddress1, inetAddress2 - }); - DNSRecordCollector.report("metadata.google.internal", new InetAddress[]{ - imdsAddress1, inetAddress2 - }); + DNSRecordCollector.report("metadata.goog", new InetAddress[]{imdsAddress1, inetAddress2}); + DNSRecordCollector.report("metadata.google.internal", new InetAddress[]{imdsAddress1, inetAddress2}); }); } + @Test + public void testBlockedDomain() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, List.of(new Domain("blocked.example.com", "block")), true, true + )); + assertThrows(BlockedOutboundException.class, () -> + DNSRecordCollector.report("blocked.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testAllowedDomainNotBlocked() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, List.of(new Domain("allowed.example.com", "allow")), true, true + )); + assertDoesNotThrow(() -> + DNSRecordCollector.report("allowed.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testUnknownDomainBlockedWhenBlockNewOutgoingRequests() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + true, List.of(), true, true + )); + assertThrows(BlockedOutboundException.class, () -> + DNSRecordCollector.report("unknown.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testHostnamesStorePort0WhenNoPendingEntry() { + Context.set(null); + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1}); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("dev.aikido", entries[0].getHostname()); + assertEquals(0, entries[0].getPort()); + } + + @Test + public void testHostnamesStoreUsesPortFromPendingStore() { + PendingHostnamesStore.add("dev.aikido", 8080); + Context.set(mock(ContextObject.class)); + + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1}); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("dev.aikido", entries[0].getHostname()); + assertEquals(8080, entries[0].getPort()); + } + + @Test + public void testHostnamesStoreIncrementedForAllPendingPorts() { + PendingHostnamesStore.add("dev.aikido", 80); + PendingHostnamesStore.add("dev.aikido", 443); + Context.set(mock(ContextObject.class)); + + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1}); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(2, entries.length); + assertEquals(80, entries[0].getPort()); + assertEquals(443, entries[1].getPort()); + } + + @Test + public void testPendingEntryRemovedAfterDNSLookup() { + PendingHostnamesStore.add("dev.aikido", 8080); + Context.set(mock(ContextObject.class)); + + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1}); + // Entry should have been consumed + assertTrue(PendingHostnamesStore.getPorts("dev.aikido").isEmpty()); + } + + @Test + public void testSSRFStillRunsWhenPendingPortIsZero() { + ServiceConfigStore.updateBlocking(true); + + PendingHostnamesStore.add("dev.aikido", 0); + Context.set(new SampleContextObject()); + + Exception exception = assertThrows(SSRFException.class, () -> { + DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1, inetAddress2}); + }); + assertEquals("Aikido Zen has blocked a server-side request forgery", exception.getMessage()); + } + @Test public void testStoredSSRFWithNoContext() throws InterruptedException { ServiceConfigStore.updateBlocking(true); @@ -142,9 +210,7 @@ public void testStoredSSRFWithNoContext() throws InterruptedException { Context.set(null); Exception exception = assertThrows(StoredSSRFException.class, () -> { - DNSRecordCollector.report("dev.aikido", new InetAddress[]{ - imdsAddress1, inetAddress2 - }); + DNSRecordCollector.report("dev.aikido", new InetAddress[]{imdsAddress1, inetAddress2}); }); DetectedAttack.DetectedAttackEvent event = (DetectedAttack.DetectedAttackEvent) AttackQueue.get(); assertEquals("stored_ssrf", event.attack().kind()); @@ -153,12 +219,8 @@ public void testStoredSSRFWithNoContext() throws InterruptedException { assertEquals("Aikido Zen has blocked a stored server-side request forgery", exception.getMessage()); assertDoesNotThrow(() -> { - DNSRecordCollector.report("metadata.goog", new InetAddress[]{ - imdsAddress1, inetAddress2 - }); - DNSRecordCollector.report("metadata.google.internal", new InetAddress[]{ - imdsAddress1, inetAddress2 - }); + DNSRecordCollector.report("metadata.goog", new InetAddress[]{imdsAddress1, inetAddress2}); + DNSRecordCollector.report("metadata.google.internal", new InetAddress[]{imdsAddress1, inetAddress2}); }); } } diff --git a/agent_api/src/test/java/collectors/URLCollectorTest.java b/agent_api/src/test/java/collectors/URLCollectorTest.java index cc6551c44..140065a3e 100644 --- a/agent_api/src/test/java/collectors/URLCollectorTest.java +++ b/agent_api/src/test/java/collectors/URLCollectorTest.java @@ -2,8 +2,8 @@ import dev.aikido.agent_api.collectors.URLCollector; import dev.aikido.agent_api.context.Context; -import dev.aikido.agent_api.storage.Hostnames; import dev.aikido.agent_api.storage.HostnamesStore; +import dev.aikido.agent_api.storage.PendingHostnamesStore; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -12,9 +12,9 @@ import java.io.IOException; import java.net.URL; +import java.util.Set; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.*; public class URLCollectorTest { @BeforeAll @@ -29,6 +29,7 @@ static void afterAll() { @BeforeEach void beforeEach() { cleanup(); + PendingHostnamesStore.clear(); } private void setContextAndLifecycle(String url) { @@ -38,62 +39,45 @@ private void setContextAndLifecycle(String url) { @Test public void testNewUrlConnectionWithPort() throws IOException { setContextAndLifecycle(""); - + URLCollector.report(new URL("http://localhost:8080")); - Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList(); - assertEquals(1, hostnameArray.length); - assertEquals(8080, hostnameArray[0].getPort()); - assertEquals("localhost", hostnameArray[0].getHostname()); + Set ports = PendingHostnamesStore.getPorts("localhost"); + assertEquals(1, ports.size()); + assertTrue(ports.contains(8080)); } @Test public void testNewUrlConnectionWithHttp() throws IOException { setContextAndLifecycle(""); URLCollector.report(new URL("http://app.local.aikido.io")); - Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList(); - assertEquals(1, hostnameArray.length); - assertEquals(80, hostnameArray[0].getPort()); - assertEquals("app.local.aikido.io", hostnameArray[0].getHostname()); - - Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray(); - assertEquals(1, hostnameArray2.length); - assertEquals(80, hostnameArray2[0].getPort()); - assertEquals("app.local.aikido.io", hostnameArray2[0].getHostname()); + Set ports = PendingHostnamesStore.getPorts("app.local.aikido.io"); + assertEquals(1, ports.size()); + assertTrue(ports.contains(80)); } @Test public void testNewUrlConnectionHttps() throws IOException { setContextAndLifecycle(""); URLCollector.report(new URL("https://aikido.dev")); - Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList(); - assertEquals(1, hostnameArray.length); - assertEquals(443, hostnameArray[0].getPort()); - assertEquals("aikido.dev", hostnameArray[0].getHostname()); - - Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray(); - assertEquals(1, hostnameArray2.length); - assertEquals(443, hostnameArray2[0].getPort()); - assertEquals("aikido.dev", hostnameArray2[0].getHostname()); + Set ports = PendingHostnamesStore.getPorts("aikido.dev"); + assertEquals(1, ports.size()); + assertTrue(ports.contains(443)); } @Test public void testNewUrlConnectionFaultyProtocol() throws IOException { setContextAndLifecycle(""); URLCollector.report(new URL("ftp://localhost:8080")); - Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList(); - assertEquals(0, hostnameArray.length); - Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray(); - assertEquals(0, hostnameArray2.length); + assertEquals(0, HostnamesStore.getHostnamesAsList().length); + assertTrue(PendingHostnamesStore.getPorts("localhost").isEmpty()); } @Test public void testWithNullURL() throws IOException { setContextAndLifecycle(""); URLCollector.report(null); - Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList(); - assertEquals(0, hostnameArray.length); - Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray(); - assertEquals(0, hostnameArray2.length); + assertEquals(0, HostnamesStore.getHostnamesAsList().length); + assertTrue(PendingHostnamesStore.getPorts("localhost").isEmpty()); } @Test @@ -101,21 +85,21 @@ public void testWithNullContext() throws IOException { setContextAndLifecycle(""); Context.reset(); URLCollector.report(new URL("https://aikido.dev")); - Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList(); - assertEquals(1, hostnameArray.length); - assertEquals(443, hostnameArray[0].getPort()); - assertEquals("aikido.dev", hostnameArray[0].getHostname()); + // URLCollector writes to PendingHostnamesStore regardless of context state + Set ports = PendingHostnamesStore.getPorts("aikido.dev"); + assertEquals(1, ports.size()); + assertTrue(ports.contains(443)); assertNull(Context.get()); } @Test - public void testOnlyContext() throws IOException { + public void testOnlyPendingStore() throws IOException { setContextAndLifecycle(""); - HostnamesStore.clear(); URLCollector.report(new URL("https://aikido.dev")); - Hostnames.HostnameEntry[] hostnameArray = Context.get().getHostnames().asArray(); - assertEquals(1, hostnameArray.length); - assertEquals(443, hostnameArray[0].getPort()); - assertEquals("aikido.dev", hostnameArray[0].getHostname()); + // HostnamesStore is only written by DNSRecordCollector, not URLCollector + assertEquals(0, HostnamesStore.getHostnamesAsList().length); + Set ports = PendingHostnamesStore.getPorts("aikido.dev"); + assertEquals(1, ports.size()); + assertTrue(ports.contains(443)); } -} \ No newline at end of file +} diff --git a/agent_api/src/test/java/collectors/WebRequestCollectorTest.java b/agent_api/src/test/java/collectors/WebRequestCollectorTest.java index 75999d50a..b7d04e9f2 100644 --- a/agent_api/src/test/java/collectors/WebRequestCollectorTest.java +++ b/agent_api/src/test/java/collectors/WebRequestCollectorTest.java @@ -212,7 +212,7 @@ void testReport_userAgentBlocked_Ip_Bypassed() { List bypassedIps = List.of("192.168.1.1"); ServiceConfigStore.updateFromAPIResponse(new APIResponse( - true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, true, false + true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false )); @@ -231,7 +231,7 @@ void testReport_ipBlockedUsingLists_Ip_Bypassed() { List bypassedIps = List.of("192.168.1.1"); ServiceConfigStore.updateFromAPIResponse(new APIResponse( - true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, true, false + true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false )); WebRequestCollector.Res response = WebRequestCollector.report(contextObject); @@ -251,7 +251,7 @@ void testReport_ipNotAllowedUsingLists_Ip_Bypassed() { List bypassedIps = List.of("192.168.1.1"); ServiceConfigStore.updateFromAPIResponse(new APIResponse( - true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, true, false + true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false )); diff --git a/agent_api/src/test/java/storage/ServiceConfigStoreTest.java b/agent_api/src/test/java/storage/ServiceConfigStoreTest.java new file mode 100644 index 000000000..acb11df9c --- /dev/null +++ b/agent_api/src/test/java/storage/ServiceConfigStoreTest.java @@ -0,0 +1,61 @@ +package storage; + +import dev.aikido.agent_api.background.cloud.api.APIResponse; +import dev.aikido.agent_api.storage.ServiceConfigStore; +import dev.aikido.agent_api.storage.service_configuration.Domain; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class ServiceConfigStoreTest { + + @BeforeEach + public void setUp() { + // Reset to a known state: no domains, blockNewOutgoingRequests=false + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, null, true, false + )); + } + + @Test + public void testShouldBlockOutgoingRequestNotBlockedByDefault() { + assertFalse(ServiceConfigStore.shouldBlockOutgoingRequest("example.com")); + } + + @Test + public void testShouldBlockOutgoingRequestBlockedDomain() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, + List.of(new Domain("blocked.com", "block")), + true, false + )); + assertTrue(ServiceConfigStore.shouldBlockOutgoingRequest("blocked.com")); + } + + @Test + public void testShouldBlockOutgoingRequestAllowedDomain() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + true, + List.of(new Domain("allowed.com", "allow")), + true, false + )); + assertFalse(ServiceConfigStore.shouldBlockOutgoingRequest("allowed.com")); + } + + @Test + public void testShouldBlockOutgoingRequestUnknownWhenBlockNewEnabled() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + true, + List.of(new Domain("allowed.com", "allow")), + true, false + )); + assertTrue(ServiceConfigStore.shouldBlockOutgoingRequest("unknown.com")); + } +} diff --git a/agent_api/src/test/java/storage/ServiceConfigurationTest.java b/agent_api/src/test/java/storage/ServiceConfigurationTest.java index c1703c309..3926fc577 100644 --- a/agent_api/src/test/java/storage/ServiceConfigurationTest.java +++ b/agent_api/src/test/java/storage/ServiceConfigurationTest.java @@ -8,6 +8,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import dev.aikido.agent_api.storage.service_configuration.Domain; + import java.util.Collections; import java.util.List; @@ -33,6 +35,8 @@ public void testUpdateConfig() { List.of(mock(Endpoint.class)), List.of("user1", "user2"), List.of("192.168.1.1"), + false, + null, true, true ); @@ -62,6 +66,8 @@ public void testUpdateConfigWithUnsuccessfulResponse() { null, null, false, + null, + false, false ); @@ -301,6 +307,8 @@ public void testReceivedAnyStats() { null, null, false, + null, + false, true ); @@ -318,6 +326,8 @@ public void testIsIpBypassedWithEmptyBypassedList() { null, null, Collections.emptyList(), + false, + null, true, true )); @@ -334,6 +344,8 @@ public void testIsIpBypassedWithMultipleBypassedEntries() { null, null, List.of("192.168.1.1", "192.168.1.2"), + false, + null, true, true ); @@ -354,6 +366,8 @@ public void testIsUserBlockedWithEmptyBlockedUserList() { null, Collections.emptyList(), null, + false, + null, true, true )); @@ -370,6 +384,8 @@ public void testIsUserBlockedWithMultipleBlockedUsers() { null, List.of("user1", "user2"), null, + false, + null, true, true ); @@ -389,6 +405,8 @@ public void testGetEndpoints() { List.of(mock(Endpoint.class)), null, null, + false, + null, true, true ); @@ -407,6 +425,8 @@ public void testGetEndpointsWithEmptyList() { Collections.emptyList(), null, null, + false, + null, true, true ); @@ -547,6 +567,39 @@ public void testIsIpBlockedWithOnlyBlockedIPs() { assertFalse(resultNotBlocked.blocked()); } + @Test + public void testShouldBlockOutgoingRequest() { + APIResponse apiResponse = new APIResponse( + true, null, 0L, null, null, null, + true, + List.of(new Domain("example.com", "block"), new Domain("allowed.com", "allow")), + true, true + ); + serviceConfiguration.updateConfig(apiResponse); + + // blockNewOutgoingRequests=true: unknown hostname gets blocked + assertTrue(serviceConfiguration.shouldBlockOutgoingRequest("unknown.com")); + // blockNewOutgoingRequests=true: "allow" mode is not blocked + assertFalse(serviceConfiguration.shouldBlockOutgoingRequest("allowed.com")); + // blockNewOutgoingRequests=true: "block" mode is blocked + assertTrue(serviceConfiguration.shouldBlockOutgoingRequest("example.com")); + + APIResponse apiResponse2 = new APIResponse( + true, null, 0L, null, null, null, + false, + List.of(new Domain("example.com", "block"), new Domain("allowed.com", "allow")), + true, true + ); + serviceConfiguration.updateConfig(apiResponse2); + + // blockNewOutgoingRequests=false: unknown hostname is not blocked + assertFalse(serviceConfiguration.shouldBlockOutgoingRequest("unknown.com")); + // blockNewOutgoingRequests=false: "allow" mode is not blocked + assertFalse(serviceConfiguration.shouldBlockOutgoingRequest("allowed.com")); + // blockNewOutgoingRequests=false: "block" mode is blocked + assertTrue(serviceConfiguration.shouldBlockOutgoingRequest("example.com")); + } + @Test public void testIsIpBlockedWithAllowedIPsAndBlockedIPs() { ReportingApi.ListsResponseEntry allowedEntry1 = new ReportingApi.ListsResponseEntry("key", "source", "allowed", List.of("10.0.0.1")); diff --git a/agent_api/src/test/java/utils/EmptyAPIResponses.java b/agent_api/src/test/java/utils/EmptyAPIResponses.java index 091631a81..ea98520c2 100644 --- a/agent_api/src/test/java/utils/EmptyAPIResponses.java +++ b/agent_api/src/test/java/utils/EmptyAPIResponses.java @@ -12,14 +12,14 @@ public class EmptyAPIResponses { public final static APIResponse emptyAPIResponse = new APIResponse( - true, "", UnixTimeMS.getUnixTimeMS(), List.of(), List.of(), List.of(), true, false + true, "", UnixTimeMS.getUnixTimeMS(), List.of(), List.of(), List.of(), false, null,true, false ); public final static ReportingApi.APIListsResponse emptyAPIListsResponse = new ReportingApi.APIListsResponse( List.of(), List.of(), List.of(), null, null, List.of() ); public static void setEmptyConfigWithEndpointList(List endpoints) { ServiceConfigStore.updateFromAPIResponse(new APIResponse( - true, "", getUnixTimeMS(), endpoints, List.of(), List.of(), true, false + true, "", getUnixTimeMS(), endpoints, List.of(), List.of(), false, null, true, false )); } } diff --git a/agent_api/src/test/java/vulnerabilities/outbound_blocking/OutboundDomainsTest.java b/agent_api/src/test/java/vulnerabilities/outbound_blocking/OutboundDomainsTest.java new file mode 100644 index 000000000..45590dbd3 --- /dev/null +++ b/agent_api/src/test/java/vulnerabilities/outbound_blocking/OutboundDomainsTest.java @@ -0,0 +1,109 @@ +package vulnerabilities.outbound_blocking; + +import dev.aikido.agent_api.storage.service_configuration.Domain; +import dev.aikido.agent_api.vulnerabilities.outbound_blocking.OutboundDomains; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class OutboundDomainsTest { + + private OutboundDomains outboundDomains; + + @BeforeEach + public void setUp() { + outboundDomains = new OutboundDomains(); + } + + // --- shouldBlockOutgoingRequest with blockNewOutgoingRequests=false (default) --- + + @Test + public void testDefaultDoesNotBlockUnknownHostname() { + assertFalse(outboundDomains.shouldBlockOutgoingRequest("unknown.com")); + } + + @Test + public void testBlockModeBlocksHostname() { + outboundDomains.update(List.of(new Domain("blocked.com", "block")), false); + assertTrue(outboundDomains.shouldBlockOutgoingRequest("blocked.com")); + } + + @Test + public void testAllowModeDoesNotBlockHostname() { + outboundDomains.update(List.of(new Domain("allowed.com", "allow")), false); + assertFalse(outboundDomains.shouldBlockOutgoingRequest("allowed.com")); + } + + @Test + public void testUnknownHostnameNotBlockedWhenBlockNewOutgoingRequestsFalse() { + outboundDomains.update(List.of(new Domain("blocked.com", "block")), false); + assertFalse(outboundDomains.shouldBlockOutgoingRequest("unknown.com")); + } + + // --- shouldBlockOutgoingRequest with blockNewOutgoingRequests=true --- + + @Test + public void testUnknownHostnameBlockedWhenBlockNewOutgoingRequestsTrue() { + outboundDomains.update(List.of(), true); + assertTrue(outboundDomains.shouldBlockOutgoingRequest("unknown.com")); + } + + @Test + public void testAllowModeNotBlockedWhenBlockNewOutgoingRequestsTrue() { + outboundDomains.update(List.of(new Domain("allowed.com", "allow")), true); + assertFalse(outboundDomains.shouldBlockOutgoingRequest("allowed.com")); + } + + @Test + public void testBlockModeBlockedWhenBlockNewOutgoingRequestsTrue() { + outboundDomains.update(List.of(new Domain("blocked.com", "block")), true); + assertTrue(outboundDomains.shouldBlockOutgoingRequest("blocked.com")); + } + + // --- update() behaviour --- + + @Test + public void testUpdateWithNullDomainsPreservesExistingDomains() { + outboundDomains.update(List.of(new Domain("blocked.com", "block")), false); + // null domains should not reset the map + outboundDomains.update(null, false); + assertTrue(outboundDomains.shouldBlockOutgoingRequest("blocked.com")); + } + + @Test + public void testUpdateWithNullDomainsUpdatesBlockFlag() { + outboundDomains.update(null, true); + // blockNewOutgoingRequests should now be true even though domains unchanged + assertTrue(outboundDomains.shouldBlockOutgoingRequest("unknown.com")); + } + + @Test + public void testUpdateReplacesDomainsMap() { + outboundDomains.update(List.of(new Domain("old.com", "block")), false); + outboundDomains.update(List.of(new Domain("new.com", "block")), false); + // old entry should be gone + assertFalse(outboundDomains.shouldBlockOutgoingRequest("old.com")); + assertTrue(outboundDomains.shouldBlockOutgoingRequest("new.com")); + } + + @Test + public void testUpdateWithEmptyListClearsDomainsMap() { + outboundDomains.update(List.of(new Domain("blocked.com", "block")), false); + outboundDomains.update(List.of(), false); + assertFalse(outboundDomains.shouldBlockOutgoingRequest("blocked.com")); + } + + @Test + public void testMultipleDomainsWithMixedModes() { + outboundDomains.update(List.of( + new Domain("blocked.com", "block"), + new Domain("allowed.com", "allow") + ), false); + assertTrue(outboundDomains.shouldBlockOutgoingRequest("blocked.com")); + assertFalse(outboundDomains.shouldBlockOutgoingRequest("allowed.com")); + assertFalse(outboundDomains.shouldBlockOutgoingRequest("other.com")); + } +} diff --git a/agent_api/src/test/java/wrappers/HttpURLConnectionTest.java b/agent_api/src/test/java/wrappers/HttpURLConnectionTest.java index ae85b3952..2037f2913 100644 --- a/agent_api/src/test/java/wrappers/HttpURLConnectionTest.java +++ b/agent_api/src/test/java/wrappers/HttpURLConnectionTest.java @@ -3,6 +3,7 @@ import dev.aikido.agent_api.context.Context; import dev.aikido.agent_api.storage.Hostnames; import dev.aikido.agent_api.storage.HostnamesStore; +import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -25,12 +26,14 @@ public class HttpURLConnectionTest { void cleanup() { Context.set(null); HostnamesStore.clear(); + PendingHostnamesStore.clear(); } @BeforeEach void beforeEach() { cleanup(); ServiceConfigStore.updateBlocking(true); + PendingHostnamesStore.clear(); } private void setContextAndLifecycle(String url) { diff --git a/end2end/server/mock_aikido_core.py b/end2end/server/mock_aikido_core.py index 7bb2cff4a..6238f5741 100644 --- a/end2end/server/mock_aikido_core.py +++ b/end2end/server/mock_aikido_core.py @@ -44,6 +44,8 @@ } ], "blockedUserIds": ["12345"], + "blockNewOutgoingRequests": False, + "domains": [], "block": True, }, "lists": {