secure-web/backend/scanner/utils.py

186 lines
4.9 KiB
Python

"""
URL validation and safety utilities.
This module provides functions for validating and normalizing URLs,
including safety checks to prevent SSRF attacks.
"""
import ipaddress
import logging
import socket
from typing import Tuple
from urllib.parse import urlparse, urlunparse
import validators
from django.conf import settings
logger = logging.getLogger(__name__)
def validate_url(url: str) -> Tuple[bool, str]:
"""
Validate and normalize a URL for scanning.
Args:
url: The URL to validate
Returns:
Tuple of (is_valid, normalized_url_or_error_message)
"""
if not url:
return False, "URL is required"
# Basic URL validation
if not validators.url(url):
return False, "Invalid URL format"
# Parse the URL
try:
parsed = urlparse(url)
except Exception as e:
return False, f"Could not parse URL: {e}"
# Check scheme
if parsed.scheme not in ('http', 'https'):
return False, "URL must use http or https scheme"
# Check hostname
hostname = parsed.netloc.split(':')[0].lower()
if not hostname:
return False, "URL must have a valid hostname"
# Safety check: block localhost and private IPs
is_safe, safety_error = check_url_safety(hostname)
if not is_safe:
return False, safety_error
# Normalize URL
normalized = normalize_url(url)
return True, normalized
def normalize_url(url: str) -> str:
"""
Normalize a URL to a canonical form.
- Lowercase hostname
- Remove trailing slashes from path
- Remove default ports
- Sort query parameters
Args:
url: The URL to normalize
Returns:
Normalized URL string
"""
parsed = urlparse(url)
# Lowercase hostname
hostname = parsed.netloc.lower()
# Remove default ports
if ':80' in hostname and parsed.scheme == 'http':
hostname = hostname.replace(':80', '')
elif ':443' in hostname and parsed.scheme == 'https':
hostname = hostname.replace(':443', '')
# Normalize path (remove trailing slash except for root)
path = parsed.path
if path != '/' and path.endswith('/'):
path = path.rstrip('/')
if not path:
path = '/'
# Reconstruct URL
normalized = urlunparse((
parsed.scheme,
hostname,
path,
parsed.params,
parsed.query,
'' # Remove fragment
))
return normalized
def check_url_safety(hostname: str) -> Tuple[bool, str]:
"""
Check if a hostname is safe to scan (not localhost/private IP).
Args:
hostname: The hostname to check
Returns:
Tuple of (is_safe, error_message_if_not_safe)
"""
scanner_config = settings.SCANNER_CONFIG
blocked_hosts = scanner_config.get('BLOCKED_HOSTS', [])
blocked_ranges = scanner_config.get('BLOCKED_IP_RANGES', [])
# Check blocked hostnames
if hostname in blocked_hosts:
return False, f"Scanning {hostname} is not allowed"
# Try to resolve hostname to IP
try:
ip_addresses = socket.getaddrinfo(
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
except socket.gaierror:
# Could not resolve - might be okay for some hostnames
logger.warning(f"Could not resolve hostname: {hostname}")
return True, ""
for family, type_, proto, canonname, sockaddr in ip_addresses:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
# Check if IP is in any blocked range
for blocked_range in blocked_ranges:
try:
network = ipaddress.ip_network(blocked_range, strict=False)
if ip in network:
return False, f"Scanning private/local IP addresses is not allowed ({ip_str})"
except ValueError:
continue
# Additional checks
if ip.is_private:
return False, f"Scanning private IP addresses is not allowed ({ip_str})"
if ip.is_loopback:
return False, f"Scanning localhost/loopback addresses is not allowed ({ip_str})"
if ip.is_link_local:
return False, f"Scanning link-local addresses is not allowed ({ip_str})"
if ip.is_reserved:
return False, f"Scanning reserved IP addresses is not allowed ({ip_str})"
except ValueError:
# Not a valid IP address format
continue
return True, ""
def get_domain_from_url(url: str) -> str:
"""
Extract the domain from a URL.
Args:
url: The URL to extract domain from
Returns:
The domain/hostname
"""
parsed = urlparse(url)
return parsed.netloc.split(':')[0].lower()