316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""
|
|
Scan Runner - Orchestrates multiple scanners.
|
|
|
|
This module coordinates the execution of all scanners and
|
|
aggregates their results into a unified scan result.
|
|
"""
|
|
|
|
import time
|
|
import logging
|
|
from typing import Dict, List, Optional, Type
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
|
|
from websites.models import (
|
|
Scan, Issue, Metric, ScanStatus,
|
|
IssueCategory, IssueSeverity, ScannerTool
|
|
)
|
|
from .base import BaseScanner, ScannerResult, ScannerStatus
|
|
from .headers_scanner import HeadersScanner
|
|
from .lighthouse_scanner import LighthouseScanner
|
|
from .playwright_scanner import PlaywrightScanner
|
|
from .zap_scanner import ZAPScanner
|
|
from .validators import validate_url
|
|
|
|
logger = logging.getLogger('scanner')
|
|
|
|
|
|
class ScanRunner:
|
|
"""
|
|
Orchestrates the execution of multiple scanners.
|
|
|
|
Manages the scan lifecycle:
|
|
1. URL validation
|
|
2. Scanner execution (parallel or sequential)
|
|
3. Result aggregation
|
|
4. Database persistence
|
|
"""
|
|
|
|
# Available scanners in execution order
|
|
SCANNER_CLASSES: List[Type[BaseScanner]] = [
|
|
HeadersScanner, # Fast, run first
|
|
LighthouseScanner, # Performance metrics
|
|
PlaywrightScanner, # Browser analysis
|
|
ZAPScanner, # Security scan (slowest)
|
|
]
|
|
|
|
def __init__(self, scan: Scan, config: dict = None):
|
|
"""
|
|
Initialize the scan runner.
|
|
|
|
Args:
|
|
scan: The Scan model instance to update
|
|
config: Optional configuration overrides
|
|
"""
|
|
self.scan = scan
|
|
self.config = config or settings.SCANNER_CONFIG
|
|
self.results: Dict[str, ScannerResult] = {}
|
|
self.logger = logging.getLogger('scanner.runner')
|
|
|
|
def run(self) -> bool:
|
|
"""
|
|
Execute all scanners and persist results.
|
|
|
|
Returns:
|
|
True if scan completed successfully, False otherwise
|
|
"""
|
|
url = self.scan.website.url
|
|
|
|
# Validate URL
|
|
is_valid, normalized_url, error = validate_url(url)
|
|
if not is_valid:
|
|
self._fail_scan(f"Invalid URL: {error}")
|
|
return False
|
|
|
|
# Update scan status to running
|
|
self.scan.status = ScanStatus.RUNNING
|
|
self.scan.started_at = timezone.now()
|
|
self.scan.save(update_fields=['status', 'started_at'])
|
|
|
|
self.logger.info(f"Starting scan for {url}")
|
|
|
|
try:
|
|
# Run all scanners
|
|
self._run_scanners(normalized_url)
|
|
|
|
# Aggregate and persist results
|
|
self._aggregate_results()
|
|
self._persist_results()
|
|
|
|
# Determine final status
|
|
failed_scanners = [
|
|
name for name, result in self.results.items()
|
|
if result.status == ScannerStatus.FAILED
|
|
]
|
|
|
|
if len(failed_scanners) == len(self.SCANNER_CLASSES):
|
|
self._fail_scan("All scanners failed")
|
|
return False
|
|
elif failed_scanners:
|
|
self.scan.status = ScanStatus.PARTIAL
|
|
self.scan.error_message = f"Some scanners failed: {', '.join(failed_scanners)}"
|
|
else:
|
|
self.scan.status = ScanStatus.DONE
|
|
|
|
self.scan.completed_at = timezone.now()
|
|
self.scan.save()
|
|
|
|
# Update website last_scanned_at
|
|
self.scan.website.last_scanned_at = timezone.now()
|
|
self.scan.website.save(update_fields=['last_scanned_at'])
|
|
|
|
self.logger.info(f"Scan completed for {url} with status {self.scan.status}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.exception(f"Scan failed for {url}")
|
|
self._fail_scan(str(e))
|
|
return False
|
|
|
|
def _run_scanners(self, url: str):
|
|
"""
|
|
Run all available scanners.
|
|
|
|
Scanners are run sequentially to avoid resource conflicts,
|
|
especially for browser-based scanners.
|
|
"""
|
|
for scanner_class in self.SCANNER_CLASSES:
|
|
scanner = scanner_class(self.config)
|
|
scanner_name = scanner.name
|
|
|
|
self.logger.info(f"Running {scanner_name} for {url}")
|
|
|
|
try:
|
|
if not scanner.is_available():
|
|
self.logger.warning(f"{scanner_name} is not available, skipping")
|
|
self.results[scanner_name] = ScannerResult(
|
|
status=ScannerStatus.SKIPPED,
|
|
scanner_name=scanner_name,
|
|
error_message="Scanner not available"
|
|
)
|
|
continue
|
|
|
|
result = scanner.run(url)
|
|
self.results[scanner_name] = result
|
|
|
|
self.logger.info(
|
|
f"{scanner_name} completed with status {result.status.value} "
|
|
f"in {result.execution_time_seconds:.2f}s"
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.exception(f"{scanner_name} failed with exception")
|
|
self.results[scanner_name] = ScannerResult(
|
|
status=ScannerStatus.FAILED,
|
|
scanner_name=scanner_name,
|
|
error_message=str(e)
|
|
)
|
|
|
|
def _aggregate_results(self):
|
|
"""Aggregate scores from all scanners."""
|
|
# Lighthouse scores
|
|
lighthouse_result = self.results.get('lighthouse')
|
|
if lighthouse_result and lighthouse_result.status == ScannerStatus.SUCCESS:
|
|
scores = lighthouse_result.scores
|
|
self.scan.performance_score = scores.get('performance')
|
|
self.scan.accessibility_score = scores.get('accessibility')
|
|
self.scan.seo_score = scores.get('seo')
|
|
self.scan.best_practices_score = scores.get('best_practices')
|
|
self.scan.raw_lighthouse_data = lighthouse_result.raw_data
|
|
|
|
# ZAP security score
|
|
zap_result = self.results.get('owasp_zap')
|
|
if zap_result and zap_result.status == ScannerStatus.SUCCESS:
|
|
self.scan.raw_zap_data = zap_result.raw_data
|
|
|
|
# Playwright data
|
|
playwright_result = self.results.get('playwright')
|
|
if playwright_result and playwright_result.status == ScannerStatus.SUCCESS:
|
|
self.scan.raw_playwright_data = playwright_result.raw_data
|
|
|
|
# Headers data
|
|
headers_result = self.results.get('header_check')
|
|
if headers_result and headers_result.status == ScannerStatus.SUCCESS:
|
|
self.scan.raw_headers_data = headers_result.raw_data
|
|
|
|
def _persist_results(self):
|
|
"""Persist all issues and metrics to the database."""
|
|
# Collect all issues and metrics
|
|
all_issues = []
|
|
all_metrics = []
|
|
|
|
for scanner_name, result in self.results.items():
|
|
if result.status in (ScannerStatus.SUCCESS, ScannerStatus.PARTIAL):
|
|
all_issues.extend(result.issues)
|
|
all_metrics.extend(result.metrics)
|
|
|
|
# Create Issue records
|
|
issue_objects = []
|
|
for issue_data in all_issues:
|
|
# Map tool name to ScannerTool enum
|
|
tool = self._map_tool_name(issue_data.get('tool', ''))
|
|
|
|
# Map severity
|
|
severity = issue_data.get('severity', 'info')
|
|
if severity not in [s.value for s in IssueSeverity]:
|
|
severity = 'info'
|
|
|
|
# Map category
|
|
category = issue_data.get('category', 'security')
|
|
if category not in [c.value for c in IssueCategory]:
|
|
category = 'security'
|
|
|
|
issue_objects.append(Issue(
|
|
scan=self.scan,
|
|
category=category,
|
|
severity=severity,
|
|
tool=tool,
|
|
title=issue_data.get('title', '')[:500],
|
|
description=issue_data.get('description', ''),
|
|
affected_url=issue_data.get('affected_url'),
|
|
remediation=issue_data.get('remediation'),
|
|
raw_data=issue_data.get('raw_data'),
|
|
))
|
|
|
|
# Bulk create issues
|
|
if issue_objects:
|
|
Issue.objects.bulk_create(issue_objects, ignore_conflicts=True)
|
|
|
|
# Create Metric records
|
|
metric_objects = []
|
|
seen_metrics = set() # Avoid duplicates
|
|
|
|
for metric_data in all_metrics:
|
|
metric_name = metric_data.get('name', '')
|
|
if metric_name in seen_metrics:
|
|
continue
|
|
seen_metrics.add(metric_name)
|
|
|
|
# Map source to ScannerTool enum
|
|
source = self._map_tool_name(metric_data.get('source', ''))
|
|
|
|
metric_objects.append(Metric(
|
|
scan=self.scan,
|
|
name=metric_name,
|
|
display_name=metric_data.get('display_name', metric_name),
|
|
value=metric_data.get('value', 0),
|
|
unit=metric_data.get('unit', 'count'),
|
|
source=source,
|
|
score=metric_data.get('score'),
|
|
))
|
|
|
|
# Bulk create metrics
|
|
if metric_objects:
|
|
Metric.objects.bulk_create(metric_objects, ignore_conflicts=True)
|
|
|
|
# Calculate security score based on issues
|
|
self.scan.calculate_security_score()
|
|
|
|
# Calculate overall score
|
|
self.scan.calculate_overall_score()
|
|
|
|
self.scan.save()
|
|
|
|
def _map_tool_name(self, tool_name: str) -> str:
|
|
"""Map scanner name to ScannerTool enum value."""
|
|
tool_mapping = {
|
|
'lighthouse': ScannerTool.LIGHTHOUSE,
|
|
'owasp_zap': ScannerTool.ZAP,
|
|
'playwright': ScannerTool.PLAYWRIGHT,
|
|
'header_check': ScannerTool.HEADER_CHECK,
|
|
'tls_check': ScannerTool.TLS_CHECK,
|
|
}
|
|
return tool_mapping.get(tool_name, ScannerTool.HEADER_CHECK)
|
|
|
|
def _fail_scan(self, error_message: str):
|
|
"""Mark scan as failed with error message."""
|
|
self.scan.status = ScanStatus.FAILED
|
|
self.scan.error_message = error_message
|
|
self.scan.completed_at = timezone.now()
|
|
self.scan.save()
|
|
|
|
# Create an issue for the failure
|
|
Issue.objects.create(
|
|
scan=self.scan,
|
|
category=IssueCategory.CONTENT,
|
|
severity=IssueSeverity.HIGH,
|
|
tool=ScannerTool.HEADER_CHECK,
|
|
title="Scan Failed",
|
|
description=f"The scan could not be completed: {error_message}",
|
|
remediation="Check the URL and try again. If the problem persists, contact support."
|
|
)
|
|
|
|
|
|
def run_scan(scan_id: str) -> bool:
|
|
"""
|
|
Run a scan by ID.
|
|
|
|
This is the main entry point for Celery tasks.
|
|
|
|
Args:
|
|
scan_id: UUID of the scan to run
|
|
|
|
Returns:
|
|
True if scan completed successfully
|
|
"""
|
|
try:
|
|
scan = Scan.objects.select_related('website').get(id=scan_id)
|
|
except Scan.DoesNotExist:
|
|
logger.error(f"Scan {scan_id} not found")
|
|
return False
|
|
|
|
runner = ScanRunner(scan)
|
|
return runner.run()
|