193 lines
5.0 KiB
Python
193 lines
5.0 KiB
Python
"""
|
|
URL validation and safety utilities.
|
|
|
|
This module provides functions to validate URLs and ensure they
|
|
don't target internal/private networks (SSRF protection).
|
|
"""
|
|
|
|
import ipaddress
|
|
import socket
|
|
import logging
|
|
from typing import Tuple, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
from django.conf import settings
|
|
|
|
logger = logging.getLogger('scanner')
|
|
|
|
|
|
def get_blocked_ip_ranges() -> list:
|
|
"""Get list of blocked IP ranges from settings."""
|
|
return settings.SCANNER_CONFIG.get('BLOCKED_IP_RANGES', [
|
|
'10.0.0.0/8',
|
|
'172.16.0.0/12',
|
|
'192.168.0.0/16',
|
|
'127.0.0.0/8',
|
|
'169.254.0.0/16',
|
|
'::1/128',
|
|
'fc00::/7',
|
|
'fe80::/10',
|
|
])
|
|
|
|
|
|
def get_blocked_hosts() -> list:
|
|
"""Get list of blocked hostnames from settings."""
|
|
return settings.SCANNER_CONFIG.get('BLOCKED_HOSTS', [
|
|
'localhost',
|
|
'localhost.localdomain',
|
|
])
|
|
|
|
|
|
def is_private_ip(ip_str: str) -> bool:
|
|
"""
|
|
Check if an IP address is private, loopback, or otherwise blocked.
|
|
|
|
Args:
|
|
ip_str: IP address string
|
|
|
|
Returns:
|
|
True if the IP should be blocked
|
|
"""
|
|
try:
|
|
ip = ipaddress.ip_address(ip_str)
|
|
|
|
# Check standard private/reserved ranges
|
|
if ip.is_private or ip.is_loopback or ip.is_link_local:
|
|
return True
|
|
if ip.is_reserved or ip.is_multicast:
|
|
return True
|
|
|
|
# Check custom blocked ranges
|
|
for cidr in get_blocked_ip_ranges():
|
|
try:
|
|
network = ipaddress.ip_network(cidr, strict=False)
|
|
if ip in network:
|
|
return True
|
|
except ValueError:
|
|
continue
|
|
|
|
return False
|
|
|
|
except ValueError:
|
|
# Invalid IP format
|
|
return True
|
|
|
|
|
|
def resolve_hostname(hostname: str) -> Optional[str]:
|
|
"""
|
|
Resolve a hostname to its IP address.
|
|
|
|
Args:
|
|
hostname: The hostname to resolve
|
|
|
|
Returns:
|
|
IP address string or None if resolution fails
|
|
"""
|
|
try:
|
|
# Get all addresses and return the first one
|
|
result = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC)
|
|
if result:
|
|
return result[0][4][0]
|
|
return None
|
|
except socket.gaierror:
|
|
return None
|
|
|
|
|
|
def validate_url(url: str) -> Tuple[bool, str, Optional[str]]:
|
|
"""
|
|
Validate a URL for scanning safety.
|
|
|
|
Checks:
|
|
1. URL format and scheme (must be http or https)
|
|
2. Hostname is not in blocked list
|
|
3. Resolved IP is not private/internal
|
|
|
|
Args:
|
|
url: The URL to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, normalized_url, error_message)
|
|
"""
|
|
if not url:
|
|
return False, url, "URL is required"
|
|
|
|
# Parse the URL
|
|
try:
|
|
parsed = urlparse(url)
|
|
except Exception as e:
|
|
return False, url, f"Invalid URL format: {e}"
|
|
|
|
# Check scheme
|
|
if parsed.scheme not in ('http', 'https'):
|
|
return False, url, "URL must use http or https scheme"
|
|
|
|
# Check hostname exists
|
|
if not parsed.netloc:
|
|
return False, url, "URL must have a valid hostname"
|
|
|
|
# Extract hostname (without port)
|
|
hostname = parsed.hostname
|
|
if not hostname:
|
|
return False, url, "Could not extract hostname from URL"
|
|
|
|
# Normalize hostname
|
|
hostname = hostname.lower()
|
|
|
|
# Check against blocked hostnames
|
|
if hostname in get_blocked_hosts():
|
|
return False, url, f"Scanning {hostname} is not allowed"
|
|
|
|
# Check if hostname is an IP address
|
|
try:
|
|
ip = ipaddress.ip_address(hostname)
|
|
if is_private_ip(hostname):
|
|
return False, url, "Scanning private/internal IP addresses is not allowed"
|
|
except ValueError:
|
|
# Not an IP, it's a hostname - resolve it
|
|
resolved_ip = resolve_hostname(hostname)
|
|
if resolved_ip:
|
|
if is_private_ip(resolved_ip):
|
|
return False, url, f"URL resolves to private IP ({resolved_ip}), scanning not allowed"
|
|
# If we can't resolve, we'll let the scanner handle the error
|
|
|
|
# Normalize the URL
|
|
# Remove trailing slash from path, lowercase scheme and host
|
|
normalized = f"{parsed.scheme}://{parsed.netloc.lower()}"
|
|
if parsed.path and parsed.path != '/':
|
|
normalized += parsed.path.rstrip('/')
|
|
if parsed.query:
|
|
normalized += f"?{parsed.query}"
|
|
|
|
return True, normalized, None
|
|
|
|
|
|
def normalize_url(url: str) -> str:
|
|
"""
|
|
Normalize a URL to a canonical form.
|
|
|
|
Args:
|
|
url: The URL to normalize
|
|
|
|
Returns:
|
|
Normalized URL string
|
|
"""
|
|
is_valid, normalized, _ = validate_url(url)
|
|
return normalized if is_valid else url
|
|
|
|
|
|
def extract_domain(url: str) -> str:
|
|
"""
|
|
Extract the domain from a URL.
|
|
|
|
Args:
|
|
url: The URL to extract domain from
|
|
|
|
Returns:
|
|
Domain string
|
|
"""
|
|
try:
|
|
parsed = urlparse(url)
|
|
return parsed.netloc.lower()
|
|
except Exception:
|
|
return ""
|