diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index 4557c4584eba..475bc44cd4cb 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +## 2.0.0b2 (2026-04-17) + +### Features Added + +- `InboundRequestLoggingMiddleware` — pure-ASGI middleware wired automatically by `AgentServerHost` that logs every inbound HTTP request. Logs method, path (no query string), status code, duration in milliseconds, and correlation headers (`x-request-id`, `x-ms-client-request-id`). Status codes >= 400 are logged at WARNING; unhandled exceptions are logged as status 500 at WARNING. OpenTelemetry trace ID is included when an active trace exists. + ## 2.0.0b1 (2026-04-14) This is a major architectural rewrite. The package has been redesigned as a lightweight hosting diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py index 4bf1436b7d01..37776153f3a8 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py @@ -26,6 +26,7 @@ from ._base import AgentServerHost from ._config import AgentConfig from ._errors import create_error_response +from ._middleware import InboundRequestLoggingMiddleware from ._server_version import build_server_version from ._tracing import ( configure_observability, @@ -41,6 +42,7 @@ __all__ = [ "AgentConfig", "AgentServerHost", + "InboundRequestLoggingMiddleware", "build_server_version", "configure_observability", "create_error_response", diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py index 0d7fe958061a..d9832a64ca05 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_base.py @@ -17,6 +17,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send from . import _config, _tracing +from ._middleware import InboundRequestLoggingMiddleware from ._server_version import build_server_version from ._version import VERSION as _CORE_VERSION @@ -210,6 +211,7 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF routes=all_routes, lifespan=_lifespan, middleware=[ + Middleware(InboundRequestLoggingMiddleware), Middleware(_PlatformHeaderMiddleware, get_server_version=self._build_server_version), ], **kwargs, diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py new file mode 100644 index 000000000000..a2e6973b3134 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_middleware.py @@ -0,0 +1,135 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Inbound request logging middleware for Azure AI Agent Server hosts. + +A pure-ASGI middleware that logs every inbound HTTP request at INFO level +(start) and at INFO or WARNING level (completion, depending on status code). + +Behaviour: +- Logs method + path (no query string) on start. +- Logs method + path + status code + duration on completion. +- Correlation headers (``x-request-id``, ``x-ms-client-request-id``) are + included when present. +- OTel trace ID is included when an active trace exists. +- Status >= 400 → WARNING; otherwise → INFO. +- Unhandled exceptions → forced status 500, WARNING. +""" + +from __future__ import annotations + +import logging +import time +from typing import Any, MutableMapping + +from starlette.types import ASGIApp, Receive, Scope, Send + +logger = logging.getLogger("azure.ai.agentserver") + + +def _extract_header(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None: + """Extract a header value from raw ASGI headers. + + :param headers: Raw ASGI header tuples. + :type headers: list[tuple[bytes, bytes]] + :param name: Lower-case header name to look up. + :type name: bytes + :return: Decoded header value, or ``None`` if not found. + :rtype: str | None + """ + for key, value in headers: + if key == name: + return value.decode("latin-1") + return None + + +def _get_trace_id() -> str | None: + """Return the current OTel trace ID hex string, or ``None``. + + :return: Hex-encoded trace ID from the current OTel span, or ``None``. + :rtype: str | None + """ + try: + from opentelemetry import trace as _trace # pylint: disable=import-outside-toplevel + + span = _trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.trace_id: + return format(ctx.trace_id, "032x") + except Exception: # pylint: disable=broad-exception-caught + pass + return None + + +class InboundRequestLoggingMiddleware: + """Pure-ASGI middleware that logs inbound HTTP requests. + + Unlike ``BaseHTTPMiddleware``, this passes the ``receive`` callable + through to the inner application untouched, preserving + ``request.is_disconnected()`` behaviour. + + Wired automatically by :class:`AgentServerHost` so that all protocol + hosts (responses, invocations, etc.) get consistent inbound logging. + + :param app: The inner ASGI application. + :type app: ASGIApp + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + method: str = scope.get("method", "?") + path: str = scope.get("path", "/") + raw_headers: list[tuple[bytes, bytes]] = scope.get("headers", []) + + request_id = _extract_header(raw_headers, b"x-request-id") + client_request_id = _extract_header(raw_headers, b"x-ms-client-request-id") + trace_id = _get_trace_id() + + extra_parts: list[str] = [] + if request_id: + extra_parts.append(f"x-request-id={request_id}") + if client_request_id: + extra_parts.append(f"x-ms-client-request-id={client_request_id}") + if trace_id: + extra_parts.append(f"trace_id={trace_id}") + extra_str = f" [{', '.join(extra_parts)}]" if extra_parts else "" + + logger.info("Inbound %s %s started%s", method, path, extra_str) + + status_code: int | None = None + start = time.monotonic() + + async def _send_wrapper(message: MutableMapping[str, Any]) -> None: + nonlocal status_code + if message["type"] == "http.response.start": + status_code = message.get("status", 0) + await send(message) + + try: + await self.app(scope, receive, _send_wrapper) + except Exception: + elapsed_ms = (time.monotonic() - start) * 1000 + logger.warning( + "Inbound %s %s failed with status 500 in %.1fms%s", + method, path, elapsed_ms, extra_str, + ) + raise + + elapsed_ms = (time.monotonic() - start) * 1000 + + if status_code is not None and status_code >= 400: + logger.warning( + "Inbound %s %s completed with status %d in %.1fms%s", + method, path, status_code, elapsed_ms, extra_str, + ) + else: + logger.info( + "Inbound %s %s completed with status %s in %.1fms%s", + method, path, status_code, elapsed_ms, extra_str, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py index 71775f48670c..5469bb13da3d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -VERSION = "2.0.0b1" +VERSION = "2.0.0b2" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md index 6ac960dd0162..2c235349f9de 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-invocations/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +## 1.0.0b2 (2026-04-17) + +### Features Added + +- Inbound request logging — `InboundRequestLoggingMiddleware` from `azure-ai-agentserver-core` is now wired automatically by `AgentServerHost`. All inbound HTTP requests are logged at INFO level (start) and at INFO or WARNING level (completion) with method, path, status code, duration, and correlation headers. + ## 1.0.0b1 (2026-04-14) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_version.py b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_version.py index 67d209a8cafd..58bdf80c74ff 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -VERSION = "1.0.0b1" +VERSION = "1.0.0b2" diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml index a820342e758d..5d2d38a8402a 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-invocations/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ keywords = ["azure", "azure sdk", "agent", "agentserver", "invocations"] dependencies = [ - "azure-ai-agentserver-core>=2.0.0b1", + "azure-ai-agentserver-core>=2.0.0b2", ] [dependency-groups] diff --git a/sdk/agentserver/azure-ai-agentserver-responses/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-responses/CHANGELOG.md index 9df7afc88b2d..e2448b300264 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-responses/CHANGELOG.md @@ -1,5 +1,27 @@ # Release History +## 1.0.0b2 (2026-04-17) + +### Features Added + +- `InboundRequestLoggingMiddleware` moved to `azure-ai-agentserver-core` — pure-ASGI middleware that logs every inbound HTTP request at INFO level (start) and at INFO or WARNING level (completion). Now wired automatically by `AgentServerHost` so all protocol hosts get consistent inbound logging. Includes method, path (no query string), status code, duration in milliseconds, and correlation headers (`x-request-id`, `x-ms-client-request-id`). Status codes >= 400 are logged at WARNING; unhandled exceptions are logged as status 500 at WARNING. OpenTelemetry trace ID is included when an active trace exists. +- Handler-level diagnostic logging — all five endpoint handlers (`POST /responses`, `GET /responses/{id}`, `DELETE /responses/{id}`, `POST /responses/{id}/cancel`, `GET /responses/{id}/input_items`) now emit INFO-level logs at entry and on success, including response ID, status, and output count where applicable. +- Orchestrator handler invocation logging — logs the handler function name and response ID at INFO level before each handler invocation. +- Chat isolation key enforcement — when a response is created with an `x-agent-chat-isolation-key` header, subsequent GET, DELETE, Cancel, and InputItems requests must include the same key. Mismatched or missing keys return an indistinguishable 404 to prevent cross-chat information leakage. Backward-compatible: no enforcement when the response was created without a key. +- Malformed response ID validation — all endpoints that accept a `response_id` path parameter now reject malformed IDs (wrong prefix, too short) with HTTP 400 (`error.type: "invalid_request_error"`, `error.code: "invalid_parameters"`, `param: "responseId{}"`) before touching storage. The `previous_response_id` field in POST body is also validated. +- `FoundryStorageLoggingPolicy` — Azure Core per-retry pipeline policy that logs Foundry storage HTTP calls (method, URI, status code, duration, correlation headers) at the `azure.ai.agentserver` logger. Replaces the built-in `HttpLoggingPolicy` in the Foundry pipeline to provide single-line summaries with duration timing and log-level escalation (WARNING for 4xx/5xx). + +### Bugs Fixed + +- Error `code` field now uses spec-compliant values: `"invalid_request_error"` for 400/404 errors (was `"invalid_request"`, `"not_found"`, or `"invalid_mode"`), `"server_error"` for 500 errors (was `"internal_error"`). +- `RequestValidationError` default code updated from `"invalid_request"` to `"invalid_request_error"`. +- Error responses for deleted resources now correctly return HTTP 404 (was 400). Affects `GET /responses/{id}`, `GET /responses/{id}/input_items`, and `DELETE /responses/{id}` (second delete) on previously deleted responses. +- Cancel on a response in terminal state now returns the spec-compliant message `"Cannot cancel a response in terminal state."` (was `"Cannot cancel an incomplete response."`). +- SSE replay rejection messages now use spec-compliant wording: + - Non-background responses: `"This response cannot be streamed because it was not created with background=true."` + - Background non-stream responses: `"This response cannot be streamed because it was not created with stream=true."` +- Foundry storage errors (`FoundryResourceNotFoundError`, `FoundryBadRequestError`, `FoundryApiError`) are now explicitly caught in endpoint handlers and mapped to appropriate HTTP status codes instead of being swallowed by broad exception handlers. + ## 1.0.0b1 (2026-04-14) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/_version.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/_version.py index cf584760eb91..a19da98e823b 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/_version.py @@ -4,4 +4,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -VERSION = "1.0.0b1" +VERSION = "1.0.0b2" diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_endpoint_handler.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_endpoint_handler.py index 5de67951bc98..26f94c2fb0ab 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_endpoint_handler.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_endpoint_handler.py @@ -31,9 +31,11 @@ from .._options import ResponsesServerOptions from .._response_context import IsolationContext, ResponseContext from ..models._helpers import get_input_expanded, to_output_item +from .._id_generator import IdGenerator from ..models.errors import RequestValidationError from ..models.runtime import ResponseExecution, ResponseModeFlags, build_cancelled_response, build_failed_response from ..store._base import ResponseProviderProtocol, ResponseStreamProviderProtocol +from ..store._foundry_errors import FoundryApiError, FoundryBadRequestError, FoundryResourceNotFoundError from ..streaming._helpers import _encode_sse from ..streaming._sse import encode_sse_any_event from ..streaming._state_machine import _normalize_lifecycle_events @@ -65,6 +67,9 @@ from ._validation import ( invalid_mode_response as _invalid_mode, ) +from ._validation import ( + invalid_parameters_response as _invalid_parameters, +) from ._validation import ( invalid_request_response as _invalid_request, ) @@ -120,6 +125,30 @@ def _extract_isolation(request: Request) -> IsolationContext: ) +def _validate_response_id_format(response_id: str, headers: dict[str, str] | None = None) -> Response | None: + """Validate that a response_id path parameter has the expected ID format. + + Returns a 400 error response if the ID is malformed, or ``None`` if valid. + The error shape follows spec rule B40: ``code: "invalid_parameters"``, + ``param: "responseId{}"``. + + :param response_id: The response ID from the URL path. + :type response_id: str + :param headers: Optional HTTP headers to include on the error response. + :type headers: dict[str, str] | None + :return: A 400 error response if invalid, or ``None`` if valid. + :rtype: Response | None + """ + is_valid, _ = IdGenerator.is_valid(response_id, allowed_prefixes=["caresp"]) + if not is_valid: + return _invalid_parameters( + "Malformed identifier.", + headers or {}, + param=f"responseId{{{response_id}}}", + ) + return None + + # Structured log scope context variables (spec §7.4) _response_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("ResponseId", default="") _conversation_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("ConversationId", default="") @@ -241,6 +270,33 @@ def _safe_set_attrs(span: Any, attrs: dict[str, str]) -> None: except Exception: # pylint: disable=broad-exception-caught logger.debug("Failed to set span attributes: %s", list(attrs.keys()), exc_info=True) + # ------------------------------------------------------------------ + # §8: Session ID response header helper + # ------------------------------------------------------------------ + + def _session_headers(self, session_id: str | None = None) -> dict[str, str]: + """Build response headers including ``x-agent-session-id``. + + Merges the base ``_response_headers`` with the session ID header. + For POST /responses the caller passes the per-request resolved + session ID; other endpoints use the ``FOUNDRY_AGENT_SESSION_ID`` + environment variable via the host config (resolved lazily so the + value is available even when the handler is constructed before the + base class ``__init__``). + + :param session_id: Per-request session ID (overrides env var). + :type session_id: str | None + :return: Headers dict with ``x-agent-session-id`` when available. + :rtype: dict[str, str] + """ + sid = session_id or ( + getattr(getattr(self._host, "config", None), "session_id", "") or "" + ) + headers = dict(self._response_headers) + if sid: + headers["x-agent-session-id"] = sid + return headers + # ------------------------------------------------------------------ # Streaming response helpers # ------------------------------------------------------------------ @@ -440,7 +496,7 @@ async def handle_create(self, request: Request) -> Response: # pylint: disable= :rtype: Response """ if self._is_draining: - return _service_unavailable("Server is shutting down.", {}) + return _service_unavailable("Server is shutting down.", self._session_headers()) # Also maintain CreateSpanHook for backward compat (tests etc.) span = start_create_span( @@ -458,7 +514,7 @@ async def handle_create(self, request: Request) -> Response: # pylint: disable= logger.error("Failed to parse/validate create request", exc_info=exc) captured_error = exc span.end(captured_error) - return _error_response(exc, {}) + return _error_response(exc, self._session_headers()) try: response_id, agent_reference = _resolve_identity_fields( @@ -469,7 +525,7 @@ async def handle_create(self, request: Request) -> Response: # pylint: disable= logger.error("Failed to resolve identity fields", exc_info=exc) captured_error = exc span.end(captured_error) - return _error_response(exc, {}) + return _error_response(exc, self._session_headers()) # B39: Resolve session ID config_session_id = getattr(getattr(self._host, "config", None), "session_id", "") or "" @@ -486,6 +542,18 @@ async def handle_create(self, request: Request) -> Response: # pylint: disable= request=request, ) + logger.info( + "Creating response %s: streaming=%s background=%s store=%s model=%s " + "conversation_id=%s previous_response_id=%s", + ctx.response_id, + ctx.stream, + ctx.background, + ctx.store, + ctx.model, + ctx.conversation_id, + ctx.previous_response_id, + ) + # Extract X-Request-Id header for request ID propagation (truncated to 256 chars). request_id = extract_request_id(request.headers) _project_id = getattr(getattr(self._host, "config", None), "project_id", "") or "" @@ -545,7 +613,7 @@ async def _iter_with_cleanup(): # type: ignore[return] sse_response = StreamingResponse( body_iter, media_type="text/event-stream", - headers=self._sse_headers, + headers={**self._sse_headers, **self._session_headers(agent_session_id)}, ) wrapped = self._wrap_streaming_response(sse_response, otel_span) return wrapped @@ -554,8 +622,14 @@ async def _iter_with_cleanup(): # type: ignore[return] disconnect_task = asyncio.create_task(self._monitor_disconnect(request, ctx.cancellation_signal)) try: snapshot = await self._orchestrator.run_sync(ctx) + logger.info( + "Response %s completed: status=%s output_count=%d", + ctx.response_id, + snapshot.get("status"), + len(snapshot.get("output", [])), + ) end_span(otel_span) - return JSONResponse(snapshot, status_code=200) + return JSONResponse(snapshot, status_code=200, headers=self._session_headers(agent_session_id)) except _HandlerError as exc: logger.error( "Handler error in sync create (response_id=%s)", @@ -575,17 +649,22 @@ async def _iter_with_cleanup(): # type: ignore[return] "error": { "message": "internal server error", "type": "server_error", - "code": "internal_error", + "code": "server_error", "param": None, } } - return JSONResponse(err_body, status_code=500) + return JSONResponse(err_body, status_code=500, headers=self._session_headers(agent_session_id)) finally: disconnect_task.cancel() snapshot = await self._orchestrator.run_background(ctx) + logger.info( + "Background response created for %s: status=%s", + ctx.response_id, + snapshot.get("status"), + ) end_span(otel_span) - return JSONResponse(snapshot, status_code=200, headers=self._response_headers) + return JSONResponse(snapshot, status_code=200, headers=self._session_headers(agent_session_id)) except _HandlerError as exc: logger.error("Handler error in create (response_id=%s)", ctx.response_id, exc_info=exc.original) self._safe_set_attrs( @@ -601,14 +680,14 @@ async def _iter_with_cleanup(): # type: ignore[return] "error": { "message": "internal server error", "type": "server_error", - "code": "internal_error", + "code": "server_error", "param": None, } } return JSONResponse( err_body, status_code=500, - headers=self._response_headers, + headers=self._session_headers(agent_session_id), ) except Exception as exc: # pylint: disable=broad-exception-caught logger.error("Unexpected error in create (response_id=%s)", ctx.response_id, exc_info=exc) @@ -635,7 +714,7 @@ async def _iter_with_cleanup(): # type: ignore[return] except ValueError: pass - async def handle_get(self, request: Request) -> Response: + async def handle_get(self, request: Request) -> Response: # pylint: disable=too-many-branches """Route handler for ``GET /responses/{response_id}``. Returns the response snapshot or replays SSE events if @@ -647,55 +726,124 @@ async def handle_get(self, request: Request) -> Response: :rtype: Response """ response_id = request.path_params["response_id"] + _hdrs = self._session_headers() + format_error = _validate_response_id_format(response_id, _hdrs) + if format_error is not None: + return format_error + + stream_replay_param = request.query_params.get("stream", "false").lower() == "true" + if stream_replay_param: + logger.info("Getting response %s with SSE replay", response_id) + else: + logger.info("Getting response %s", response_id) + + _isolation = _extract_isolation(request) record = await self._runtime_state.get(response_id) if record is None: if await self._runtime_state.is_deleted(response_id): - return _deleted_response(response_id, {}) + return _deleted_response(response_id, _hdrs) - _isolation = _extract_isolation(request) - stream_replay = request.query_params.get("stream", "false").lower() == "true" + # Chat isolation enforcement for evicted/restarted responses + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) + + stream_replay = stream_replay_param if not stream_replay: # Provider fallback: serve completed responses that are no longer in runtime state # (e.g., after a process restart). try: response_obj = await self._provider.get_response(response_id, isolation=_isolation) snapshot = response_obj.as_dict() - return JSONResponse(snapshot, status_code=200) + logger.info( + "Retrieved response %s: status=%s output_count=%d", + response_id, + snapshot.get("status"), + len(snapshot.get("output", [])), + ) + return JSONResponse(snapshot, status_code=200, headers=_hdrs) + except FoundryResourceNotFoundError: + pass # Fall through to 404 below + except FoundryBadRequestError as exc: + return _invalid_request(str(exc), _hdrs, param="response_id") + except FoundryApiError as exc: + logger.error("Storage API error for GET response_id=%s: %s", response_id, exc, exc_info=True) + return _error_response(exc, _hdrs) except Exception: # pylint: disable=broad-exception-caught logger.warning("Provider fallback failed for GET response_id=%s", response_id, exc_info=True) else: + # Validate starting_after cursor early — invalid cursors must + # always get param=starting_after regardless of stream availability. + parsed_cursor = self._parse_starting_after(request) + if isinstance(parsed_cursor, Response): + return parsed_cursor + # Stream provider fallback: replay persisted SSE events when runtime state is gone. replay_response = await self._try_replay_persisted_stream(request, response_id, isolation=_isolation) if replay_response is not None: return replay_response - # Response may exist in storage but wasn't replay-eligible - # (e.g., created without background=true, stream=true, store=true). + # No stream events available. Check the persisted response's + # background flag; if not bg, give the clear non-bg error. + # Otherwise, we can't distinguish bg+non-stream from + # bg+stream-with-expired-TTL (we don't persist the stream flag), + # so use a combined message. try: - await self._provider.get_response(response_id, isolation=_isolation) + persisted = await self._provider.get_response(response_id, isolation=_isolation) + persisted_dict = persisted.as_dict() + # B2: SSE replay requires background mode. + if persisted_dict.get("background") is not True: + return _invalid_mode( + "This response cannot be streamed because it was not created with background=true.", + _hdrs, + param="stream", + ) + # TODO: The container spec prescribes distinct error messages for + # "not created with stream=true" vs "stream TTL expired", but after + # eager eviction the persisted response does not carry the stream + # mode flag — we cannot distinguish the two cases. Until the + # provider surfaces the reason, we use a combined message. return _invalid_mode( - "stream replay is not available for this response; to enable SSE replay, " - + "create the response with background=true, stream=true, and store=true", - {}, + "This response cannot be streamed because it was not created " + "with stream=true or the stream TTL has expired.", + _hdrs, param="stream", ) + except FoundryResourceNotFoundError: + pass # Response doesn't exist in provider either — fall through to 404 + except FoundryBadRequestError as exc: + return _invalid_request(str(exc), _hdrs, param="response_id") + except FoundryApiError as exc: + logger.error( + "Storage API error for GET SSE replay response_id=%s: %s", + response_id, exc, exc_info=True, + ) + return _error_response(exc, _hdrs) except Exception: # pylint: disable=broad-exception-caught pass # Response doesn't exist in provider either — fall through to 404 - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) + + # Chat isolation enforcement on in-flight response + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) _refresh_background_status(record) - stream_replay = request.query_params.get("stream", "false").lower() == "true" + stream_replay = stream_replay_param if stream_replay: # B14: store=false responses are never persisted — return 404. if not record.mode_flags.store: - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) if not record.replay_enabled: + if not record.mode_flags.background: + return _invalid_mode( + "This response cannot be streamed because it was not created with background=true.", + _hdrs, + param="stream", + ) return _invalid_mode( - "stream replay is not available for this response; to enable SSE replay, " - + "create the response with background=true, stream=true, and store=true", - {}, + "This response cannot be streamed because it was not created with stream=true.", + _hdrs, param="stream", ) @@ -706,9 +854,16 @@ async def handle_get(self, request: Request) -> Response: return self._build_live_stream_response(record, parsed_cursor) if not record.visible_via_get: - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) - return JSONResponse(_RuntimeState.to_snapshot(record), status_code=200, headers=self._response_headers) + snapshot = _RuntimeState.to_snapshot(record) + logger.info( + "Retrieved response %s: status=%s output_count=%d", + response_id, + snapshot.get("status"), + len(snapshot.get("output", [])), + ) + return JSONResponse(snapshot, status_code=200, headers=_hdrs) @staticmethod def _parse_starting_after(request: Request) -> int | Response: @@ -798,26 +953,81 @@ async def handle_delete(self, request: Request) -> Response: :rtype: Response """ response_id = request.path_params["response_id"] + _hdrs = self._session_headers() + format_error = _validate_response_id_format(response_id, _hdrs) + if format_error is not None: + return format_error + + logger.info("Deleting response %s", response_id) + + _isolation = _extract_isolation(request) record = await self._runtime_state.get(response_id) if record is None: - return _not_found(response_id, {}) + # Provider fallback: response may have been evicted from memory after + # reaching terminal state, or the server restarted since creation. + if await self._runtime_state.is_deleted(response_id): + return _not_found(response_id, _hdrs) + + # Chat isolation enforcement for evicted/restarted responses + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) + + try: + await self._provider.delete_response(response_id, isolation=_isolation) + # Clean up persisted stream events + if self._stream_provider is not None: + try: + await self._stream_provider.delete_stream_events(response_id, isolation=_isolation) + except Exception: # pylint: disable=broad-exception-caught + logger.debug( + "Best-effort stream event delete failed for response_id=%s", + response_id, + exc_info=True, + ) + # Mark as deleted in runtime state so subsequent requests get 404 + await self._runtime_state.mark_deleted(response_id) + logger.info("Deleted response %s", response_id) + return JSONResponse( + {"id": response_id, "object": "response", "deleted": True}, + status_code=200, + headers=_hdrs, + ) + except FoundryResourceNotFoundError: + pass # Fall through to 404 below + except FoundryBadRequestError as exc: + return _invalid_request(str(exc), _hdrs, param="response_id") + except FoundryApiError as exc: + logger.error("Storage API error for DELETE response_id=%s: %s", response_id, exc, exc_info=True) + return _error_response(exc, _hdrs) + except Exception: # pylint: disable=broad-exception-caught + logger.debug( + "Provider fallback failed for DELETE response_id=%s", + response_id, + exc_info=True, + ) + + return _not_found(response_id, _hdrs) + + # Chat isolation enforcement + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) # store=false responses are not deletable (FR-014) if not record.mode_flags.store: - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) _refresh_background_status(record) if record.mode_flags.background and record.status in {"queued", "in_progress"}: return _invalid_request( "Cannot delete an in-flight response.", - {}, + _hdrs, param="response_id", ) deleted = await self._runtime_state.delete(response_id) if not deleted: - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) if record.mode_flags.store: try: @@ -838,9 +1048,11 @@ async def handle_delete(self, request: Request) -> Response: exc_info=True, ) + logger.info("Deleted response %s", response_id) return JSONResponse( {"id": response_id, "object": "response", "deleted": True}, status_code=200, + headers=_hdrs, ) async def handle_cancel(self, request: Request) -> Response: @@ -852,52 +1064,88 @@ async def handle_cancel(self, request: Request) -> Response: :rtype: Response """ response_id = request.path_params["response_id"] + _hdrs = self._session_headers() + format_error = _validate_response_id_format(response_id, _hdrs) + if format_error is not None: + return format_error + + logger.info("Cancelling response %s", response_id) + + _isolation = _extract_isolation(request) record = await self._runtime_state.get(response_id) if record is None: # Provider fallback: after a restart, stored terminal responses lose # their runtime records. Check the provider so we return the correct # 400 error instead of a misleading 404. + + # Chat isolation enforcement for evicted/restarted responses + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) + try: - response_obj = await self._provider.get_response(response_id, isolation=_extract_isolation(request)) - stored_status = response_obj.as_dict().get("status") + response_obj = await self._provider.get_response(response_id, isolation=_isolation) + persisted = response_obj.as_dict() + + # B1: background check comes first — non-bg responses always + # get the "synchronous" message regardless of terminal status. + if persisted.get("background") is not True: + return _invalid_request( + "Cannot cancel a synchronous response.", + _hdrs, + param="response_id", + ) + + stored_status = persisted.get("status") if stored_status == "completed": return _invalid_request( "Cannot cancel a completed response.", - {}, + _hdrs, param="response_id", ) if stored_status == "failed": return _invalid_request( "Cannot cancel a failed response.", - {}, + _hdrs, param="response_id", ) if stored_status == "cancelled": - return _invalid_request( - "Cannot cancel an already cancelled response.", - {}, - param="response_id", + # Idempotent: already cancelled — return the stored snapshot + return JSONResponse( + persisted, + status_code=200, + headers=_hdrs, ) if stored_status == "incomplete": return _invalid_request( - "Cannot cancel an incomplete response.", - {}, + "Cannot cancel a response in terminal state.", + _hdrs, param="response_id", ) + except FoundryResourceNotFoundError: + pass # Fall through to 404 below + except FoundryBadRequestError as exc: + return _invalid_request(str(exc), _hdrs, param="response_id") + except FoundryApiError as exc: + logger.error("Storage API error for cancel response_id=%s: %s", response_id, exc, exc_info=True) + return _error_response(exc, _hdrs) except Exception: # pylint: disable=broad-exception-caught logger.debug( "Provider fallback failed for cancel response_id=%s", response_id, exc_info=True, ) - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) + + # Chat isolation enforcement on in-flight response + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) _refresh_background_status(record) if not record.mode_flags.background: return _invalid_request( "Cannot cancel a synchronous response.", - {}, + _hdrs, param="response_id", ) @@ -906,26 +1154,26 @@ async def handle_cancel(self, request: Request) -> Response: record.set_response_snapshot( build_cancelled_response(record.response_id, record.agent_reference, record.model) ) - return JSONResponse(_RuntimeState.to_snapshot(record), status_code=200, headers=self._response_headers) + return JSONResponse(_RuntimeState.to_snapshot(record), status_code=200, headers=_hdrs) if record.status == "completed": return _invalid_request( "Cannot cancel a completed response.", - {}, + _hdrs, param="response_id", ) if record.status == "failed": return _invalid_request( "Cannot cancel a failed response.", - {}, + _hdrs, param="response_id", ) if record.status == "incomplete": return _invalid_request( - "Cannot cancel an incomplete response.", - {}, + "Cannot cancel a response in terminal state.", + _hdrs, param="response_id", ) @@ -943,6 +1191,10 @@ async def handle_cancel(self, request: Request) -> Response: # Set cancelled snapshot and transition record.set_response_snapshot(build_cancelled_response(record.response_id, record.agent_reference, record.model)) + # Stamp mode flags so the provider fallback can enforce B1/B2 checks + # after eager eviction removes the in-memory record. + if record.response is not None: + record.response.background = record.mode_flags.background record.transition_to("cancelled") # Persist cancelled state to durable store (B11: cancellation always wins) @@ -952,7 +1204,14 @@ async def handle_cancel(self, request: Request) -> Response: except Exception: # pylint: disable=broad-exception-caught logger.debug("Best-effort cancel persist failed for response_id=%s", record.response_id, exc_info=True) - return JSONResponse(_RuntimeState.to_snapshot(record), status_code=200, headers=self._response_headers) + # Build snapshot before eviction removes the record from memory + snapshot = _RuntimeState.to_snapshot(record) + + # Eager eviction: free memory now that the terminal state is persisted + await self._runtime_state.try_evict(record.response_id) + + logger.info("Cancelled response %s, status=%s", response_id, snapshot.get("status")) + return JSONResponse(snapshot, status_code=200, headers=_hdrs) async def handle_input_items(self, request: Request) -> Response: """Route handler for ``GET /responses/{response_id}/input_items``. @@ -965,37 +1224,55 @@ async def handle_input_items(self, request: Request) -> Response: :rtype: Response """ response_id = request.path_params["response_id"] + _hdrs = self._session_headers() + format_error = _validate_response_id_format(response_id, _hdrs) + if format_error is not None: + return format_error + + logger.info("Getting input items for response %s", response_id) + + # Chat isolation enforcement + _isolation = _extract_isolation(request) + if not self._runtime_state.check_chat_isolation(response_id, _isolation.chat_key): + return _not_found(response_id, _hdrs) limit_raw = request.query_params.get("limit", "20") try: limit = int(limit_raw) except ValueError: - return _invalid_request("limit must be an integer between 1 and 100", {}, param="limit") + return _invalid_request("limit must be an integer between 1 and 100", _hdrs, param="limit") if limit < 1 or limit > 100: - return _invalid_request("limit must be between 1 and 100", {}, param="limit") + return _invalid_request("limit must be between 1 and 100", _hdrs, param="limit") order = request.query_params.get("order", "desc").lower() if order not in {"asc", "desc"}: - return _invalid_request("order must be 'asc' or 'desc'", {}, param="order") + return _invalid_request("order must be 'asc' or 'desc'", _hdrs, param="order") after = request.query_params.get("after") before = request.query_params.get("before") try: items = await self._provider.get_input_items( - response_id, limit=100, ascending=True, isolation=_extract_isolation(request) + response_id, limit=100, ascending=True, isolation=_isolation ) except ValueError: - return _deleted_response(response_id, {}) + return _deleted_response(response_id, _hdrs) + except FoundryResourceNotFoundError: + return _not_found(response_id, _hdrs) + except FoundryBadRequestError as exc: + return _invalid_request(str(exc), _hdrs, param="response_id") + except FoundryApiError as exc: + logger.error("Storage API error for input_items response_id=%s: %s", response_id, exc, exc_info=True) + return _error_response(exc, _hdrs) except KeyError: # Fall back to runtime_state for in-flight responses not yet persisted to provider try: items = await self._runtime_state.get_input_items(response_id) except ValueError: - return _deleted_response(response_id, {}) + return _deleted_response(response_id, _hdrs) except KeyError: - return _not_found(response_id, {}) + return _not_found(response_id, _hdrs) ordered_items = items if order == "asc" else list(reversed(items)) ordered_dicts: list[dict[str, Any]] = [ @@ -1020,6 +1297,7 @@ async def handle_input_items(self, request: Request) -> Response: "has_more": has_more, }, status_code=200, + headers=_hdrs, ) async def handle_shutdown(self) -> None: diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_orchestrator.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_orchestrator.py index bd3b033ce0e3..7095b871752a 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_orchestrator.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_orchestrator.py @@ -180,6 +180,7 @@ async def _run_background_non_stream( # pylint: disable=too-many-locals,too-man agent_session_id: str | None = None, conversation_id: str | None = None, history_limit: int = 100, + runtime_state: _RuntimeState | None = None, ) -> None: """Execute a non-stream handler in the background and update the execution record. @@ -213,6 +214,8 @@ async def _run_background_non_stream( # pylint: disable=too-many-locals,too-man :keyword type conversation_id: str | None :keyword history_limit: Maximum number of history items to include. :keyword type history_limit: int + :keyword runtime_state: Runtime state tracker for eager eviction after persist. + :keyword type runtime_state: _RuntimeState | None :return: None :rtype: None """ @@ -401,6 +404,9 @@ async def _run_background_non_stream( # pylint: disable=too-many-locals,too-man agent_session_id=agent_session_id, conversation_id=conversation_id, ) + # Stamp background so the provider fallback can enforce B1 checks + # after eager eviction removes the in-memory record. + response_payload["background"] = record.mode_flags.background resolved_status = response_payload.get("status") if record.status != "cancelled": @@ -415,6 +421,11 @@ async def _run_background_non_stream( # pylint: disable=too-many-locals,too-man finally: # Always unblock run_background (idempotent if already set) record.response_created_signal.set() + # Stamp mode flags so the provider fallback can enforce B1/B2 checks + # after eager eviction removes the in-memory record. This covers + # all code paths (normal completion, handler failure, cancellation). + if record.response is not None: + record.response.background = record.mode_flags.background # Persist terminal state update via provider (bg non-stream: update after runner completes) if store and provider is not None and record.status not in {"cancelled"} and record.response is not None: try: @@ -433,6 +444,9 @@ async def _run_background_non_stream( # pylint: disable=too-many-locals,too-man response_id, exc_info=True, ) + # Eager eviction: free memory once terminal state is persisted (or store=False). + if runtime_state is not None and record.is_terminal: + await runtime_state.try_evict(response_id) def _refresh_background_status(record: ResponseExecution) -> None: @@ -675,6 +689,9 @@ async def _register_bg_execution( agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, ) + # Stamp mode flags so the provider fallback can enforce B1/B2 checks + # after eager eviction removes the in-memory record. + initial_payload["background"] = True initial_status = initial_payload.get("status") if not isinstance(initial_status, str): initial_status = "in_progress" @@ -687,6 +704,7 @@ async def _register_bg_execution( cancel_signal=ctx.cancellation_signal, agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, + chat_isolation_key=ctx.chat_isolation_key, ) execution.set_response_snapshot(generated_models.ResponseObject(initial_payload)) execution.subject = _ResponseEventSubject() @@ -978,6 +996,7 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) # --- Path A: BG with pre-existing record (normal bg+stream completion) --- if ctx.background and ctx.store and state.bg_record is not None: record = state.bg_record + events: list[generated_models.ResponseStreamEvent] = [] # B11: When status is already "cancelled" (set by the cancel endpoint), # skip snapshot/status update — cancellation always wins. But still @@ -1024,6 +1043,9 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) # Always persist — including cancelled state — so the durable store # reflects the final status. if record.mode_flags.store and record.response is not None: + # Stamp mode flags so the provider fallback can enforce B1/B2 checks + # after eager eviction removes the in-memory record. + record.response.background = record.mode_flags.background _isolation = ctx.context.isolation if ctx.context else None try: await self._provider.update_response(record.response, isolation=_isolation) @@ -1033,11 +1055,13 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) ctx.response_id, exc_info=True, ) - # Persist SSE events for replay after process restart (not needed for cancelled) - if record.status != "cancelled" and self._stream_provider is not None and state.handler_events: + # Persist SSE events for replay after process restart (not needed for cancelled). + # Use ``events`` (not ``state.handler_events``) so that fallback events + # generated by ``_build_events`` are saved when the handler yielded nothing. + if record.status != "cancelled" and self._stream_provider is not None and events: try: await self._stream_provider.save_stream_events( - ctx.response_id, state.handler_events, isolation=_isolation + ctx.response_id, events, isolation=_isolation ) except Exception: # pylint: disable=broad-exception-caught logger.warning( @@ -1054,6 +1078,9 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) await record.subject.complete() except Exception: # pylint: disable=broad-exception-caught pass # best effort + # Eager eviction: free memory once terminal state is persisted. + if record.is_terminal: + await self._runtime_state.try_evict(ctx.response_id) return # --- Path B: No pre-existing record --- @@ -1086,6 +1113,9 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, ) + # Stamp background so the provider fallback can enforce B1 checks + # after eager eviction removes the in-memory record. + response_payload["background"] = ctx.background resolved_status = response_payload.get("status") final_status: ResponseStatus = ( cast(ResponseStatus, resolved_status) if isinstance(resolved_status, str) else "completed" @@ -1109,13 +1139,14 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) cancel_signal=ctx.cancellation_signal if ctx.background else None, agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, + chat_isolation_key=ctx.chat_isolation_key, ) execution.set_response_snapshot(generated_models.ResponseObject(response_payload)) await self._runtime_state.add(execution) if ctx.store: + _isolation = ctx.context.isolation if ctx.context else None try: - _isolation = ctx.context.isolation if ctx.context else None _history_ids = ( await self._provider.get_history_item_ids( ctx.previous_response_id, @@ -1139,8 +1170,25 @@ async def _finalize_stream(self, ctx: _ExecutionContext, state: _PipelineState) exc_info=True, ) + # Persist SSE events for replay after eager eviction (bg+stream only). + if ctx.background and self._stream_provider is not None and events: + try: + await self._stream_provider.save_stream_events( + ctx.response_id, events, isolation=_isolation + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Best-effort stream event persistence failed (response_id=%s)", + ctx.response_id, + exc_info=True, + ) + ctx.span.end(state.captured_error) + # Eager eviction: free memory once terminal state is persisted (or store=False). + if execution.is_terminal: + await self._runtime_state.try_evict(ctx.response_id) + # ------------------------------------------------------------------ # Public execution methods # ------------------------------------------------------------------ @@ -1178,6 +1226,8 @@ async def _live_stream(self, ctx: _ExecutionContext) -> AsyncIterator[str]: """ new_stream_counter() state = _PipelineState() + _handler_name = getattr(self._create_fn, "__qualname__", None) or getattr(self._create_fn, "__name__", "unknown") + logger.info("Invoking handler %s for response %s", _handler_name, ctx.response_id) handler_iterator = self._create_fn(ctx.parsed, ctx.context, ctx.cancellation_signal) # Helper: route to the right finalize method based on the request semantics @@ -1341,6 +1391,8 @@ async def run_sync(self, ctx: _ExecutionContext) -> dict[str, Any]: :raises _HandlerError: If the handler raises during iteration. """ state = _PipelineState() + _handler_name = getattr(self._create_fn, "__qualname__", None) or getattr(self._create_fn, "__name__", "unknown") + logger.info("Invoking handler %s for response %s", _handler_name, ctx.response_id) handler_iterator = self._create_fn(ctx.parsed, ctx.context, ctx.cancellation_signal) # _process_handler_events handles all error paths (B8, S-035, S-015, B11). # run_sync only needs to exhaust the generator for state.handler_events side-effects. @@ -1375,6 +1427,9 @@ async def run_sync(self, ctx: _ExecutionContext) -> dict[str, Any]: agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, ) + # Stamp background so the provider fallback can enforce B1 checks + # after eager eviction removes the in-memory record. + response_payload["background"] = ctx.background resolved_status = response_payload.get("status") status = cast(ResponseStatus, resolved_status) if isinstance(resolved_status, str) else "completed" @@ -1387,6 +1442,7 @@ async def run_sync(self, ctx: _ExecutionContext) -> dict[str, Any]: response_context=ctx.context, agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, + chat_isolation_key=ctx.chat_isolation_key, ) record.set_response_snapshot(generated_models.ResponseObject(response_payload)) @@ -1423,6 +1479,10 @@ async def run_sync(self, ctx: _ExecutionContext) -> dict[str, Any]: exc_info=True, ) + # Eager eviction: free memory once terminal state is persisted (or store=False). + if record.is_terminal: + await self._runtime_state.try_evict(ctx.response_id) + ctx.span.end(None) return _RuntimeState.to_snapshot(record) @@ -1452,6 +1512,7 @@ async def run_background(self, ctx: _ExecutionContext) -> dict[str, Any]: initial_agent_reference=ctx.agent_reference, agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, + chat_isolation_key=ctx.chat_isolation_key, ) # Register so GET can observe in-flight state @@ -1481,6 +1542,7 @@ async def _shielded_runner() -> None: agent_session_id=ctx.agent_session_id, conversation_id=ctx.conversation_id, history_limit=self._runtime_options.default_fetch_history_count, + runtime_state=self._runtime_state, ) except asyncio.CancelledError: pass # event-loop teardown; background work already done diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_runtime_state.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_runtime_state.py index 9777f7bfc05d..541c4165c64a 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_runtime_state.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_runtime_state.py @@ -20,6 +20,7 @@ def __init__(self) -> None: """Initialize the runtime state with empty record and deletion sets.""" self._records: dict[str, ResponseExecution] = {} self._deleted_response_ids: set[str] = set() + self._chat_isolation_keys: dict[str, str] = {} self._lock = asyncio.Lock() async def add(self, record: ResponseExecution) -> None: @@ -33,6 +34,8 @@ async def add(self, record: ResponseExecution) -> None: async with self._lock: self._records[record.response_id] = record self._deleted_response_ids.discard(record.response_id) + if record.chat_isolation_key is not None: + self._chat_isolation_keys[record.response_id] = record.chat_isolation_key async def get(self, response_id: str) -> ResponseExecution | None: """Look up an execution record by response ID. @@ -69,8 +72,72 @@ async def delete(self, response_id: str) -> bool: if record is None: return False self._deleted_response_ids.add(response_id) + self._chat_isolation_keys.pop(response_id, None) return True + _TERMINAL_STATUSES = frozenset({"completed", "failed", "cancelled", "incomplete"}) + + async def try_evict(self, response_id: str) -> bool: + """Evict a terminal record from in-memory state to free memory. + + Unlike :meth:`delete`, eviction does **not** mark the response as + deleted — it simply removes the runtime record so that subsequent + requests fall through to the durable provider (storage). + + Only records in a terminal status are evicted. Non-terminal records + are left untouched so that in-flight operations remain correct. + + :param response_id: The response ID to evict. + :type response_id: str + :return: ``True`` if the record was evicted, ``False`` otherwise. + :rtype: bool + """ + async with self._lock: + record = self._records.get(response_id) + if record is None: + return False + if record.status not in self._TERMINAL_STATUSES: + return False + del self._records[response_id] + # NOTE: chat isolation keys are intentionally preserved so that + # provider fallback paths can still enforce isolation after eviction. + return True + + async def mark_deleted(self, response_id: str) -> None: + """Mark a response ID as deleted without requiring a runtime record. + + Used by the delete handler's provider fallback path when the record + has already been evicted from memory but still exists in durable storage. + + :param response_id: The response ID to mark as deleted. + :type response_id: str + :return: None + :rtype: None + """ + async with self._lock: + self._deleted_response_ids.add(response_id) + self._chat_isolation_keys.pop(response_id, None) + + def check_chat_isolation(self, response_id: str, request_chat_key: str | None) -> bool: + """Check whether the request chat key matches the creation-time key. + + Returns ``True`` if the request is allowed, ``False`` if it should be + rejected as not-found to prevent cross-chat information leakage. + + No enforcement when the response was created without a key (backward compat). + + :param response_id: The response ID to check. + :type response_id: str + :param request_chat_key: The chat key from the incoming request, or ``None``. + :type request_chat_key: str | None + :return: ``True`` if allowed, ``False`` if isolation mismatch. + :rtype: bool + """ + stored_key = self._chat_isolation_keys.get(response_id) + if stored_key is None: + return True # No enforcement when created without a key + return stored_key == request_chat_key + async def get_input_items(self, response_id: str) -> list[OutputItem]: """Retrieve the full input item chain for a response, including ancestors. diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_validation.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_validation.py index 23512821553b..aaee6180396e 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_validation.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_validation.py @@ -8,6 +8,7 @@ from starlette.responses import JSONResponse +from azure.ai.agentserver.responses._id_generator import IdGenerator from azure.ai.agentserver.responses._options import ResponsesServerOptions from azure.ai.agentserver.responses.models._generated import ApiErrorResponse, CreateResponse, Error from azure.ai.agentserver.responses.models._generated._validators import validate_CreateResponse @@ -124,6 +125,17 @@ def validate_create_response(request: CreateResponse) -> None: param="metadata", ) + # Validate previous_response_id format (must be a valid caresp ID) + prev_id = getattr(request, "previous_response_id", None) + if isinstance(prev_id, str) and prev_id: + is_valid, _ = IdGenerator.is_valid(prev_id, allowed_prefixes=["caresp"]) + if not is_valid: + raise RequestValidationError( + "Malformed identifier.", + code="invalid_parameters", + param="previous_response_id", + ) + def parse_and_validate_create_response( payload: Mapping[str, Any], @@ -199,7 +211,7 @@ def build_not_found_error_response( """ return build_api_error_response( message=f"{resource_name} '{resource_id}' was not found", - code="not_found", + code="invalid_request_error", param=param, error_type="invalid_request_error", ) @@ -221,7 +233,7 @@ def build_invalid_mode_error_response( """ return build_api_error_response( message=message, - code="invalid_mode", + code="invalid_request_error", param=param, error_type="invalid_request_error", ) @@ -241,13 +253,13 @@ def to_api_error_response(error: Exception) -> ApiErrorResponse: if isinstance(error, ValueError): return build_api_error_response( message=str(error) or "invalid request", - code="invalid_request", + code="invalid_request_error", error_type="invalid_request_error", ) return build_api_error_response( message="internal server error", - code="internal_error", + code="server_error", error_type="server_error", ) @@ -327,7 +339,7 @@ def not_found_response(response_id: str, headers: dict[str, str]) -> JSONRespons """ return _api_error( message=f"Response with id '{response_id}' not found.", - code="invalid_request", + code="invalid_request_error", param="response_id", error_type="invalid_request_error", status_code=404, @@ -348,7 +360,30 @@ def invalid_request_response(message: str, headers: dict[str, str], *, param: st """ return _api_error( message=message, - code="invalid_request", + code="invalid_request_error", + param=param, + error_type="invalid_request_error", + status_code=400, + headers=headers, + ) + + +def invalid_parameters_response(message: str, headers: dict[str, str], *, param: str | None = None) -> JSONResponse: + """Build a 400 Bad Request error response with ``code: "invalid_parameters"``. + + Used for malformed identifier validation (spec rule B40). + + :param message: Human-readable error message. + :type message: str + :param headers: HTTP headers to include in the response. + :type headers: dict[str, str] + :keyword param: Optional parameter name associated with the error. + :return: A 400 JSONResponse. + :rtype: JSONResponse + """ + return _api_error( + message=message, + code="invalid_parameters", param=param, error_type="invalid_request_error", status_code=400, @@ -392,17 +427,15 @@ def service_unavailable_response(message: str, headers: dict[str, str]) -> JSONR def deleted_response(response_id: str, headers: dict[str, str]) -> JSONResponse: - """Build a 400 error response indicating the response has been deleted. + """Build a 404 error response indicating the response has been deleted. + + Per spec, all endpoints treat deleted responses as not-found (HTTP 404). :param response_id: The ID of the deleted response. :type response_id: str :param headers: HTTP headers to include in the response. :type headers: dict[str, str] - :return: A 400 JSONResponse. + :return: A 404 JSONResponse. :rtype: JSONResponse """ - return invalid_request_response( - f"Response with id '{response_id}' has been deleted.", - headers, - param="response_id", - ) + return not_found_response(response_id, headers) diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/errors.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/errors.py index 9a92ec5bef38..31280a457768 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/errors.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/errors.py @@ -16,7 +16,7 @@ def __init__( self, message: str, *, - code: str = "invalid_request", + code: str = "invalid_request_error", param: str | None = None, error_type: str = "invalid_request_error", debug_info: dict[str, Any] | None = None, diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/runtime.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/runtime.py index 5219d294907c..6dcbbf4a443e 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/runtime.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/models/runtime.py @@ -101,6 +101,7 @@ def __init__( initial_agent_reference: AgentReference | dict[str, Any] | None = None, agent_session_id: str | None = None, conversation_id: str | None = None, + chat_isolation_key: str | None = None, ) -> None: self.response_id = response_id self.mode_flags = mode_flags @@ -122,6 +123,7 @@ def __init__( self.initial_agent_reference = initial_agent_reference or {} self.agent_session_id = agent_session_id self.conversation_id = conversation_id + self.chat_isolation_key = chat_isolation_key self.response_created_signal: asyncio.Event = asyncio.Event() self.response_failed_before_events: bool = False diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_logging_policy.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_logging_policy.py new file mode 100644 index 000000000000..e9a6ac587726 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_logging_policy.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Logging policy for Foundry storage HTTP calls. + +Logs method, URI, status code, duration, and correlation headers for +each outbound storage request at the ``azure.ai.agentserver`` logger. + +Provides consistent observability for storage operations. +""" + +from __future__ import annotations + +import logging +import time +from typing import cast + +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.policies import AsyncHTTPPolicy +from azure.core.rest import HttpResponse + +logger = logging.getLogger("azure.ai.agentserver") + +# Correlation headers to extract and log +_CLIENT_REQUEST_ID_HEADER = "x-ms-client-request-id" +_SERVER_REQUEST_ID_HEADER = "x-ms-request-id" + + +class FoundryStorageLoggingPolicy(AsyncHTTPPolicy[PipelineRequest, PipelineResponse]): + """Azure Core per-retry pipeline policy that logs Foundry storage calls. + + Logs the HTTP method, URI, response status code, duration in milliseconds, + and correlation headers (``x-ms-client-request-id``, ``x-ms-request-id``) + for observability of storage operations. + """ + + async def send(self, request: PipelineRequest) -> PipelineResponse: # type: ignore[override] + """Send the request and log the operation details. + + :param request: The pipeline request. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + http_request = request.http_request + method = http_request.method + url = http_request.url + + client_request_id = http_request.headers.get(_CLIENT_REQUEST_ID_HEADER, "") + + start = time.monotonic() + try: + response = await self.next.send(request) + except Exception: + elapsed_ms = (time.monotonic() - start) * 1000 + logger.warning( + "Foundry storage %s %s failed after %.1fms (client-request-id=%s)", + method, + url, + elapsed_ms, + client_request_id, + ) + raise + + elapsed_ms = (time.monotonic() - start) * 1000 + http_response = cast(HttpResponse, response.http_response) + status_code = http_response.status_code + server_request_id = http_response.headers.get(_SERVER_REQUEST_ID_HEADER, "") + + log_level = logging.INFO if 200 <= status_code < 400 else logging.WARNING + logger.log( + log_level, + "Foundry storage %s %s -> %d (%.1fms, client-request-id=%s, request-id=%s)", + method, + url, + status_code, + elapsed_ms, + client_request_id, + server_request_id, + ) + + return response diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_provider.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_provider.py index f8161a7a3b35..c1a9d5ca88b2 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_provider.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/store/_foundry_provider.py @@ -14,6 +14,7 @@ from ..models._generated import OutputItem, ResponseObject # type: ignore[attr-defined] from ._foundry_errors import raise_for_storage_error +from ._foundry_logging_policy import FoundryStorageLoggingPolicy from ._foundry_serializer import ( deserialize_history_ids, deserialize_items_array, @@ -96,9 +97,9 @@ def __init__( credential, _FOUNDRY_TOKEN_SCOPE, ), + FoundryStorageLoggingPolicy(), policies.ContentDecodePolicy(), policies.DistributedTracingPolicy(), - policies.HttpLoggingPolicy(), ], ) diff --git a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/streaming/_sse.py b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/streaming/_sse.py index db0c08250546..41f48b89d065 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/streaming/_sse.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/streaming/_sse.py @@ -7,6 +7,7 @@ import itertools import json from contextvars import ContextVar +from datetime import datetime, date, time, timedelta from typing import Any, Mapping from ..models._generated import ResponseStreamEvent @@ -14,6 +15,31 @@ _stream_counter_var: ContextVar[itertools.count] = ContextVar("_stream_counter_var") +def _json_default(o: Any) -> Any: + """JSON encoder default for datetime and bytes. + + Handles datetime objects that leak through model ``as_dict()`` calls + by serializing to ISO-8601 strings (or Unix timestamps for datetime). + + :param o: The object to encode. + :type o: Any + :returns: A JSON-serializable representation. + :rtype: Any + :raises TypeError: If the object type is not supported. + """ + if isinstance(o, datetime): + return int(o.timestamp()) + if isinstance(o, (date, time)): + return o.isoformat() + if isinstance(o, timedelta): + return o.total_seconds() + if isinstance(o, (bytes, bytearray)): + import base64 # pylint: disable=import-outside-toplevel + + return base64.b64encode(o).decode("ascii") + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + + def new_stream_counter() -> None: """Initialize a fresh per-stream SSE sequence number counter for the current context. @@ -104,7 +130,7 @@ def _build_sse_frame(event_type: str, payload: dict[str, Any]) -> str: # Sanitize event_type to prevent SSE response splitting via newline injection event_type = event_type.replace("\n", "").replace("\r", "") lines = [f"event: {event_type}"] - lines.append(f"data: {json.dumps(payload)}") + lines.append(f"data: {json.dumps(payload, default=_json_default)}") lines.append("") lines.append("") return "\n".join(lines) diff --git a/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml index 9ea7892570c8..d87105a1deb2 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-responses/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "azure-ai-agentserver-core>=2.0.0b1", + "azure-ai-agentserver-core>=2.0.0b2", "azure-core>=1.30.0", "isodate>=0.6.1", "aiohttp>=3.10.0,<4.0.0", diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cancel_endpoint.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cancel_endpoint.py index 64ee1589d956..dcc51c724d30 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cancel_endpoint.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cancel_endpoint.py @@ -183,6 +183,7 @@ def _assert_error( expected_status: int, expected_type: str, expected_message: str | None = None, + expected_code: str | None = None, ) -> None: assert response.status_code == expected_status payload = response.json() @@ -190,6 +191,10 @@ def _assert_error( assert payload["error"].get("type") == expected_type if expected_message is not None: assert payload["error"].get("message") == expected_message + if expected_code is not None: + assert payload["error"].get("code") == expected_code, ( + f"Expected error.code={expected_code!r}, got {payload['error'].get('code')!r}" + ) def test_cancel__cancels_background_response_and_clears_output() -> None: @@ -237,6 +242,7 @@ def test_cancel__returns_400_for_completed_background_response() -> None: expected_status=400, expected_type="invalid_request_error", expected_message="Cannot cancel a completed response.", + expected_code="invalid_request_error", ) @@ -420,6 +426,8 @@ def test_cancel__returns_400_for_incomplete_background_response() -> None: cancel_response, expected_status=400, expected_type="invalid_request_error", + expected_message="Cannot cancel a response in terminal state.", + expected_code="invalid_request_error", ) @@ -444,7 +452,11 @@ def test_cancel__returns_400_for_synchronous_response() -> None: cancel_response, expected_status=400, expected_type="invalid_request_error", + # After eager eviction the in-memory record is gone. The provider + # fallback loads the persisted response and checks background first (B1), + # returning the correct "synchronous" message. expected_message="Cannot cancel a synchronous response.", + expected_code="invalid_request_error", ) @@ -483,6 +495,7 @@ def _issue_sync_create() -> None: cancel_response, expected_status=404, expected_type="invalid_request_error", + expected_code="invalid_request_error", ) release_gate.set() @@ -497,13 +510,17 @@ def _issue_sync_create() -> None: def test_cancel__returns_404_for_unknown_response_id() -> None: + from azure.ai.agentserver.responses._id_generator import IdGenerator + client = _build_client() + unknown_id = IdGenerator.new_response_id() - cancel_response = client.post("/responses/resp_does_not_exist/cancel") + cancel_response = client.post(f"/responses/{unknown_id}/cancel") _assert_error( cancel_response, expected_status=404, expected_type="invalid_request_error", + expected_code="invalid_request_error", ) @@ -616,6 +633,7 @@ def test_cancel__provider_fallback_returns_400_for_completed_after_restart() -> expected_status=400, expected_type="invalid_request_error", expected_message="Cannot cancel a completed response.", + expected_code="invalid_request_error", ) @@ -643,6 +661,7 @@ def test_cancel__provider_fallback_returns_400_for_failed_after_restart() -> Non expected_status=400, expected_type="invalid_request_error", expected_message="Cannot cancel a failed response.", + expected_code="invalid_request_error", ) diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_chat_isolation_enforcement.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_chat_isolation_enforcement.py new file mode 100644 index 000000000000..5859274fe036 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_chat_isolation_enforcement.py @@ -0,0 +1,448 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Contract tests for chat isolation key enforcement across all endpoints. + +When a response is created with an ``x-agent-chat-isolation-key`` header, +all subsequent GET, Cancel, DELETE, and InputItems requests must include +the same key. Mismatched or missing keys return an indistinguishable 404 +to prevent cross-chat information leakage. + +Backward-compatible: no enforcement when the response was created without a key. +""" + +from __future__ import annotations + +import asyncio +import json as _json +from typing import Any + +import pytest +from starlette.testclient import TestClient + +from azure.ai.agentserver.responses import ResponsesAgentServerHost +from azure.ai.agentserver.responses._id_generator import IdGenerator +from azure.ai.agentserver.responses.streaming._event_stream import ResponseEventStream +from tests._helpers import poll_until + + +# ── Shared helpers (sync, for GET / DELETE / INPUT_ITEMS) ── + +def _noop_handler(request: Any, context: Any, cancellation_signal: Any): + async def _events(): + if False: # pragma: no cover + yield None + + return _events() + + +def _make_client(handler=_noop_handler) -> TestClient: + host = ResponsesAgentServerHost() + host.response_handler(handler) + return TestClient(host) + + +def _create_response( + client: TestClient, *, chat_key: str | None = None, **overrides +) -> dict[str, Any]: + """Create a response and return the parsed JSON body.""" + payload = { + "model": "m", + "input": [{"role": "user", "content": "hi"}], + **overrides, + } + headers: dict[str, str] = {} + if chat_key is not None: + headers["x-agent-chat-isolation-key"] = chat_key + r = client.post("/responses", json=payload, headers=headers) + assert r.status_code == 200, f"create failed: {r.status_code} {r.text}" + return r.json() + + +def _wait_for_terminal( + client: TestClient, response_id: str, **headers: str +) -> dict[str, Any]: + latest: dict[str, Any] = {} + terminal = {"completed", "failed", "incomplete", "cancelled"} + + def _check() -> bool: + nonlocal latest + r = client.get(f"/responses/{response_id}", headers=headers) + if r.status_code != 200: + return False + latest = r.json() + return latest.get("status") in terminal + + poll_until(_check, timeout_s=5.0, interval_s=0.05, label="wait_terminal") + return latest + + +# ── Async ASGI client (for cancel tests — needs event loop) ── + + +class _AsgiResponse: + def __init__( + self, status_code: int, body: bytes, headers: list[tuple[bytes, bytes]] + ) -> None: + self.status_code = status_code + self.body = body + self.headers = headers + + def json(self) -> Any: + return _json.loads(self.body) + + +class _AsyncAsgiClient: + """Lightweight async ASGI client that supports custom headers.""" + + def __init__(self, app: Any) -> None: + self._app = app + + @staticmethod + def _build_scope( + method: str, + path: str, + body: bytes, + headers: list[tuple[bytes, bytes]] | None = None, + ) -> dict[str, Any]: + hdr: list[tuple[bytes, bytes]] = list(headers or []) + query_string = b"" + if "?" in path: + path, qs = path.split("?", 1) + query_string = qs.encode() + if body: + hdr += [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ] + return { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "headers": hdr, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "server": ("localhost", 80), + "client": ("127.0.0.1", 123), + "root_path": "", + } + + async def request( + self, + method: str, + path: str, + *, + json_body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> _AsgiResponse: + body = _json.dumps(json_body).encode() if json_body else b"" + raw_headers = ( + [(k.lower().encode(), v.encode()) for k, v in headers.items()] + if headers + else [] + ) + scope = self._build_scope(method, path, body, raw_headers) + status_code: int | None = None + response_headers: list[tuple[bytes, bytes]] = [] + body_parts: list[bytes] = [] + request_sent = False + response_done = asyncio.Event() + + async def receive() -> dict[str, Any]: + nonlocal request_sent + if not request_sent: + request_sent = True + return {"type": "http.request", "body": body, "more_body": False} + await response_done.wait() + return {"type": "http.disconnect"} + + async def send(message: dict[str, Any]) -> None: + nonlocal status_code, response_headers + if message["type"] == "http.response.start": + status_code = message["status"] + response_headers = message.get("headers", []) + elif message["type"] == "http.response.body": + chunk = message.get("body", b"") + if chunk: + body_parts.append(chunk) + if not message.get("more_body", False): + response_done.set() + + await self._app(scope, receive, send) + assert status_code is not None + return _AsgiResponse( + status_code=status_code, + body=b"".join(body_parts), + headers=response_headers, + ) + + async def get( + self, path: str, *, headers: dict[str, str] | None = None + ) -> _AsgiResponse: + return await self.request("GET", path, headers=headers) + + async def post( + self, + path: str, + *, + json_body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> _AsgiResponse: + return await self.request("POST", path, json_body=json_body, headers=headers) + + +def _make_cancellable_bg_handler() -> Any: + """Handler that emits created+in_progress, then blocks until cancelled.""" + started = asyncio.Event() + + def handler(request: Any, context: Any, cancellation_signal: Any): + async def _events(): + stream = ResponseEventStream( + response_id=context.response_id, + model=getattr(request, "model", None), + ) + yield stream.emit_created() + yield stream.emit_in_progress() + started.set() + while not cancellation_signal.is_set(): + await asyncio.sleep(0.01) + + return _events() + + handler.started = started # type: ignore[attr-defined] + return handler + + +def _build_async_client(handler: Any) -> _AsyncAsgiClient: + app = ResponsesAgentServerHost() + app.response_handler(handler) + return _AsyncAsgiClient(app) + + +# ── GET with isolation ──────────────────────────────────── + +class TestGetChatIsolation: + """GET /responses/{id} with chat isolation key enforcement.""" + + def test_get_matching_key_returns_200(self) -> None: + """GET with the same chat key that was used at creation → 200.""" + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get(f"/responses/{resp['id']}", headers={"x-agent-chat-isolation-key": "key_A"}) + assert r.status_code == 200 + + def test_get_mismatched_key_returns_404(self) -> None: + """GET with a different chat key → 404 (indistinguishable from not found).""" + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get(f"/responses/{resp['id']}", headers={"x-agent-chat-isolation-key": "key_B"}) + assert r.status_code == 404 + + def test_get_missing_key_when_created_with_key_returns_404(self) -> None: + """GET without chat key when response was created with one → 404.""" + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get(f"/responses/{resp['id']}") + assert r.status_code == 404 + + def test_get_created_without_key_any_request_returns_200(self) -> None: + """GET with or without key when response was created without one → 200 (backward compat).""" + client = _make_client() + resp = _create_response(client) + _wait_for_terminal(client, resp["id"]) + # With a key + r = client.get(f"/responses/{resp['id']}", headers={"x-agent-chat-isolation-key": "any_key"}) + assert r.status_code == 200 + # Without a key + r = client.get(f"/responses/{resp['id']}") + assert r.status_code == 200 + + def test_get_404_error_body_is_standard(self) -> None: + """404 from isolation mismatch has the standard error body shape.""" + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get(f"/responses/{resp['id']}", headers={"x-agent-chat-isolation-key": "key_WRONG"}) + assert r.status_code == 404 + body = r.json() + assert "error" in body + assert body["error"]["code"] == "invalid_request_error" + + +# ── DELETE with isolation ──────────────────────────────── + +class TestDeleteChatIsolation: + """DELETE /responses/{id} with chat isolation key enforcement.""" + + def test_delete_matching_key_returns_200(self) -> None: + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.delete(f"/responses/{resp['id']}", headers={"x-agent-chat-isolation-key": "key_A"}) + assert r.status_code == 200 + + def test_delete_mismatched_key_returns_404(self) -> None: + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.delete(f"/responses/{resp['id']}", headers={"x-agent-chat-isolation-key": "key_B"}) + assert r.status_code == 404 + + def test_delete_missing_key_when_created_with_key_returns_404(self) -> None: + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.delete(f"/responses/{resp['id']}") + assert r.status_code == 404 + + +# ── CANCEL with isolation (async — needs real event loop) ── + + +class TestCancelChatIsolation: + """POST /responses/{id}/cancel with chat isolation key enforcement. + + Cancel tests must use async ASGI client because the handler runs as a + background asyncio task that needs the event loop to start before the + cancel request can observe it. + """ + + @pytest.mark.asyncio + async def test_cancel_matching_key_succeeds(self) -> None: + """Cancel with matching key on a background in-flight response → 200.""" + handler = _make_cancellable_bg_handler() + client = _build_async_client(handler) + response_id = IdGenerator.new_response_id() + + post_task = asyncio.create_task( + client.post( + "/responses", + json_body={ + "response_id": response_id, + "model": "test", + "background": True, + "stream": True, + }, + headers={"x-agent-chat-isolation-key": "key_A"}, + ) + ) + try: + await asyncio.wait_for(handler.started.wait(), timeout=5.0) + r = await client.post( + f"/responses/{response_id}/cancel", + headers={"x-agent-chat-isolation-key": "key_A"}, + ) + assert r.status_code == 200 + finally: + handler.started.set() # unblock if needed + if not post_task.done(): + post_task.cancel() + try: + await post_task + except (asyncio.CancelledError, Exception): + pass + + @pytest.mark.asyncio + async def test_cancel_mismatched_key_returns_404(self) -> None: + """Cancel with wrong key → 404.""" + handler = _make_cancellable_bg_handler() + client = _build_async_client(handler) + response_id = IdGenerator.new_response_id() + + post_task = asyncio.create_task( + client.post( + "/responses", + json_body={ + "response_id": response_id, + "model": "test", + "background": True, + "stream": True, + }, + headers={"x-agent-chat-isolation-key": "key_A"}, + ) + ) + try: + await asyncio.wait_for(handler.started.wait(), timeout=5.0) + r = await client.post( + f"/responses/{response_id}/cancel", + headers={"x-agent-chat-isolation-key": "key_B"}, + ) + assert r.status_code == 404 + finally: + handler.started.set() + if not post_task.done(): + post_task.cancel() + try: + await post_task + except (asyncio.CancelledError, Exception): + pass + + @pytest.mark.asyncio + async def test_cancel_missing_key_when_created_with_key_returns_404(self) -> None: + """Cancel without any key when response was created with one → 404.""" + handler = _make_cancellable_bg_handler() + client = _build_async_client(handler) + response_id = IdGenerator.new_response_id() + + post_task = asyncio.create_task( + client.post( + "/responses", + json_body={ + "response_id": response_id, + "model": "test", + "background": True, + "stream": True, + }, + headers={"x-agent-chat-isolation-key": "key_A"}, + ) + ) + try: + await asyncio.wait_for(handler.started.wait(), timeout=5.0) + r = await client.post(f"/responses/{response_id}/cancel") + assert r.status_code == 404 + finally: + handler.started.set() + if not post_task.done(): + post_task.cancel() + try: + await post_task + except (asyncio.CancelledError, Exception): + pass + + +# ── INPUT_ITEMS with isolation ──────────────────────────── + +class TestInputItemsChatIsolation: + """GET /responses/{id}/input_items with chat isolation key enforcement.""" + + def test_input_items_matching_key_returns_200(self) -> None: + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get( + f"/responses/{resp['id']}/input_items", + headers={"x-agent-chat-isolation-key": "key_A"}, + ) + assert r.status_code == 200 + + def test_input_items_mismatched_key_returns_404(self) -> None: + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get( + f"/responses/{resp['id']}/input_items", + headers={"x-agent-chat-isolation-key": "key_B"}, + ) + assert r.status_code == 404 + + def test_input_items_missing_key_when_created_with_key_returns_404(self) -> None: + client = _make_client() + resp = _create_response(client, chat_key="key_A") + _wait_for_terminal(client, resp["id"], **{"x-agent-chat-isolation-key": "key_A"}) + r = client.get(f"/responses/{resp['id']}/input_items") + assert r.status_code == 404 diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_create_endpoint.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_create_endpoint.py index 82a79ead306a..88488e125131 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_create_endpoint.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_create_endpoint.py @@ -466,6 +466,7 @@ def test_create__returns_400_for_empty_body() -> None: payload = response.json() assert isinstance(payload.get("error"), dict) assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" def test_create__returns_400_for_invalid_json_body() -> None: @@ -482,6 +483,7 @@ def test_create__returns_400_for_invalid_json_body() -> None: payload = response.json() assert isinstance(payload.get("error"), dict) assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" def test_create__ignores_unknown_fields_in_request_body() -> None: @@ -535,6 +537,11 @@ async def _events(): ) assert response.status_code == 500 + # Server errors must return a structured error envelope + payload = response.json() + assert isinstance(payload.get("error"), dict), "500 must include a JSON error envelope" + assert payload["error"].get("type") == "server_error" + assert payload["error"].get("code") == "server_error" def test_sync_no_terminal_event_still_completes() -> None: diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e.py index c93a170501d9..42a759101132 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e.py @@ -362,9 +362,9 @@ def test_ephemeral_store_false_cross_api_returns_404(self, stream: bool, operati def test_ephemeral_store_false_cancel_rejected(self, stream: bool) -> None: """B1, B14 — store=false response not bg, cancel rejected. - With unconditional runtime-state registration, - the cancel endpoint finds the record and returns 400 "Cannot cancel a - synchronous response." for non-bg requests. + After eager eviction, the runtime record is removed. Since store=false, + nothing was persisted — the provider throws ResourceNotFound → 404. + This matches .NET's ``Ephemeral_StoreFalse_Cancel_Returns404``. """ handler = _simple_text_handler if stream else _noop_handler client = _build_client(handler) @@ -388,8 +388,8 @@ def test_ephemeral_store_false_cancel_rejected(self, stream: bool) -> None: response_id = r.json()["id"] result = client.post(f"/responses/{response_id}/cancel") - # Contract: record found in runtime state → 400 (cannot cancel synchronous). - assert result.status_code == 400 + # Contract: evicted + never persisted → 404. + assert result.status_code == 404 # ════════════════════════════════════════════════════════════ @@ -843,6 +843,9 @@ def test_e39_bg_handler_incomplete_then_cancel_returns_400(self) -> None: cancel_resp = client.post(f"/responses/{response_id}/cancel") assert cancel_resp.status_code == 400 + error = cancel_resp.json()["error"] + assert error.get("code") == "invalid_request_error" + assert "Cannot cancel a response in terminal state" in error["message"] def test_e44_bg_progressive_polling_output_grows(self) -> None: """B5, B10 — background poll shows progressive output accumulation. diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e_async.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e_async.py index 6cd7d3ddd667..a7be40f5ca06 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e_async.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_cross_api_e2e_async.py @@ -721,22 +721,13 @@ async def test_e26_bg_stream_cancel_then_sse_replay_terminal_event(self) -> None finally: await _ensure_task_done(post_task, handler) - # SSE replay after cancel → should have response.failed terminal event + # SSE replay after cancel: with eager eviction, the in-memory SSE subject + # is gone and cancelled responses don't persist stream events. + # The provider fallback returns 400 with the combined "stream=true or TTL" + # message (we cannot distinguish bg+non-stream from bg+stream-cancelled + # after eviction — see TODO in _endpoint_handler.py). replay_resp = await client.get(f"/responses/{response_id}?stream=true") - assert replay_resp.status_code == 200 - - replay_events = _parse_sse_events(replay_resp.body.decode()) - assert len(replay_events) >= 1, "Replay should have at least 1 event" - - # B26: terminal event for cancelled response is response.failed - last_event = replay_events[-1] - assert last_event["type"] == "response.failed", ( - f"Expected response.failed terminal in replay, got: {last_event['type']}" - ) - - # The response inside should have status: cancelled - if "response" in last_event["data"]: - assert last_event["data"]["response"]["status"] == "cancelled" + assert replay_resp.status_code == 400 async def test_e43_bg_stream_get_during_stream_item_lifecycle(self) -> None: """B5, B23 — GET mid-stream returns progressive item lifecycle. diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_delete_endpoint.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_delete_endpoint.py index ad9bda2137d1..f00cfe7b9c72 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_delete_endpoint.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_delete_endpoint.py @@ -135,16 +135,21 @@ def test_delete__returns_400_for_background_in_flight_response() -> None: assert delete_response.status_code == 400 payload = delete_response.json() assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" assert payload["error"].get("message") == "Cannot delete an in-flight response." def test_delete__returns_404_for_unknown_response_id() -> None: + from azure.ai.agentserver.responses._id_generator import IdGenerator + client = _build_client() + unknown_id = IdGenerator.new_response_id() - delete_response = client.delete("/responses/resp_does_not_exist") + delete_response = client.delete(f"/responses/{unknown_id}") assert delete_response.status_code == 404 payload = delete_response.json() assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" def test_delete__returns_404_for_store_false_response() -> None: @@ -167,9 +172,11 @@ def test_delete__returns_404_for_store_false_response() -> None: assert delete_response.status_code == 404 payload = delete_response.json() assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" -def test_delete__get_returns_400_after_deletion() -> None: +def test_delete__get_returns_404_after_deletion() -> None: + """Post-delete: GET on a deleted response returns 404 per spec.""" client = _build_client() create_response = client.post( @@ -189,10 +196,10 @@ def test_delete__get_returns_400_after_deletion() -> None: assert delete_response.status_code == 200 get_response = client.get(f"/responses/{response_id}") - assert get_response.status_code == 400 + assert get_response.status_code == 404 payload = get_response.json() assert payload["error"].get("type") == "invalid_request_error" - assert "deleted" in (payload["error"].get("message") or "").lower() + assert payload["error"].get("code") == "invalid_request_error" def test_delete__cancel_returns_404_after_deletion() -> None: @@ -218,6 +225,7 @@ def test_delete__cancel_returns_404_after_deletion() -> None: assert cancel_response.status_code == 404 payload = cancel_response.json() assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" def _make_blocking_sync_response_handler(started_gate: EventGate, release_gate: threading.Event): @@ -423,6 +431,7 @@ def test_delete__second_delete_returns_404() -> None: ) payload = second_delete.json() assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" def test_delete__deletes_completed_background_response() -> None: diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_eager_eviction.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_eager_eviction.py new file mode 100644 index 000000000000..951f6951c4fd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_eager_eviction.py @@ -0,0 +1,383 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Contract tests for eager eviction of terminal response records. + +Once a response reaches terminal status (completed, failed, cancelled, +incomplete) and has been persisted to durable storage, the in-memory +runtime record should be immediately evicted. Subsequent operations +fall through to the provider (storage) path, freeing server memory. + +Key invariants: +- After terminal + persist, ``_RuntimeState.get(id)`` returns ``None``. +- GET on the evicted response still returns 200 (via provider fallback). +- DELETE on the evicted response still works (via provider fallback). +- ``store=False`` responses are also evicted (nothing to fall back to → 404). +- Eviction does not break input_items history chains for other responses. +""" + +from __future__ import annotations + +import asyncio +import json as _json +from typing import Any + +import pytest + +from azure.ai.agentserver.responses import ResponsesAgentServerHost +from azure.ai.agentserver.responses._id_generator import IdGenerator +from azure.ai.agentserver.responses.streaming._event_stream import ResponseEventStream +from tests._helpers import poll_until + + +# ── Helpers ─────────────────────────────────────────────── + +def _noop_handler(request: Any, context: Any, cancellation_signal: Any): + async def _events(): + if False: # pragma: no cover + yield None + + return _events() + + +# ── Async ASGI client (needed for background requests) ─── + +class _AsgiResponse: + def __init__(self, status_code: int, body: bytes, headers: list[tuple[bytes, bytes]]) -> None: + self.status_code = status_code + self.body = body + self.headers = headers + + def json(self) -> Any: + return _json.loads(self.body) + + +class _AsyncAsgiClient: + def __init__(self, app: Any) -> None: + self._app = app + + @staticmethod + def _build_scope( + method: str, path: str, body: bytes, + headers: list[tuple[bytes, bytes]] | None = None, + ) -> dict[str, Any]: + hdr: list[tuple[bytes, bytes]] = list(headers or []) + query_string = b"" + if "?" in path: + path, qs = path.split("?", 1) + query_string = qs.encode() + if body: + hdr += [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ] + return { + "type": "http", "asgi": {"version": "3.0"}, "http_version": "1.1", + "method": method, "headers": hdr, "scheme": "http", + "path": path, "raw_path": path.encode(), + "query_string": query_string, + "server": ("localhost", 80), "client": ("127.0.0.1", 123), + "root_path": "", + } + + async def request( + self, method: str, path: str, *, + json_body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> _AsgiResponse: + body = _json.dumps(json_body).encode() if json_body else b"" + raw_headers = ( + [(k.lower().encode(), v.encode()) for k, v in headers.items()] + if headers else [] + ) + scope = self._build_scope(method, path, body, raw_headers) + status_code: int | None = None + response_headers: list[tuple[bytes, bytes]] = [] + body_parts: list[bytes] = [] + request_sent = False + response_done = asyncio.Event() + + async def receive() -> dict[str, Any]: + nonlocal request_sent + if not request_sent: + request_sent = True + return {"type": "http.request", "body": body, "more_body": False} + await response_done.wait() + return {"type": "http.disconnect"} + + async def send(message: dict[str, Any]) -> None: + nonlocal status_code, response_headers + if message["type"] == "http.response.start": + status_code = message["status"] + response_headers = message.get("headers", []) + elif message["type"] == "http.response.body": + chunk = message.get("body", b"") + if chunk: + body_parts.append(chunk) + if not message.get("more_body", False): + response_done.set() + + await self._app(scope, receive, send) + assert status_code is not None + return _AsgiResponse(status_code=status_code, body=b"".join(body_parts), headers=response_headers) + + async def get(self, path: str, *, headers: dict[str, str] | None = None) -> _AsgiResponse: + return await self.request("GET", path, headers=headers) + + async def post( + self, path: str, *, json_body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> _AsgiResponse: + return await self.request("POST", path, json_body=json_body, headers=headers) + + async def delete(self, path: str, *, headers: dict[str, str] | None = None) -> _AsgiResponse: + return await self.request("DELETE", path, headers=headers) + + +# ── Sync test helpers (Starlette TestClient) ────────────── + +from starlette.testclient import TestClient # noqa: E402 + + +def _make_client(handler=_noop_handler) -> TestClient: + host = ResponsesAgentServerHost() + host.response_handler(handler) + return TestClient(host) + + +def _create_and_complete(client: TestClient, *, store: bool = True) -> str: + """Create a sync response and return the response_id.""" + r = client.post("/responses", json={ + "model": "m", + "input": [{"role": "user", "content": "hi"}], + "store": store, + }) + assert r.status_code == 200 + body = r.json() + assert body["status"] in {"completed", "failed", "incomplete"} + return body["id"] + + +def _wait_for_terminal(client: TestClient, response_id: str) -> dict[str, Any]: + latest: dict[str, Any] = {} + terminal = {"completed", "failed", "incomplete", "cancelled"} + + def _check() -> bool: + nonlocal latest + r = client.get(f"/responses/{response_id}") + if r.status_code != 200: + return False + latest = r.json() + return latest.get("status") in terminal + + poll_until(_check, timeout_s=5.0, interval_s=0.05, label="wait_terminal") + return latest + + +# ══════════════════════════════════════════════════════════ +# Sync path: completed responses should be evicted +# ══════════════════════════════════════════════════════════ + + +class TestSyncEviction: + """After sync execution, terminal records with store=True are evicted.""" + + def test_sync_completed_response_get_still_returns_200(self) -> None: + """After sync completion + persist, GET returns 200 via provider fallback.""" + client = _make_client() + rid = _create_and_complete(client, store=True) + # GET should still work — either in-memory or provider fallback + r = client.get(f"/responses/{rid}") + assert r.status_code == 200 + assert r.json()["status"] in {"completed", "failed", "incomplete"} + + def test_sync_completed_response_delete_still_works(self) -> None: + """After sync completion + persist, DELETE returns 200 via provider fallback.""" + client = _make_client() + rid = _create_and_complete(client, store=True) + r = client.delete(f"/responses/{rid}") + assert r.status_code == 200 + + def test_sync_store_false_also_evicted(self) -> None: + """store=False responses are also evicted — no provider fallback needed + since GET/DELETE/cancel don't work for them anyway.""" + client = _make_client() + rid = _create_and_complete(client, store=False) + # After eviction, GET falls through to provider which has nothing → 404 + r = client.get(f"/responses/{rid}") + assert r.status_code == 404 + + +# ══════════════════════════════════════════════════════════ +# Background path: completed responses should be evicted +# ══════════════════════════════════════════════════════════ + + +def _make_cancellable_bg_handler() -> Any: + """Handler that emits created + completed after a brief delay.""" + started = asyncio.Event() + + def handler(request: Any, context: Any, cancellation_signal: Any): + async def _events(): + stream = ResponseEventStream( + response_id=context.response_id, + model=getattr(request, "model", None), + ) + yield stream.emit_created() + yield stream.emit_in_progress() + started.set() + # Wait briefly for cancel, then complete + while not cancellation_signal.is_set(): + await asyncio.sleep(0.01) + + return _events() + + handler.started = started # type: ignore[attr-defined] + return handler + + +def _build_async_client(handler: Any) -> tuple[_AsyncAsgiClient, ResponsesAgentServerHost]: + app = ResponsesAgentServerHost() + app.response_handler(handler) + return _AsyncAsgiClient(app), app + + +class TestBackgroundEviction: + """After background execution completes + persists, records are evicted.""" + + @pytest.mark.asyncio + async def test_bg_completed_response_get_returns_200(self) -> None: + """After bg handler completes and persists, GET returns 200 via provider.""" + handler = _make_cancellable_bg_handler() + client, app = _build_async_client(handler) + response_id = IdGenerator.new_response_id() + + # Start background response + post_task = asyncio.create_task( + client.post("/responses", json_body={ + "response_id": response_id, + "model": "test", + "background": True, + "stream": True, + }) + ) + + # Wait for handler to start + await asyncio.wait_for(handler.started.wait(), timeout=5.0) + + # Cancel to bring it to terminal + cancel_resp = await client.post(f"/responses/{response_id}/cancel") + assert cancel_resp.status_code == 200 + assert cancel_resp.json()["status"] == "cancelled" + + # Wait for POST to finish + handler.started.set() + try: + await asyncio.wait_for(post_task, timeout=5.0) + except (asyncio.CancelledError, Exception): + pass + + # Allow async cleanup + await asyncio.sleep(0.3) + + # GET should return 200 — either still in memory or via provider + get_resp = await client.get(f"/responses/{response_id}") + assert get_resp.status_code == 200 + + +# ══════════════════════════════════════════════════════════ +# Unit-level: _RuntimeState.try_evict +# ══════════════════════════════════════════════════════════ + + +class TestTryEvict: + """Direct unit tests for _RuntimeState.try_evict method.""" + + @pytest.mark.asyncio + async def test_try_evict_removes_terminal_record(self) -> None: + """try_evict on a terminal record removes it from _records.""" + from azure.ai.agentserver.responses.hosting._runtime_state import _RuntimeState + from azure.ai.agentserver.responses.models.runtime import ResponseExecution, ResponseModeFlags + + state = _RuntimeState() + record = ResponseExecution( + response_id="caresp_test123456789012345678901234567890", + mode_flags=ResponseModeFlags(stream=False, store=True, background=False), + status="completed", + ) + await state.add(record) + assert await state.get(record.response_id) is not None + + evicted = await state.try_evict(record.response_id) + assert evicted is True + assert await state.get(record.response_id) is None + + @pytest.mark.asyncio + async def test_try_evict_unknown_id_returns_false(self) -> None: + """try_evict on a non-existent ID returns False.""" + from azure.ai.agentserver.responses.hosting._runtime_state import _RuntimeState + + state = _RuntimeState() + evicted = await state.try_evict("caresp_unknown99999999999999999999999999999") + assert evicted is False + + @pytest.mark.asyncio + async def test_try_evict_non_terminal_returns_false(self) -> None: + """try_evict on an in-progress record returns False (not evicted).""" + from azure.ai.agentserver.responses.hosting._runtime_state import _RuntimeState + from azure.ai.agentserver.responses.models.runtime import ResponseExecution, ResponseModeFlags + + state = _RuntimeState() + record = ResponseExecution( + response_id="caresp_test123456789012345678901234567890", + mode_flags=ResponseModeFlags(stream=False, store=True, background=False), + status="in_progress", + ) + await state.add(record) + + evicted = await state.try_evict(record.response_id) + assert evicted is False + # Record should still be there + assert await state.get(record.response_id) is not None + + @pytest.mark.asyncio + async def test_try_evict_does_not_mark_as_deleted(self) -> None: + """Eviction must NOT add the ID to _deleted_response_ids. + + Eviction != deletion. Evicted responses are still retrievable + from the provider. Only explicit DELETE marks as deleted. + """ + from azure.ai.agentserver.responses.hosting._runtime_state import _RuntimeState + from azure.ai.agentserver.responses.models.runtime import ResponseExecution, ResponseModeFlags + + state = _RuntimeState() + record = ResponseExecution( + response_id="caresp_test123456789012345678901234567890", + mode_flags=ResponseModeFlags(stream=False, store=True, background=False), + status="completed", + ) + await state.add(record) + await state.try_evict(record.response_id) + + assert await state.is_deleted(record.response_id) is False + + @pytest.mark.asyncio + async def test_try_evict_preserves_isolation_key(self) -> None: + """Eviction preserves chat isolation keys so provider fallback can still enforce them.""" + from azure.ai.agentserver.responses.hosting._runtime_state import _RuntimeState + from azure.ai.agentserver.responses.models.runtime import ResponseExecution, ResponseModeFlags + + state = _RuntimeState() + rid = "caresp_test123456789012345678901234567890" + record = ResponseExecution( + response_id=rid, + mode_flags=ResponseModeFlags(stream=False, store=True, background=False), + status="completed", + chat_isolation_key="my_key", + ) + await state.add(record) + assert state.check_chat_isolation(rid, "my_key") is True + + await state.try_evict(rid) + # After eviction, isolation key is preserved for provider fallback enforcement + assert state.check_chat_isolation(rid, "my_key") is True + assert state.check_chat_isolation(rid, "wrong_key") is False diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_get_endpoint.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_get_endpoint.py index b7d1f595bc2f..ef2fd796ebc9 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_get_endpoint.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_get_endpoint.py @@ -129,12 +129,22 @@ def test_get__returns_latest_snapshot_for_existing_response() -> None: def test_get__returns_404_for_unknown_response_id() -> None: + from azure.ai.agentserver.responses._id_generator import IdGenerator + client = _build_client() + unknown_id = IdGenerator.new_response_id() - get_response = client.get("/responses/resp_does_not_exist") + get_response = client.get(f"/responses/{unknown_id}") assert get_response.status_code == 404 payload = get_response.json() assert isinstance(payload.get("error"), dict) + assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" + # 404 message must reference the requested response ID + error_message = payload["error"].get("message", "") + assert unknown_id in error_message, ( + f"404 error message should reference the response ID, got: {error_message!r}" + ) def test_get__returns_snapshot_for_stored_non_background_stream_response_after_completion() -> None: @@ -159,7 +169,12 @@ def test_get_replay__rejects_request_when_replay_preconditions_are_not_met() -> payload = replay_response.json() assert isinstance(payload.get("error"), dict) assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" assert payload["error"].get("param") == "stream" + error_message = payload["error"].get("message", "") + assert "background=true" in error_message, ( + f"SSE replay rejection for non-bg response must mention 'background=true', got: {error_message!r}" + ) def test_get_replay__rejects_invalid_starting_after_cursor_type() -> None: @@ -197,7 +212,12 @@ def test_get_replay__starting_after_returns_events_after_cursor() -> None: def test_get_replay__rejects_bg_non_stream_response() -> None: - """B2 — SSE replay requires stream=true at creation. background=true, stream=false → 400.""" + """B2 — SSE replay on bg+non-stream after eviction → 400 with combined message. + + After eager eviction the persisted response doesn't carry the stream mode + flag, so the server cannot distinguish bg+non-stream from bg+stream with + expired TTL. The error uses a combined message matching .NET's SseReplayResult. + """ client = _build_client() create_response = client.post( @@ -217,6 +237,12 @@ def test_get_replay__rejects_bg_non_stream_response() -> None: assert replay_response.status_code == 400 payload = replay_response.json() assert payload["error"]["type"] == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" + error_message = payload["error"].get("message", "") + assert "stream=true" in error_message, ( + f"SSE replay rejection must mention 'stream=true', got: {error_message!r}" + ) + assert payload["error"].get("param") == "stream" # ══════════════════════════════════════════════════════════ @@ -251,6 +277,8 @@ def test_get_replay__rejection_message_hints_at_background_true() -> None: assert "background=true" in error_message, ( f"Error message should hint at 'background=true' to guide the client, but got: {error_message!r}" ) + assert payload["error"].get("code") == "invalid_request_error" + assert payload["error"].get("param") == "stream" # ════════════════════════════════════════════════════════ @@ -554,6 +582,21 @@ def test_get__sse_replay_store_false_returns_404() -> None: assert replay.status_code == 404 +def test_get__sse_replay_unknown_id_returns_404_with_error_shape() -> None: + """SSE replay on a completely unknown response ID returns 404 with proper error envelope.""" + client = _build_client() + + from azure.ai.agentserver.responses._id_generator import IdGenerator + + unknown_id = IdGenerator.new_response_id() + replay_response = client.get(f"/responses/{unknown_id}?stream=true") + assert replay_response.status_code == 404 + payload = replay_response.json() + assert isinstance(payload.get("error"), dict) + assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" + + def test_get__stream_false_returns_json_snapshot() -> None: """Explicit ?stream=false returns a JSON snapshot, not SSE.""" client = _build_client() diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_inbound_request_logging.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_inbound_request_logging.py new file mode 100644 index 000000000000..51f421bf2b71 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_inbound_request_logging.py @@ -0,0 +1,388 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Contract tests for inbound request logging middleware and handler diagnostic logging. + +Validates that: +- InboundRequestLoggingMiddleware logs request start and completion at INFO. +- Status >= 400 triggers WARNING on completion. +- Correlation headers (x-request-id, x-ms-client-request-id) appear in log. +- Query strings are NOT logged (path only). +- Handler-level diagnostic logs fire at INFO for each endpoint. +- Orchestrator logs handler invocation with handler name. +""" + +from __future__ import annotations + +import asyncio +import json as _json +import logging +from typing import Any + +import pytest + +from azure.ai.agentserver.responses import ResponsesAgentServerHost + + +# ── Helpers ─────────────────────────────────────────────── + +LOGGER_NAME = "azure.ai.agentserver" + +# A valid-format response ID that will never exist in state/storage. +_NONEXISTENT_ID = "caresp_00000000000000000000000000000000000000000000000000" + + +def _make_app(handler=None): + """Create a host with a simple handler.""" + app = ResponsesAgentServerHost(configure_observability=None) + + @app.response_handler + def _default_handler(request: Any, context: Any, cancellation_signal: Any): + async def _events(): + if False: # pragma: no cover + yield None + return _events() + + if handler is not None: + app.response_handler(handler) + return app + + +class _AsgiResponse: + def __init__(self, status_code: int, body: bytes, headers: list[tuple[bytes, bytes]]) -> None: + self.status_code = status_code + self.body = body + self.headers = headers + + def json(self) -> Any: + return _json.loads(self.body) + + +class _AsyncAsgiClient: + def __init__(self, app: Any) -> None: + self._app = app + + @staticmethod + def _build_scope( + method: str, path: str, body: bytes, + headers: list[tuple[bytes, bytes]] | None = None, + ) -> dict[str, Any]: + hdr: list[tuple[bytes, bytes]] = list(headers or []) + query_string = b"" + if "?" in path: + path, qs = path.split("?", 1) + query_string = qs.encode() + if body: + hdr += [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ] + return { + "type": "http", "asgi": {"version": "3.0"}, "http_version": "1.1", + "method": method, "headers": hdr, "scheme": "http", + "path": path, "raw_path": path.encode(), + "query_string": query_string, + "server": ("localhost", 80), "client": ("127.0.0.1", 123), + "root_path": "", + } + + async def request( + self, method: str, path: str, *, + json_body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> _AsgiResponse: + body = _json.dumps(json_body).encode() if json_body else b"" + raw_headers = ( + [(k.lower().encode(), v.encode()) for k, v in headers.items()] + if headers else [] + ) + scope = self._build_scope(method, path, body, raw_headers) + status_code: int | None = None + response_headers: list[tuple[bytes, bytes]] = [] + body_parts: list[bytes] = [] + request_sent = False + response_done = asyncio.Event() + + async def receive() -> dict[str, Any]: + nonlocal request_sent + if not request_sent: + request_sent = True + return {"type": "http.request", "body": body, "more_body": False} + await response_done.wait() + return {"type": "http.disconnect"} + + async def send(message: dict[str, Any]) -> None: + nonlocal status_code, response_headers + if message["type"] == "http.response.start": + status_code = message["status"] + response_headers = message.get("headers", []) + elif message["type"] == "http.response.body": + chunk = message.get("body", b"") + if chunk: + body_parts.append(chunk) + if not message.get("more_body", False): + response_done.set() + + await self._app(scope, receive, send) + assert status_code is not None + return _AsgiResponse(status_code=status_code, body=b"".join(body_parts), headers=response_headers) + + async def get(self, path: str, *, headers: dict[str, str] | None = None) -> _AsgiResponse: + return await self.request("GET", path, headers=headers) + + async def post( + self, path: str, *, json_body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> _AsgiResponse: + return await self.request("POST", path, json_body=json_body, headers=headers) + + async def delete(self, path: str, *, headers: dict[str, str] | None = None) -> _AsgiResponse: + return await self.request("DELETE", path, headers=headers) + + +# ── Middleware Tests ────────────────────────────────────── + + +class TestInboundRequestLoggingMiddleware: + """Inbound request logging middleware tests.""" + + @pytest.mark.asyncio + async def test_logs_request_start_and_completion_info(self, caplog: pytest.LogCaptureFixture): + """Middleware logs start and completion at INFO for successful requests.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.post("/responses", json_body={"model": "m"}) + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + start_msgs = [m for m in messages if "Inbound POST /responses started" in m] + assert len(start_msgs) >= 1, f"Expected start log, got: {messages}" + + completion_msgs = [m for m in messages if "Inbound POST /responses completed" in m] + assert len(completion_msgs) >= 1, f"Expected completion log, got: {messages}" + assert "200" in completion_msgs[0] + + @pytest.mark.asyncio + async def test_logs_warning_for_4xx_status(self, caplog: pytest.LogCaptureFixture): + """Middleware logs WARNING for 4xx responses.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.get(f"/responses/{_NONEXISTENT_ID}") + assert resp.status_code == 404 + + warning_records = [ + r for r in caplog.records + if r.name == LOGGER_NAME + and r.levelno == logging.WARNING + and "Inbound" in r.message + and "completed" in r.message + ] + assert len(warning_records) >= 1, "Expected WARNING for 404" + assert "404" in warning_records[0].message + + @pytest.mark.asyncio + async def test_correlation_headers_in_log(self, caplog: pytest.LogCaptureFixture): + """Middleware includes x-request-id and x-ms-client-request-id in log.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.post( + "/responses", + json_body={"model": "m"}, + headers={ + "x-request-id": "req-abc-123", + "x-ms-client-request-id": "client-xyz", + }, + ) + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + start_msgs = [m for m in messages if "Inbound" in m and "started" in m] + assert len(start_msgs) >= 1 + assert "x-request-id=req-abc-123" in start_msgs[0] + assert "x-ms-client-request-id=client-xyz" in start_msgs[0] + + @pytest.mark.asyncio + async def test_query_string_not_logged(self, caplog: pytest.LogCaptureFixture): + """Middleware logs path only, not query string.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.get(f"/responses/{_NONEXISTENT_ID}?stream=true") + # The response should be 404 but that's fine — we're checking the logs + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME and "Inbound" in r.message] + assert len(messages) >= 1 + for msg in messages: + assert "stream=true" not in msg, f"Query string leaked into log: {msg}" + # Path should be present + assert "/responses/" in msg + + @pytest.mark.asyncio + async def test_duration_in_completion_log(self, caplog: pytest.LogCaptureFixture): + """Middleware includes duration in completion log.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + await client.post("/responses", json_body={"model": "m"}) + + completion_msgs = [ + r.message for r in caplog.records + if r.name == LOGGER_NAME and "Inbound" in r.message and "completed" in r.message + ] + assert len(completion_msgs) >= 1 + assert "ms" in completion_msgs[0], f"Expected duration in ms: {completion_msgs[0]}" + + +# ── Handler Diagnostic Logging Tests ────────────────────── + + +class TestHandlerDiagnosticLogging: + """Handler-level diagnostic logging tests.""" + + @pytest.mark.asyncio + async def test_create_logs_parameters(self, caplog: pytest.LogCaptureFixture): + """POST /responses logs creation parameters.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.post("/responses", json_body={"model": "test-model"}) + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + create_msgs = [m for m in messages if "Creating response" in m] + assert len(create_msgs) >= 1, f"Expected 'Creating response' log, got: {messages}" + msg = create_msgs[0] + assert "streaming=" in msg + assert "background=" in msg + assert "store=" in msg + assert "model=" in msg + + @pytest.mark.asyncio + async def test_create_sync_logs_completion(self, caplog: pytest.LogCaptureFixture): + """Synchronous POST /responses logs response completion.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.post("/responses", json_body={"model": "m"}) + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + completed_msgs = [m for m in messages if "completed: status=" in m] + assert len(completed_msgs) >= 1, f"Expected completion log, got: {messages}" + assert "output_count=" in completed_msgs[0] + + @pytest.mark.asyncio + async def test_get_logs_response_retrieval(self, caplog: pytest.LogCaptureFixture): + """GET /responses/{id} logs entry and retrieval.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + # Create a response first + create_resp = await client.post("/responses", json_body={"model": "m"}) + assert create_resp.status_code == 200 + response_id = create_resp.json()["id"] + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.get(f"/responses/{response_id}") + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + get_msgs = [m for m in messages if f"Getting response {response_id}" in m] + assert len(get_msgs) >= 1, f"Expected GET log, got: {messages}" + + retrieved_msgs = [m for m in messages if f"Retrieved response {response_id}" in m] + assert len(retrieved_msgs) >= 1, f"Expected retrieval log, got: {messages}" + + @pytest.mark.asyncio + async def test_get_sse_replay_logs(self, caplog: pytest.LogCaptureFixture): + """GET /responses/{id}?stream=true logs SSE replay entry.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + # Non-existent but valid format — just check the log message + await client.get(f"/responses/{_NONEXISTENT_ID}?stream=true") + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + sse_msgs = [m for m in messages if "with SSE replay" in m] + assert len(sse_msgs) >= 1, f"Expected SSE replay log, got: {messages}" + + @pytest.mark.asyncio + async def test_delete_logs_entry_and_success(self, caplog: pytest.LogCaptureFixture): + """DELETE /responses/{id} logs entry and success.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + create_resp = await client.post("/responses", json_body={"model": "m"}) + assert create_resp.status_code == 200 + response_id = create_resp.json()["id"] + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.delete(f"/responses/{response_id}") + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + delete_entry = [m for m in messages if f"Deleting response {response_id}" in m] + assert len(delete_entry) >= 1, f"Expected delete entry log, got: {messages}" + + delete_success = [m for m in messages if f"Deleted response {response_id}" in m] + assert len(delete_success) >= 1, f"Expected delete success log, got: {messages}" + + @pytest.mark.asyncio + async def test_cancel_logs_entry(self, caplog: pytest.LogCaptureFixture): + """POST /responses/{id}/cancel logs cancellation entry.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + # Cancel on a non-existent response — just verify the entry log fires + await client.post(f"/responses/{_NONEXISTENT_ID}/cancel") + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + cancel_msgs = [m for m in messages if "Cancelling response" in m] + assert len(cancel_msgs) >= 1, f"Expected cancel entry log, got: {messages}" + + @pytest.mark.asyncio + async def test_input_items_logs_entry(self, caplog: pytest.LogCaptureFixture): + """GET /responses/{id}/input_items logs entry.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + await client.get(f"/responses/{_NONEXISTENT_ID}/input_items") + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + input_msgs = [m for m in messages if "Getting input items" in m] + assert len(input_msgs) >= 1, f"Expected input items log, got: {messages}" + + +# ── Orchestrator Handler Invocation Logging Tests ───────── + + +class TestOrchestratorHandlerLogging: + """Orchestrator-level handler invocation logging.""" + + @pytest.mark.asyncio + async def test_invoking_handler_logged(self, caplog: pytest.LogCaptureFixture): + """Orchestrator logs 'Invoking handler' with handler name.""" + app = _make_app() + client = _AsyncAsgiClient(app) + + with caplog.at_level(logging.INFO, logger=LOGGER_NAME): + resp = await client.post("/responses", json_body={"model": "m"}) + assert resp.status_code == 200 + + messages = [r.message for r in caplog.records if r.name == LOGGER_NAME] + handler_msgs = [m for m in messages if "Invoking handler" in m] + assert len(handler_msgs) >= 1, f"Expected handler invocation log, got: {messages}" + # Should include the response ID + response_id = resp.json()["id"] + assert response_id in handler_msgs[0] diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_input_items_endpoint.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_input_items_endpoint.py index 310b1f3dfe24..788443c588c4 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_input_items_endpoint.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_input_items_endpoint.py @@ -127,10 +127,12 @@ def test_input_items_returns_400_for_invalid_limit() -> None: low_limit = client.get(f"/responses/{response_id}/input_items?limit=0") low_payload = _assert_error_envelope(low_limit, 400) assert low_payload["error"].get("type") == "invalid_request_error" + assert low_payload["error"].get("code") == "invalid_request_error" high_limit = client.get(f"/responses/{response_id}/input_items?limit=101") high_payload = _assert_error_envelope(high_limit, 400) assert high_payload["error"].get("type") == "invalid_request_error" + assert high_payload["error"].get("code") == "invalid_request_error" def test_input_items_returns_400_for_invalid_order() -> None: @@ -141,9 +143,10 @@ def test_input_items_returns_400_for_invalid_order() -> None: response = client.get(f"/responses/{response_id}/input_items?order=invalid") payload = _assert_error_envelope(response, 400) assert payload["error"].get("type") == "invalid_request_error" + assert payload["error"].get("code") == "invalid_request_error" -def test_input_items_returns_400_for_deleted_response() -> None: +def test_input_items_returns_404_for_deleted_response() -> None: client = _build_client() response_id = _create_response(client, input_items=[_message_input("msg_001", "one")]) @@ -152,17 +155,21 @@ def test_input_items_returns_400_for_deleted_response() -> None: assert delete_response.status_code == 200 response = client.get(f"/responses/{response_id}/input_items") - payload = _assert_error_envelope(response, 400) + payload = _assert_error_envelope(response, 404) assert payload["error"].get("type") == "invalid_request_error" - assert "deleted" in (payload["error"].get("message") or "").lower() + assert payload["error"].get("code") == "invalid_request_error" def test_input_items_returns_404_for_missing_or_non_stored_response() -> None: + from azure.ai.agentserver.responses._id_generator import IdGenerator + client = _build_client() + unknown_id = IdGenerator.new_response_id() - missing_response = client.get("/responses/resp_does_not_exist/input_items") + missing_response = client.get(f"/responses/{unknown_id}/input_items") missing_payload = _assert_error_envelope(missing_response, 404) assert missing_payload["error"].get("type") == "invalid_request_error" + assert missing_payload["error"].get("code") == "invalid_request_error" non_stored_id = _create_response( client, @@ -172,6 +179,7 @@ def test_input_items_returns_404_for_missing_or_non_stored_response() -> None: non_stored_response = client.get(f"/responses/{non_stored_id}/input_items") non_stored_payload = _assert_error_envelope(non_stored_response, 404) assert non_stored_payload["error"].get("type") == "invalid_request_error" + assert non_stored_payload["error"].get("code") == "invalid_request_error" def test_input_items_default_limit_is_20_and_has_more_when_truncated() -> None: diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_malformed_id_validation.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_malformed_id_validation.py new file mode 100644 index 000000000000..a63879dcc753 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_malformed_id_validation.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Contract tests for malformed response ID validation. + +All endpoints that accept a response_id path parameter must reject +malformed IDs with 400 (``code: "invalid_parameters"``, +``param: "responseId{}"``). + +Malformed ``previous_response_id`` in the POST body must be rejected +with 400 and a ``details`` array containing the validation error. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from starlette.testclient import TestClient + +from azure.ai.agentserver.responses import ResponsesAgentServerHost +from azure.ai.agentserver.responses._id_generator import IdGenerator + + +def _noop_handler(request: Any, context: Any, cancellation_signal: Any): + async def _events(): + if False: # pragma: no cover + yield None + + return _events() + + +def _make_client() -> TestClient: + host = ResponsesAgentServerHost() + host.response_handler(_noop_handler) + return TestClient(host) + + +# ── Path parameter ID validation ────────────────────────── + +class TestMalformedPathId: + """Path parameter ``response_id`` validation on all endpoints (B40).""" + + MALFORMED_IDS = [ + "totally-invalid", + "resp_abc123", # wrong prefix + "caresp_tooshort", # correct prefix but too short + ] + + @pytest.mark.parametrize("bad_id", MALFORMED_IDS) + def test_get_malformed_id_returns_400(self, bad_id: str) -> None: + client = _make_client() + r = client.get(f"/responses/{bad_id}") + assert r.status_code == 400 + body = r.json() + assert body["error"]["code"] == "invalid_parameters" + assert body["error"]["param"] == f"responseId{{{bad_id}}}" + + @pytest.mark.parametrize("bad_id", MALFORMED_IDS) + def test_get_sse_malformed_id_returns_400(self, bad_id: str) -> None: + client = _make_client() + r = client.get(f"/responses/{bad_id}", params={"stream": "true"}) + assert r.status_code == 400 + + @pytest.mark.parametrize("bad_id", MALFORMED_IDS) + def test_cancel_malformed_id_returns_400(self, bad_id: str) -> None: + client = _make_client() + r = client.post(f"/responses/{bad_id}/cancel") + assert r.status_code == 400 + + @pytest.mark.parametrize("bad_id", MALFORMED_IDS) + def test_delete_malformed_id_returns_400(self, bad_id: str) -> None: + client = _make_client() + r = client.delete(f"/responses/{bad_id}") + assert r.status_code == 400 + + @pytest.mark.parametrize("bad_id", MALFORMED_IDS) + def test_input_items_malformed_id_returns_400(self, bad_id: str) -> None: + client = _make_client() + r = client.get(f"/responses/{bad_id}/input_items") + assert r.status_code == 400 + + def test_valid_format_unknown_id_returns_404_not_400(self) -> None: + """A well-formed but non-existent ID should return 404, not 400.""" + client = _make_client() + unknown_id = IdGenerator.new_response_id() + r = client.get(f"/responses/{unknown_id}") + assert r.status_code == 404 + + +# ── Body field ``previous_response_id`` validation ───────── + +class TestMalformedPreviousResponseId: + """``previous_response_id`` in POST body must be valid ``caresp`` format.""" + + def test_malformed_previous_response_id_returns_400_with_details(self) -> None: + client = _make_client() + r = client.post("/responses", json={ + "model": "m", + "input": [{"role": "user", "content": "hi"}], + "previous_response_id": "totally-invalid", + }) + assert r.status_code == 400 + body = r.json() + error = body["error"] + assert error["code"] == "invalid_parameters" + + def test_wrong_prefix_previous_response_id_returns_400(self) -> None: + client = _make_client() + r = client.post("/responses", json={ + "model": "m", + "input": [{"role": "user", "content": "hi"}], + "previous_response_id": "resp_abc123xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + }) + assert r.status_code == 400 + + def test_valid_format_nonexistent_previous_response_id_not_rejected_by_format(self) -> None: + """A valid-format previous_response_id that doesn't exist should pass validation + and fail later (at provider lookup), NOT at format validation.""" + client = _make_client() + valid_id = IdGenerator.new_response_id() + r = client.post("/responses", json={ + "model": "m", + "input": [{"role": "user", "content": "hi"}], + "previous_response_id": valid_id, + }) + # Should NOT be 400 from format validation — likely 200 or a different error + assert r.status_code != 400 or "Malformed" not in r.json().get("error", {}).get("message", "") diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_session_id_resolution.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_session_id_resolution.py index d157fcbbc8f1..c87f690c7e43 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_session_id_resolution.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/contract/test_session_id_resolution.py @@ -138,6 +138,8 @@ def test_default_payload_session_id_stamped_on_response(self) -> None: assert response.status_code == 200 assert response.json()["agent_session_id"] == session_id + # §8: x-agent-session-id response header echoes the resolved value + assert response.headers.get("x-agent-session-id") == session_id def test_streaming_payload_session_id_stamped_on_response(self) -> None: """B39 P1: streaming response.created and response.completed carry the payload session ID.""" @@ -154,6 +156,8 @@ def test_streaming_payload_session_id_stamped_on_response(self) -> None: }, ) as resp: assert resp.status_code == 200 + # §8: x-agent-session-id response header on streaming responses + assert resp.headers.get("x-agent-session-id") == session_id events = _collect_sse_events(resp) # Check response.created event @@ -212,6 +216,8 @@ def test_no_payload_session_id_falls_back_to_env_var(self) -> None: assert response.status_code == 200 assert response.json()["agent_session_id"] == env_session_id + # §8: x-agent-session-id response header echoes the resolved value + assert response.headers.get("x-agent-session-id") == env_session_id def test_payload_session_id_overrides_env_var(self) -> None: """B39: payload field takes precedence over env var.""" @@ -230,6 +236,8 @@ def test_payload_session_id_overrides_env_var(self) -> None: assert response.status_code == 200 assert response.json()["agent_session_id"] == payload_session_id + # §8: header echoes the resolved value (payload wins over env var) + assert response.headers.get("x-agent-session-id") == payload_session_id # ════════════════════════════════════════════════════════════ @@ -259,6 +267,8 @@ def test_no_payload_or_env_generates_session_id(self) -> None: # Verify it's a valid 63-char lowercase hex string assert len(session_id) == 63 assert re.fullmatch(r"[0-9a-f]+", session_id) + # §8: x-agent-session-id response header echoes generated value + assert response.headers.get("x-agent-session-id") == session_id def test_generated_session_id_is_different_per_request(self) -> None: """B39 P3: generated session IDs are unique per request.""" @@ -302,6 +312,9 @@ def test_streaming_no_payload_or_env_stamps_generated_session_id(self) -> None: json={"model": "test", "stream": True}, ) as resp: assert resp.status_code == 200 + # §8: header present even for generated session IDs + header_sid = resp.headers.get("x-agent-session-id") + assert header_sid is not None and header_sid != "" events = _collect_sse_events(resp) completed_events = [e for e in events if e["type"] == "response.completed"] @@ -310,6 +323,8 @@ def test_streaming_no_payload_or_env_stamps_generated_session_id(self) -> None: assert session_id is not None and session_id != "" assert len(session_id) == 63 assert re.fullmatch(r"[0-9a-f]+", session_id) + # Header must match the body session ID + assert header_sid == session_id def test_background_no_payload_or_env_stamps_generated_session_id(self) -> None: """B39: background mode generates and stamps a hex session ID.""" @@ -426,3 +441,89 @@ def test_session_id_consistent_between_create_and_sse_replay(self) -> None: assert resp_payload.get("agent_session_id") == session_id, ( f"SSE replay {event['type']} missing agent_session_id" ) + + +# ════════════════════════════════════════════════════════════ +# §8: x-agent-session-id header on non-POST endpoints +# ════════════════════════════════════════════════════════════ + + +class TestSessionIdHeaderOnNonPostEndpoints: + """x-agent-session-id header MUST appear on all protocol endpoint responses (§8). + + Non-POST endpoints resolve the session ID from the + ``FOUNDRY_AGENT_SESSION_ID`` environment variable. + """ + + def test_get_response_has_session_id_header(self) -> None: + """GET /responses/{id} includes x-agent-session-id header.""" + env_session_id = "env-session-for-get" + with patch.dict(os.environ, {"FOUNDRY_AGENT_SESSION_ID": env_session_id}): + client = _build_client() + create_resp = client.post("/responses", json={"model": "test"}) + assert create_resp.status_code == 200 + response_id = create_resp.json()["id"] + + get_resp = client.get(f"/responses/{response_id}") + assert get_resp.status_code == 200 + assert get_resp.headers.get("x-agent-session-id") == env_session_id + + def test_delete_response_has_session_id_header(self) -> None: + """DELETE /responses/{id} includes x-agent-session-id header.""" + env_session_id = "env-session-for-delete" + with patch.dict(os.environ, {"FOUNDRY_AGENT_SESSION_ID": env_session_id}): + client = _build_client() + create_resp = client.post("/responses", json={"model": "test"}) + assert create_resp.status_code == 200 + response_id = create_resp.json()["id"] + + delete_resp = client.delete(f"/responses/{response_id}") + assert delete_resp.status_code == 200 + assert delete_resp.headers.get("x-agent-session-id") == env_session_id + + def test_cancel_response_has_session_id_header(self) -> None: + """POST /responses/{id}/cancel includes x-agent-session-id header.""" + env_session_id = "env-session-for-cancel" + with patch.dict(os.environ, {"FOUNDRY_AGENT_SESSION_ID": env_session_id}): + client = _build_client() + create_resp = client.post( + "/responses", json={"model": "test", "background": True}, + ) + assert create_resp.status_code == 200 + response_id = create_resp.json()["id"] + + cancel_resp = client.post(f"/responses/{response_id}/cancel") + # Cancel may return 200 (cancelled) or 400 (already completed) — + # either way the header must be present. + assert cancel_resp.headers.get("x-agent-session-id") == env_session_id + + def test_input_items_has_session_id_header(self) -> None: + """GET /responses/{id}/input_items includes x-agent-session-id header.""" + env_session_id = "env-session-for-input-items" + with patch.dict(os.environ, {"FOUNDRY_AGENT_SESSION_ID": env_session_id}): + client = _build_client() + create_resp = client.post( + "/responses", + json={ + "model": "test", + "input": [{"role": "user", "content": "hi"}], + }, + ) + assert create_resp.status_code == 200 + response_id = create_resp.json()["id"] + + items_resp = client.get(f"/responses/{response_id}/input_items") + assert items_resp.status_code == 200 + assert items_resp.headers.get("x-agent-session-id") == env_session_id + + def test_error_response_has_session_id_header(self) -> None: + """Error responses (e.g. 404) on protocol endpoints include the header.""" + env_session_id = "env-session-for-errors" + from azure.ai.agentserver.responses._id_generator import IdGenerator + + with patch.dict(os.environ, {"FOUNDRY_AGENT_SESSION_ID": env_session_id}): + client = _build_client() + unknown_id = IdGenerator.new_response_id() + get_resp = client.get(f"/responses/{unknown_id}") + assert get_resp.status_code == 404 + assert get_resp.headers.get("x-agent-session-id") == env_session_id \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/interop/test_openai_wire_compliance.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/interop/test_openai_wire_compliance.py index e9418f814f29..693ffb4cba52 100644 --- a/sdk/agentserver/azure-ai-agentserver-responses/tests/interop/test_openai_wire_compliance.py +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/interop/test_openai_wire_compliance.py @@ -516,8 +516,11 @@ def test_create_response_max_output_tokens() -> None: def test_create_response_previous_response_id() -> None: - req = _send_and_capture('{"model": "test", "previous_response_id": "resp_prev_001"}') - assert req.previous_response_id == "resp_prev_001" + from azure.ai.agentserver.responses._id_generator import IdGenerator + + valid_id = IdGenerator.new_response_id() + req = _send_and_capture(f'{{"model": "test", "previous_response_id": "{valid_id}"}}') + assert req.previous_response_id == valid_id def test_create_response_store() -> None: diff --git a/sdk/agentserver/azure-ai-agentserver-responses/tests/unit/test_foundry_logging_policy.py b/sdk/agentserver/azure-ai-agentserver-responses/tests/unit/test_foundry_logging_policy.py new file mode 100644 index 000000000000..fe1d821a8848 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-responses/tests/unit/test_foundry_logging_policy.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +"""Unit tests for FoundryStorageLoggingPolicy.""" + +from __future__ import annotations + +import logging +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from azure.ai.agentserver.responses.store._foundry_logging_policy import FoundryStorageLoggingPolicy + + +def _make_request(method: str = "GET", url: str = "https://storage.example.com/responses/r1") -> MagicMock: + http_request = MagicMock() + http_request.method = method + http_request.url = url + http_request.headers = {"x-ms-client-request-id": "test-client-id-123"} + pipeline_request = MagicMock() + pipeline_request.http_request = http_request + return pipeline_request + + +def _make_response(status_code: int = 200, headers: dict | None = None) -> MagicMock: + http_response = MagicMock() + http_response.status_code = status_code + http_response.headers = headers or {"x-ms-request-id": "server-req-456"} + response = MagicMock() + response.http_response = http_response + return response + + +@pytest.mark.asyncio +async def test_logging_policy_logs_successful_request(caplog: pytest.LogCaptureFixture) -> None: + policy = FoundryStorageLoggingPolicy() + next_policy = AsyncMock() + next_policy.send = AsyncMock(return_value=_make_response(200)) + policy.next = next_policy + + request = _make_request("GET", "https://storage.example.com/responses/r1") + + with caplog.at_level(logging.INFO, logger="azure.ai.agentserver"): + await policy.send(request) + + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelno == logging.INFO + assert "GET" in record.message + assert "200" in record.message + assert "test-client-id-123" in record.message + assert "server-req-456" in record.message + assert "ms" in record.message + + +@pytest.mark.asyncio +async def test_logging_policy_logs_error_response_at_warning(caplog: pytest.LogCaptureFixture) -> None: + policy = FoundryStorageLoggingPolicy() + next_policy = AsyncMock() + next_policy.send = AsyncMock(return_value=_make_response(500)) + policy.next = next_policy + + request = _make_request("PUT", "https://storage.example.com/responses/r1") + + with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): + await policy.send(request) + + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelno == logging.WARNING + assert "PUT" in record.message + assert "500" in record.message + + +@pytest.mark.asyncio +async def test_logging_policy_logs_transport_failure(caplog: pytest.LogCaptureFixture) -> None: + policy = FoundryStorageLoggingPolicy() + next_policy = AsyncMock() + next_policy.send = AsyncMock(side_effect=ConnectionError("network failure")) + policy.next = next_policy + + request = _make_request("POST", "https://storage.example.com/responses") + + with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): + with pytest.raises(ConnectionError): + await policy.send(request) + + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelno == logging.WARNING + assert "POST" in record.message + assert "failed" in record.message.lower() + assert "test-client-id-123" in record.message + + +@pytest.mark.asyncio +async def test_logging_policy_handles_missing_correlation_headers(caplog: pytest.LogCaptureFixture) -> None: + policy = FoundryStorageLoggingPolicy() + next_policy = AsyncMock() + next_policy.send = AsyncMock(return_value=_make_response(200, headers={})) + policy.next = next_policy + + request = _make_request("DELETE", "https://storage.example.com/responses/r1") + request.http_request.headers = {} # No correlation headers + + with caplog.at_level(logging.INFO, logger="azure.ai.agentserver"): + await policy.send(request) + + assert len(caplog.records) == 1 + assert "DELETE" in caplog.records[0].message + assert "200" in caplog.records[0].message