secure-web/backend/scanner/tasks.py

307 lines
8.9 KiB
Python

"""
Celery tasks for background scanning.
This module defines the Celery tasks that orchestrate website scans
in the background.
"""
import logging
from datetime import timedelta
from typing import Optional
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from django.conf import settings
from django.utils import timezone
from websites.models import Website, Scan, ScanStatus, Issue, Metric
from scanner.scanners import ScanRunner
from scanner.utils import validate_url, get_domain_from_url
logger = logging.getLogger(__name__)
@shared_task(
bind=True,
max_retries=2,
default_retry_delay=60,
soft_time_limit=300,
time_limit=330,
)
def run_scan_task(self, scan_id: str) -> dict:
"""
Main Celery task for running a website scan.
This task:
1. Updates scan status to running
2. Orchestrates all scanners
3. Saves results to database
4. Handles errors and partial results
Args:
scan_id: UUID of the Scan record
Returns:
Dict with scan results summary
"""
logger.info(f"Starting scan task for scan_id: {scan_id}")
try:
# Get the scan record
scan = Scan.objects.select_related('website').get(id=scan_id)
except Scan.DoesNotExist:
logger.error(f"Scan {scan_id} not found")
return {'error': f'Scan {scan_id} not found'}
# Update status to running
scan.status = ScanStatus.RUNNING
scan.started_at = timezone.now()
scan.celery_task_id = self.request.id
scan.save(update_fields=['status', 'started_at', 'celery_task_id'])
url = scan.website.url
try:
# Run the scan pipeline
runner = ScanRunner()
results = runner.run(url)
# Save results to database
_save_scan_results(scan, results)
# Update website last_scanned_at
scan.website.last_scanned_at = timezone.now()
scan.website.save(update_fields=['last_scanned_at'])
logger.info(f"Scan {scan_id} completed successfully")
return {
'scan_id': str(scan_id),
'status': scan.status,
'overall_score': scan.overall_score,
'issues_count': scan.issues.count(),
'metrics_count': scan.metrics.count(),
}
except SoftTimeLimitExceeded:
logger.warning(f"Scan {scan_id} timed out")
scan.status = ScanStatus.PARTIAL
scan.error_message = "Scan timed out before completing all checks"
scan.completed_at = timezone.now()
scan.save(update_fields=['status', 'error_message', 'completed_at'])
return {
'scan_id': str(scan_id),
'status': 'partial',
'error': 'Scan timed out'
}
except Exception as e:
logger.exception(f"Scan {scan_id} failed with error: {e}")
scan.status = ScanStatus.FAILED
scan.error_message = str(e)
scan.completed_at = timezone.now()
scan.save(update_fields=['status', 'error_message', 'completed_at'])
# Retry on certain errors
if self.request.retries < self.max_retries:
raise self.retry(exc=e)
return {
'scan_id': str(scan_id),
'status': 'failed',
'error': str(e)
}
def _save_scan_results(scan: Scan, results: dict) -> None:
"""
Save scan results to the database.
Args:
scan: The Scan model instance
results: Aggregated results from ScanRunner
"""
# Update scan status
status_map = {
'done': ScanStatus.DONE,
'partial': ScanStatus.PARTIAL,
'failed': ScanStatus.FAILED,
}
scan.status = status_map.get(results['status'], ScanStatus.DONE)
scan.completed_at = timezone.now()
# Save scores
scores = results.get('scores', {})
scan.performance_score = scores.get('performance')
scan.accessibility_score = scores.get('accessibility')
scan.seo_score = scores.get('seo')
scan.best_practices_score = scores.get('best_practices')
# Save raw data
raw_data = results.get('raw_data', {})
scan.raw_lighthouse_data = raw_data.get('lighthouse')
scan.raw_zap_data = raw_data.get('owasp_zap')
scan.raw_playwright_data = raw_data.get('playwright')
scan.raw_headers_data = raw_data.get('header_check')
# Save errors if any
if results.get('errors'):
scan.error_message = '\n'.join(
f"{e['scanner']}: {e['error']}"
for e in results['errors']
)
scan.save()
# Create Issue records
issues_to_create = []
for issue_data in results.get('issues', []):
issues_to_create.append(Issue(
scan=scan,
category=issue_data['category'],
severity=issue_data['severity'],
title=issue_data['title'][:500], # Truncate if too long
description=issue_data['description'],
tool=issue_data['tool'],
affected_url=issue_data.get('affected_url'),
remediation=issue_data.get('remediation'),
raw_data=issue_data.get('raw_data'),
))
if issues_to_create:
Issue.objects.bulk_create(issues_to_create)
# Create Metric records
metrics_to_create = []
seen_metrics = set() # Track unique metrics
for metric_data in results.get('metrics', []):
metric_key = metric_data['name']
if metric_key in seen_metrics:
continue # Skip duplicates
seen_metrics.add(metric_key)
# Map unit strings to model choices
unit_map = {
'ms': 'ms',
'milliseconds': 'ms',
's': 's',
'seconds': 's',
'bytes': 'bytes',
'kb': 'kb',
'kilobytes': 'kb',
'mb': 'mb',
'megabytes': 'mb',
'score': 'score',
'percent': 'percent',
'count': 'count',
}
unit = unit_map.get(metric_data['unit'].lower(), 'count')
metrics_to_create.append(Metric(
scan=scan,
name=metric_data['name'],
display_name=metric_data['display_name'][:200],
value=metric_data['value'],
unit=unit,
source=metric_data['source'],
score=metric_data.get('score'),
))
if metrics_to_create:
Metric.objects.bulk_create(metrics_to_create)
# Calculate security score based on issues
scan.calculate_security_score()
# Calculate overall score
scan.calculate_overall_score()
scan.save(update_fields=['security_score', 'overall_score'])
logger.info(
f"Saved scan results: {len(issues_to_create)} issues, "
f"{len(metrics_to_create)} metrics"
)
@shared_task
def cleanup_old_scans(days: int = 30) -> dict:
"""
Clean up old scan data to prevent database growth.
Args:
days: Number of days to keep scans
Returns:
Dict with cleanup statistics
"""
cutoff_date = timezone.now() - timedelta(days=days)
# Delete old scans (cascades to issues and metrics)
deleted_count, _ = Scan.objects.filter(
created_at__lt=cutoff_date
).delete()
logger.info(f"Cleaned up {deleted_count} old scans")
return {
'deleted_scans': deleted_count,
'cutoff_date': cutoff_date.isoformat(),
}
def check_rate_limit(url: str) -> Optional[str]:
"""
Check if URL scanning is rate limited.
Args:
url: The URL to check
Returns:
Error message if rate limited, None otherwise
"""
from django.core.cache import cache
scanner_config = settings.SCANNER_CONFIG
rate_limit_minutes = scanner_config.get('SCAN_RATE_LIMIT_MINUTES', 5)
# Create a cache key based on the URL
domain = get_domain_from_url(url)
cache_key = f"scan_rate_limit:{domain}"
# Check if already scanned recently
last_scan_time = cache.get(cache_key)
if last_scan_time:
return (
f"This URL was scanned recently. "
f"Please wait {rate_limit_minutes} minutes between scans."
)
# Set the rate limit
cache.set(cache_key, timezone.now().isoformat(), timeout=rate_limit_minutes * 60)
return None
def check_concurrent_scan_limit() -> Optional[str]:
"""
Check if maximum concurrent scans limit is reached.
Returns:
Error message if limit reached, None otherwise
"""
scanner_config = settings.SCANNER_CONFIG
max_concurrent = scanner_config.get('MAX_CONCURRENT_SCANS', 3)
running_count = Scan.objects.filter(status=ScanStatus.RUNNING).count()
if running_count >= max_concurrent:
return (
f"Maximum concurrent scans ({max_concurrent}) reached. "
"Please wait for current scans to complete."
)
return None