From a9a9ed48f1385ab7a504a685c75dc4f2ffd5f12e Mon Sep 17 00:00:00 2001 From: Sereth1 Date: Mon, 8 Dec 2025 10:23:07 +0700 Subject: [PATCH] test --- backend/core/urls.py | 8 +- backend/scanner/base.py | 169 +++++++++ backend/scanner/headers_scanner.py | 508 ++++++++++++++++++++++++++ backend/scanner/lighthouse_scanner.py | 320 ++++++++++++++++ backend/scanner/playwright_scanner.py | 432 ++++++++++++++++++++++ backend/scanner/runner.py | 315 ++++++++++++++++ backend/scanner/validators.py | 192 ++++++++++ backend/scanner/zap_scanner.py | 341 +++++++++++++++++ backend/templates/index.html | 250 +++++++++++++ backend/templates/scan_detail.html | 496 +++++++++++++++++++++++++ backend/tests/test_scans.py | 262 +++++++++++++ backend/tests/test_validators.py | 155 ++++++++ 12 files changed, 3447 insertions(+), 1 deletion(-) create mode 100644 backend/scanner/base.py create mode 100644 backend/scanner/headers_scanner.py create mode 100644 backend/scanner/lighthouse_scanner.py create mode 100644 backend/scanner/playwright_scanner.py create mode 100644 backend/scanner/runner.py create mode 100644 backend/scanner/validators.py create mode 100644 backend/scanner/zap_scanner.py create mode 100644 backend/templates/index.html create mode 100644 backend/templates/scan_detail.html create mode 100644 backend/tests/test_scans.py create mode 100644 backend/tests/test_validators.py diff --git a/backend/core/urls.py b/backend/core/urls.py index 445efb3..104d2f8 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -5,6 +5,12 @@ URL configuration for Website Analyzer project. from django.contrib import admin from django.urls import path, include from django.views.generic import TemplateView +from django.shortcuts import render + + +def scan_detail_view(request, scan_id): + """View for scan detail page that passes scan_id to template.""" + return render(request, 'scan_detail.html', {'scan_id': str(scan_id)}) urlpatterns = [ @@ -16,5 +22,5 @@ urlpatterns = [ # Frontend views path('', TemplateView.as_view(template_name='index.html'), name='home'), - path('scan//', TemplateView.as_view(template_name='scan_detail.html'), name='scan_detail'), + path('scan//', scan_detail_view, name='scan_detail'), ] diff --git a/backend/scanner/base.py b/backend/scanner/base.py new file mode 100644 index 0000000..a9da250 --- /dev/null +++ b/backend/scanner/base.py @@ -0,0 +1,169 @@ +""" +Base scanner interface and common utilities. + +All scanner modules inherit from BaseScanner and implement +the common interface for running scans and returning results. +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Optional +from enum import Enum + +logger = logging.getLogger('scanner') + + +class ScannerStatus(Enum): + """Status of a scanner execution.""" + SUCCESS = 'success' + PARTIAL = 'partial' + FAILED = 'failed' + SKIPPED = 'skipped' + + +@dataclass +class ScannerResult: + """ + Standardized result from any scanner. + + Attributes: + status: The execution status of the scanner + scanner_name: Name of the scanner that produced this result + scores: Dictionary of score values (0-100) + metrics: List of metric dictionaries with name, value, unit + issues: List of issue dictionaries with category, severity, title, etc. + raw_data: Original response from the scanner tool + error_message: Error message if the scan failed + execution_time_seconds: How long the scan took + """ + status: ScannerStatus + scanner_name: str + scores: dict = field(default_factory=dict) + metrics: list = field(default_factory=list) + issues: list = field(default_factory=list) + raw_data: Optional[dict] = None + error_message: Optional[str] = None + execution_time_seconds: float = 0.0 + + def to_dict(self) -> dict: + """Convert result to dictionary for serialization.""" + return { + 'status': self.status.value, + 'scanner_name': self.scanner_name, + 'scores': self.scores, + 'metrics': self.metrics, + 'issues': self.issues, + 'raw_data': self.raw_data, + 'error_message': self.error_message, + 'execution_time_seconds': self.execution_time_seconds, + } + + +class BaseScanner(ABC): + """ + Abstract base class for all scanners. + + Each scanner must implement the `run` method which takes a URL + and returns a ScannerResult. + """ + + name: str = "base_scanner" + + def __init__(self, config: dict = None): + """ + Initialize the scanner with optional configuration. + + Args: + config: Dictionary of configuration options + """ + self.config = config or {} + self.logger = logging.getLogger(f'scanner.{self.name}') + + @abstractmethod + def run(self, url: str) -> ScannerResult: + """ + Run the scanner on the given URL. + + Args: + url: The URL to scan + + Returns: + ScannerResult with scan data + """ + pass + + def is_available(self) -> bool: + """ + Check if the scanner is available and properly configured. + + Returns: + True if the scanner can run, False otherwise + """ + return True + + def _create_issue( + self, + category: str, + severity: str, + title: str, + description: str, + affected_url: str = None, + remediation: str = None, + raw_data: dict = None + ) -> dict: + """ + Helper to create a standardized issue dictionary. + + Args: + category: Issue category (security, performance, etc.) + severity: Severity level (critical, high, medium, low, info) + title: Brief issue title + description: Detailed description + affected_url: Specific URL affected + remediation: Suggested fix + raw_data: Original scanner data + + Returns: + Issue dictionary + """ + return { + 'category': category, + 'severity': severity, + 'title': title, + 'description': description, + 'affected_url': affected_url, + 'remediation': remediation, + 'tool': self.name, + 'raw_data': raw_data, + } + + def _create_metric( + self, + name: str, + display_name: str, + value: float, + unit: str, + score: float = None + ) -> dict: + """ + Helper to create a standardized metric dictionary. + + Args: + name: Machine-readable metric name + display_name: Human-readable name + value: Numeric value + unit: Unit of measurement + score: Optional score (0-1) + + Returns: + Metric dictionary + """ + return { + 'name': name, + 'display_name': display_name, + 'value': value, + 'unit': unit, + 'source': self.name, + 'score': score, + } diff --git a/backend/scanner/headers_scanner.py b/backend/scanner/headers_scanner.py new file mode 100644 index 0000000..4cfeaae --- /dev/null +++ b/backend/scanner/headers_scanner.py @@ -0,0 +1,508 @@ +""" +HTTP Headers Security Scanner. + +This module checks HTTP response headers for security best practices, +including CSP, HSTS, X-Frame-Options, and cookie security flags. +""" + +import time +import logging +import ssl +import socket +from typing import Dict, List, Optional, Tuple +from urllib.parse import urlparse + +import httpx + +from .base import BaseScanner, ScannerResult, ScannerStatus + +logger = logging.getLogger('scanner') + + +class HeadersScanner(BaseScanner): + """ + Scanner for HTTP security headers and TLS configuration. + + Checks for: + - Security headers (CSP, HSTS, X-Frame-Options, etc.) + - Cookie security flags (Secure, HttpOnly, SameSite) + - CORS configuration + - TLS/SSL certificate validity + - HTTP to HTTPS redirect + """ + + name = "header_check" + + # Required security headers and their importance + SECURITY_HEADERS = { + 'strict-transport-security': { + 'severity': 'high', + 'title': 'Missing Strict-Transport-Security (HSTS) header', + 'description': 'HSTS ensures browsers only connect via HTTPS, preventing SSL stripping attacks.', + 'remediation': 'Add header: Strict-Transport-Security: max-age=31536000; includeSubDomains; preload' + }, + 'content-security-policy': { + 'severity': 'high', + 'title': 'Missing Content-Security-Policy (CSP) header', + 'description': 'CSP helps prevent XSS attacks by controlling which resources can be loaded.', + 'remediation': "Add a Content-Security-Policy header. Start with: Content-Security-Policy: default-src 'self'" + }, + 'x-frame-options': { + 'severity': 'medium', + 'title': 'Missing X-Frame-Options header', + 'description': 'X-Frame-Options prevents clickjacking attacks by controlling iframe embedding.', + 'remediation': 'Add header: X-Frame-Options: DENY or X-Frame-Options: SAMEORIGIN' + }, + 'x-content-type-options': { + 'severity': 'medium', + 'title': 'Missing X-Content-Type-Options header', + 'description': 'Prevents MIME type sniffing, reducing risk of drive-by downloads.', + 'remediation': 'Add header: X-Content-Type-Options: nosniff' + }, + 'referrer-policy': { + 'severity': 'low', + 'title': 'Missing Referrer-Policy header', + 'description': 'Controls how much referrer information is sent with requests.', + 'remediation': 'Add header: Referrer-Policy: strict-origin-when-cross-origin' + }, + 'permissions-policy': { + 'severity': 'low', + 'title': 'Missing Permissions-Policy header', + 'description': 'Controls which browser features can be used by the page.', + 'remediation': 'Add header: Permissions-Policy: geolocation=(), camera=(), microphone=()' + }, + } + + def __init__(self, config: dict = None): + super().__init__(config) + self.timeout = self.config.get('timeout', 30) + + def run(self, url: str) -> ScannerResult: + """ + Run HTTP headers and TLS security checks. + + Args: + url: The URL to scan + + Returns: + ScannerResult with header analysis + """ + start_time = time.time() + issues = [] + metrics = [] + raw_data = { + 'headers': {}, + 'cookies': [], + 'tls': {}, + 'redirects': [], + } + + try: + # Create HTTP client with redirect following + with httpx.Client( + timeout=self.timeout, + follow_redirects=True, + verify=True + ) as client: + + # Make GET request + response = client.get(url) + + # Store response headers + raw_data['headers'] = dict(response.headers) + raw_data['status_code'] = response.status_code + raw_data['final_url'] = str(response.url) + + # Track redirects + raw_data['redirects'] = [ + { + 'url': str(r.url), + 'status_code': r.status_code + } + for r in response.history + ] + + # Check for HTTP to HTTPS redirect + redirect_issues = self._check_https_redirect(url, response) + issues.extend(redirect_issues) + + # Check security headers + header_issues = self._check_security_headers(response.headers) + issues.extend(header_issues) + + # Check cookies + cookie_issues, cookie_data = self._check_cookies(response) + issues.extend(cookie_issues) + raw_data['cookies'] = cookie_data + + # Check CORS + cors_issues = self._check_cors(response.headers, url) + issues.extend(cors_issues) + + # Check for information disclosure + info_issues = self._check_info_disclosure(response.headers) + issues.extend(info_issues) + + # Check TLS certificate + tls_issues, tls_data = self._check_tls(url) + issues.extend(tls_issues) + raw_data['tls'] = tls_data + + # Add metrics + metrics.append(self._create_metric( + 'security_headers_present', + 'Security Headers Present', + self._count_present_headers(raw_data['headers']), + 'count' + )) + metrics.append(self._create_metric( + 'security_headers_total', + 'Total Security Headers Checked', + len(self.SECURITY_HEADERS), + 'count' + )) + + execution_time = time.time() - start_time + + return ScannerResult( + status=ScannerStatus.SUCCESS, + scanner_name=self.name, + issues=issues, + metrics=metrics, + raw_data=raw_data, + execution_time_seconds=execution_time + ) + + except httpx.TimeoutException as e: + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Request timed out: {e}", + execution_time_seconds=time.time() - start_time + ) + except httpx.RequestError as e: + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Request failed: {e}", + execution_time_seconds=time.time() - start_time + ) + except Exception as e: + logger.exception(f"Header check failed for {url}") + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Unexpected error: {e}", + execution_time_seconds=time.time() - start_time + ) + + def _check_security_headers(self, headers: httpx.Headers) -> List[dict]: + """Check for missing or misconfigured security headers.""" + issues = [] + headers_lower = {k.lower(): v for k, v in headers.items()} + + for header_name, config in self.SECURITY_HEADERS.items(): + if header_name not in headers_lower: + issues.append(self._create_issue( + category='headers', + severity=config['severity'], + title=config['title'], + description=config['description'], + remediation=config['remediation'] + )) + else: + # Check for weak configurations + value = headers_lower[header_name] + + if header_name == 'strict-transport-security': + # Check for short max-age + if 'max-age' in value.lower(): + try: + max_age = int(value.lower().split('max-age=')[1].split(';')[0]) + if max_age < 31536000: # Less than 1 year + issues.append(self._create_issue( + category='headers', + severity='low', + title='HSTS max-age is too short', + description=f'HSTS max-age is {max_age} seconds. Recommend at least 1 year (31536000).', + remediation='Increase max-age to at least 31536000 seconds.' + )) + except (IndexError, ValueError): + pass + + elif header_name == 'content-security-policy': + # Check for unsafe directives + if "'unsafe-inline'" in value or "'unsafe-eval'" in value: + issues.append(self._create_issue( + category='headers', + severity='medium', + title='CSP contains unsafe directives', + description="Content-Security-Policy contains 'unsafe-inline' or 'unsafe-eval' which weakens XSS protection.", + remediation="Remove 'unsafe-inline' and 'unsafe-eval' from CSP. Use nonces or hashes instead." + )) + + return issues + + def _check_https_redirect(self, original_url: str, response: httpx.Response) -> List[dict]: + """Check for HTTP to HTTPS redirect.""" + issues = [] + parsed = urlparse(original_url) + + if parsed.scheme == 'http': + final_url = str(response.url) + final_parsed = urlparse(final_url) + + if final_parsed.scheme == 'https': + # Good - redirects to HTTPS + pass + else: + issues.append(self._create_issue( + category='tls', + severity='high', + title='No HTTP to HTTPS redirect', + description='The site does not redirect HTTP requests to HTTPS, allowing insecure connections.', + remediation='Configure server to redirect all HTTP traffic to HTTPS (301 redirect).' + )) + + return issues + + def _check_cookies(self, response: httpx.Response) -> Tuple[List[dict], List[dict]]: + """Check cookie security flags.""" + issues = [] + cookie_data = [] + + # Get cookies from Set-Cookie headers + set_cookie_headers = response.headers.get_list('set-cookie') + + for cookie_header in set_cookie_headers: + cookie_info = self._parse_cookie(cookie_header) + cookie_data.append(cookie_info) + + cookie_name = cookie_info.get('name', 'Unknown') + is_https = str(response.url).startswith('https') + + # Check Secure flag + if is_https and not cookie_info.get('secure'): + issues.append(self._create_issue( + category='security', + severity='medium', + title=f"Cookie '{cookie_name}' missing Secure flag", + description='Cookie transmitted over HTTPS should have the Secure flag to prevent transmission over HTTP.', + remediation='Add the Secure flag to the Set-Cookie header.' + )) + + # Check HttpOnly flag (for session-like cookies) + if not cookie_info.get('httponly'): + # Check if it looks like a session cookie + session_indicators = ['session', 'sess', 'auth', 'token', 'jwt', 'csrf'] + if any(ind in cookie_name.lower() for ind in session_indicators): + issues.append(self._create_issue( + category='security', + severity='medium', + title=f"Cookie '{cookie_name}' missing HttpOnly flag", + description='Session cookies should have HttpOnly flag to prevent JavaScript access (XSS protection).', + remediation='Add the HttpOnly flag to the Set-Cookie header.' + )) + + # Check SameSite attribute + if not cookie_info.get('samesite'): + issues.append(self._create_issue( + category='security', + severity='low', + title=f"Cookie '{cookie_name}' missing SameSite attribute", + description='SameSite attribute helps prevent CSRF attacks.', + remediation='Add SameSite=Strict or SameSite=Lax to the Set-Cookie header.' + )) + + return issues, cookie_data + + def _parse_cookie(self, cookie_header: str) -> dict: + """Parse a Set-Cookie header into a dictionary.""" + parts = cookie_header.split(';') + + # First part is name=value + name_value = parts[0].strip() + if '=' in name_value: + name, value = name_value.split('=', 1) + else: + name, value = name_value, '' + + cookie_info = { + 'name': name.strip(), + 'value': value[:50] + '...' if len(value) > 50 else value, # Truncate for privacy + 'secure': False, + 'httponly': False, + 'samesite': None, + 'path': None, + 'domain': None, + 'expires': None, + } + + for part in parts[1:]: + part = part.strip().lower() + if part == 'secure': + cookie_info['secure'] = True + elif part == 'httponly': + cookie_info['httponly'] = True + elif part.startswith('samesite='): + cookie_info['samesite'] = part.split('=')[1] + elif part.startswith('path='): + cookie_info['path'] = part.split('=')[1] + elif part.startswith('domain='): + cookie_info['domain'] = part.split('=')[1] + + return cookie_info + + def _check_cors(self, headers: httpx.Headers, url: str) -> List[dict]: + """Check CORS configuration for security issues.""" + issues = [] + + acao = headers.get('access-control-allow-origin', '').lower() + acac = headers.get('access-control-allow-credentials', '').lower() + + if acao == '*': + if acac == 'true': + issues.append(self._create_issue( + category='cors', + severity='critical', + title='Dangerous CORS configuration', + description='Access-Control-Allow-Origin: * combined with Access-Control-Allow-Credentials: true is a security vulnerability.', + remediation='Never use wildcard origin with credentials. Specify allowed origins explicitly.' + )) + else: + issues.append(self._create_issue( + category='cors', + severity='info', + title='CORS allows all origins', + description='Access-Control-Allow-Origin is set to *, allowing any website to make requests.', + remediation='Consider restricting CORS to specific trusted origins if the API handles sensitive data.' + )) + + return issues + + def _check_info_disclosure(self, headers: httpx.Headers) -> List[dict]: + """Check for information disclosure in headers.""" + issues = [] + + # Check Server header + server = headers.get('server', '') + if server: + # Check for version numbers + import re + if re.search(r'\d+\.\d+', server): + issues.append(self._create_issue( + category='security', + severity='low', + title='Server version disclosed', + description=f'The Server header reveals version information: {server}', + remediation='Configure the server to hide version information.' + )) + + # Check X-Powered-By + powered_by = headers.get('x-powered-by', '') + if powered_by: + issues.append(self._create_issue( + category='security', + severity='low', + title='X-Powered-By header present', + description=f'The X-Powered-By header reveals technology: {powered_by}', + remediation='Remove the X-Powered-By header to reduce information disclosure.' + )) + + # Check X-AspNet-Version + aspnet = headers.get('x-aspnet-version', '') + if aspnet: + issues.append(self._create_issue( + category='security', + severity='low', + title='ASP.NET version disclosed', + description=f'X-AspNet-Version header reveals: {aspnet}', + remediation='Disable ASP.NET version header in web.config.' + )) + + return issues + + def _check_tls(self, url: str) -> Tuple[List[dict], dict]: + """Check TLS/SSL certificate validity.""" + issues = [] + tls_data = { + 'has_tls': False, + 'certificate_valid': None, + 'issuer': None, + 'expires': None, + 'protocol': None, + } + + parsed = urlparse(url) + + # Only check HTTPS URLs + if parsed.scheme != 'https': + issues.append(self._create_issue( + category='tls', + severity='high', + title='Site not served over HTTPS', + description='The site is served over unencrypted HTTP, exposing data to interception.', + remediation='Enable HTTPS with a valid TLS certificate.' + )) + return issues, tls_data + + hostname = parsed.hostname + port = parsed.port or 443 + + try: + context = ssl.create_default_context() + + with socket.create_connection((hostname, port), timeout=10) as sock: + with context.wrap_socket(sock, server_hostname=hostname) as ssock: + tls_data['has_tls'] = True + tls_data['protocol'] = ssock.version() + + cert = ssock.getpeercert() + if cert: + tls_data['certificate_valid'] = True + tls_data['issuer'] = dict(x[0] for x in cert.get('issuer', [])) + tls_data['subject'] = dict(x[0] for x in cert.get('subject', [])) + tls_data['expires'] = cert.get('notAfter') + + # Check for weak protocols + if ssock.version() in ('SSLv2', 'SSLv3', 'TLSv1', 'TLSv1.1'): + issues.append(self._create_issue( + category='tls', + severity='high', + title=f'Weak TLS version: {ssock.version()}', + description='The server supports outdated TLS versions with known vulnerabilities.', + remediation='Disable TLS 1.0 and 1.1. Use TLS 1.2 or higher.' + )) + + except ssl.SSLCertVerificationError as e: + tls_data['certificate_valid'] = False + issues.append(self._create_issue( + category='tls', + severity='critical', + title='Invalid TLS certificate', + description=f'The TLS certificate failed validation: {e}', + remediation='Obtain and install a valid TLS certificate from a trusted CA.' + )) + except ssl.SSLError as e: + tls_data['has_tls'] = False + issues.append(self._create_issue( + category='tls', + severity='high', + title='TLS connection error', + description=f'Could not establish TLS connection: {e}', + remediation='Check TLS configuration and certificate installation.' + )) + except (socket.timeout, socket.error) as e: + # Network error, not a TLS issue + pass + + return issues, tls_data + + def _count_present_headers(self, headers: dict) -> int: + """Count how many security headers are present.""" + headers_lower = {k.lower(): v for k, v in headers.items()} + count = 0 + for header_name in self.SECURITY_HEADERS: + if header_name in headers_lower: + count += 1 + return count diff --git a/backend/scanner/lighthouse_scanner.py b/backend/scanner/lighthouse_scanner.py new file mode 100644 index 0000000..97db1cf --- /dev/null +++ b/backend/scanner/lighthouse_scanner.py @@ -0,0 +1,320 @@ +""" +Lighthouse Scanner Integration. + +This module integrates with the Lighthouse scanner service +to perform performance, accessibility, SEO, and best practices audits. +""" + +import time +import logging +from typing import Optional + +import httpx + +from django.conf import settings + +from .base import BaseScanner, ScannerResult, ScannerStatus + +logger = logging.getLogger('scanner') + + +class LighthouseScanner(BaseScanner): + """ + Scanner that integrates with the Lighthouse service. + + Lighthouse audits: + - Performance (FCP, LCP, TTI, TBT, CLS, Speed Index) + - Accessibility + - Best Practices + - SEO + """ + + name = "lighthouse" + + def __init__(self, config: dict = None): + super().__init__(config) + self.service_url = self.config.get( + 'lighthouse_url', + settings.SCANNER_CONFIG.get('LIGHTHOUSE_URL', 'http://lighthouse:3001') + ) + self.timeout = self.config.get('timeout', 120) + + def is_available(self) -> bool: + """Check if Lighthouse service is available.""" + try: + response = httpx.get( + f"{self.service_url}/health", + timeout=5 + ) + return response.status_code == 200 + except Exception as e: + self.logger.warning(f"Lighthouse service not available: {e}") + return False + + def run(self, url: str) -> ScannerResult: + """ + Run Lighthouse audit on the given URL. + + Args: + url: The URL to audit + + Returns: + ScannerResult with Lighthouse data + """ + start_time = time.time() + + if not self.is_available(): + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message="Lighthouse service is not available", + execution_time_seconds=time.time() - start_time + ) + + try: + # Call Lighthouse service + response = httpx.post( + f"{self.service_url}/scan", + json={"url": url}, + timeout=self.timeout + ) + + if response.status_code != 200: + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Lighthouse returned status {response.status_code}: {response.text}", + execution_time_seconds=time.time() - start_time + ) + + data = response.json() + + # Extract scores + scores = { + 'performance': data.get('scores', {}).get('performance'), + 'accessibility': data.get('scores', {}).get('accessibility'), + 'best_practices': data.get('scores', {}).get('bestPractices'), + 'seo': data.get('scores', {}).get('seo'), + } + + # Extract metrics + metrics = self._extract_metrics(data) + + # Extract issues + issues = self._extract_issues(data) + + execution_time = time.time() - start_time + + return ScannerResult( + status=ScannerStatus.SUCCESS, + scanner_name=self.name, + scores=scores, + metrics=metrics, + issues=issues, + raw_data=data, + execution_time_seconds=execution_time + ) + + except httpx.TimeoutException: + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message="Lighthouse scan timed out", + execution_time_seconds=time.time() - start_time + ) + except httpx.RequestError as e: + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Lighthouse request failed: {e}", + execution_time_seconds=time.time() - start_time + ) + except Exception as e: + logger.exception(f"Lighthouse scan failed for {url}") + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Unexpected error: {e}", + execution_time_seconds=time.time() - start_time + ) + + def _extract_metrics(self, data: dict) -> list: + """Extract key metrics from Lighthouse data.""" + metrics = [] + + # Core Web Vitals and performance metrics + metrics_config = { + 'first_contentful_paint': ('First Contentful Paint', 'firstContentfulPaint', 'ms'), + 'largest_contentful_paint': ('Largest Contentful Paint', 'largestContentfulPaint', 'ms'), + 'speed_index': ('Speed Index', 'speedIndex', 'ms'), + 'time_to_interactive': ('Time to Interactive', 'timeToInteractive', 'ms'), + 'total_blocking_time': ('Total Blocking Time', 'totalBlockingTime', 'ms'), + 'cumulative_layout_shift': ('Cumulative Layout Shift', 'cumulativeLayoutShift', 'score'), + } + + lh_metrics = data.get('metrics', {}) + + for metric_name, (display_name, lh_key, unit) in metrics_config.items(): + metric_data = lh_metrics.get(lh_key, {}) + if metric_data and metric_data.get('value') is not None: + metrics.append(self._create_metric( + name=metric_name, + display_name=display_name, + value=metric_data['value'], + unit=unit, + score=metric_data.get('score') + )) + + # Resource metrics + resources = data.get('resources', {}) + diagnostics = data.get('diagnostics', {}) + + if resources.get('totalByteWeight'): + metrics.append(self._create_metric( + name='total_byte_weight', + display_name='Total Page Weight', + value=resources['totalByteWeight'], + unit='bytes' + )) + + if diagnostics.get('numRequests'): + metrics.append(self._create_metric( + name='num_requests', + display_name='Total Requests', + value=diagnostics['numRequests'], + unit='count' + )) + + if diagnostics.get('numScripts'): + metrics.append(self._create_metric( + name='num_scripts', + display_name='JavaScript Files', + value=diagnostics['numScripts'], + unit='count' + )) + + if diagnostics.get('totalTransferSize'): + metrics.append(self._create_metric( + name='total_transfer_size', + display_name='Total Transfer Size', + value=diagnostics['totalTransferSize'], + unit='bytes' + )) + + return metrics + + def _extract_issues(self, data: dict) -> list: + """Extract issues from Lighthouse audit results.""" + issues = [] + + # Convert Lighthouse issues to our format + lh_issues = data.get('issues', []) + + # Map Lighthouse categories to our categories + category_map = { + 'performance': 'performance', + 'accessibility': 'accessibility', + 'best-practices': 'best_practices', + 'seo': 'seo', + } + + for lh_issue in lh_issues: + # Determine severity based on score and impact + score = lh_issue.get('score', 0) + impact = lh_issue.get('impact', 0) + + if score == 0 and impact > 5: + severity = 'high' + elif score < 0.5 and impact > 3: + severity = 'medium' + elif score < 0.5: + severity = 'low' + else: + severity = 'info' + + category = category_map.get(lh_issue.get('category'), 'performance') + + issues.append(self._create_issue( + category=category, + severity=severity, + title=lh_issue.get('title', 'Unknown issue'), + description=lh_issue.get('description', ''), + raw_data={ + 'id': lh_issue.get('id'), + 'displayValue': lh_issue.get('displayValue'), + 'score': score, + 'impact': impact, + } + )) + + # Check for unused resources + resources = data.get('resources', {}) + + # Unused JavaScript + unused_js = resources.get('unusedJavascript', []) + for item in unused_js[:5]: # Top 5 + if item.get('wastedBytes', 0) > 50000: # > 50KB wasted + issues.append(self._create_issue( + category='performance', + severity='medium', + title='Unused JavaScript', + description=f"Remove unused JavaScript to reduce payload. {item.get('url', '')} has {item.get('wastedBytes', 0) / 1024:.1f}KB unused.", + remediation='Remove unused JavaScript code or use code splitting to load only what is needed.', + raw_data=item + )) + + # Unused CSS + unused_css = resources.get('unusedCss', []) + for item in unused_css[:5]: + if item.get('wastedBytes', 0) > 20000: # > 20KB wasted + issues.append(self._create_issue( + category='performance', + severity='low', + title='Unused CSS', + description=f"Remove unused CSS rules. {item.get('url', '')} has {item.get('wastedBytes', 0) / 1024:.1f}KB unused.", + remediation='Use tools like PurgeCSS to remove unused CSS.', + raw_data=item + )) + + # Render-blocking resources + blocking = resources.get('renderBlockingResources', []) + if len(blocking) > 3: + issues.append(self._create_issue( + category='performance', + severity='medium', + title='Multiple render-blocking resources', + description=f'Found {len(blocking)} render-blocking resources that delay page rendering.', + remediation='Defer non-critical JavaScript and inline critical CSS.', + raw_data={'resources': blocking[:10]} + )) + + # Large JavaScript bundles + large_scripts = resources.get('scriptTreemap', []) + for script in large_scripts[:5]: + if script.get('resourceBytes', 0) > 500000: # > 500KB + issues.append(self._create_issue( + category='resources', + severity='medium', + title='Large JavaScript bundle', + description=f"Large script bundle detected: {script.get('name', 'Unknown')} ({script.get('resourceBytes', 0) / 1024:.1f}KB)", + remediation='Consider code splitting and lazy loading to reduce bundle size.', + raw_data=script + )) + + # Third-party impact + third_party = resources.get('thirdPartySummary', []) + high_impact_third_party = [ + tp for tp in third_party + if tp.get('blockingTime', 0) > 500 # > 500ms blocking + ] + if high_impact_third_party: + issues.append(self._create_issue( + category='performance', + severity='medium', + title='Third-party scripts impacting performance', + description=f'{len(high_impact_third_party)} third-party scripts are significantly impacting page load time.', + remediation='Consider lazy loading third-party scripts or using async/defer attributes.', + raw_data={'third_parties': high_impact_third_party} + )) + + return issues diff --git a/backend/scanner/playwright_scanner.py b/backend/scanner/playwright_scanner.py new file mode 100644 index 0000000..7be2f69 --- /dev/null +++ b/backend/scanner/playwright_scanner.py @@ -0,0 +1,432 @@ +""" +Playwright Browser Scanner. + +This module uses Playwright to perform browser-based analysis, +including console error capture, resource loading, and basic +memory usage indicators. +""" + +import time +import logging +import asyncio +from typing import Dict, List, Optional, Tuple + +from django.conf import settings + +from .base import BaseScanner, ScannerResult, ScannerStatus + +logger = logging.getLogger('scanner') + + +class PlaywrightScanner(BaseScanner): + """ + Browser-based scanner using Playwright. + + Captures: + - Console errors and warnings + - Network request metrics + - Large images and resources + - JavaScript errors + - Memory usage indicators + - Page load timing + """ + + name = "playwright" + + def __init__(self, config: dict = None): + super().__init__(config) + self.timeout = self.config.get( + 'timeout', + settings.SCANNER_CONFIG.get('PLAYWRIGHT_TIMEOUT', 30000) + ) + self.viewport = self.config.get( + 'viewport', + settings.SCANNER_CONFIG.get('PLAYWRIGHT_VIEWPORT', {'width': 1920, 'height': 1080}) + ) + self.large_image_threshold = settings.SCANNER_CONFIG.get( + 'LARGE_IMAGE_THRESHOLD_BYTES', 1024 * 1024 + ) + + def is_available(self) -> bool: + """Check if Playwright is available.""" + try: + from playwright.sync_api import sync_playwright + return True + except ImportError: + self.logger.warning("Playwright not installed") + return False + + def run(self, url: str) -> ScannerResult: + """ + Run browser-based analysis using Playwright. + + Args: + url: The URL to analyze + + Returns: + ScannerResult with browser analysis data + """ + start_time = time.time() + + if not self.is_available(): + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message="Playwright is not available", + execution_time_seconds=time.time() - start_time + ) + + try: + from playwright.sync_api import sync_playwright + + with sync_playwright() as p: + # Launch browser + browser = p.chromium.launch( + headless=True, + args=[ + '--no-sandbox', + '--disable-dev-shm-usage', + '--disable-gpu', + '--disable-extensions', + ] + ) + + context = browser.new_context( + viewport=self.viewport, + user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' + ) + + page = context.new_page() + + # Collect data + console_messages = [] + network_requests = [] + failed_requests = [] + js_errors = [] + + # Console message handler + def handle_console(msg): + console_messages.append({ + 'type': msg.type, + 'text': msg.text[:500], # Truncate long messages + 'location': str(msg.location) if hasattr(msg, 'location') else None + }) + + # Request handler + def handle_request(request): + network_requests.append({ + 'url': request.url[:200], + 'method': request.method, + 'resource_type': request.resource_type, + }) + + # Response handler + def handle_response(response): + # Find the corresponding request + for req in network_requests: + if req['url'] == response.url[:200]: + req['status'] = response.status + try: + headers = response.headers + content_length = headers.get('content-length', '0') + req['size'] = int(content_length) if content_length else 0 + except: + req['size'] = 0 + break + + # Request failed handler + def handle_request_failed(request): + failed_requests.append({ + 'url': request.url[:200], + 'failure': request.failure, + 'resource_type': request.resource_type, + }) + + # Page error handler + def handle_page_error(error): + js_errors.append({ + 'message': str(error)[:500], + }) + + # Attach handlers + page.on('console', handle_console) + page.on('request', handle_request) + page.on('response', handle_response) + page.on('requestfailed', handle_request_failed) + page.on('pageerror', handle_page_error) + + # Navigate to page + load_start = time.time() + + try: + page.goto(url, timeout=self.timeout, wait_until='networkidle') + except Exception as e: + # Try with less strict wait condition + self.logger.warning(f"Network idle timeout, trying load: {e}") + page.goto(url, timeout=self.timeout, wait_until='load') + + load_time = (time.time() - load_start) * 1000 # ms + + # Wait a bit more for any async content + page.wait_for_timeout(2000) + + # Get performance metrics if available + performance_data = page.evaluate('''() => { + const timing = performance.timing; + const memory = performance.memory || {}; + return { + domContentLoaded: timing.domContentLoadedEventEnd - timing.navigationStart, + loadComplete: timing.loadEventEnd - timing.navigationStart, + domInteractive: timing.domInteractive - timing.navigationStart, + firstPaint: performance.getEntriesByType('paint').find(p => p.name === 'first-paint')?.startTime || null, + firstContentfulPaint: performance.getEntriesByType('paint').find(p => p.name === 'first-contentful-paint')?.startTime || null, + jsHeapSizeLimit: memory.jsHeapSizeLimit || null, + totalJSHeapSize: memory.totalJSHeapSize || null, + usedJSHeapSize: memory.usedJSHeapSize || null, + }; + }''') + + # Close browser + browser.close() + + # Process results + metrics = self._extract_metrics( + load_time, + performance_data, + network_requests + ) + + issues = self._extract_issues( + console_messages, + network_requests, + failed_requests, + js_errors, + performance_data + ) + + raw_data = { + 'console_messages': console_messages[:50], # Limit size + 'network_requests': network_requests[:100], + 'failed_requests': failed_requests, + 'js_errors': js_errors, + 'performance': performance_data, + 'load_time_ms': load_time, + } + + execution_time = time.time() - start_time + + return ScannerResult( + status=ScannerStatus.SUCCESS, + scanner_name=self.name, + metrics=metrics, + issues=issues, + raw_data=raw_data, + execution_time_seconds=execution_time + ) + + except Exception as e: + logger.exception(f"Playwright scan failed for {url}") + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"Browser scan failed: {e}", + execution_time_seconds=time.time() - start_time + ) + + def _extract_metrics( + self, + load_time: float, + performance_data: dict, + network_requests: list + ) -> list: + """Extract metrics from browser data.""" + metrics = [] + + # Page load time + metrics.append(self._create_metric( + name='page_load_time', + display_name='Page Load Time', + value=load_time, + unit='ms' + )) + + # DOM Content Loaded + if performance_data.get('domContentLoaded'): + metrics.append(self._create_metric( + name='dom_content_loaded', + display_name='DOM Content Loaded', + value=performance_data['domContentLoaded'], + unit='ms' + )) + + # DOM Interactive + if performance_data.get('domInteractive'): + metrics.append(self._create_metric( + name='dom_interactive', + display_name='DOM Interactive', + value=performance_data['domInteractive'], + unit='ms' + )) + + # Network metrics + total_requests = len(network_requests) + total_size = sum(r.get('size', 0) for r in network_requests) + + metrics.append(self._create_metric( + name='total_requests_playwright', + display_name='Total Network Requests', + value=total_requests, + unit='count' + )) + + metrics.append(self._create_metric( + name='total_download_size', + display_name='Total Downloaded', + value=total_size, + unit='bytes' + )) + + # Request type breakdown + scripts = [r for r in network_requests if r.get('resource_type') == 'script'] + stylesheets = [r for r in network_requests if r.get('resource_type') == 'stylesheet'] + images = [r for r in network_requests if r.get('resource_type') == 'image'] + fonts = [r for r in network_requests if r.get('resource_type') == 'font'] + + metrics.append(self._create_metric( + name='script_requests', + display_name='Script Requests', + value=len(scripts), + unit='count' + )) + + metrics.append(self._create_metric( + name='image_requests', + display_name='Image Requests', + value=len(images), + unit='count' + )) + + # Memory metrics + if performance_data.get('usedJSHeapSize'): + metrics.append(self._create_metric( + name='js_heap_used', + display_name='JS Heap Used', + value=performance_data['usedJSHeapSize'], + unit='bytes' + )) + + if performance_data.get('totalJSHeapSize'): + metrics.append(self._create_metric( + name='js_heap_total', + display_name='JS Heap Total', + value=performance_data['totalJSHeapSize'], + unit='bytes' + )) + + return metrics + + def _extract_issues( + self, + console_messages: list, + network_requests: list, + failed_requests: list, + js_errors: list, + performance_data: dict + ) -> list: + """Extract issues from browser data.""" + issues = [] + + # Console errors + errors = [m for m in console_messages if m.get('type') == 'error'] + if errors: + issues.append(self._create_issue( + category='content', + severity='medium', + title=f'{len(errors)} console error(s) detected', + description='JavaScript console errors were detected on the page.', + remediation='Review and fix JavaScript errors to improve user experience.', + raw_data={'errors': errors[:10]} + )) + + # Console warnings + warnings = [m for m in console_messages if m.get('type') == 'warning'] + if len(warnings) > 5: + issues.append(self._create_issue( + category='content', + severity='low', + title=f'{len(warnings)} console warning(s) detected', + description='Multiple JavaScript warnings were detected on the page.', + remediation='Review console warnings for potential issues.', + raw_data={'warnings': warnings[:10]} + )) + + # JavaScript page errors + if js_errors: + issues.append(self._create_issue( + category='content', + severity='high', + title=f'{len(js_errors)} JavaScript error(s) detected', + description='Uncaught JavaScript exceptions were detected.', + remediation='Fix JavaScript errors that could break page functionality.', + raw_data={'errors': js_errors} + )) + + # Failed network requests + if failed_requests: + issues.append(self._create_issue( + category='content', + severity='medium', + title=f'{len(failed_requests)} failed network request(s)', + description='Some resources failed to load.', + remediation='Ensure all resources are available and URLs are correct.', + raw_data={'failed': failed_requests} + )) + + # Large images + large_images = [ + r for r in network_requests + if r.get('resource_type') == 'image' and r.get('size', 0) > self.large_image_threshold + ] + if large_images: + issues.append(self._create_issue( + category='resources', + severity='medium', + title=f'{len(large_images)} large image(s) detected (>1MB)', + description='Large images slow down page load and increase bandwidth usage.', + remediation='Compress images and use modern formats like WebP or AVIF.', + raw_data={'images': [{'url': i['url'], 'size': i.get('size')} for i in large_images]} + )) + + # Too many requests + if len(network_requests) > 100: + issues.append(self._create_issue( + category='performance', + severity='medium', + title='High number of network requests', + description=f'Page makes {len(network_requests)} network requests, which can slow loading.', + remediation='Combine files, use sprites, and reduce third-party scripts.' + )) + + # High memory usage (potential memory issues) + used_heap = performance_data.get('usedJSHeapSize', 0) + total_heap = performance_data.get('totalJSHeapSize', 0) + + if used_heap > 100 * 1024 * 1024: # > 100MB + issues.append(self._create_issue( + category='resources', + severity='medium', + title='High JavaScript memory usage', + description=f'Page uses {used_heap / (1024*1024):.1f}MB of JavaScript heap memory.', + remediation='Review for memory leaks and optimize JavaScript memory usage.' + )) + + if total_heap > 0 and used_heap / total_heap > 0.9: + issues.append(self._create_issue( + category='resources', + severity='high', + title='JavaScript heap near capacity', + description='JavaScript heap is using >90% of available memory, risking out-of-memory errors.', + remediation='Investigate potential memory leaks and reduce memory consumption.' + )) + + return issues diff --git a/backend/scanner/runner.py b/backend/scanner/runner.py new file mode 100644 index 0000000..2103c0b --- /dev/null +++ b/backend/scanner/runner.py @@ -0,0 +1,315 @@ +""" +Scan Runner - Orchestrates multiple scanners. + +This module coordinates the execution of all scanners and +aggregates their results into a unified scan result. +""" + +import time +import logging +from typing import Dict, List, Optional, Type +from concurrent.futures import ThreadPoolExecutor, as_completed + +from django.conf import settings +from django.utils import timezone + +from websites.models import ( + Scan, Issue, Metric, ScanStatus, + IssueCategory, IssueSeverity, ScannerTool +) +from .base import BaseScanner, ScannerResult, ScannerStatus +from .headers_scanner import HeadersScanner +from .lighthouse_scanner import LighthouseScanner +from .playwright_scanner import PlaywrightScanner +from .zap_scanner import ZAPScanner +from .validators import validate_url + +logger = logging.getLogger('scanner') + + +class ScanRunner: + """ + Orchestrates the execution of multiple scanners. + + Manages the scan lifecycle: + 1. URL validation + 2. Scanner execution (parallel or sequential) + 3. Result aggregation + 4. Database persistence + """ + + # Available scanners in execution order + SCANNER_CLASSES: List[Type[BaseScanner]] = [ + HeadersScanner, # Fast, run first + LighthouseScanner, # Performance metrics + PlaywrightScanner, # Browser analysis + ZAPScanner, # Security scan (slowest) + ] + + def __init__(self, scan: Scan, config: dict = None): + """ + Initialize the scan runner. + + Args: + scan: The Scan model instance to update + config: Optional configuration overrides + """ + self.scan = scan + self.config = config or settings.SCANNER_CONFIG + self.results: Dict[str, ScannerResult] = {} + self.logger = logging.getLogger('scanner.runner') + + def run(self) -> bool: + """ + Execute all scanners and persist results. + + Returns: + True if scan completed successfully, False otherwise + """ + url = self.scan.website.url + + # Validate URL + is_valid, normalized_url, error = validate_url(url) + if not is_valid: + self._fail_scan(f"Invalid URL: {error}") + return False + + # Update scan status to running + self.scan.status = ScanStatus.RUNNING + self.scan.started_at = timezone.now() + self.scan.save(update_fields=['status', 'started_at']) + + self.logger.info(f"Starting scan for {url}") + + try: + # Run all scanners + self._run_scanners(normalized_url) + + # Aggregate and persist results + self._aggregate_results() + self._persist_results() + + # Determine final status + failed_scanners = [ + name for name, result in self.results.items() + if result.status == ScannerStatus.FAILED + ] + + if len(failed_scanners) == len(self.SCANNER_CLASSES): + self._fail_scan("All scanners failed") + return False + elif failed_scanners: + self.scan.status = ScanStatus.PARTIAL + self.scan.error_message = f"Some scanners failed: {', '.join(failed_scanners)}" + else: + self.scan.status = ScanStatus.DONE + + self.scan.completed_at = timezone.now() + self.scan.save() + + # Update website last_scanned_at + self.scan.website.last_scanned_at = timezone.now() + self.scan.website.save(update_fields=['last_scanned_at']) + + self.logger.info(f"Scan completed for {url} with status {self.scan.status}") + return True + + except Exception as e: + self.logger.exception(f"Scan failed for {url}") + self._fail_scan(str(e)) + return False + + def _run_scanners(self, url: str): + """ + Run all available scanners. + + Scanners are run sequentially to avoid resource conflicts, + especially for browser-based scanners. + """ + for scanner_class in self.SCANNER_CLASSES: + scanner = scanner_class(self.config) + scanner_name = scanner.name + + self.logger.info(f"Running {scanner_name} for {url}") + + try: + if not scanner.is_available(): + self.logger.warning(f"{scanner_name} is not available, skipping") + self.results[scanner_name] = ScannerResult( + status=ScannerStatus.SKIPPED, + scanner_name=scanner_name, + error_message="Scanner not available" + ) + continue + + result = scanner.run(url) + self.results[scanner_name] = result + + self.logger.info( + f"{scanner_name} completed with status {result.status.value} " + f"in {result.execution_time_seconds:.2f}s" + ) + + except Exception as e: + self.logger.exception(f"{scanner_name} failed with exception") + self.results[scanner_name] = ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=scanner_name, + error_message=str(e) + ) + + def _aggregate_results(self): + """Aggregate scores from all scanners.""" + # Lighthouse scores + lighthouse_result = self.results.get('lighthouse') + if lighthouse_result and lighthouse_result.status == ScannerStatus.SUCCESS: + scores = lighthouse_result.scores + self.scan.performance_score = scores.get('performance') + self.scan.accessibility_score = scores.get('accessibility') + self.scan.seo_score = scores.get('seo') + self.scan.best_practices_score = scores.get('best_practices') + self.scan.raw_lighthouse_data = lighthouse_result.raw_data + + # ZAP security score + zap_result = self.results.get('owasp_zap') + if zap_result and zap_result.status == ScannerStatus.SUCCESS: + self.scan.raw_zap_data = zap_result.raw_data + + # Playwright data + playwright_result = self.results.get('playwright') + if playwright_result and playwright_result.status == ScannerStatus.SUCCESS: + self.scan.raw_playwright_data = playwright_result.raw_data + + # Headers data + headers_result = self.results.get('header_check') + if headers_result and headers_result.status == ScannerStatus.SUCCESS: + self.scan.raw_headers_data = headers_result.raw_data + + def _persist_results(self): + """Persist all issues and metrics to the database.""" + # Collect all issues and metrics + all_issues = [] + all_metrics = [] + + for scanner_name, result in self.results.items(): + if result.status in (ScannerStatus.SUCCESS, ScannerStatus.PARTIAL): + all_issues.extend(result.issues) + all_metrics.extend(result.metrics) + + # Create Issue records + issue_objects = [] + for issue_data in all_issues: + # Map tool name to ScannerTool enum + tool = self._map_tool_name(issue_data.get('tool', '')) + + # Map severity + severity = issue_data.get('severity', 'info') + if severity not in [s.value for s in IssueSeverity]: + severity = 'info' + + # Map category + category = issue_data.get('category', 'security') + if category not in [c.value for c in IssueCategory]: + category = 'security' + + issue_objects.append(Issue( + scan=self.scan, + category=category, + severity=severity, + tool=tool, + title=issue_data.get('title', '')[:500], + description=issue_data.get('description', ''), + affected_url=issue_data.get('affected_url'), + remediation=issue_data.get('remediation'), + raw_data=issue_data.get('raw_data'), + )) + + # Bulk create issues + if issue_objects: + Issue.objects.bulk_create(issue_objects, ignore_conflicts=True) + + # Create Metric records + metric_objects = [] + seen_metrics = set() # Avoid duplicates + + for metric_data in all_metrics: + metric_name = metric_data.get('name', '') + if metric_name in seen_metrics: + continue + seen_metrics.add(metric_name) + + # Map source to ScannerTool enum + source = self._map_tool_name(metric_data.get('source', '')) + + metric_objects.append(Metric( + scan=self.scan, + name=metric_name, + display_name=metric_data.get('display_name', metric_name), + value=metric_data.get('value', 0), + unit=metric_data.get('unit', 'count'), + source=source, + score=metric_data.get('score'), + )) + + # Bulk create metrics + if metric_objects: + Metric.objects.bulk_create(metric_objects, ignore_conflicts=True) + + # Calculate security score based on issues + self.scan.calculate_security_score() + + # Calculate overall score + self.scan.calculate_overall_score() + + self.scan.save() + + def _map_tool_name(self, tool_name: str) -> str: + """Map scanner name to ScannerTool enum value.""" + tool_mapping = { + 'lighthouse': ScannerTool.LIGHTHOUSE, + 'owasp_zap': ScannerTool.ZAP, + 'playwright': ScannerTool.PLAYWRIGHT, + 'header_check': ScannerTool.HEADER_CHECK, + 'tls_check': ScannerTool.TLS_CHECK, + } + return tool_mapping.get(tool_name, ScannerTool.HEADER_CHECK) + + def _fail_scan(self, error_message: str): + """Mark scan as failed with error message.""" + self.scan.status = ScanStatus.FAILED + self.scan.error_message = error_message + self.scan.completed_at = timezone.now() + self.scan.save() + + # Create an issue for the failure + Issue.objects.create( + scan=self.scan, + category=IssueCategory.CONTENT, + severity=IssueSeverity.HIGH, + tool=ScannerTool.HEADER_CHECK, + title="Scan Failed", + description=f"The scan could not be completed: {error_message}", + remediation="Check the URL and try again. If the problem persists, contact support." + ) + + +def run_scan(scan_id: str) -> bool: + """ + Run a scan by ID. + + This is the main entry point for Celery tasks. + + Args: + scan_id: UUID of the scan to run + + Returns: + True if scan completed successfully + """ + try: + scan = Scan.objects.select_related('website').get(id=scan_id) + except Scan.DoesNotExist: + logger.error(f"Scan {scan_id} not found") + return False + + runner = ScanRunner(scan) + return runner.run() diff --git a/backend/scanner/validators.py b/backend/scanner/validators.py new file mode 100644 index 0000000..fb1b89d --- /dev/null +++ b/backend/scanner/validators.py @@ -0,0 +1,192 @@ +""" +URL validation and safety utilities. + +This module provides functions to validate URLs and ensure they +don't target internal/private networks (SSRF protection). +""" + +import ipaddress +import socket +import logging +from typing import Tuple, Optional +from urllib.parse import urlparse + +from django.conf import settings + +logger = logging.getLogger('scanner') + + +def get_blocked_ip_ranges() -> list: + """Get list of blocked IP ranges from settings.""" + return settings.SCANNER_CONFIG.get('BLOCKED_IP_RANGES', [ + '10.0.0.0/8', + '172.16.0.0/12', + '192.168.0.0/16', + '127.0.0.0/8', + '169.254.0.0/16', + '::1/128', + 'fc00::/7', + 'fe80::/10', + ]) + + +def get_blocked_hosts() -> list: + """Get list of blocked hostnames from settings.""" + return settings.SCANNER_CONFIG.get('BLOCKED_HOSTS', [ + 'localhost', + 'localhost.localdomain', + ]) + + +def is_private_ip(ip_str: str) -> bool: + """ + Check if an IP address is private, loopback, or otherwise blocked. + + Args: + ip_str: IP address string + + Returns: + True if the IP should be blocked + """ + try: + ip = ipaddress.ip_address(ip_str) + + # Check standard private/reserved ranges + if ip.is_private or ip.is_loopback or ip.is_link_local: + return True + if ip.is_reserved or ip.is_multicast: + return True + + # Check custom blocked ranges + for cidr in get_blocked_ip_ranges(): + try: + network = ipaddress.ip_network(cidr, strict=False) + if ip in network: + return True + except ValueError: + continue + + return False + + except ValueError: + # Invalid IP format + return True + + +def resolve_hostname(hostname: str) -> Optional[str]: + """ + Resolve a hostname to its IP address. + + Args: + hostname: The hostname to resolve + + Returns: + IP address string or None if resolution fails + """ + try: + # Get all addresses and return the first one + result = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC) + if result: + return result[0][4][0] + return None + except socket.gaierror: + return None + + +def validate_url(url: str) -> Tuple[bool, str, Optional[str]]: + """ + Validate a URL for scanning safety. + + Checks: + 1. URL format and scheme (must be http or https) + 2. Hostname is not in blocked list + 3. Resolved IP is not private/internal + + Args: + url: The URL to validate + + Returns: + Tuple of (is_valid, normalized_url, error_message) + """ + if not url: + return False, url, "URL is required" + + # Parse the URL + try: + parsed = urlparse(url) + except Exception as e: + return False, url, f"Invalid URL format: {e}" + + # Check scheme + if parsed.scheme not in ('http', 'https'): + return False, url, "URL must use http or https scheme" + + # Check hostname exists + if not parsed.netloc: + return False, url, "URL must have a valid hostname" + + # Extract hostname (without port) + hostname = parsed.hostname + if not hostname: + return False, url, "Could not extract hostname from URL" + + # Normalize hostname + hostname = hostname.lower() + + # Check against blocked hostnames + if hostname in get_blocked_hosts(): + return False, url, f"Scanning {hostname} is not allowed" + + # Check if hostname is an IP address + try: + ip = ipaddress.ip_address(hostname) + if is_private_ip(hostname): + return False, url, "Scanning private/internal IP addresses is not allowed" + except ValueError: + # Not an IP, it's a hostname - resolve it + resolved_ip = resolve_hostname(hostname) + if resolved_ip: + if is_private_ip(resolved_ip): + return False, url, f"URL resolves to private IP ({resolved_ip}), scanning not allowed" + # If we can't resolve, we'll let the scanner handle the error + + # Normalize the URL + # Remove trailing slash from path, lowercase scheme and host + normalized = f"{parsed.scheme}://{parsed.netloc.lower()}" + if parsed.path and parsed.path != '/': + normalized += parsed.path.rstrip('/') + if parsed.query: + normalized += f"?{parsed.query}" + + return True, normalized, None + + +def normalize_url(url: str) -> str: + """ + Normalize a URL to a canonical form. + + Args: + url: The URL to normalize + + Returns: + Normalized URL string + """ + is_valid, normalized, _ = validate_url(url) + return normalized if is_valid else url + + +def extract_domain(url: str) -> str: + """ + Extract the domain from a URL. + + Args: + url: The URL to extract domain from + + Returns: + Domain string + """ + try: + parsed = urlparse(url) + return parsed.netloc.lower() + except Exception: + return "" diff --git a/backend/scanner/zap_scanner.py b/backend/scanner/zap_scanner.py new file mode 100644 index 0000000..775821a --- /dev/null +++ b/backend/scanner/zap_scanner.py @@ -0,0 +1,341 @@ +""" +OWASP ZAP Security Scanner Integration. + +This module integrates with OWASP ZAP (Zed Attack Proxy) to perform +security vulnerability scanning. +""" + +import time +import logging +from typing import Dict, List, Optional + +import httpx + +from django.conf import settings + +from .base import BaseScanner, ScannerResult, ScannerStatus + +logger = logging.getLogger('scanner') + + +class ZAPScanner(BaseScanner): + """ + Security scanner using OWASP ZAP. + + Performs: + - Spider crawling + - Passive scanning + - Active scanning (optional) + - Security vulnerability detection + """ + + name = "owasp_zap" + + # ZAP risk levels mapped to our severity + RISK_MAPPING = { + '0': 'info', # Informational + '1': 'low', # Low + '2': 'medium', # Medium + '3': 'high', # High + } + + def __init__(self, config: dict = None): + super().__init__(config) + self.zap_url = self.config.get( + 'zap_url', + settings.SCANNER_CONFIG.get('ZAP_HOST', 'http://zap:8080') + ) + self.api_key = self.config.get( + 'api_key', + settings.SCANNER_CONFIG.get('ZAP_API_KEY', '') + ) + self.timeout = self.config.get( + 'timeout', + settings.SCANNER_CONFIG.get('ZAP_TIMEOUT', 120) + ) + # Whether to run active scan (slower, more intrusive) + self.active_scan = self.config.get('active_scan', False) + + def is_available(self) -> bool: + """Check if ZAP is available.""" + try: + response = httpx.get( + f"{self.zap_url}/JSON/core/view/version/", + params={'apikey': self.api_key}, + timeout=10 + ) + return response.status_code == 200 + except Exception as e: + self.logger.warning(f"ZAP not available: {e}") + return False + + def run(self, url: str) -> ScannerResult: + """ + Run ZAP security scan on the given URL. + + Args: + url: The URL to scan + + Returns: + ScannerResult with security findings + """ + start_time = time.time() + + if not self.is_available(): + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message="OWASP ZAP is not available. Check ZAP service configuration.", + execution_time_seconds=time.time() - start_time + ) + + try: + # Access the URL to seed ZAP + self._access_url(url) + + # Run spider to crawl the site + self._run_spider(url) + + # Wait for passive scan to complete + self._wait_for_passive_scan() + + # Optionally run active scan + if self.active_scan: + self._run_active_scan(url) + + # Get alerts + alerts = self._get_alerts(url) + + # Process alerts into issues + issues = self._process_alerts(alerts) + + # Calculate security score based on findings + scores = self._calculate_scores(issues) + + raw_data = { + 'alerts': alerts, + 'alert_count': len(alerts), + 'scan_type': 'active' if self.active_scan else 'passive', + } + + execution_time = time.time() - start_time + + return ScannerResult( + status=ScannerStatus.SUCCESS, + scanner_name=self.name, + scores=scores, + issues=issues, + raw_data=raw_data, + execution_time_seconds=execution_time + ) + + except httpx.TimeoutException: + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message="ZAP scan timed out", + execution_time_seconds=time.time() - start_time + ) + except Exception as e: + logger.exception(f"ZAP scan failed for {url}") + return ScannerResult( + status=ScannerStatus.FAILED, + scanner_name=self.name, + error_message=f"ZAP scan error: {e}", + execution_time_seconds=time.time() - start_time + ) + + def _zap_request(self, endpoint: str, params: dict = None) -> dict: + """Make a request to ZAP API.""" + params = params or {} + params['apikey'] = self.api_key + + response = httpx.get( + f"{self.zap_url}{endpoint}", + params=params, + timeout=self.timeout + ) + response.raise_for_status() + return response.json() + + def _access_url(self, url: str): + """Access URL through ZAP to initialize scanning.""" + self.logger.info(f"Accessing URL through ZAP: {url}") + self._zap_request('/JSON/core/action/accessUrl/', {'url': url}) + time.sleep(2) # Give ZAP time to process + + def _run_spider(self, url: str): + """Run ZAP spider to crawl the site.""" + self.logger.info(f"Starting ZAP spider for: {url}") + + # Start spider + result = self._zap_request('/JSON/spider/action/scan/', { + 'url': url, + 'maxChildren': '10', # Limit crawl depth + 'recurse': 'true', + }) + + scan_id = result.get('scan') + if not scan_id: + return + + # Wait for spider to complete (with timeout) + max_wait = 60 # seconds + waited = 0 + while waited < max_wait: + status = self._zap_request('/JSON/spider/view/status/', {'scanId': scan_id}) + progress = int(status.get('status', '100')) + + if progress >= 100: + break + + time.sleep(2) + waited += 2 + + self.logger.info("Spider completed") + + def _wait_for_passive_scan(self): + """Wait for passive scanning to complete.""" + self.logger.info("Waiting for passive scan...") + + max_wait = 30 + waited = 0 + while waited < max_wait: + result = self._zap_request('/JSON/pscan/view/recordsToScan/') + records = int(result.get('recordsToScan', '0')) + + if records == 0: + break + + time.sleep(2) + waited += 2 + + self.logger.info("Passive scan completed") + + def _run_active_scan(self, url: str): + """Run active security scan (more intrusive).""" + self.logger.info(f"Starting active scan for: {url}") + + result = self._zap_request('/JSON/ascan/action/scan/', { + 'url': url, + 'recurse': 'true', + 'inScopeOnly': 'false', + }) + + scan_id = result.get('scan') + if not scan_id: + return + + # Wait for active scan (with timeout) + max_wait = 120 + waited = 0 + while waited < max_wait: + status = self._zap_request('/JSON/ascan/view/status/', {'scanId': scan_id}) + progress = int(status.get('status', '100')) + + if progress >= 100: + break + + time.sleep(5) + waited += 5 + + self.logger.info("Active scan completed") + + def _get_alerts(self, url: str) -> List[dict]: + """Get all security alerts for the URL.""" + result = self._zap_request('/JSON/core/view/alerts/', { + 'baseurl': url, + 'start': '0', + 'count': '100', # Limit alerts + }) + + return result.get('alerts', []) + + def _process_alerts(self, alerts: List[dict]) -> List[dict]: + """Convert ZAP alerts to our issue format.""" + issues = [] + + for alert in alerts: + risk = alert.get('risk', '0') + severity = self.RISK_MAPPING.get(risk, 'info') + + # Determine category based on alert type + category = self._categorize_alert(alert) + + # Build remediation from ZAP's solution + remediation = alert.get('solution', '') + if alert.get('reference'): + remediation += f"\n\nReferences: {alert.get('reference')}" + + issues.append(self._create_issue( + category=category, + severity=severity, + title=alert.get('alert', 'Unknown vulnerability'), + description=alert.get('description', ''), + affected_url=alert.get('url'), + remediation=remediation.strip() if remediation else None, + raw_data={ + 'pluginId': alert.get('pluginId'), + 'cweid': alert.get('cweid'), + 'wascid': alert.get('wascid'), + 'evidence': alert.get('evidence', '')[:200], # Truncate + 'param': alert.get('param'), + 'attack': alert.get('attack', '')[:200], + 'confidence': alert.get('confidence'), + } + )) + + return issues + + def _categorize_alert(self, alert: dict) -> str: + """Categorize ZAP alert into our categories.""" + alert_name = alert.get('alert', '').lower() + cwe_id = alert.get('cweid', '') + + # XSS related + if 'xss' in alert_name or 'cross-site scripting' in alert_name or cwe_id == '79': + return 'security' + + # SQL Injection + if 'sql' in alert_name and 'injection' in alert_name or cwe_id == '89': + return 'security' + + # Header related + if any(h in alert_name for h in ['header', 'csp', 'hsts', 'x-frame', 'x-content-type']): + return 'headers' + + # Cookie related + if 'cookie' in alert_name: + return 'security' + + # TLS/SSL related + if any(t in alert_name for t in ['ssl', 'tls', 'certificate', 'https']): + return 'tls' + + # CORS related + if 'cors' in alert_name or 'cross-origin' in alert_name: + return 'cors' + + # Default to security + return 'security' + + def _calculate_scores(self, issues: List[dict]) -> dict: + """Calculate security score based on issues found.""" + # Start at 100, deduct based on severity + score = 100 + + severity_deductions = { + 'critical': 25, + 'high': 15, + 'medium': 8, + 'low': 3, + 'info': 1, + } + + for issue in issues: + severity = issue.get('severity', 'info') + score -= severity_deductions.get(severity, 0) + + return { + 'zap_security': max(0, min(100, score)) + } diff --git a/backend/templates/index.html b/backend/templates/index.html new file mode 100644 index 0000000..f17c369 --- /dev/null +++ b/backend/templates/index.html @@ -0,0 +1,250 @@ +{% extends "base.html" %} + +{% block title %}Website Analyzer - Scan Your Website{% endblock %} + +{% block content %} +
+ +
+

+ Analyze Your Website +

+

+ Get comprehensive insights on performance, security, accessibility, and SEO. + Powered by Lighthouse, OWASP ZAP, and advanced browser analysis. +

+
+ + +
+
+
+
+ + +
+ +
+ + +
+
+ + + + +
+
+
+
+ + +
+ +
+
+
+ + + +
+

Performance

+
+

+ Core Web Vitals, load times, bundle sizes, and render-blocking resources. +

+
+ + +
+
+
+ + + +
+

Security

+
+

+ OWASP ZAP scanning, security headers, TLS/SSL, and cookie analysis. +

+
+ + +
+
+
+ + + +
+

Accessibility

+
+

+ WCAG compliance, ARIA labels, color contrast, and keyboard navigation. +

+
+ + +
+
+
+ + + +
+

SEO

+
+

+ Meta tags, structured data, crawlability, and search engine optimization. +

+
+
+ + +
+

Recent Scans

+
+ + + + + + + + + + + + + +
WebsiteStatusScoreDateActions
+
+
+
+ + +{% endblock %} diff --git a/backend/templates/scan_detail.html b/backend/templates/scan_detail.html new file mode 100644 index 0000000..28b4e18 --- /dev/null +++ b/backend/templates/scan_detail.html @@ -0,0 +1,496 @@ +{% extends "base.html" %} + +{% block title %}Scan Results - Website Analyzer{% endblock %} + +{% block content %} +
+ +
+
+ + + + +

Loading scan results...

+
+
+ + +
+
+
+
+ + + +
+

Scan in Progress

+

+ Analyzing +

+
+

⏱️ This typically takes 1-3 minutes

+

🔍 Running Lighthouse, OWASP ZAP, and browser analysis

+
+
+
+
+
+
+
+
+
+ + +
+
+
+ + + +

Error Loading Scan

+

+ + Start New Scan + +
+
+
+ + +
+ +
+
+
+

+ Scan Results +

+

+ +

+

+ Scanned +

+
+
+ +
+
+
+ + +
+ +
+
+ + + + + +
+

Overall Score

+
+ + +
+
+

Performance

+
+ + +
+
+

Security

+
+ + +
+
+

Accessibility

+
+ + +
+
+

SEO

+
+
+ + +
+

Key Metrics

+
+ +
+
+ + +
+
+
+

+ Issues Found + +

+ +
+ + + + +
+
+
+ + +
+ + + +
+ + + +

No issues found!

+

No severity issues found.

+
+
+
+ + +
+
+

Issues by Category

+ +
+
+

Issues by Severity

+ +
+
+ + +
+
+ + + +
+

Scan Warning

+

+
+
+
+ + +
+ + New Scan + + +
+
+
+ + +{% endblock %} diff --git a/backend/tests/test_scans.py b/backend/tests/test_scans.py new file mode 100644 index 0000000..0343889 --- /dev/null +++ b/backend/tests/test_scans.py @@ -0,0 +1,262 @@ +""" +Tests for scan creation and management. +""" + +import pytest +import json +from unittest.mock import patch, MagicMock +from django.test import TestCase, Client +from django.urls import reverse +from rest_framework.test import APITestCase +from rest_framework import status + +from websites.models import Website, Scan, Issue, Metric, ScanStatus + + +class TestScanCreation(APITestCase): + """Tests for scan creation via API.""" + + def test_create_scan_valid_url(self): + """Test creating a scan with a valid URL.""" + response = self.client.post( + '/api/scans/', + {'url': 'https://example.com'}, + format='json' + ) + + assert response.status_code == status.HTTP_201_CREATED + assert 'id' in response.data + assert response.data['status'] == 'pending' + + def test_create_scan_creates_website(self): + """Test that creating a scan also creates a Website record.""" + response = self.client.post( + '/api/scans/', + {'url': 'https://newsite.example.com'}, + format='json' + ) + + assert response.status_code == status.HTTP_201_CREATED + assert Website.objects.filter(url='https://newsite.example.com').exists() + + def test_create_scan_reuses_website(self): + """Test that scanning same URL reuses Website record.""" + # Create first scan + self.client.post( + '/api/scans/', + {'url': 'https://example.com'}, + format='json' + ) + + # Create second scan for same URL + self.client.post( + '/api/scans/', + {'url': 'https://example.com'}, + format='json' + ) + + # Should only have one website + assert Website.objects.filter(domain='example.com').count() == 1 + # But two scans + assert Scan.objects.count() == 2 + + def test_create_scan_invalid_url(self): + """Test that invalid URL is rejected.""" + response = self.client.post( + '/api/scans/', + {'url': 'not-a-valid-url'}, + format='json' + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_create_scan_localhost_blocked(self): + """Test that localhost URLs are blocked.""" + response = self.client.post( + '/api/scans/', + {'url': 'http://localhost:8000'}, + format='json' + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_create_scan_private_ip_blocked(self): + """Test that private IP URLs are blocked.""" + response = self.client.post( + '/api/scans/', + {'url': 'http://192.168.1.1'}, + format='json' + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +class TestScanRetrieval(APITestCase): + """Tests for retrieving scan results.""" + + def setUp(self): + """Set up test data.""" + self.website = Website.objects.create( + url='https://example.com', + domain='example.com' + ) + self.scan = Scan.objects.create( + website=self.website, + status=ScanStatus.DONE, + performance_score=85, + accessibility_score=90, + seo_score=75, + best_practices_score=80, + security_score=70, + overall_score=80 + ) + # Add some issues + Issue.objects.create( + scan=self.scan, + category='security', + severity='medium', + tool='header_check', + title='Missing CSP Header', + description='Content-Security-Policy header is not set.' + ) + Issue.objects.create( + scan=self.scan, + category='performance', + severity='low', + tool='lighthouse', + title='Large JavaScript bundle', + description='Bundle size exceeds recommended limit.' + ) + # Add some metrics + Metric.objects.create( + scan=self.scan, + name='first_contentful_paint', + display_name='First Contentful Paint', + value=1200, + unit='ms', + source='lighthouse' + ) + + def test_get_scan_detail(self): + """Test retrieving scan details.""" + response = self.client.get(f'/api/scans/{self.scan.id}/') + + assert response.status_code == status.HTTP_200_OK + assert response.data['id'] == str(self.scan.id) + assert response.data['overall_score'] == 80 + + def test_scan_includes_issues(self): + """Test that scan detail includes issues.""" + response = self.client.get(f'/api/scans/{self.scan.id}/') + + assert 'issues' in response.data + assert len(response.data['issues']) == 2 + + def test_scan_includes_metrics(self): + """Test that scan detail includes metrics.""" + response = self.client.get(f'/api/scans/{self.scan.id}/') + + assert 'metrics' in response.data + assert len(response.data['metrics']) == 1 + + def test_list_scans(self): + """Test listing all scans.""" + response = self.client.get('/api/scans/') + + assert response.status_code == status.HTTP_200_OK + assert 'results' in response.data + assert len(response.data['results']) >= 1 + + def test_get_nonexistent_scan(self): + """Test retrieving a scan that doesn't exist.""" + import uuid + fake_id = uuid.uuid4() + response = self.client.get(f'/api/scans/{fake_id}/') + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestScanModel(TestCase): + """Tests for Scan model methods.""" + + def setUp(self): + """Set up test data.""" + self.website = Website.objects.create( + url='https://example.com', + domain='example.com' + ) + self.scan = Scan.objects.create( + website=self.website, + status=ScanStatus.DONE + ) + + def test_calculate_security_score_no_issues(self): + """Test security score calculation with no issues.""" + score = self.scan.calculate_security_score() + assert score == 100 + + def test_calculate_security_score_with_issues(self): + """Test security score calculation with security issues.""" + Issue.objects.create( + scan=self.scan, + category='security', + severity='high', + tool='owasp_zap', + title='XSS Vulnerability', + description='Cross-site scripting vulnerability found.' + ) + + score = self.scan.calculate_security_score() + assert score == 85 # 100 - 15 (high severity) + + def test_calculate_security_score_multiple_issues(self): + """Test security score with multiple issues.""" + Issue.objects.create( + scan=self.scan, + category='security', + severity='critical', + tool='owasp_zap', + title='SQL Injection', + description='SQL injection vulnerability.' + ) + Issue.objects.create( + scan=self.scan, + category='headers', + severity='medium', + tool='header_check', + title='Missing HSTS', + description='HSTS header not set.' + ) + + score = self.scan.calculate_security_score() + # 100 - 25 (critical) - 8 (medium) = 67 + assert score == 67 + + def test_calculate_overall_score(self): + """Test overall score calculation.""" + self.scan.performance_score = 80 + self.scan.security_score = 90 + self.scan.accessibility_score = 85 + self.scan.seo_score = 75 + self.scan.best_practices_score = 70 + self.scan.save() + + overall = self.scan.calculate_overall_score() + + # Weighted average based on model weights + assert overall is not None + assert 0 <= overall <= 100 + + def test_calculate_overall_score_missing_values(self): + """Test overall score with some missing values.""" + self.scan.performance_score = 80 + self.scan.security_score = None + self.scan.accessibility_score = None + self.scan.seo_score = None + self.scan.best_practices_score = None + self.scan.save() + + overall = self.scan.calculate_overall_score() + + # Should still calculate with available scores + assert overall == 80 # Only performance available diff --git a/backend/tests/test_validators.py b/backend/tests/test_validators.py new file mode 100644 index 0000000..5182d0b --- /dev/null +++ b/backend/tests/test_validators.py @@ -0,0 +1,155 @@ +""" +Tests for URL validation and safety checks. +""" + +import pytest +from scanner.validators import ( + validate_url, + normalize_url, + is_private_ip, + extract_domain, +) + + +class TestURLValidation: + """Tests for URL validation functionality.""" + + def test_valid_https_url(self): + """Test that valid HTTPS URLs pass validation.""" + is_valid, normalized, error = validate_url("https://example.com") + assert is_valid is True + assert error is None + assert normalized == "https://example.com" + + def test_valid_http_url(self): + """Test that valid HTTP URLs pass validation.""" + is_valid, normalized, error = validate_url("http://example.com") + assert is_valid is True + assert error is None + + def test_url_with_path(self): + """Test URL with path is normalized correctly.""" + is_valid, normalized, error = validate_url("https://example.com/page/") + assert is_valid is True + assert normalized == "https://example.com/page" # Trailing slash removed + + def test_url_with_query(self): + """Test URL with query string is preserved.""" + is_valid, normalized, error = validate_url("https://example.com/search?q=test") + assert is_valid is True + assert "q=test" in normalized + + def test_invalid_scheme(self): + """Test that non-http/https schemes are rejected.""" + is_valid, _, error = validate_url("ftp://example.com") + assert is_valid is False + assert "http or https" in error.lower() + + def test_missing_scheme(self): + """Test that URLs without scheme are rejected.""" + is_valid, _, error = validate_url("example.com") + assert is_valid is False + + def test_empty_url(self): + """Test that empty URL is rejected.""" + is_valid, _, error = validate_url("") + assert is_valid is False + assert "required" in error.lower() + + def test_localhost_blocked(self): + """Test that localhost is blocked (SSRF protection).""" + is_valid, _, error = validate_url("http://localhost") + assert is_valid is False + assert "not allowed" in error.lower() + + def test_localhost_127_blocked(self): + """Test that 127.0.0.1 is blocked.""" + is_valid, _, error = validate_url("http://127.0.0.1") + assert is_valid is False + assert "not allowed" in error.lower() + + def test_private_ip_10_blocked(self): + """Test that 10.x.x.x IPs are blocked.""" + is_valid, _, error = validate_url("http://10.0.0.1") + assert is_valid is False + assert "not allowed" in error.lower() + + def test_private_ip_172_blocked(self): + """Test that 172.16-31.x.x IPs are blocked.""" + is_valid, _, error = validate_url("http://172.16.0.1") + assert is_valid is False + assert "not allowed" in error.lower() + + def test_private_ip_192_blocked(self): + """Test that 192.168.x.x IPs are blocked.""" + is_valid, _, error = validate_url("http://192.168.1.1") + assert is_valid is False + assert "not allowed" in error.lower() + + +class TestPrivateIP: + """Tests for private IP detection.""" + + def test_loopback_is_private(self): + """Test that loopback addresses are detected as private.""" + assert is_private_ip("127.0.0.1") is True + assert is_private_ip("127.0.0.2") is True + + def test_10_range_is_private(self): + """Test that 10.x.x.x range is private.""" + assert is_private_ip("10.0.0.1") is True + assert is_private_ip("10.255.255.255") is True + + def test_172_range_is_private(self): + """Test that 172.16-31.x.x range is private.""" + assert is_private_ip("172.16.0.1") is True + assert is_private_ip("172.31.255.255") is True + + def test_192_range_is_private(self): + """Test that 192.168.x.x range is private.""" + assert is_private_ip("192.168.0.1") is True + assert is_private_ip("192.168.255.255") is True + + def test_public_ip_not_private(self): + """Test that public IPs are not marked as private.""" + assert is_private_ip("8.8.8.8") is False + assert is_private_ip("1.1.1.1") is False + assert is_private_ip("93.184.216.34") is False # example.com + + +class TestNormalizeURL: + """Tests for URL normalization.""" + + def test_lowercase_host(self): + """Test that hostname is lowercased.""" + result = normalize_url("https://EXAMPLE.COM") + assert "example.com" in result + + def test_remove_trailing_slash(self): + """Test that trailing slash is removed from path.""" + result = normalize_url("https://example.com/page/") + assert result.endswith("/page") + + def test_preserve_query_string(self): + """Test that query string is preserved.""" + result = normalize_url("https://example.com?foo=bar") + assert "foo=bar" in result + + +class TestExtractDomain: + """Tests for domain extraction.""" + + def test_simple_domain(self): + """Test extracting domain from simple URL.""" + domain = extract_domain("https://example.com/page") + assert domain == "example.com" + + def test_subdomain(self): + """Test extracting domain with subdomain.""" + domain = extract_domain("https://www.example.com") + assert domain == "www.example.com" + + def test_with_port(self): + """Test extracting domain with port.""" + domain = extract_domain("https://example.com:8080") + assert domain == "example.com:8080"