""" Celery tasks for background scanning. This module defines the Celery tasks that orchestrate website scans in the background. """ import logging from datetime import timedelta from typing import Optional from celery import shared_task from celery.exceptions import SoftTimeLimitExceeded from django.conf import settings from django.utils import timezone from websites.models import Website, Scan, ScanStatus, Issue, Metric from scanner.scanners import ScanRunner from scanner.utils import validate_url, get_domain_from_url logger = logging.getLogger(__name__) @shared_task( bind=True, max_retries=2, default_retry_delay=60, soft_time_limit=300, time_limit=330, ) def run_scan_task(self, scan_id: str) -> dict: """ Main Celery task for running a website scan. This task: 1. Updates scan status to running 2. Orchestrates all scanners 3. Saves results to database 4. Handles errors and partial results Args: scan_id: UUID of the Scan record Returns: Dict with scan results summary """ logger.info(f"Starting scan task for scan_id: {scan_id}") try: # Get the scan record scan = Scan.objects.select_related('website').get(id=scan_id) except Scan.DoesNotExist: logger.error(f"Scan {scan_id} not found") return {'error': f'Scan {scan_id} not found'} # Update status to running scan.status = ScanStatus.RUNNING scan.started_at = timezone.now() scan.celery_task_id = self.request.id scan.save(update_fields=['status', 'started_at', 'celery_task_id']) url = scan.website.url try: # Run the scan pipeline runner = ScanRunner() results = runner.run(url) # Save results to database _save_scan_results(scan, results) # Update website last_scanned_at scan.website.last_scanned_at = timezone.now() scan.website.save(update_fields=['last_scanned_at']) logger.info(f"Scan {scan_id} completed successfully") return { 'scan_id': str(scan_id), 'status': scan.status, 'overall_score': scan.overall_score, 'issues_count': scan.issues.count(), 'metrics_count': scan.metrics.count(), } except SoftTimeLimitExceeded: logger.warning(f"Scan {scan_id} timed out") scan.status = ScanStatus.PARTIAL scan.error_message = "Scan timed out before completing all checks" scan.completed_at = timezone.now() scan.save(update_fields=['status', 'error_message', 'completed_at']) return { 'scan_id': str(scan_id), 'status': 'partial', 'error': 'Scan timed out' } except Exception as e: logger.exception(f"Scan {scan_id} failed with error: {e}") scan.status = ScanStatus.FAILED scan.error_message = str(e) scan.completed_at = timezone.now() scan.save(update_fields=['status', 'error_message', 'completed_at']) # Retry on certain errors if self.request.retries < self.max_retries: raise self.retry(exc=e) return { 'scan_id': str(scan_id), 'status': 'failed', 'error': str(e) } def _save_scan_results(scan: Scan, results: dict) -> None: """ Save scan results to the database. Args: scan: The Scan model instance results: Aggregated results from ScanRunner """ # Update scan status status_map = { 'done': ScanStatus.DONE, 'partial': ScanStatus.PARTIAL, 'failed': ScanStatus.FAILED, } scan.status = status_map.get(results['status'], ScanStatus.DONE) scan.completed_at = timezone.now() # Save scores scores = results.get('scores', {}) scan.performance_score = scores.get('performance') scan.accessibility_score = scores.get('accessibility') scan.seo_score = scores.get('seo') scan.best_practices_score = scores.get('best_practices') # Save raw data raw_data = results.get('raw_data', {}) scan.raw_lighthouse_data = raw_data.get('lighthouse') scan.raw_zap_data = raw_data.get('owasp_zap') scan.raw_playwright_data = raw_data.get('playwright') scan.raw_headers_data = raw_data.get('header_check') # Save errors if any if results.get('errors'): scan.error_message = '\n'.join( f"{e['scanner']}: {e['error']}" for e in results['errors'] ) scan.save() # Create Issue records issues_to_create = [] for issue_data in results.get('issues', []): issues_to_create.append(Issue( scan=scan, category=issue_data['category'], severity=issue_data['severity'], title=issue_data['title'][:500], # Truncate if too long description=issue_data['description'], tool=issue_data['tool'], affected_url=issue_data.get('affected_url'), remediation=issue_data.get('remediation'), raw_data=issue_data.get('raw_data'), )) if issues_to_create: Issue.objects.bulk_create(issues_to_create) # Create Metric records metrics_to_create = [] seen_metrics = set() # Track unique metrics for metric_data in results.get('metrics', []): metric_key = metric_data['name'] if metric_key in seen_metrics: continue # Skip duplicates seen_metrics.add(metric_key) # Map unit strings to model choices unit_map = { 'ms': 'ms', 'milliseconds': 'ms', 's': 's', 'seconds': 's', 'bytes': 'bytes', 'kb': 'kb', 'kilobytes': 'kb', 'mb': 'mb', 'megabytes': 'mb', 'score': 'score', 'percent': 'percent', 'count': 'count', } unit = unit_map.get(metric_data['unit'].lower(), 'count') metrics_to_create.append(Metric( scan=scan, name=metric_data['name'], display_name=metric_data['display_name'][:200], value=metric_data['value'], unit=unit, source=metric_data['source'], score=metric_data.get('score'), )) if metrics_to_create: Metric.objects.bulk_create(metrics_to_create) # Calculate security score based on issues scan.calculate_security_score() # Calculate overall score scan.calculate_overall_score() scan.save(update_fields=['security_score', 'overall_score']) logger.info( f"Saved scan results: {len(issues_to_create)} issues, " f"{len(metrics_to_create)} metrics" ) @shared_task def cleanup_old_scans(days: int = 30) -> dict: """ Clean up old scan data to prevent database growth. Args: days: Number of days to keep scans Returns: Dict with cleanup statistics """ cutoff_date = timezone.now() - timedelta(days=days) # Delete old scans (cascades to issues and metrics) deleted_count, _ = Scan.objects.filter( created_at__lt=cutoff_date ).delete() logger.info(f"Cleaned up {deleted_count} old scans") return { 'deleted_scans': deleted_count, 'cutoff_date': cutoff_date.isoformat(), } def check_rate_limit(url: str) -> Optional[str]: """ Check if URL scanning is rate limited. Args: url: The URL to check Returns: Error message if rate limited, None otherwise """ from django.core.cache import cache scanner_config = settings.SCANNER_CONFIG rate_limit_minutes = scanner_config.get('SCAN_RATE_LIMIT_MINUTES', 5) # Create a cache key based on the URL domain = get_domain_from_url(url) cache_key = f"scan_rate_limit:{domain}" # Check if already scanned recently last_scan_time = cache.get(cache_key) if last_scan_time: return ( f"This URL was scanned recently. " f"Please wait {rate_limit_minutes} minutes between scans." ) # Set the rate limit cache.set(cache_key, timezone.now().isoformat(), timeout=rate_limit_minutes * 60) return None def check_concurrent_scan_limit() -> Optional[str]: """ Check if maximum concurrent scans limit is reached. Returns: Error message if limit reached, None otherwise """ scanner_config = settings.SCANNER_CONFIG max_concurrent = scanner_config.get('MAX_CONCURRENT_SCANS', 3) running_count = Scan.objects.filter(status=ScanStatus.RUNNING).count() if running_count >= max_concurrent: return ( f"Maximum concurrent scans ({max_concurrent}) reached. " "Please wait for current scans to complete." ) return None