156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
"""
|
|
Tests for URL validation and safety checks.
|
|
"""
|
|
|
|
import pytest
|
|
from scanner.validators import (
|
|
validate_url,
|
|
normalize_url,
|
|
is_private_ip,
|
|
extract_domain,
|
|
)
|
|
|
|
|
|
class TestURLValidation:
|
|
"""Tests for URL validation functionality."""
|
|
|
|
def test_valid_https_url(self):
|
|
"""Test that valid HTTPS URLs pass validation."""
|
|
is_valid, normalized, error = validate_url("https://example.com")
|
|
assert is_valid is True
|
|
assert error is None
|
|
assert normalized == "https://example.com"
|
|
|
|
def test_valid_http_url(self):
|
|
"""Test that valid HTTP URLs pass validation."""
|
|
is_valid, normalized, error = validate_url("http://example.com")
|
|
assert is_valid is True
|
|
assert error is None
|
|
|
|
def test_url_with_path(self):
|
|
"""Test URL with path is normalized correctly."""
|
|
is_valid, normalized, error = validate_url("https://example.com/page/")
|
|
assert is_valid is True
|
|
assert normalized == "https://example.com/page" # Trailing slash removed
|
|
|
|
def test_url_with_query(self):
|
|
"""Test URL with query string is preserved."""
|
|
is_valid, normalized, error = validate_url("https://example.com/search?q=test")
|
|
assert is_valid is True
|
|
assert "q=test" in normalized
|
|
|
|
def test_invalid_scheme(self):
|
|
"""Test that non-http/https schemes are rejected."""
|
|
is_valid, _, error = validate_url("ftp://example.com")
|
|
assert is_valid is False
|
|
assert "http or https" in error.lower()
|
|
|
|
def test_missing_scheme(self):
|
|
"""Test that URLs without scheme are rejected."""
|
|
is_valid, _, error = validate_url("example.com")
|
|
assert is_valid is False
|
|
|
|
def test_empty_url(self):
|
|
"""Test that empty URL is rejected."""
|
|
is_valid, _, error = validate_url("")
|
|
assert is_valid is False
|
|
assert "required" in error.lower()
|
|
|
|
def test_localhost_blocked(self):
|
|
"""Test that localhost is blocked (SSRF protection)."""
|
|
is_valid, _, error = validate_url("http://localhost")
|
|
assert is_valid is False
|
|
assert "not allowed" in error.lower()
|
|
|
|
def test_localhost_127_blocked(self):
|
|
"""Test that 127.0.0.1 is blocked."""
|
|
is_valid, _, error = validate_url("http://127.0.0.1")
|
|
assert is_valid is False
|
|
assert "not allowed" in error.lower()
|
|
|
|
def test_private_ip_10_blocked(self):
|
|
"""Test that 10.x.x.x IPs are blocked."""
|
|
is_valid, _, error = validate_url("http://10.0.0.1")
|
|
assert is_valid is False
|
|
assert "not allowed" in error.lower()
|
|
|
|
def test_private_ip_172_blocked(self):
|
|
"""Test that 172.16-31.x.x IPs are blocked."""
|
|
is_valid, _, error = validate_url("http://172.16.0.1")
|
|
assert is_valid is False
|
|
assert "not allowed" in error.lower()
|
|
|
|
def test_private_ip_192_blocked(self):
|
|
"""Test that 192.168.x.x IPs are blocked."""
|
|
is_valid, _, error = validate_url("http://192.168.1.1")
|
|
assert is_valid is False
|
|
assert "not allowed" in error.lower()
|
|
|
|
|
|
class TestPrivateIP:
|
|
"""Tests for private IP detection."""
|
|
|
|
def test_loopback_is_private(self):
|
|
"""Test that loopback addresses are detected as private."""
|
|
assert is_private_ip("127.0.0.1") is True
|
|
assert is_private_ip("127.0.0.2") is True
|
|
|
|
def test_10_range_is_private(self):
|
|
"""Test that 10.x.x.x range is private."""
|
|
assert is_private_ip("10.0.0.1") is True
|
|
assert is_private_ip("10.255.255.255") is True
|
|
|
|
def test_172_range_is_private(self):
|
|
"""Test that 172.16-31.x.x range is private."""
|
|
assert is_private_ip("172.16.0.1") is True
|
|
assert is_private_ip("172.31.255.255") is True
|
|
|
|
def test_192_range_is_private(self):
|
|
"""Test that 192.168.x.x range is private."""
|
|
assert is_private_ip("192.168.0.1") is True
|
|
assert is_private_ip("192.168.255.255") is True
|
|
|
|
def test_public_ip_not_private(self):
|
|
"""Test that public IPs are not marked as private."""
|
|
assert is_private_ip("8.8.8.8") is False
|
|
assert is_private_ip("1.1.1.1") is False
|
|
assert is_private_ip("93.184.216.34") is False # example.com
|
|
|
|
|
|
class TestNormalizeURL:
|
|
"""Tests for URL normalization."""
|
|
|
|
def test_lowercase_host(self):
|
|
"""Test that hostname is lowercased."""
|
|
result = normalize_url("https://EXAMPLE.COM")
|
|
assert "example.com" in result
|
|
|
|
def test_remove_trailing_slash(self):
|
|
"""Test that trailing slash is removed from path."""
|
|
result = normalize_url("https://example.com/page/")
|
|
assert result.endswith("/page")
|
|
|
|
def test_preserve_query_string(self):
|
|
"""Test that query string is preserved."""
|
|
result = normalize_url("https://example.com?foo=bar")
|
|
assert "foo=bar" in result
|
|
|
|
|
|
class TestExtractDomain:
|
|
"""Tests for domain extraction."""
|
|
|
|
def test_simple_domain(self):
|
|
"""Test extracting domain from simple URL."""
|
|
domain = extract_domain("https://example.com/page")
|
|
assert domain == "example.com"
|
|
|
|
def test_subdomain(self):
|
|
"""Test extracting domain with subdomain."""
|
|
domain = extract_domain("https://www.example.com")
|
|
assert domain == "www.example.com"
|
|
|
|
def test_with_port(self):
|
|
"""Test extracting domain with port."""
|
|
domain = extract_domain("https://example.com:8080")
|
|
assert domain == "example.com:8080"
|