186 lines
4.9 KiB
Python
186 lines
4.9 KiB
Python
"""
|
|
URL validation and safety utilities.
|
|
|
|
This module provides functions for validating and normalizing URLs,
|
|
including safety checks to prevent SSRF attacks.
|
|
"""
|
|
|
|
import ipaddress
|
|
import logging
|
|
import socket
|
|
from typing import Tuple
|
|
from urllib.parse import urlparse, urlunparse
|
|
|
|
import validators
|
|
|
|
from django.conf import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_url(url: str) -> Tuple[bool, str]:
|
|
"""
|
|
Validate and normalize a URL for scanning.
|
|
|
|
Args:
|
|
url: The URL to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, normalized_url_or_error_message)
|
|
"""
|
|
if not url:
|
|
return False, "URL is required"
|
|
|
|
# Basic URL validation
|
|
if not validators.url(url):
|
|
return False, "Invalid URL format"
|
|
|
|
# Parse the URL
|
|
try:
|
|
parsed = urlparse(url)
|
|
except Exception as e:
|
|
return False, f"Could not parse URL: {e}"
|
|
|
|
# Check scheme
|
|
if parsed.scheme not in ('http', 'https'):
|
|
return False, "URL must use http or https scheme"
|
|
|
|
# Check hostname
|
|
hostname = parsed.netloc.split(':')[0].lower()
|
|
|
|
if not hostname:
|
|
return False, "URL must have a valid hostname"
|
|
|
|
# Safety check: block localhost and private IPs
|
|
is_safe, safety_error = check_url_safety(hostname)
|
|
if not is_safe:
|
|
return False, safety_error
|
|
|
|
# Normalize URL
|
|
normalized = normalize_url(url)
|
|
|
|
return True, normalized
|
|
|
|
|
|
def normalize_url(url: str) -> str:
|
|
"""
|
|
Normalize a URL to a canonical form.
|
|
|
|
- Lowercase hostname
|
|
- Remove trailing slashes from path
|
|
- Remove default ports
|
|
- Sort query parameters
|
|
|
|
Args:
|
|
url: The URL to normalize
|
|
|
|
Returns:
|
|
Normalized URL string
|
|
"""
|
|
parsed = urlparse(url)
|
|
|
|
# Lowercase hostname
|
|
hostname = parsed.netloc.lower()
|
|
|
|
# Remove default ports
|
|
if ':80' in hostname and parsed.scheme == 'http':
|
|
hostname = hostname.replace(':80', '')
|
|
elif ':443' in hostname and parsed.scheme == 'https':
|
|
hostname = hostname.replace(':443', '')
|
|
|
|
# Normalize path (remove trailing slash except for root)
|
|
path = parsed.path
|
|
if path != '/' and path.endswith('/'):
|
|
path = path.rstrip('/')
|
|
if not path:
|
|
path = '/'
|
|
|
|
# Reconstruct URL
|
|
normalized = urlunparse((
|
|
parsed.scheme,
|
|
hostname,
|
|
path,
|
|
parsed.params,
|
|
parsed.query,
|
|
'' # Remove fragment
|
|
))
|
|
|
|
return normalized
|
|
|
|
|
|
def check_url_safety(hostname: str) -> Tuple[bool, str]:
|
|
"""
|
|
Check if a hostname is safe to scan (not localhost/private IP).
|
|
|
|
Args:
|
|
hostname: The hostname to check
|
|
|
|
Returns:
|
|
Tuple of (is_safe, error_message_if_not_safe)
|
|
"""
|
|
scanner_config = settings.SCANNER_CONFIG
|
|
blocked_hosts = scanner_config.get('BLOCKED_HOSTS', [])
|
|
blocked_ranges = scanner_config.get('BLOCKED_IP_RANGES', [])
|
|
|
|
# Check blocked hostnames
|
|
if hostname in blocked_hosts:
|
|
return False, f"Scanning {hostname} is not allowed"
|
|
|
|
# Try to resolve hostname to IP
|
|
try:
|
|
ip_addresses = socket.getaddrinfo(
|
|
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
|
|
)
|
|
except socket.gaierror:
|
|
# Could not resolve - might be okay for some hostnames
|
|
logger.warning(f"Could not resolve hostname: {hostname}")
|
|
return True, ""
|
|
|
|
for family, type_, proto, canonname, sockaddr in ip_addresses:
|
|
ip_str = sockaddr[0]
|
|
|
|
try:
|
|
ip = ipaddress.ip_address(ip_str)
|
|
|
|
# Check if IP is in any blocked range
|
|
for blocked_range in blocked_ranges:
|
|
try:
|
|
network = ipaddress.ip_network(blocked_range, strict=False)
|
|
if ip in network:
|
|
return False, f"Scanning private/local IP addresses is not allowed ({ip_str})"
|
|
except ValueError:
|
|
continue
|
|
|
|
# Additional checks
|
|
if ip.is_private:
|
|
return False, f"Scanning private IP addresses is not allowed ({ip_str})"
|
|
|
|
if ip.is_loopback:
|
|
return False, f"Scanning localhost/loopback addresses is not allowed ({ip_str})"
|
|
|
|
if ip.is_link_local:
|
|
return False, f"Scanning link-local addresses is not allowed ({ip_str})"
|
|
|
|
if ip.is_reserved:
|
|
return False, f"Scanning reserved IP addresses is not allowed ({ip_str})"
|
|
|
|
except ValueError:
|
|
# Not a valid IP address format
|
|
continue
|
|
|
|
return True, ""
|
|
|
|
|
|
def get_domain_from_url(url: str) -> str:
|
|
"""
|
|
Extract the domain from a URL.
|
|
|
|
Args:
|
|
url: The URL to extract domain from
|
|
|
|
Returns:
|
|
The domain/hostname
|
|
"""
|
|
parsed = urlparse(url)
|
|
return parsed.netloc.split(':')[0].lower()
|