secure-web/backend/tests/test_validators.py

156 lines
5.5 KiB
Python

"""
Tests for URL validation and safety checks.
"""
import pytest
from scanner.validators import (
validate_url,
normalize_url,
is_private_ip,
extract_domain,
)
class TestURLValidation:
"""Tests for URL validation functionality."""
def test_valid_https_url(self):
"""Test that valid HTTPS URLs pass validation."""
is_valid, normalized, error = validate_url("https://example.com")
assert is_valid is True
assert error is None
assert normalized == "https://example.com"
def test_valid_http_url(self):
"""Test that valid HTTP URLs pass validation."""
is_valid, normalized, error = validate_url("http://example.com")
assert is_valid is True
assert error is None
def test_url_with_path(self):
"""Test URL with path is normalized correctly."""
is_valid, normalized, error = validate_url("https://example.com/page/")
assert is_valid is True
assert normalized == "https://example.com/page" # Trailing slash removed
def test_url_with_query(self):
"""Test URL with query string is preserved."""
is_valid, normalized, error = validate_url("https://example.com/search?q=test")
assert is_valid is True
assert "q=test" in normalized
def test_invalid_scheme(self):
"""Test that non-http/https schemes are rejected."""
is_valid, _, error = validate_url("ftp://example.com")
assert is_valid is False
assert "http or https" in error.lower()
def test_missing_scheme(self):
"""Test that URLs without scheme are rejected."""
is_valid, _, error = validate_url("example.com")
assert is_valid is False
def test_empty_url(self):
"""Test that empty URL is rejected."""
is_valid, _, error = validate_url("")
assert is_valid is False
assert "required" in error.lower()
def test_localhost_blocked(self):
"""Test that localhost is blocked (SSRF protection)."""
is_valid, _, error = validate_url("http://localhost")
assert is_valid is False
assert "not allowed" in error.lower()
def test_localhost_127_blocked(self):
"""Test that 127.0.0.1 is blocked."""
is_valid, _, error = validate_url("http://127.0.0.1")
assert is_valid is False
assert "not allowed" in error.lower()
def test_private_ip_10_blocked(self):
"""Test that 10.x.x.x IPs are blocked."""
is_valid, _, error = validate_url("http://10.0.0.1")
assert is_valid is False
assert "not allowed" in error.lower()
def test_private_ip_172_blocked(self):
"""Test that 172.16-31.x.x IPs are blocked."""
is_valid, _, error = validate_url("http://172.16.0.1")
assert is_valid is False
assert "not allowed" in error.lower()
def test_private_ip_192_blocked(self):
"""Test that 192.168.x.x IPs are blocked."""
is_valid, _, error = validate_url("http://192.168.1.1")
assert is_valid is False
assert "not allowed" in error.lower()
class TestPrivateIP:
"""Tests for private IP detection."""
def test_loopback_is_private(self):
"""Test that loopback addresses are detected as private."""
assert is_private_ip("127.0.0.1") is True
assert is_private_ip("127.0.0.2") is True
def test_10_range_is_private(self):
"""Test that 10.x.x.x range is private."""
assert is_private_ip("10.0.0.1") is True
assert is_private_ip("10.255.255.255") is True
def test_172_range_is_private(self):
"""Test that 172.16-31.x.x range is private."""
assert is_private_ip("172.16.0.1") is True
assert is_private_ip("172.31.255.255") is True
def test_192_range_is_private(self):
"""Test that 192.168.x.x range is private."""
assert is_private_ip("192.168.0.1") is True
assert is_private_ip("192.168.255.255") is True
def test_public_ip_not_private(self):
"""Test that public IPs are not marked as private."""
assert is_private_ip("8.8.8.8") is False
assert is_private_ip("1.1.1.1") is False
assert is_private_ip("93.184.216.34") is False # example.com
class TestNormalizeURL:
"""Tests for URL normalization."""
def test_lowercase_host(self):
"""Test that hostname is lowercased."""
result = normalize_url("https://EXAMPLE.COM")
assert "example.com" in result
def test_remove_trailing_slash(self):
"""Test that trailing slash is removed from path."""
result = normalize_url("https://example.com/page/")
assert result.endswith("/page")
def test_preserve_query_string(self):
"""Test that query string is preserved."""
result = normalize_url("https://example.com?foo=bar")
assert "foo=bar" in result
class TestExtractDomain:
"""Tests for domain extraction."""
def test_simple_domain(self):
"""Test extracting domain from simple URL."""
domain = extract_domain("https://example.com/page")
assert domain == "example.com"
def test_subdomain(self):
"""Test extracting domain with subdomain."""
domain = extract_domain("https://www.example.com")
assert domain == "www.example.com"
def test_with_port(self):
"""Test extracting domain with port."""
domain = extract_domain("https://example.com:8080")
assert domain == "example.com:8080"