secure-web/backend/api/views.py

337 lines
9.9 KiB
Python

"""
DRF Views for the API.
This module defines API views for scans, websites, and issues.
"""
import logging
from django.db import connection
from django.utils import timezone
from django.core.cache import cache
from rest_framework import viewsets, status, generics
from rest_framework.decorators import api_view, action
from rest_framework.response import Response
from rest_framework.pagination import PageNumberPagination
from rest_framework.throttling import AnonRateThrottle
from websites.models import Website, Scan, Issue, Metric
from .serializers import (
WebsiteSerializer,
WebsiteDetailSerializer,
ScanListSerializer,
ScanDetailSerializer,
ScanCreateSerializer,
IssueSerializer,
MetricSerializer,
HealthCheckSerializer,
)
logger = logging.getLogger(__name__)
class ScanRateThrottle(AnonRateThrottle):
"""Custom throttle for scan creation."""
rate = '10/hour'
class StandardResultsPagination(PageNumberPagination):
"""Standard pagination for list views."""
page_size = 20
page_size_query_param = 'page_size'
max_page_size = 100
class ScanViewSet(viewsets.ModelViewSet):
"""
ViewSet for Scan operations.
Endpoints:
- POST /api/scans/ - Create a new scan
- GET /api/scans/ - List all scans
- GET /api/scans/{id}/ - Get scan details
- DELETE /api/scans/{id}/ - Delete a scan
"""
queryset = Scan.objects.select_related('website').prefetch_related('issues', 'metrics')
pagination_class = StandardResultsPagination
def get_serializer_class(self):
if self.action == 'list':
return ScanListSerializer
elif self.action == 'create':
return ScanCreateSerializer
return ScanDetailSerializer
def get_throttles(self):
if self.action == 'create':
return [ScanRateThrottle()]
return super().get_throttles()
def create(self, request, *args, **kwargs):
"""
Create a new scan.
Request body:
```json
{"url": "https://example.com"}
```
Returns the created scan with pending status.
The scan will be processed asynchronously.
"""
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
try:
scan = serializer.save()
# Return the created scan details
response_serializer = ScanDetailSerializer(scan)
return Response(
response_serializer.data,
status=status.HTTP_201_CREATED
)
except Exception as e:
logger.exception("Error creating scan")
return Response(
{'error': str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@action(detail=True, methods=['get'])
def issues(self, request, pk=None):
"""Get all issues for a scan."""
scan = self.get_object()
issues = scan.issues.all()
# Optional filtering
category = request.query_params.get('category')
severity = request.query_params.get('severity')
tool = request.query_params.get('tool')
if category:
issues = issues.filter(category=category)
if severity:
issues = issues.filter(severity=severity)
if tool:
issues = issues.filter(tool=tool)
serializer = IssueSerializer(issues, many=True)
return Response(serializer.data)
@action(detail=True, methods=['get'])
def metrics(self, request, pk=None):
"""Get all metrics for a scan."""
scan = self.get_object()
metrics = scan.metrics.all()
# Optional filtering by source
source = request.query_params.get('source')
if source:
metrics = metrics.filter(source=source)
serializer = MetricSerializer(metrics, many=True)
return Response(serializer.data)
@action(detail=True, methods=['get'])
def status(self, request, pk=None):
"""Get just the status of a scan (for polling)."""
scan = self.get_object()
return Response({
'id': str(scan.id),
'status': scan.status,
'status_display': scan.get_status_display(),
'progress': self._get_scan_progress(scan),
})
def _get_scan_progress(self, scan):
"""Estimate scan progress based on status and results."""
if scan.status == 'done':
return 100
elif scan.status == 'failed':
return 0
elif scan.status == 'running':
# Estimate based on what data we have
progress = 10 # Started
if scan.raw_headers_data:
progress += 20
if scan.raw_playwright_data:
progress += 25
if scan.raw_lighthouse_data:
progress += 30
if scan.raw_zap_data:
progress += 15
return min(progress, 95)
return 0
class WebsiteViewSet(viewsets.ReadOnlyModelViewSet):
"""
ViewSet for Website operations.
Endpoints:
- GET /api/websites/ - List all websites
- GET /api/websites/{id}/ - Get website details
- GET /api/websites/{id}/scans/ - Get scans for a website
"""
queryset = Website.objects.prefetch_related('scans')
pagination_class = StandardResultsPagination
def get_serializer_class(self):
if self.action == 'retrieve':
return WebsiteDetailSerializer
return WebsiteSerializer
@action(detail=True, methods=['get'])
def scans(self, request, pk=None):
"""Get all scans for a website."""
website = self.get_object()
scans = website.scans.all()
# Apply pagination
page = self.paginate_queryset(scans)
if page is not None:
serializer = ScanListSerializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = ScanListSerializer(scans, many=True)
return Response(serializer.data)
class IssueViewSet(viewsets.ReadOnlyModelViewSet):
"""
ViewSet for Issue operations.
Endpoints:
- GET /api/issues/ - List all issues (with filtering)
- GET /api/issues/{id}/ - Get issue details
"""
queryset = Issue.objects.select_related('scan', 'scan__website')
serializer_class = IssueSerializer
pagination_class = StandardResultsPagination
def get_queryset(self):
queryset = super().get_queryset()
# Filter by scan
scan_id = self.request.query_params.get('scan')
if scan_id:
queryset = queryset.filter(scan_id=scan_id)
# Filter by category
category = self.request.query_params.get('category')
if category:
queryset = queryset.filter(category=category)
# Filter by severity
severity = self.request.query_params.get('severity')
if severity:
queryset = queryset.filter(severity=severity)
# Filter by tool
tool = self.request.query_params.get('tool')
if tool:
queryset = queryset.filter(tool=tool)
return queryset
@api_view(['GET'])
def health_check(request):
"""
Health check endpoint.
Checks:
- Database connectivity
- Redis connectivity
- Celery worker status
Returns health status of all components.
"""
health = {
'status': 'healthy',
'database': 'unknown',
'redis': 'unknown',
'celery': 'unknown',
'timestamp': timezone.now(),
}
# Check database
try:
connection.ensure_connection()
health['database'] = 'healthy'
except Exception as e:
health['database'] = f'unhealthy: {e}'
health['status'] = 'unhealthy'
# Check Redis
try:
cache.set('health_check', 'ok', 10)
if cache.get('health_check') == 'ok':
health['redis'] = 'healthy'
else:
health['redis'] = 'unhealthy: cache not working'
health['status'] = 'degraded'
except Exception as e:
health['redis'] = f'unhealthy: {e}'
health['status'] = 'degraded'
# Check Celery (basic check)
try:
from core.celery import app as celery_app
inspect = celery_app.control.inspect()
# Try to get active workers
active = inspect.active()
if active:
health['celery'] = f'healthy ({len(active)} workers)'
else:
health['celery'] = 'degraded: no active workers'
health['status'] = 'degraded'
except Exception as e:
health['celery'] = f'unknown: {e}'
status_code = 200 if health['status'] == 'healthy' else 503
serializer = HealthCheckSerializer(health)
return Response(serializer.data, status=status_code)
@api_view(['GET'])
def api_root(request):
"""
API root endpoint.
Returns available endpoints and basic API information.
"""
return Response({
'message': 'Website Analyzer API',
'version': '1.0.0',
'endpoints': {
'scans': '/api/scans/',
'websites': '/api/websites/',
'issues': '/api/issues/',
'health': '/api/health/',
},
'documentation': {
'create_scan': {
'method': 'POST',
'url': '/api/scans/',
'body': {'url': 'https://example.com'},
'description': 'Create a new website scan'
},
'get_scan': {
'method': 'GET',
'url': '/api/scans/{id}/',
'description': 'Get scan results and details'
},
'list_scans': {
'method': 'GET',
'url': '/api/scans/',
'description': 'List all scans with pagination'
},
}
})