Middleware¶
Middleware provides a powerful way to intercept and process RPC requests and responses in the Pyvider RPC Plugin framework. This guide covers implementing custom middleware for logging, authentication, rate limiting, and request transformation.
Overview¶
Middleware in Pyvider RPC operates at the gRPC service level, allowing you to:
- Log requests and responses for debugging and monitoring
- Authenticate and authorize incoming requests
- Rate limit clients to prevent abuse
- Transform data before processing or after responses
- Handle errors consistently across services
- Collect metrics for performance monitoring
Basic Middleware Implementation¶
Request Interceptor Middleware¶
import asyncio
import logging
import time
from typing import Any, Callable, Awaitable
from grpc import ServicerContext
from grpc.aio import ServicerInterceptor
logger = logging.getLogger(__name__)
class RequestLoggingInterceptor(ServicerInterceptor):
"""Logs all incoming RPC requests with timing and metadata."""
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
start_time = time.time()
method = handler_call_details.method
logger.info(f"RPC started: {method}")
try:
response = await continuation(handler_call_details)
duration = time.time() - start_time
logger.info(f"RPC completed: {method} ({duration:.3f}s)")
return response
except Exception as e:
duration = time.time() - start_time
logger.error(f"RPC failed: {method} ({duration:.3f}s) - {e}")
raise
class MetricsInterceptor(ServicerInterceptor):
"""Collects performance metrics for RPC calls."""
def __init__(self):
self.request_count = 0
self.error_count = 0
self.total_duration = 0.0
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
self.request_count += 1
start_time = time.time()
try:
response = await continuation(handler_call_details)
duration = time.time() - start_time
self.total_duration += duration
return response
except Exception:
self.error_count += 1
raise
def get_metrics(self) -> dict[str, Any]:
"""Get collected metrics."""
return {
"request_count": self.request_count,
"error_count": self.error_count,
"error_rate": self.error_count / max(self.request_count, 1),
"avg_duration": self.total_duration / max(self.request_count, 1),
}
Authentication Middleware¶
import jwt
from grpc import StatusCode
from grpc.aio import ServicerContext, ServicerInterceptor
from pyvider.exceptions import AuthenticationError
class JWTAuthInterceptor(ServicerInterceptor):
"""Validates JWT tokens in RPC requests."""
def __init__(self, secret_key: str, exempt_methods: set[str] | None = None):
self.secret_key = secret_key
self.exempt_methods = exempt_methods or set()
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
method = handler_call_details.method
# Skip authentication for exempt methods
if method in self.exempt_methods:
return await continuation(handler_call_details)
# Extract token from metadata
context: ServicerContext = handler_call_details.invocation_metadata
auth_header = None
for key, value in context.invocation_metadata():
if key.lower() == 'authorization':
auth_header = value
break
if not auth_header or not auth_header.startswith('Bearer '):
await context.abort(StatusCode.UNAUTHENTICATED, 'Missing or invalid token')
return
token = auth_header[7:] # Remove 'Bearer ' prefix
try:
payload = jwt.decode(token, self.secret_key, algorithms=['HS256'])
# Add user info to context for downstream use
context.user_id = payload.get('user_id')
context.permissions = payload.get('permissions', [])
except jwt.InvalidTokenError as e:
await context.abort(StatusCode.UNAUTHENTICATED, f'Invalid token: {e}')
return
return await continuation(handler_call_details)
class RoleBasedAuthInterceptor(ServicerInterceptor):
"""Enforces role-based access control."""
def __init__(self, method_permissions: dict[str, set[str]]):
self.method_permissions = method_permissions
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
method = handler_call_details.method
context: ServicerContext = handler_call_details.invocation_metadata
# Check if method requires specific permissions
required_perms = self.method_permissions.get(method, set())
if not required_perms:
return await continuation(handler_call_details)
# Get user permissions from context (set by JWT middleware)
user_perms = set(getattr(context, 'permissions', []))
if not required_perms.intersection(user_perms):
await context.abort(
StatusCode.PERMISSION_DENIED,
f'Insufficient permissions for {method}'
)
return
return await continuation(handler_call_details)
Rate Limiting Middleware¶
import asyncio
import time
from collections import defaultdict
from grpc import StatusCode
from grpc.aio import ServicerContext, ServicerInterceptor
class TokenBucketRateLimiter:
"""Token bucket rate limiting implementation."""
def __init__(self, capacity: int, refill_rate: float):
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Attempt to consume tokens. Returns True if successful."""
async with self._lock:
now = time.time()
elapsed = now - self.last_refill
# Refill tokens based on elapsed time
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.refill_rate
)
self.last_refill = now
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
class RateLimitingInterceptor(ServicerInterceptor):
"""Rate limits RPC calls per client."""
def __init__(
self,
requests_per_second: int = 100,
burst_capacity: int = 200,
per_method: bool = False
):
self.requests_per_second = requests_per_second
self.burst_capacity = burst_capacity
self.per_method = per_method
self.limiters: dict[str, TokenBucketRateLimiter] = defaultdict(
lambda: TokenBucketRateLimiter(burst_capacity, requests_per_second)
)
def _get_client_key(self, context: ServicerContext, method: str) -> str:
"""Generate client key for rate limiting."""
# Use peer address as client identifier
peer = context.peer()
if self.per_method:
return f"{peer}:{method}"
return peer
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
method = handler_call_details.method
context: ServicerContext = handler_call_details.invocation_metadata
client_key = self._get_client_key(context, method)
limiter = self.limiters[client_key]
if not await limiter.consume():
await context.abort(
StatusCode.RESOURCE_EXHAUSTED,
'Rate limit exceeded'
)
return
return await continuation(handler_call_details)
Advanced Middleware Patterns¶
Request/Response Transformation¶
from google.protobuf.message import Message
from google.protobuf.json_format import MessageToJson, Parse
class RequestTransformInterceptor(ServicerInterceptor):
"""Transforms request data before processing."""
def __init__(self, transformers: dict[str, Callable[[Any], Any]]):
self.transformers = transformers
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
method = handler_call_details.method
transformer = self.transformers.get(method)
if transformer:
# Transform the request
original_request = handler_call_details.request
transformed_request = transformer(original_request)
handler_call_details = handler_call_details._replace(
request=transformed_request
)
return await continuation(handler_call_details)
class ResponseCacheInterceptor(ServicerInterceptor):
"""Caches responses for idempotent operations."""
def __init__(self, cache_ttl: int = 300):
self.cache: dict[str, tuple[Any, float]] = {}
self.cache_ttl = cache_ttl
def _get_cache_key(self, method: str, request: Message) -> str:
"""Generate cache key from method and request."""
request_json = MessageToJson(request)
return f"{method}:{hash(request_json)}"
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
method = handler_call_details.method
request = handler_call_details.request
# Only cache GET-like operations
if not method.endswith(('Get', 'List', 'Search')):
return await continuation(handler_call_details)
cache_key = self._get_cache_key(method, request)
now = time.time()
# Check cache
if cache_key in self.cache:
cached_response, timestamp = self.cache[cache_key]
if now - timestamp < self.cache_ttl:
return cached_response
else:
del self.cache[cache_key]
# Execute request and cache result
response = await continuation(handler_call_details)
self.cache[cache_key] = (response, now)
return response
Error Handling Middleware¶
from grpc import StatusCode
from pyvider.exceptions import (
ValidationError,
AuthenticationError,
AuthorizationError,
NotFoundError,
InternalError
)
class ErrorHandlingInterceptor(ServicerInterceptor):
"""Standardizes error handling across all RPC methods."""
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
context: ServicerContext = handler_call_details.invocation_metadata
try:
return await continuation(handler_call_details)
except ValidationError as e:
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
except AuthenticationError as e:
await context.abort(StatusCode.UNAUTHENTICATED, str(e))
except AuthorizationError as e:
await context.abort(StatusCode.PERMISSION_DENIED, str(e))
except NotFoundError as e:
await context.abort(StatusCode.NOT_FOUND, str(e))
except InternalError as e:
logger.error(f"Internal error in {handler_call_details.method}: {e}")
await context.abort(StatusCode.INTERNAL, "Internal server error")
except Exception as e:
logger.exception(f"Unexpected error in {handler_call_details.method}")
await context.abort(StatusCode.INTERNAL, "Internal server error")
class CircuitBreakerInterceptor(ServicerInterceptor):
"""Implements circuit breaker pattern for fault tolerance."""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type = Exception
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.failure_count = 0
self.last_failure_time: float | None = None
self.state = 'CLOSED' # CLOSED, OPEN, HALF_OPEN
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
context: ServicerContext = handler_call_details.invocation_metadata
# Check circuit breaker state
if self.state == 'OPEN':
if (self.last_failure_time and
time.time() - self.last_failure_time > self.recovery_timeout):
self.state = 'HALF_OPEN'
else:
await context.abort(
StatusCode.UNAVAILABLE,
'Service temporarily unavailable'
)
return
try:
response = await continuation(handler_call_details)
# Reset on success
if self.state == 'HALF_OPEN':
self.state = 'CLOSED'
self.failure_count = 0
return response
except self.expected_exception:
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = 'OPEN'
raise
Integration with Server¶
Configuring Middleware Stack¶
from pyvider.server import RPCPluginServer
from pyvider.config import ServerConfig
async def create_server_with_middleware():
"""Create server with comprehensive middleware stack."""
# Create interceptors
logging_interceptor = RequestLoggingInterceptor()
metrics_interceptor = MetricsInterceptor()
auth_interceptor = JWTAuthInterceptor(
secret_key="<YOUR_SECRET_KEY>",
exempt_methods={'/grpc.health.v1.Health/Check'}
)
rate_limit_interceptor = RateLimitingInterceptor(
requests_per_second=100,
burst_capacity=200
)
error_interceptor = ErrorHandlingInterceptor()
circuit_breaker = CircuitBreakerInterceptor()
# Configure server with middleware stack
config = ServerConfig(
host="localhost",
port=50051,
interceptors=[
logging_interceptor,
metrics_interceptor,
rate_limit_interceptor,
auth_interceptor,
circuit_breaker,
error_interceptor, # Should be last for proper error handling
]
)
server = RPCPluginServer(config)
return server, metrics_interceptor
# Usage example
async def main():
server, metrics = await create_server_with_middleware()
try:
await server.start()
# Monitor metrics
while True:
await asyncio.sleep(30)
print("Metrics:", metrics.get_metrics())
except KeyboardInterrupt:
await server.stop()
if __name__ == "__main__":
asyncio.run(main())
Conditional Middleware¶
class ConditionalMiddleware(ServicerInterceptor):
"""Applies middleware conditionally based on request attributes."""
def __init__(self, condition: Callable[[Any], bool], middleware: ServicerInterceptor):
self.condition = condition
self.middleware = middleware
async def intercept_service(
self,
continuation: Callable[[Any], Awaitable[Any]],
handler_call_details: Any,
) -> Any:
if self.condition(handler_call_details):
return await self.middleware.intercept_service(continuation, handler_call_details)
return await continuation(handler_call_details)
# Example: Only apply auth to specific methods
def requires_auth(handler_call_details) -> bool:
method = handler_call_details.method
return not method.endswith(('Health/Check', 'Metrics'))
auth_middleware = ConditionalMiddleware(
condition=requires_auth,
middleware=JWTAuthInterceptor("<YOUR_SECRET_KEY>")
)
Testing Middleware¶
Unit Testing¶
import pytest
from unittest.mock import Mock, AsyncMock
from grpc.aio import ServicerContext
@pytest.fixture
def mock_handler_call_details():
"""Mock handler call details for testing."""
details = Mock()
details.method = "/test.Service/TestMethod"
details.request = Mock()
return details
@pytest.fixture
def mock_context():
"""Mock gRPC context for testing."""
context = Mock(spec=ServicerContext)
context.invocation_metadata.return_value = []
context.peer.return_value = "192.168.1.1:12345"
context.abort = AsyncMock()
return context
@pytest.mark.asyncio
async def test_request_logging_interceptor(mock_handler_call_details, mock_context):
"""Test request logging middleware."""
interceptor = RequestLoggingInterceptor()
# Mock continuation
continuation = AsyncMock(return_value="response")
# Execute interceptor
result = await interceptor.intercept_service(continuation, mock_handler_call_details)
# Verify
assert result == "response"
continuation.assert_called_once_with(mock_handler_call_details)
@pytest.mark.asyncio
async def test_rate_limiting_interceptor(mock_handler_call_details, mock_context):
"""Test rate limiting middleware."""
interceptor = RateLimitingInterceptor(requests_per_second=1, burst_capacity=1)
mock_handler_call_details.invocation_metadata = mock_context
continuation = AsyncMock(return_value="response")
# First request should succeed
result1 = await interceptor.intercept_service(continuation, mock_handler_call_details)
assert result1 == "response"
# Second request should be rate limited
await interceptor.intercept_service(continuation, mock_handler_call_details)
mock_context.abort.assert_called_with(
StatusCode.RESOURCE_EXHAUSTED,
'Rate limit exceeded'
)
Integration Testing¶
@pytest.mark.asyncio
async def test_middleware_integration():
"""Test middleware stack integration."""
config = ServerConfig(
host="localhost",
port=0, # Use any available port
interceptors=[
RequestLoggingInterceptor(),
MetricsInterceptor(),
ErrorHandlingInterceptor(),
]
)
server = RPCPluginServer(config)
client = None
try:
await server.start()
client = await create_test_client(server.port)
# Test normal request
response = await client.test_method(TestRequest(data="test"))
assert response.status == "success"
# Test error handling
with pytest.raises(ValidationError):
await client.test_method(TestRequest(data="invalid"))
finally:
if client:
await client.close()
await server.stop()
Best Practices¶
Performance Considerations¶
# Use connection pooling for external services
class DatabaseMiddleware(ServicerInterceptor):
def __init__(self, pool_size: int = 10):
self.pool = asyncpg.create_pool(
"postgresql://...",
min_size=1,
max_size=pool_size
)
async def intercept_service(self, continuation, handler_call_details):
# Attach database connection to context
async with self.pool.acquire() as conn:
handler_call_details.db_connection = conn
return await continuation(handler_call_details)
# Cache expensive computations
class ComputationCacheMiddleware(ServicerInterceptor):
def __init__(self):
self.cache = {}
self.cache_lock = asyncio.Lock()
async def intercept_service(self, continuation, handler_call_details):
# Use weak references to prevent memory leaks
import weakref
self.cache = {k: v for k, v in self.cache.items() if v() is not None}
return await continuation(handler_call_details)
Security Guidelines¶
- Always validate middleware order - Authentication should come before authorization
- Sanitize log output - Don't log sensitive data like tokens or passwords
- Use secure defaults - Rate limiting should be conservative initially
- Implement proper cleanup - Release resources in middleware destructors
- Monitor middleware performance - Track latency added by each middleware
Error Handling¶
class RobustMiddleware(ServicerInterceptor):
"""Template for robust middleware implementation."""
async def intercept_service(self, continuation, handler_call_details):
try:
# Pre-processing logic
await self._pre_process(handler_call_details)
# Execute request
response = await continuation(handler_call_details)
# Post-processing logic
await self._post_process(response)
return response
except Exception as e:
# Log middleware errors separately
logger.error(f"Middleware error: {e}")
# Don't mask the original error
raise
finally:
# Cleanup logic
await self._cleanup()
async def _pre_process(self, handler_call_details):
"""Override in subclasses."""
pass
async def _post_process(self, response):
"""Override in subclasses."""
pass
async def _cleanup(self):
"""Override in subclasses."""
pass
Middleware provides powerful capabilities for cross-cutting concerns in your RPC services. Use the patterns and examples above to build robust, secure, and performant RPC applications with consistent behavior across all methods.