secure-web/backend/scanner/runner.py

316 lines
11 KiB
Python

"""
Scan Runner - Orchestrates multiple scanners.
This module coordinates the execution of all scanners and
aggregates their results into a unified scan result.
"""
import time
import logging
from typing import Dict, List, Optional, Type
from concurrent.futures import ThreadPoolExecutor, as_completed
from django.conf import settings
from django.utils import timezone
from websites.models import (
Scan, Issue, Metric, ScanStatus,
IssueCategory, IssueSeverity, ScannerTool
)
from .base import BaseScanner, ScannerResult, ScannerStatus
from .headers_scanner import HeadersScanner
from .lighthouse_scanner import LighthouseScanner
from .playwright_scanner import PlaywrightScanner
from .zap_scanner import ZAPScanner
from .validators import validate_url
logger = logging.getLogger('scanner')
class ScanRunner:
"""
Orchestrates the execution of multiple scanners.
Manages the scan lifecycle:
1. URL validation
2. Scanner execution (parallel or sequential)
3. Result aggregation
4. Database persistence
"""
# Available scanners in execution order
SCANNER_CLASSES: List[Type[BaseScanner]] = [
HeadersScanner, # Fast, run first
LighthouseScanner, # Performance metrics
PlaywrightScanner, # Browser analysis
ZAPScanner, # Security scan (slowest)
]
def __init__(self, scan: Scan, config: dict = None):
"""
Initialize the scan runner.
Args:
scan: The Scan model instance to update
config: Optional configuration overrides
"""
self.scan = scan
self.config = config or settings.SCANNER_CONFIG
self.results: Dict[str, ScannerResult] = {}
self.logger = logging.getLogger('scanner.runner')
def run(self) -> bool:
"""
Execute all scanners and persist results.
Returns:
True if scan completed successfully, False otherwise
"""
url = self.scan.website.url
# Validate URL
is_valid, normalized_url, error = validate_url(url)
if not is_valid:
self._fail_scan(f"Invalid URL: {error}")
return False
# Update scan status to running
self.scan.status = ScanStatus.RUNNING
self.scan.started_at = timezone.now()
self.scan.save(update_fields=['status', 'started_at'])
self.logger.info(f"Starting scan for {url}")
try:
# Run all scanners
self._run_scanners(normalized_url)
# Aggregate and persist results
self._aggregate_results()
self._persist_results()
# Determine final status
failed_scanners = [
name for name, result in self.results.items()
if result.status == ScannerStatus.FAILED
]
if len(failed_scanners) == len(self.SCANNER_CLASSES):
self._fail_scan("All scanners failed")
return False
elif failed_scanners:
self.scan.status = ScanStatus.PARTIAL
self.scan.error_message = f"Some scanners failed: {', '.join(failed_scanners)}"
else:
self.scan.status = ScanStatus.DONE
self.scan.completed_at = timezone.now()
self.scan.save()
# Update website last_scanned_at
self.scan.website.last_scanned_at = timezone.now()
self.scan.website.save(update_fields=['last_scanned_at'])
self.logger.info(f"Scan completed for {url} with status {self.scan.status}")
return True
except Exception as e:
self.logger.exception(f"Scan failed for {url}")
self._fail_scan(str(e))
return False
def _run_scanners(self, url: str):
"""
Run all available scanners.
Scanners are run sequentially to avoid resource conflicts,
especially for browser-based scanners.
"""
for scanner_class in self.SCANNER_CLASSES:
scanner = scanner_class(self.config)
scanner_name = scanner.name
self.logger.info(f"Running {scanner_name} for {url}")
try:
if not scanner.is_available():
self.logger.warning(f"{scanner_name} is not available, skipping")
self.results[scanner_name] = ScannerResult(
status=ScannerStatus.SKIPPED,
scanner_name=scanner_name,
error_message="Scanner not available"
)
continue
result = scanner.run(url)
self.results[scanner_name] = result
self.logger.info(
f"{scanner_name} completed with status {result.status.value} "
f"in {result.execution_time_seconds:.2f}s"
)
except Exception as e:
self.logger.exception(f"{scanner_name} failed with exception")
self.results[scanner_name] = ScannerResult(
status=ScannerStatus.FAILED,
scanner_name=scanner_name,
error_message=str(e)
)
def _aggregate_results(self):
"""Aggregate scores from all scanners."""
# Lighthouse scores
lighthouse_result = self.results.get('lighthouse')
if lighthouse_result and lighthouse_result.status == ScannerStatus.SUCCESS:
scores = lighthouse_result.scores
self.scan.performance_score = scores.get('performance')
self.scan.accessibility_score = scores.get('accessibility')
self.scan.seo_score = scores.get('seo')
self.scan.best_practices_score = scores.get('best_practices')
self.scan.raw_lighthouse_data = lighthouse_result.raw_data
# ZAP security score
zap_result = self.results.get('owasp_zap')
if zap_result and zap_result.status == ScannerStatus.SUCCESS:
self.scan.raw_zap_data = zap_result.raw_data
# Playwright data
playwright_result = self.results.get('playwright')
if playwright_result and playwright_result.status == ScannerStatus.SUCCESS:
self.scan.raw_playwright_data = playwright_result.raw_data
# Headers data
headers_result = self.results.get('header_check')
if headers_result and headers_result.status == ScannerStatus.SUCCESS:
self.scan.raw_headers_data = headers_result.raw_data
def _persist_results(self):
"""Persist all issues and metrics to the database."""
# Collect all issues and metrics
all_issues = []
all_metrics = []
for scanner_name, result in self.results.items():
if result.status in (ScannerStatus.SUCCESS, ScannerStatus.PARTIAL):
all_issues.extend(result.issues)
all_metrics.extend(result.metrics)
# Create Issue records
issue_objects = []
for issue_data in all_issues:
# Map tool name to ScannerTool enum
tool = self._map_tool_name(issue_data.get('tool', ''))
# Map severity
severity = issue_data.get('severity', 'info')
if severity not in [s.value for s in IssueSeverity]:
severity = 'info'
# Map category
category = issue_data.get('category', 'security')
if category not in [c.value for c in IssueCategory]:
category = 'security'
issue_objects.append(Issue(
scan=self.scan,
category=category,
severity=severity,
tool=tool,
title=issue_data.get('title', '')[:500],
description=issue_data.get('description', ''),
affected_url=issue_data.get('affected_url'),
remediation=issue_data.get('remediation'),
raw_data=issue_data.get('raw_data'),
))
# Bulk create issues
if issue_objects:
Issue.objects.bulk_create(issue_objects, ignore_conflicts=True)
# Create Metric records
metric_objects = []
seen_metrics = set() # Avoid duplicates
for metric_data in all_metrics:
metric_name = metric_data.get('name', '')
if metric_name in seen_metrics:
continue
seen_metrics.add(metric_name)
# Map source to ScannerTool enum
source = self._map_tool_name(metric_data.get('source', ''))
metric_objects.append(Metric(
scan=self.scan,
name=metric_name,
display_name=metric_data.get('display_name', metric_name),
value=metric_data.get('value', 0),
unit=metric_data.get('unit', 'count'),
source=source,
score=metric_data.get('score'),
))
# Bulk create metrics
if metric_objects:
Metric.objects.bulk_create(metric_objects, ignore_conflicts=True)
# Calculate security score based on issues
self.scan.calculate_security_score()
# Calculate overall score
self.scan.calculate_overall_score()
self.scan.save()
def _map_tool_name(self, tool_name: str) -> str:
"""Map scanner name to ScannerTool enum value."""
tool_mapping = {
'lighthouse': ScannerTool.LIGHTHOUSE,
'owasp_zap': ScannerTool.ZAP,
'playwright': ScannerTool.PLAYWRIGHT,
'header_check': ScannerTool.HEADER_CHECK,
'tls_check': ScannerTool.TLS_CHECK,
}
return tool_mapping.get(tool_name, ScannerTool.HEADER_CHECK)
def _fail_scan(self, error_message: str):
"""Mark scan as failed with error message."""
self.scan.status = ScanStatus.FAILED
self.scan.error_message = error_message
self.scan.completed_at = timezone.now()
self.scan.save()
# Create an issue for the failure
Issue.objects.create(
scan=self.scan,
category=IssueCategory.CONTENT,
severity=IssueSeverity.HIGH,
tool=ScannerTool.HEADER_CHECK,
title="Scan Failed",
description=f"The scan could not be completed: {error_message}",
remediation="Check the URL and try again. If the problem persists, contact support."
)
def run_scan(scan_id: str) -> bool:
"""
Run a scan by ID.
This is the main entry point for Celery tasks.
Args:
scan_id: UUID of the scan to run
Returns:
True if scan completed successfully
"""
try:
scan = Scan.objects.select_related('website').get(id=scan_id)
except Scan.DoesNotExist:
logger.error(f"Scan {scan_id} not found")
return False
runner = ScanRunner(scan)
return runner.run()