ssrf validation #1

merged
opened by nekomimi.pet targeting main from nekomimi.pet/atproto-connect: main

current code completely trusts all redirects DID documents offer which allow the server to send http requests to arbitrary internal or external hosts

per https://atproto.com/specs/did: The PDS service network location for the account is found under the service array, with id ending #atproto_pds, and type matching AtprotoPersonalDataServer. The first matching entry in the array should be used, and any others ignored. The serviceEndpoint field must contain an HTTPS URL of server. It should contain only the URI scheme (http or https), hostname, and optional port number, not any "userinfo", path prefix, or other components.

Changed files
+67 -2
src
main
kotlin
com
jollywhoppers
+67 -2
src/main/kotlin/com/jollywhoppers/atproto/AtProtoClient.kt
··· 159 */ 160 suspend fun resolveDid(did: String): Result<DidDocument> = runCatching { 161 logger.info("Resolving DID: $did") 162 - 163 val url = when { 164 did.startsWith("did:plc:") -> { 165 val identifier = did.removePrefix("did:plc:") ··· 167 } 168 did.startsWith("did:web:") -> { 169 val domain = did.removePrefix("did:web:") 170 "https://$domain/.well-known/did.json" 171 } 172 else -> throw IllegalArgumentException("Unsupported DID method: $did") ··· 363 ?: throw Exception("No handle found in DID document") 364 val pdsService = didDoc.service.firstOrNull { it.type == "AtprotoPersonalDataServer" } 365 ?: throw Exception("No PDS service found in DID document") 366 - Triple(identifier, handle, pdsService.serviceEndpoint) 367 } 368 } 369 else -> { ··· 382 .rawSchemeSpecificPart 383 .replace("+", "%20") 384 } 385 }
··· 159 */ 160 suspend fun resolveDid(did: String): Result<DidDocument> = runCatching { 161 logger.info("Resolving DID: $did") 162 + 163 val url = when { 164 did.startsWith("did:plc:") -> { 165 val identifier = did.removePrefix("did:plc:") ··· 167 } 168 did.startsWith("did:web:") -> { 169 val domain = did.removePrefix("did:web:") 170 + 171 + // Validate domain format (no IPs, only valid hostnames) 172 + if (!domain.matches(Regex("^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"))) { 173 + throw IllegalArgumentException("Invalid did:web domain format: must be a valid hostname") 174 + } 175 + 176 + // Block private IP ranges and localhost 177 + validateNotPrivateNetwork(domain) 178 + 179 "https://$domain/.well-known/did.json" 180 } 181 else -> throw IllegalArgumentException("Unsupported DID method: $did") ··· 372 ?: throw Exception("No handle found in DID document") 373 val pdsService = didDoc.service.firstOrNull { it.type == "AtprotoPersonalDataServer" } 374 ?: throw Exception("No PDS service found in DID document") 375 + 376 + // Validate serviceEndpoint per AT Protocol spec 377 + val serviceEndpoint = pdsService.serviceEndpoint 378 + val uri = try { 379 + URI.create(serviceEndpoint) 380 + } catch (e: Exception) { 381 + throw Exception("Invalid serviceEndpoint URI: ${e.message}") 382 + } 383 + 384 + // Validate per AT Protocol spec 385 + require(uri.scheme in listOf("http", "https")) { 386 + "serviceEndpoint must use HTTP or HTTPS scheme, got: ${uri.scheme}" 387 + } 388 + require(uri.host != null) { 389 + "serviceEndpoint must have a valid host" 390 + } 391 + require(uri.path.isNullOrEmpty() || uri.path == "/") { 392 + "serviceEndpoint must not contain path, got: ${uri.path}" 393 + } 394 + require(uri.query == null) { 395 + "serviceEndpoint must not contain query parameters" 396 + } 397 + require(uri.fragment == null) { 398 + "serviceEndpoint must not contain fragment" 399 + } 400 + require(uri.userInfo == null) { 401 + "serviceEndpoint must not contain userinfo" 402 + } 403 + 404 + // Block private IP ranges 405 + validateNotPrivateNetwork(uri.host) 406 + 407 + // Reconstruct clean URL 408 + val cleanPdsUrl = "${uri.scheme}://${uri.host}${uri.port.takeIf { it != -1 }?.let { ":$it" } ?: ""}" 409 + Triple(identifier, handle, cleanPdsUrl) 410 } 411 } 412 else -> { ··· 425 .rawSchemeSpecificPart 426 .replace("+", "%20") 427 } 428 + 429 + /** 430 + * Validates that a hostname or domain is not a private network address. 431 + * Throws IllegalArgumentException if the address is localhost or a private IP range. 432 + */ 433 + private fun validateNotPrivateNetwork(host: String) { 434 + val blockedPatterns = listOf( 435 + Regex("^localhost$", RegexOption.IGNORE_CASE), 436 + Regex("^127\\."), 437 + Regex("^10\\."), 438 + Regex("^172\\.(1[6-9]|2[0-9]|3[01])\\."), 439 + Regex("^192\\.168\\."), 440 + Regex("^169\\.254\\."), 441 + Regex("^::1$"), 442 + Regex("^fc00:"), 443 + Regex("^fe80:") 444 + ) 445 + 446 + if (blockedPatterns.any { it.containsMatchIn(host) }) { 447 + throw IllegalArgumentException("Access to private networks is not allowed: $host") 448 + } 449 + } 450 }