337 lines
9.9 KiB
Python
337 lines
9.9 KiB
Python
"""
|
|
DRF Views for the API.
|
|
|
|
This module defines API views for scans, websites, and issues.
|
|
"""
|
|
|
|
import logging
|
|
from django.db import connection
|
|
from django.utils import timezone
|
|
from django.core.cache import cache
|
|
from rest_framework import viewsets, status, generics
|
|
from rest_framework.decorators import api_view, action
|
|
from rest_framework.response import Response
|
|
from rest_framework.pagination import PageNumberPagination
|
|
from rest_framework.throttling import AnonRateThrottle
|
|
|
|
from websites.models import Website, Scan, Issue, Metric
|
|
from .serializers import (
|
|
WebsiteSerializer,
|
|
WebsiteDetailSerializer,
|
|
ScanListSerializer,
|
|
ScanDetailSerializer,
|
|
ScanCreateSerializer,
|
|
IssueSerializer,
|
|
MetricSerializer,
|
|
HealthCheckSerializer,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ScanRateThrottle(AnonRateThrottle):
|
|
"""Custom throttle for scan creation."""
|
|
rate = '10/hour'
|
|
|
|
|
|
class StandardResultsPagination(PageNumberPagination):
|
|
"""Standard pagination for list views."""
|
|
page_size = 20
|
|
page_size_query_param = 'page_size'
|
|
max_page_size = 100
|
|
|
|
|
|
class ScanViewSet(viewsets.ModelViewSet):
|
|
"""
|
|
ViewSet for Scan operations.
|
|
|
|
Endpoints:
|
|
- POST /api/scans/ - Create a new scan
|
|
- GET /api/scans/ - List all scans
|
|
- GET /api/scans/{id}/ - Get scan details
|
|
- DELETE /api/scans/{id}/ - Delete a scan
|
|
"""
|
|
|
|
queryset = Scan.objects.select_related('website').prefetch_related('issues', 'metrics')
|
|
pagination_class = StandardResultsPagination
|
|
|
|
def get_serializer_class(self):
|
|
if self.action == 'list':
|
|
return ScanListSerializer
|
|
elif self.action == 'create':
|
|
return ScanCreateSerializer
|
|
return ScanDetailSerializer
|
|
|
|
def get_throttles(self):
|
|
if self.action == 'create':
|
|
return [ScanRateThrottle()]
|
|
return super().get_throttles()
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
"""
|
|
Create a new scan.
|
|
|
|
Request body:
|
|
```json
|
|
{"url": "https://example.com"}
|
|
```
|
|
|
|
Returns the created scan with pending status.
|
|
The scan will be processed asynchronously.
|
|
"""
|
|
serializer = self.get_serializer(data=request.data)
|
|
serializer.is_valid(raise_exception=True)
|
|
|
|
try:
|
|
scan = serializer.save()
|
|
|
|
# Return the created scan details
|
|
response_serializer = ScanDetailSerializer(scan)
|
|
return Response(
|
|
response_serializer.data,
|
|
status=status.HTTP_201_CREATED
|
|
)
|
|
except Exception as e:
|
|
logger.exception("Error creating scan")
|
|
return Response(
|
|
{'error': str(e)},
|
|
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
)
|
|
|
|
@action(detail=True, methods=['get'])
|
|
def issues(self, request, pk=None):
|
|
"""Get all issues for a scan."""
|
|
scan = self.get_object()
|
|
issues = scan.issues.all()
|
|
|
|
# Optional filtering
|
|
category = request.query_params.get('category')
|
|
severity = request.query_params.get('severity')
|
|
tool = request.query_params.get('tool')
|
|
|
|
if category:
|
|
issues = issues.filter(category=category)
|
|
if severity:
|
|
issues = issues.filter(severity=severity)
|
|
if tool:
|
|
issues = issues.filter(tool=tool)
|
|
|
|
serializer = IssueSerializer(issues, many=True)
|
|
return Response(serializer.data)
|
|
|
|
@action(detail=True, methods=['get'])
|
|
def metrics(self, request, pk=None):
|
|
"""Get all metrics for a scan."""
|
|
scan = self.get_object()
|
|
metrics = scan.metrics.all()
|
|
|
|
# Optional filtering by source
|
|
source = request.query_params.get('source')
|
|
if source:
|
|
metrics = metrics.filter(source=source)
|
|
|
|
serializer = MetricSerializer(metrics, many=True)
|
|
return Response(serializer.data)
|
|
|
|
@action(detail=True, methods=['get'])
|
|
def status(self, request, pk=None):
|
|
"""Get just the status of a scan (for polling)."""
|
|
scan = self.get_object()
|
|
return Response({
|
|
'id': str(scan.id),
|
|
'status': scan.status,
|
|
'status_display': scan.get_status_display(),
|
|
'progress': self._get_scan_progress(scan),
|
|
})
|
|
|
|
def _get_scan_progress(self, scan):
|
|
"""Estimate scan progress based on status and results."""
|
|
if scan.status == 'done':
|
|
return 100
|
|
elif scan.status == 'failed':
|
|
return 0
|
|
elif scan.status == 'running':
|
|
# Estimate based on what data we have
|
|
progress = 10 # Started
|
|
if scan.raw_headers_data:
|
|
progress += 20
|
|
if scan.raw_playwright_data:
|
|
progress += 25
|
|
if scan.raw_lighthouse_data:
|
|
progress += 30
|
|
if scan.raw_zap_data:
|
|
progress += 15
|
|
return min(progress, 95)
|
|
return 0
|
|
|
|
|
|
class WebsiteViewSet(viewsets.ReadOnlyModelViewSet):
|
|
"""
|
|
ViewSet for Website operations.
|
|
|
|
Endpoints:
|
|
- GET /api/websites/ - List all websites
|
|
- GET /api/websites/{id}/ - Get website details
|
|
- GET /api/websites/{id}/scans/ - Get scans for a website
|
|
"""
|
|
|
|
queryset = Website.objects.prefetch_related('scans')
|
|
pagination_class = StandardResultsPagination
|
|
|
|
def get_serializer_class(self):
|
|
if self.action == 'retrieve':
|
|
return WebsiteDetailSerializer
|
|
return WebsiteSerializer
|
|
|
|
@action(detail=True, methods=['get'])
|
|
def scans(self, request, pk=None):
|
|
"""Get all scans for a website."""
|
|
website = self.get_object()
|
|
scans = website.scans.all()
|
|
|
|
# Apply pagination
|
|
page = self.paginate_queryset(scans)
|
|
if page is not None:
|
|
serializer = ScanListSerializer(page, many=True)
|
|
return self.get_paginated_response(serializer.data)
|
|
|
|
serializer = ScanListSerializer(scans, many=True)
|
|
return Response(serializer.data)
|
|
|
|
|
|
class IssueViewSet(viewsets.ReadOnlyModelViewSet):
|
|
"""
|
|
ViewSet for Issue operations.
|
|
|
|
Endpoints:
|
|
- GET /api/issues/ - List all issues (with filtering)
|
|
- GET /api/issues/{id}/ - Get issue details
|
|
"""
|
|
|
|
queryset = Issue.objects.select_related('scan', 'scan__website')
|
|
serializer_class = IssueSerializer
|
|
pagination_class = StandardResultsPagination
|
|
|
|
def get_queryset(self):
|
|
queryset = super().get_queryset()
|
|
|
|
# Filter by scan
|
|
scan_id = self.request.query_params.get('scan')
|
|
if scan_id:
|
|
queryset = queryset.filter(scan_id=scan_id)
|
|
|
|
# Filter by category
|
|
category = self.request.query_params.get('category')
|
|
if category:
|
|
queryset = queryset.filter(category=category)
|
|
|
|
# Filter by severity
|
|
severity = self.request.query_params.get('severity')
|
|
if severity:
|
|
queryset = queryset.filter(severity=severity)
|
|
|
|
# Filter by tool
|
|
tool = self.request.query_params.get('tool')
|
|
if tool:
|
|
queryset = queryset.filter(tool=tool)
|
|
|
|
return queryset
|
|
|
|
|
|
@api_view(['GET'])
|
|
def health_check(request):
|
|
"""
|
|
Health check endpoint.
|
|
|
|
Checks:
|
|
- Database connectivity
|
|
- Redis connectivity
|
|
- Celery worker status
|
|
|
|
Returns health status of all components.
|
|
"""
|
|
health = {
|
|
'status': 'healthy',
|
|
'database': 'unknown',
|
|
'redis': 'unknown',
|
|
'celery': 'unknown',
|
|
'timestamp': timezone.now(),
|
|
}
|
|
|
|
# Check database
|
|
try:
|
|
connection.ensure_connection()
|
|
health['database'] = 'healthy'
|
|
except Exception as e:
|
|
health['database'] = f'unhealthy: {e}'
|
|
health['status'] = 'unhealthy'
|
|
|
|
# Check Redis
|
|
try:
|
|
cache.set('health_check', 'ok', 10)
|
|
if cache.get('health_check') == 'ok':
|
|
health['redis'] = 'healthy'
|
|
else:
|
|
health['redis'] = 'unhealthy: cache not working'
|
|
health['status'] = 'degraded'
|
|
except Exception as e:
|
|
health['redis'] = f'unhealthy: {e}'
|
|
health['status'] = 'degraded'
|
|
|
|
# Check Celery (basic check)
|
|
try:
|
|
from core.celery import app as celery_app
|
|
inspect = celery_app.control.inspect()
|
|
|
|
# Try to get active workers
|
|
active = inspect.active()
|
|
if active:
|
|
health['celery'] = f'healthy ({len(active)} workers)'
|
|
else:
|
|
health['celery'] = 'degraded: no active workers'
|
|
health['status'] = 'degraded'
|
|
except Exception as e:
|
|
health['celery'] = f'unknown: {e}'
|
|
|
|
status_code = 200 if health['status'] == 'healthy' else 503
|
|
|
|
serializer = HealthCheckSerializer(health)
|
|
return Response(serializer.data, status=status_code)
|
|
|
|
|
|
@api_view(['GET'])
|
|
def api_root(request):
|
|
"""
|
|
API root endpoint.
|
|
|
|
Returns available endpoints and basic API information.
|
|
"""
|
|
return Response({
|
|
'message': 'Website Analyzer API',
|
|
'version': '1.0.0',
|
|
'endpoints': {
|
|
'scans': '/api/scans/',
|
|
'websites': '/api/websites/',
|
|
'issues': '/api/issues/',
|
|
'health': '/api/health/',
|
|
},
|
|
'documentation': {
|
|
'create_scan': {
|
|
'method': 'POST',
|
|
'url': '/api/scans/',
|
|
'body': {'url': 'https://example.com'},
|
|
'description': 'Create a new website scan'
|
|
},
|
|
'get_scan': {
|
|
'method': 'GET',
|
|
'url': '/api/scans/{id}/',
|
|
'description': 'Get scan results and details'
|
|
},
|
|
'list_scans': {
|
|
'method': 'GET',
|
|
'url': '/api/scans/',
|
|
'description': 'List all scans with pagination'
|
|
},
|
|
}
|
|
})
|