""" Tests for URL validation and safety checks. """ import pytest from scanner.validators import ( validate_url, normalize_url, is_private_ip, extract_domain, ) class TestURLValidation: """Tests for URL validation functionality.""" def test_valid_https_url(self): """Test that valid HTTPS URLs pass validation.""" is_valid, normalized, error = validate_url("https://example.com") assert is_valid is True assert error is None assert normalized == "https://example.com" def test_valid_http_url(self): """Test that valid HTTP URLs pass validation.""" is_valid, normalized, error = validate_url("http://example.com") assert is_valid is True assert error is None def test_url_with_path(self): """Test URL with path is normalized correctly.""" is_valid, normalized, error = validate_url("https://example.com/page/") assert is_valid is True assert normalized == "https://example.com/page" # Trailing slash removed def test_url_with_query(self): """Test URL with query string is preserved.""" is_valid, normalized, error = validate_url("https://example.com/search?q=test") assert is_valid is True assert "q=test" in normalized def test_invalid_scheme(self): """Test that non-http/https schemes are rejected.""" is_valid, _, error = validate_url("ftp://example.com") assert is_valid is False assert "http or https" in error.lower() def test_missing_scheme(self): """Test that URLs without scheme are rejected.""" is_valid, _, error = validate_url("example.com") assert is_valid is False def test_empty_url(self): """Test that empty URL is rejected.""" is_valid, _, error = validate_url("") assert is_valid is False assert "required" in error.lower() def test_localhost_blocked(self): """Test that localhost is blocked (SSRF protection).""" is_valid, _, error = validate_url("http://localhost") assert is_valid is False assert "not allowed" in error.lower() def test_localhost_127_blocked(self): """Test that 127.0.0.1 is blocked.""" is_valid, _, error = validate_url("http://127.0.0.1") assert is_valid is False assert "not allowed" in error.lower() def test_private_ip_10_blocked(self): """Test that 10.x.x.x IPs are blocked.""" is_valid, _, error = validate_url("http://10.0.0.1") assert is_valid is False assert "not allowed" in error.lower() def test_private_ip_172_blocked(self): """Test that 172.16-31.x.x IPs are blocked.""" is_valid, _, error = validate_url("http://172.16.0.1") assert is_valid is False assert "not allowed" in error.lower() def test_private_ip_192_blocked(self): """Test that 192.168.x.x IPs are blocked.""" is_valid, _, error = validate_url("http://192.168.1.1") assert is_valid is False assert "not allowed" in error.lower() class TestPrivateIP: """Tests for private IP detection.""" def test_loopback_is_private(self): """Test that loopback addresses are detected as private.""" assert is_private_ip("127.0.0.1") is True assert is_private_ip("127.0.0.2") is True def test_10_range_is_private(self): """Test that 10.x.x.x range is private.""" assert is_private_ip("10.0.0.1") is True assert is_private_ip("10.255.255.255") is True def test_172_range_is_private(self): """Test that 172.16-31.x.x range is private.""" assert is_private_ip("172.16.0.1") is True assert is_private_ip("172.31.255.255") is True def test_192_range_is_private(self): """Test that 192.168.x.x range is private.""" assert is_private_ip("192.168.0.1") is True assert is_private_ip("192.168.255.255") is True def test_public_ip_not_private(self): """Test that public IPs are not marked as private.""" assert is_private_ip("8.8.8.8") is False assert is_private_ip("1.1.1.1") is False assert is_private_ip("93.184.216.34") is False # example.com class TestNormalizeURL: """Tests for URL normalization.""" def test_lowercase_host(self): """Test that hostname is lowercased.""" result = normalize_url("https://EXAMPLE.COM") assert "example.com" in result def test_remove_trailing_slash(self): """Test that trailing slash is removed from path.""" result = normalize_url("https://example.com/page/") assert result.endswith("/page") def test_preserve_query_string(self): """Test that query string is preserved.""" result = normalize_url("https://example.com?foo=bar") assert "foo=bar" in result class TestExtractDomain: """Tests for domain extraction.""" def test_simple_domain(self): """Test extracting domain from simple URL.""" domain = extract_domain("https://example.com/page") assert domain == "example.com" def test_subdomain(self): """Test extracting domain with subdomain.""" domain = extract_domain("https://www.example.com") assert domain == "www.example.com" def test_with_port(self): """Test extracting domain with port.""" domain = extract_domain("https://example.com:8080") assert domain == "example.com:8080"