307 lines
8.9 KiB
Python
307 lines
8.9 KiB
Python
"""
|
|
Celery tasks for background scanning.
|
|
|
|
This module defines the Celery tasks that orchestrate website scans
|
|
in the background.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import timedelta
|
|
from typing import Optional
|
|
|
|
from celery import shared_task
|
|
from celery.exceptions import SoftTimeLimitExceeded
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
|
|
from websites.models import Website, Scan, ScanStatus, Issue, Metric
|
|
from scanner.scanners import ScanRunner
|
|
from scanner.utils import validate_url, get_domain_from_url
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
max_retries=2,
|
|
default_retry_delay=60,
|
|
soft_time_limit=300,
|
|
time_limit=330,
|
|
)
|
|
def run_scan_task(self, scan_id: str) -> dict:
|
|
"""
|
|
Main Celery task for running a website scan.
|
|
|
|
This task:
|
|
1. Updates scan status to running
|
|
2. Orchestrates all scanners
|
|
3. Saves results to database
|
|
4. Handles errors and partial results
|
|
|
|
Args:
|
|
scan_id: UUID of the Scan record
|
|
|
|
Returns:
|
|
Dict with scan results summary
|
|
"""
|
|
logger.info(f"Starting scan task for scan_id: {scan_id}")
|
|
|
|
try:
|
|
# Get the scan record
|
|
scan = Scan.objects.select_related('website').get(id=scan_id)
|
|
except Scan.DoesNotExist:
|
|
logger.error(f"Scan {scan_id} not found")
|
|
return {'error': f'Scan {scan_id} not found'}
|
|
|
|
# Update status to running
|
|
scan.status = ScanStatus.RUNNING
|
|
scan.started_at = timezone.now()
|
|
scan.celery_task_id = self.request.id
|
|
scan.save(update_fields=['status', 'started_at', 'celery_task_id'])
|
|
|
|
url = scan.website.url
|
|
|
|
try:
|
|
# Run the scan pipeline
|
|
runner = ScanRunner()
|
|
results = runner.run(url)
|
|
|
|
# Save results to database
|
|
_save_scan_results(scan, results)
|
|
|
|
# Update website last_scanned_at
|
|
scan.website.last_scanned_at = timezone.now()
|
|
scan.website.save(update_fields=['last_scanned_at'])
|
|
|
|
logger.info(f"Scan {scan_id} completed successfully")
|
|
|
|
return {
|
|
'scan_id': str(scan_id),
|
|
'status': scan.status,
|
|
'overall_score': scan.overall_score,
|
|
'issues_count': scan.issues.count(),
|
|
'metrics_count': scan.metrics.count(),
|
|
}
|
|
|
|
except SoftTimeLimitExceeded:
|
|
logger.warning(f"Scan {scan_id} timed out")
|
|
scan.status = ScanStatus.PARTIAL
|
|
scan.error_message = "Scan timed out before completing all checks"
|
|
scan.completed_at = timezone.now()
|
|
scan.save(update_fields=['status', 'error_message', 'completed_at'])
|
|
|
|
return {
|
|
'scan_id': str(scan_id),
|
|
'status': 'partial',
|
|
'error': 'Scan timed out'
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Scan {scan_id} failed with error: {e}")
|
|
scan.status = ScanStatus.FAILED
|
|
scan.error_message = str(e)
|
|
scan.completed_at = timezone.now()
|
|
scan.save(update_fields=['status', 'error_message', 'completed_at'])
|
|
|
|
# Retry on certain errors
|
|
if self.request.retries < self.max_retries:
|
|
raise self.retry(exc=e)
|
|
|
|
return {
|
|
'scan_id': str(scan_id),
|
|
'status': 'failed',
|
|
'error': str(e)
|
|
}
|
|
|
|
|
|
def _save_scan_results(scan: Scan, results: dict) -> None:
|
|
"""
|
|
Save scan results to the database.
|
|
|
|
Args:
|
|
scan: The Scan model instance
|
|
results: Aggregated results from ScanRunner
|
|
"""
|
|
# Update scan status
|
|
status_map = {
|
|
'done': ScanStatus.DONE,
|
|
'partial': ScanStatus.PARTIAL,
|
|
'failed': ScanStatus.FAILED,
|
|
}
|
|
scan.status = status_map.get(results['status'], ScanStatus.DONE)
|
|
scan.completed_at = timezone.now()
|
|
|
|
# Save scores
|
|
scores = results.get('scores', {})
|
|
scan.performance_score = scores.get('performance')
|
|
scan.accessibility_score = scores.get('accessibility')
|
|
scan.seo_score = scores.get('seo')
|
|
scan.best_practices_score = scores.get('best_practices')
|
|
|
|
# Save raw data
|
|
raw_data = results.get('raw_data', {})
|
|
scan.raw_lighthouse_data = raw_data.get('lighthouse')
|
|
scan.raw_zap_data = raw_data.get('owasp_zap')
|
|
scan.raw_playwright_data = raw_data.get('playwright')
|
|
scan.raw_headers_data = raw_data.get('header_check')
|
|
|
|
# Save errors if any
|
|
if results.get('errors'):
|
|
scan.error_message = '\n'.join(
|
|
f"{e['scanner']}: {e['error']}"
|
|
for e in results['errors']
|
|
)
|
|
|
|
scan.save()
|
|
|
|
# Create Issue records
|
|
issues_to_create = []
|
|
for issue_data in results.get('issues', []):
|
|
issues_to_create.append(Issue(
|
|
scan=scan,
|
|
category=issue_data['category'],
|
|
severity=issue_data['severity'],
|
|
title=issue_data['title'][:500], # Truncate if too long
|
|
description=issue_data['description'],
|
|
tool=issue_data['tool'],
|
|
affected_url=issue_data.get('affected_url'),
|
|
remediation=issue_data.get('remediation'),
|
|
raw_data=issue_data.get('raw_data'),
|
|
))
|
|
|
|
if issues_to_create:
|
|
Issue.objects.bulk_create(issues_to_create)
|
|
|
|
# Create Metric records
|
|
metrics_to_create = []
|
|
seen_metrics = set() # Track unique metrics
|
|
|
|
for metric_data in results.get('metrics', []):
|
|
metric_key = metric_data['name']
|
|
if metric_key in seen_metrics:
|
|
continue # Skip duplicates
|
|
seen_metrics.add(metric_key)
|
|
|
|
# Map unit strings to model choices
|
|
unit_map = {
|
|
'ms': 'ms',
|
|
'milliseconds': 'ms',
|
|
's': 's',
|
|
'seconds': 's',
|
|
'bytes': 'bytes',
|
|
'kb': 'kb',
|
|
'kilobytes': 'kb',
|
|
'mb': 'mb',
|
|
'megabytes': 'mb',
|
|
'score': 'score',
|
|
'percent': 'percent',
|
|
'count': 'count',
|
|
}
|
|
unit = unit_map.get(metric_data['unit'].lower(), 'count')
|
|
|
|
metrics_to_create.append(Metric(
|
|
scan=scan,
|
|
name=metric_data['name'],
|
|
display_name=metric_data['display_name'][:200],
|
|
value=metric_data['value'],
|
|
unit=unit,
|
|
source=metric_data['source'],
|
|
score=metric_data.get('score'),
|
|
))
|
|
|
|
if metrics_to_create:
|
|
Metric.objects.bulk_create(metrics_to_create)
|
|
|
|
# Calculate security score based on issues
|
|
scan.calculate_security_score()
|
|
|
|
# Calculate overall score
|
|
scan.calculate_overall_score()
|
|
|
|
scan.save(update_fields=['security_score', 'overall_score'])
|
|
|
|
logger.info(
|
|
f"Saved scan results: {len(issues_to_create)} issues, "
|
|
f"{len(metrics_to_create)} metrics"
|
|
)
|
|
|
|
|
|
@shared_task
|
|
def cleanup_old_scans(days: int = 30) -> dict:
|
|
"""
|
|
Clean up old scan data to prevent database growth.
|
|
|
|
Args:
|
|
days: Number of days to keep scans
|
|
|
|
Returns:
|
|
Dict with cleanup statistics
|
|
"""
|
|
cutoff_date = timezone.now() - timedelta(days=days)
|
|
|
|
# Delete old scans (cascades to issues and metrics)
|
|
deleted_count, _ = Scan.objects.filter(
|
|
created_at__lt=cutoff_date
|
|
).delete()
|
|
|
|
logger.info(f"Cleaned up {deleted_count} old scans")
|
|
|
|
return {
|
|
'deleted_scans': deleted_count,
|
|
'cutoff_date': cutoff_date.isoformat(),
|
|
}
|
|
|
|
|
|
def check_rate_limit(url: str) -> Optional[str]:
|
|
"""
|
|
Check if URL scanning is rate limited.
|
|
|
|
Args:
|
|
url: The URL to check
|
|
|
|
Returns:
|
|
Error message if rate limited, None otherwise
|
|
"""
|
|
from django.core.cache import cache
|
|
|
|
scanner_config = settings.SCANNER_CONFIG
|
|
rate_limit_minutes = scanner_config.get('SCAN_RATE_LIMIT_MINUTES', 5)
|
|
|
|
# Create a cache key based on the URL
|
|
domain = get_domain_from_url(url)
|
|
cache_key = f"scan_rate_limit:{domain}"
|
|
|
|
# Check if already scanned recently
|
|
last_scan_time = cache.get(cache_key)
|
|
if last_scan_time:
|
|
return (
|
|
f"This URL was scanned recently. "
|
|
f"Please wait {rate_limit_minutes} minutes between scans."
|
|
)
|
|
|
|
# Set the rate limit
|
|
cache.set(cache_key, timezone.now().isoformat(), timeout=rate_limit_minutes * 60)
|
|
|
|
return None
|
|
|
|
|
|
def check_concurrent_scan_limit() -> Optional[str]:
|
|
"""
|
|
Check if maximum concurrent scans limit is reached.
|
|
|
|
Returns:
|
|
Error message if limit reached, None otherwise
|
|
"""
|
|
scanner_config = settings.SCANNER_CONFIG
|
|
max_concurrent = scanner_config.get('MAX_CONCURRENT_SCANS', 3)
|
|
|
|
running_count = Scan.objects.filter(status=ScanStatus.RUNNING).count()
|
|
|
|
if running_count >= max_concurrent:
|
|
return (
|
|
f"Maximum concurrent scans ({max_concurrent}) reached. "
|
|
"Please wait for current scans to complete."
|
|
)
|
|
|
|
return None
|