diff --git a/backend/app/api/docs/api_keys/verify.md b/backend/app/api/docs/api_keys/verify.md new file mode 100644 index 000000000..2b3886d16 --- /dev/null +++ b/backend/app/api/docs/api_keys/verify.md @@ -0,0 +1,3 @@ +Verify the provided API key and return the resolved auth context. + +This endpoint validates the `X-API-KEY` header and returns `user_id`, `organization_id`, and `project_id` for the authenticated key. diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 723eecc84..95e660f17 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -3,7 +3,12 @@ from app.api.deps import SessionDep, AuthContextDep from app.crud.api_key import APIKeyCrud -from app.models import APIKeyPublic, APIKeyCreateResponse, Message +from app.models import ( + APIKeyPublic, + APIKeyCreateResponse, + APIKeyVerifyResponse, + Message, +) from app.utils import APIResponse, load_description from app.api.permissions import Permission, require_permission @@ -71,3 +76,21 @@ def delete_api_key_route( api_key_crud.delete(key_id=key_id) return APIResponse.success_response(Message(message="API Key deleted successfully")) + + +@router.get( + "/verify", + response_model=APIResponse[APIKeyVerifyResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + description=load_description("api_keys/verify.md"), +) +def verify_api_key_route( + current_user: AuthContextDep, +): + return APIResponse.success_response( + APIKeyVerifyResponse( + user_id=current_user.user.id, + organization_id=current_user.organization_.id, + project_id=current_user.project_.id, + ) + ) diff --git a/backend/app/crud/config/config.py b/backend/app/crud/config/config.py index 69d4bcedb..0a2ed2138 100644 --- a/backend/app/crud/config/config.py +++ b/backend/app/crud/config/config.py @@ -47,7 +47,7 @@ def create_or_raise( version = ConfigVersion( config_id=config.id, version=1, - config_blob=config_create.config_blob.model_dump(), + config_blob=config_create.config_blob.model_dump(mode="json"), commit_message=config_create.commit_message, ) diff --git a/backend/app/crud/config/version.py b/backend/app/crud/config/version.py index 915d1b18d..b3da74f14 100644 --- a/backend/app/crud/config/version.py +++ b/backend/app/crud/config/version.py @@ -79,7 +79,7 @@ def create_or_raise(self, version_create: ConfigVersionUpdate) -> ConfigVersion: version = ConfigVersion( config_id=self.config_id, version=next_version, - config_blob=validated_blob.model_dump(), + config_blob=validated_blob.model_dump(mode="json"), commit_message=version_create.commit_message, ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index a0149a820..2c28d7b4f 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -2,7 +2,13 @@ from .auth import AuthContext, Token, TokenPayload -from .api_key import APIKey, APIKeyBase, APIKeyPublic, APIKeyCreateResponse +from .api_key import ( + APIKey, + APIKeyBase, + APIKeyPublic, + APIKeyCreateResponse, + APIKeyVerifyResponse, +) from .assistants import Assistant, AssistantBase, AssistantCreate, AssistantUpdate diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index 516073f2d..e8bd6c1b0 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -45,6 +45,14 @@ class APIKeyCreateResponse(APIKeyPublic): key: str +class APIKeyVerifyResponse(SQLModel): + """Response model for API key verification.""" + + user_id: int + organization_id: int + project_id: int + + class APIKey(APIKeyBase, table=True): """Database model for API keys.""" diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 79ce10bf1..b90fb6229 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -210,10 +210,24 @@ def validate_params(self): ] +class Validator(SQLModel): + validator_config_id: UUID + + class ConfigBlob(SQLModel): """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") + + input_guardrails: list[Validator] | None = Field( + default=None, + description="Guardrails applied to validate/sanitize the input before the LLM call", + ) + + output_guardrails: list[Validator] | None = Field( + default=None, + description="Guardrails applied to validate/sanitize the output after the LLM call", + ) # Future additions: # classifier: ClassifierConfig | None = None # pre_filter: PreFilterConfig | None = None @@ -298,20 +312,6 @@ class LLMCallRequest(SQLModel): "in production, always use the id + version." ), ) - input_guardrails: list[dict[str, Any]] | None = Field( - default=None, - description=( - "Optional guardrails configuration to apply input validation. " - "If not provided, no guardrails will be applied." - ), - ) - output_guardrails: list[dict[str, Any]] | None = Field( - default=None, - description=( - "Optional guardrails configuration to apply output validation. " - "If not provided, no guardrails will be applied." - ), - ) callback_url: HttpUrl | None = Field( default=None, description="Webhook URL for async response delivery" ) diff --git a/backend/app/services/llm/guardrails.py b/backend/app/services/llm/guardrails.py index 37f0d1ebf..7ba8d72fe 100644 --- a/backend/app/services/llm/guardrails.py +++ b/backend/app/services/llm/guardrails.py @@ -5,12 +5,18 @@ import httpx from app.core.config import settings +from app.models.llm.request import Validator logger = logging.getLogger(__name__) -def call_guardrails( - input_text: str, guardrail_config: list[dict], job_id: UUID +def run_guardrails_validation( + input_text: str, + guardrail_config: list[Validator | dict[str, Any]], + job_id: UUID, + project_id: int | None, + organization_id: int | None, + suppress_pass_logs: bool = True, ) -> dict[str, Any]: """ Call the Kaapi guardrails service to validate and process input text. @@ -19,14 +25,26 @@ def call_guardrails( input_text: Text to validate and process. guardrail_config: List of validator configurations to apply. job_id: Unique identifier for the request. + project_id: Project identifier expected by guardrails API. + organization_id: Organization identifier expected by guardrails API. + suppress_pass_logs: Whether to suppress successful validation logs in guardrails service. Returns: JSON response from the guardrails service with validation results. """ + validators = [ + validator.model_dump(mode="json") + if isinstance(validator, Validator) + else validator + for validator in guardrail_config + ] + payload = { "request_id": str(job_id), + "project_id": project_id, + "organization_id": organization_id, "input": input_text, - "validators": guardrail_config, + "validators": validators, } headers = { @@ -38,8 +56,9 @@ def call_guardrails( try: with httpx.Client(timeout=10.0) as client: response = client.post( - settings.KAAPI_GUARDRAILS_URL, + f"{settings.KAAPI_GUARDRAILS_URL}/", json=payload, + params={"suppress_pass_logs": str(suppress_pass_logs).lower()}, headers=headers, ) @@ -47,7 +66,7 @@ def call_guardrails( return response.json() except Exception as e: logger.warning( - f"[call_guardrails] Service unavailable. Bypassing guardrails. job_id={job_id}. error={e}" + f"[run_guardrails_validation] Service unavailable. Bypassing guardrails. job_id={job_id}. error={e}" ) return { @@ -58,3 +77,93 @@ def call_guardrails( "rephrase_needed": False, }, } + + +def list_validators_config( + organization_id: int | None, + project_id: int | None, + input_validator_configs: list[Validator] | None, + output_validator_configs: list[Validator] | None, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Fetch validator configurations by IDs for input and output guardrails. + + Calls: + GET /validators/configs/?organization_id={organization_id}&project_id={project_id}&ids={uuid} + """ + input_validator_config_ids = [ + validator_config.validator_config_id + for validator_config in (input_validator_configs or []) + ] + output_validator_config_ids = [ + validator_config.validator_config_id + for validator_config in (output_validator_configs or []) + ] + + if not input_validator_config_ids and not output_validator_config_ids: + return [], [] + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}", + "Content-Type": "application/json", + } + + endpoint = f"{settings.KAAPI_GUARDRAILS_URL}/validators/configs/" + + def _build_params(validator_ids: list[UUID]) -> dict[str, Any]: + params = { + "organization_id": organization_id, + "project_id": project_id, + "ids": [str(validator_config_id) for validator_config_id in validator_ids], + } + return {key: value for key, value in params.items() if value is not None} + + try: + with httpx.Client(timeout=10.0) as client: + + def _fetch_by_ids(validator_ids: list[UUID]) -> list[dict[str, Any]]: + if not validator_ids: + return [] + + response = client.get( + endpoint, + params=_build_params(validator_ids), + headers=headers, + ) + response.raise_for_status() + + payload = response.json() + if not isinstance(payload, dict): + raise ValueError( + "Invalid validators response format: expected JSON object." + ) + + if not payload.get("success", False): + raise ValueError( + "Validator config fetch failed: `success` is false." + ) + + validators = payload.get("data", []) + if not isinstance(validators, list): + raise ValueError( + "Invalid validators response format: `data` must be a list." + ) + + return [ + validator for validator in validators if isinstance(validator, dict) + ] + + input_guardrails = _fetch_by_ids(input_validator_config_ids) + output_guardrails = _fetch_by_ids(output_validator_config_ids) + return input_guardrails, output_guardrails + + except Exception as e: + logger.warning( + "[list_validators_config] Guardrails service unavailable or invalid response. " + "Proceeding without input/output guardrails. " + f"input_validator_config_ids={input_validator_config_ids}, output_validator_config_ids={output_validator_config_ids}, " + f"organization_id={organization_id}, " + f"project_id={project_id}, endpoint={endpoint}, error={e}" + ) + return [], [] diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index adc2051eb..c6997a084 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -13,8 +13,17 @@ from app.crud.jobs import JobCrud from app.crud.llm import create_llm_call, update_llm_call_response from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, Job -from app.models.llm.request import ConfigBlob, LLMCallConfig, KaapiCompletionConfig -from app.services.llm.guardrails import call_guardrails +from app.models.llm.request import ( + ConfigBlob, + LLMCallConfig, + KaapiCompletionConfig, + TextInput, +) +from app.models.llm.response import TextOutput +from app.services.llm.guardrails import ( + list_validators_config, + run_guardrails_validation, +) from app.services.llm.providers.registry import get_llm_provider from app.services.llm.mappers import transform_kaapi_config_to_native from app.services.llm.input_resolver import resolve_input, cleanup_temp_file @@ -144,13 +153,11 @@ def execute_job( request = LLMCallRequest(**request_data) job_id: UUID = UUID(job_id) - # one of (id, version) or blob is guaranteed to be present due to prior validation config = request.config - input_query = request.query.input - input_guardrails = request.input_guardrails - output_guardrails = request.output_guardrails callback_response = None config_blob: ConfigBlob | None = None + input_guardrails: list[dict] = [] + output_guardrails: list[dict] = [] llm_call_id: UUID | None = None # Track the LLM call record logger.info( @@ -158,38 +165,6 @@ def execute_job( ) try: - if input_guardrails: - safe_input = call_guardrails(input_query, input_guardrails, job_id) - - logger.info( - f"[execute_job] Input guardrail validation | success={safe_input['success']}." - ) - - if safe_input.get("bypassed"): - logger.info("[execute_job] Guardrails bypassed (service unavailable)") - - elif safe_input["success"]: - # Update the text value within the QueryInput structure - request.query.input.content.value = safe_input["data"]["safe_text"] - - if safe_input["data"]["rephrase_needed"]: - callback_response = APIResponse.failure_response( - error=safe_input["data"]["safe_text"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - else: - # Update the text value with error message - request.query.input.content.value = safe_input["error"] - - callback_response = APIResponse.failure_response( - error=safe_input["error"], - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - with Session(engine) as session: # Update job status to PROCESSING job_crud = JobCrud(session=session) @@ -236,6 +211,55 @@ def execute_job( else: config_blob = config.blob + if config_blob is not None: + if config_blob.input_guardrails or config_blob.output_guardrails: + input_guardrails, output_guardrails = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=config_blob.input_guardrails, + output_validator_configs=config_blob.output_guardrails, + ) + + if input_guardrails: + if not isinstance(request.query.input, TextInput): + logger.info( + "[execute_job] Skipping input guardrails for non-text input. " + f"job_id={job_id}, input_type={getattr(request.query.input, 'type', type(request.query.input).__name__)}" + ) + else: + safe_input = run_guardrails_validation( + request.query.input.content.value, + input_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) + + logger.info( + f"[execute_job] Input guardrail validation | success={safe_input['success']}." + ) + + if safe_input.get("bypassed"): + logger.info( + "[execute_job] Guardrails bypassed (service unavailable)" + ) + + elif safe_input["success"]: + request.query.input.content.value = safe_input["data"][ + "safe_text" + ] + else: + # Update the text value with error message + request.query.input.content.value = safe_input["error"] + + callback_response = APIResponse.failure_response( + error=safe_input["error"], + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) user_sent_config_provider = "" try: @@ -266,7 +290,11 @@ def execute_job( # Create LLM call record before execution try: # Rebuild ConfigBlob with transformed native config - resolved_config_blob = ConfigBlob(completion=completion_config) + resolved_config_blob = ConfigBlob( + completion=completion_config, + input_guardrails=config_blob.input_guardrails, + output_guardrails=config_blob.output_guardrails, + ) llm_call = create_llm_call( session, @@ -347,43 +375,56 @@ def execute_job( if response: if output_guardrails: - output_text = response.response.output.content.value - safe_output = call_guardrails(output_text, output_guardrails, job_id) - - logger.info( - f"[execute_job] Output guardrail validation | success={safe_output['success']}." - ) + if not isinstance(response.response.output, TextOutput): + logger.info( + "[execute_job] Skipping output guardrails for non-text output. " + f"job_id={job_id}, output_type={getattr(response.response.output, 'type', type(response.response.output).__name__)}" + ) + else: + output_text = response.response.output.content.value + safe_output = run_guardrails_validation( + output_text, + output_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) - if safe_output.get("bypassed"): logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" + f"[execute_job] Output guardrail validation | success={safe_output['success']}." ) - elif safe_output["success"]: - response.response.output.content.value = safe_output["data"][ - "safe_text" - ] + if safe_output.get("bypassed"): + logger.info( + "[execute_job] Guardrails bypassed (service unavailable)" + ) + + elif safe_output["success"]: + response.response.output.content.value = safe_output["data"][ + "safe_text" + ] + + if safe_output["data"]["rephrase_needed"] == True: + callback_response = APIResponse.failure_response( + error=request.query.input, + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + + else: + response.response.output.content.value = safe_output["error"] - if safe_output["data"]["rephrase_needed"] == True: callback_response = APIResponse.failure_response( - error=request.query.input, + error=safe_output["error"], metadata=request.request_metadata, ) return handle_job_error( job_id, request.callback_url, callback_response ) - else: - response.response.output.text = safe_output["error"] - - callback_response = APIResponse.failure_response( - error=safe_output["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - callback_response = APIResponse.success_response( data=response, metadata=request.request_metadata ) diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index ee3231c0b..dfcffa92f 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -4,6 +4,7 @@ from sqlmodel import Session from app.core.config import settings +from app.models import Organization, Project, User from app.tests.utils.auth import TestAuthContext from app.tests.utils.test_data import create_test_api_key, create_test_project from app.tests.utils.user import create_random_user @@ -112,3 +113,89 @@ def test_delete_api_key_nonexistent( headers={"X-API-KEY": user_api_key.key}, ) assert response.status_code == 404 + + +def test_verify_api_key( + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test API key verification endpoint with a valid API key.""" + response = client.get( + f"{settings.API_V1_STR}/apikeys/verify", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["success"] is True + assert payload["data"]["user_id"] == user_api_key.user_id + assert payload["data"]["organization_id"] == user_api_key.organization_id + assert payload["data"]["project_id"] == user_api_key.project_id + + +def test_verify_api_key_invalid_key(client: TestClient) -> None: + """Test API key verification endpoint with an invalid API key.""" + response = client.get( + f"{settings.API_V1_STR}/apikeys/verify", + headers={"X-API-KEY": "ApiKey InvalidKeyThatDoesNotExist123456789"}, + ) + assert response.status_code == 401 + + +def test_verify_api_key_missing_auth(client: TestClient) -> None: + """Test API key verification endpoint without any authentication.""" + response = client.get(f"{settings.API_V1_STR}/apikeys/verify") + assert response.status_code == 401 + + +def test_verify_api_key_inactive_user( + db: Session, + client: TestClient, +) -> None: + """Test API key verification fails when the user is inactive.""" + api_key = create_test_api_key(db) + user = db.get(User, api_key.user_id) + user.is_active = False + db.add(user) + db.commit() + + response = client.get( + f"{settings.API_V1_STR}/apikeys/verify", + headers={"X-API-KEY": api_key.key}, + ) + assert response.status_code == 403 + + +def test_verify_api_key_inactive_organization( + db: Session, + client: TestClient, +) -> None: + """Test API key verification fails when the organization is inactive.""" + api_key = create_test_api_key(db) + organization = db.get(Organization, api_key.organization_id) + organization.is_active = False + db.add(organization) + db.commit() + + response = client.get( + f"{settings.API_V1_STR}/apikeys/verify", + headers={"X-API-KEY": api_key.key}, + ) + assert response.status_code == 403 + + +def test_verify_api_key_inactive_project( + db: Session, + client: TestClient, +) -> None: + """Test API key verification fails when the project is inactive.""" + api_key = create_test_api_key(db) + project = db.get(Project, api_key.project_id) + project.is_active = False + db.add(project) + db.commit() + + response = client.get( + f"{settings.API_V1_STR}/apikeys/verify", + headers={"X-API-KEY": api_key.key}, + ) + assert response.status_code == 403 diff --git a/backend/app/tests/api/routes/test_llm.py b/backend/app/tests/api/routes/test_llm.py index cc543ba69..748755b9d 100644 --- a/backend/app/tests/api/routes/test_llm.py +++ b/backend/app/tests/api/routes/test_llm.py @@ -173,21 +173,9 @@ def test_llm_call_success_with_guardrails( ) -> None: """Test successful LLM call when guardrails are enabled (no validators).""" - with ( - patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job, - patch("app.services.llm.guardrails.call_guardrails") as mock_guardrails, - ): + with patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job: mock_start_job.return_value = "test-task-id" - mock_guardrails.return_value = { - "success": True, - "bypassed": False, - "data": { - "safe_text": "What is the capital of France?", - "rephrase_needed": False, - }, - } - payload = LLMCallRequest( query=QueryParams(input="What is the capital of France?"), config=LLMCallConfig( @@ -202,8 +190,6 @@ def test_llm_call_success_with_guardrails( ) ) ), - input_guardrails=[], - output_guardrails=[], callback_url="https://example.com/callback", ) @@ -220,7 +206,6 @@ def test_llm_call_success_with_guardrails( assert "response is being generated" in body["data"]["message"] mock_start_job.assert_called_once() - mock_guardrails.assert_not_called() def test_llm_call_guardrails_bypassed_still_succeeds( @@ -229,21 +214,9 @@ def test_llm_call_guardrails_bypassed_still_succeeds( ) -> None: """If guardrails service is unavailable (bypassed), request should still succeed.""" - with ( - patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job, - patch("app.services.llm.guardrails.call_guardrails") as mock_guardrails, - ): + with patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job: mock_start_job.return_value = "test-task-id" - mock_guardrails.return_value = { - "success": True, - "bypassed": True, - "data": { - "safe_text": "What is the capital of France?", - "rephrase_needed": False, - }, - } - payload = LLMCallRequest( query=QueryParams(input="What is the capital of France?"), config=LLMCallConfig( @@ -258,8 +231,6 @@ def test_llm_call_guardrails_bypassed_still_succeeds( ) ) ), - input_guardrails=[{"type": "pii_remover"}], - output_guardrails=[], callback_url="https://example.com/callback", ) diff --git a/backend/app/tests/services/llm/test_guardrails.py b/backend/app/tests/services/llm/test_guardrails.py index 4443aecad..161056980 100644 --- a/backend/app/tests/services/llm/test_guardrails.py +++ b/backend/app/tests/services/llm/test_guardrails.py @@ -1,20 +1,25 @@ import uuid from unittest.mock import MagicMock, patch -import pytest import httpx -from app.services.llm.guardrails import call_guardrails from app.core.config import settings +from app.models.llm.request import Validator +from app.services.llm.guardrails import ( + list_validators_config, + run_guardrails_validation, +) TEST_JOB_ID = uuid.uuid4() TEST_TEXT = "hello world" TEST_CONFIG = [{"type": "pii_remover"}] +TEST_PROJECT_ID = 1 +TEST_ORGANIZATION_ID = 1 @patch("app.services.llm.guardrails.httpx.Client") -def test_call_guardrails_success(mock_client_cls) -> None: +def test_run_guardrails_validation_success(mock_client_cls) -> None: mock_response = MagicMock() mock_response.json.return_value = {"success": True} mock_response.raise_for_status.return_value = None @@ -23,23 +28,30 @@ def test_call_guardrails_success(mock_client_cls) -> None: mock_client.post.return_value = mock_response mock_client_cls.return_value.__enter__.return_value = mock_client - result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + result = run_guardrails_validation( + TEST_TEXT, + TEST_CONFIG, + TEST_JOB_ID, + TEST_PROJECT_ID, + TEST_ORGANIZATION_ID, + ) assert result == {"success": True} mock_client.post.assert_called_once() - args, kwargs = mock_client.post.call_args - + _, kwargs = mock_client.post.call_args assert kwargs["json"]["input"] == TEST_TEXT assert kwargs["json"]["validators"] == TEST_CONFIG assert kwargs["json"]["request_id"] == str(TEST_JOB_ID) - + assert kwargs["json"]["project_id"] == TEST_PROJECT_ID + assert kwargs["json"]["organization_id"] == TEST_ORGANIZATION_ID + assert kwargs["params"]["suppress_pass_logs"] == "true" assert kwargs["headers"]["Authorization"].startswith("Bearer ") assert kwargs["headers"]["Content-Type"] == "application/json" @patch("app.services.llm.guardrails.httpx.Client") -def test_call_guardrails_http_error_bypasses(mock_client_cls) -> None: +def test_run_guardrails_validation_http_error_bypasses(mock_client_cls) -> None: mock_response = MagicMock() mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( "bad", request=None, response=None @@ -49,7 +61,13 @@ def test_call_guardrails_http_error_bypasses(mock_client_cls) -> None: mock_client.post.return_value = mock_response mock_client_cls.return_value.__enter__.return_value = mock_client - result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + result = run_guardrails_validation( + TEST_TEXT, + TEST_CONFIG, + TEST_JOB_ID, + TEST_PROJECT_ID, + TEST_ORGANIZATION_ID, + ) assert result["success"] is False assert result["bypassed"] is True @@ -57,42 +75,181 @@ def test_call_guardrails_http_error_bypasses(mock_client_cls) -> None: @patch("app.services.llm.guardrails.httpx.Client") -def test_call_guardrails_network_failure_bypasses(mock_client_cls) -> None: +def test_run_guardrails_validation_uses_settings(mock_client_cls) -> None: + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {"ok": True} + mock_client = MagicMock() - mock_client.post.side_effect = httpx.ConnectError("failed") + mock_client.post.return_value = mock_response mock_client_cls.return_value.__enter__.return_value = mock_client - result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + run_guardrails_validation( + TEST_TEXT, + TEST_CONFIG, + TEST_JOB_ID, + TEST_PROJECT_ID, + TEST_ORGANIZATION_ID, + ) - assert result["bypassed"] is True - assert result["data"]["safe_text"] == TEST_TEXT + _, kwargs = mock_client.post.call_args + assert ( + kwargs["headers"]["Authorization"] == f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}" + ) @patch("app.services.llm.guardrails.httpx.Client") -def test_call_guardrails_timeout_bypasses(mock_client_cls) -> None: +def test_run_guardrails_validation_serializes_validator_models(mock_client_cls) -> None: + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {"success": True} + mock_client = MagicMock() - mock_client.post.side_effect = httpx.TimeoutException("timeout") + mock_client.post.return_value = mock_response mock_client_cls.return_value.__enter__.return_value = mock_client - result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + vid = uuid.uuid4() + run_guardrails_validation( + TEST_TEXT, + [Validator(validator_config_id=vid)], + TEST_JOB_ID, + TEST_PROJECT_ID, + TEST_ORGANIZATION_ID, + ) - assert result["bypassed"] is True + _, kwargs = mock_client.post.call_args + assert kwargs["json"]["validators"] == [{"validator_config_id": str(vid)}] @patch("app.services.llm.guardrails.httpx.Client") -def test_call_guardrails_uses_settings(mock_client_cls) -> None: +def test_run_guardrails_validation_allows_disable_suppress_pass_logs( + mock_client_cls, +) -> None: mock_response = MagicMock() mock_response.raise_for_status.return_value = None - mock_response.json.return_value = {"ok": True} + mock_response.json.return_value = {"success": True} mock_client = MagicMock() mock_client.post.return_value = mock_response mock_client_cls.return_value.__enter__.return_value = mock_client - call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + run_guardrails_validation( + TEST_TEXT, + TEST_CONFIG, + TEST_JOB_ID, + TEST_PROJECT_ID, + TEST_ORGANIZATION_ID, + suppress_pass_logs=False, + ) _, kwargs = mock_client.post.call_args + assert kwargs["params"]["suppress_pass_logs"] == "false" - assert ( - kwargs["headers"]["Authorization"] == f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}" + +@patch("app.services.llm.guardrails.httpx.Client") +def test_list_validators_config_fetches_input_and_output_by_refs( + mock_client_cls, +) -> None: + input_validator_configs = [Validator(validator_config_id=uuid.uuid4())] + output_validator_configs = [Validator(validator_config_id=uuid.uuid4())] + + input_response = MagicMock() + input_response.raise_for_status.return_value = None + input_response.json.return_value = { + "success": True, + "data": [{"type": "uli_slur_match", "config": {"severity": "high"}}], + } + output_response = MagicMock() + output_response.raise_for_status.return_value = None + output_response.json.return_value = { + "success": True, + "data": [{"type": "gender_assumption_bias"}], + } + + mock_client = MagicMock() + mock_client.get.side_effect = [input_response, output_response] + mock_client_cls.return_value.__enter__.return_value = mock_client + + input_guardrails, output_guardrails = list_validators_config( + input_validator_configs=input_validator_configs, + output_validator_configs=output_validator_configs, + organization_id=1, + project_id=1, + ) + + assert input_guardrails == [ + {"type": "uli_slur_match", "config": {"severity": "high"}} + ] + assert output_guardrails == [{"type": "gender_assumption_bias"}] + assert mock_client.get.call_count == 2 + + first_call_kwargs = mock_client.get.call_args_list[0].kwargs + second_call_kwargs = mock_client.get.call_args_list[1].kwargs + assert first_call_kwargs["params"]["ids"] == [ + str(v.validator_config_id) for v in input_validator_configs + ] + assert second_call_kwargs["params"]["ids"] == [ + str(v.validator_config_id) for v in output_validator_configs + ] + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_list_validators_config_empty_short_circuits_without_http( + mock_client_cls, +) -> None: + input_guardrails, output_guardrails = list_validators_config( + input_validator_configs=[], + output_validator_configs=[], + organization_id=1, + project_id=1, + ) + + assert input_guardrails == [] + assert output_guardrails == [] + mock_client_cls.assert_not_called() + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_list_validators_config_omits_none_query_params(mock_client_cls) -> None: + input_validator_configs = [Validator(validator_config_id=uuid.uuid4())] + + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {"success": True, "data": []} + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client_cls.return_value.__enter__.return_value = mock_client + + list_validators_config( + input_validator_configs=input_validator_configs, + output_validator_configs=[], + organization_id=None, + project_id=None, + ) + + _, kwargs = mock_client.get.call_args + assert kwargs["params"]["ids"] == [ + str(v.validator_config_id) for v in input_validator_configs + ] + assert "organization_id" not in kwargs["params"] + assert "project_id" not in kwargs["params"] + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_list_validators_config_network_error_fails_open(mock_client_cls) -> None: + input_validator_configs = [Validator(validator_config_id=uuid.uuid4())] + + mock_client = MagicMock() + mock_client.get.side_effect = httpx.ConnectError("Network is unreachable") + mock_client_cls.return_value.__enter__.return_value = mock_client + + input_guardrails, output_guardrails = list_validators_config( + input_validator_configs=input_validator_configs, + output_validator_configs=[], + organization_id=1, + project_id=1, ) + + assert input_guardrails == [] + assert output_guardrails == [] diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 15fc9a3cb..27bb0384e 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1,5 +1,6 @@ import pytest from unittest.mock import patch, MagicMock +from uuid import UUID, uuid4 from fastapi import HTTPException from sqlmodel import Session, select @@ -17,6 +18,8 @@ Usage, TextOutput, TextContent, + AudioOutput, + AudioContent, # KaapiLLMParams, KaapiCompletionConfig, ) @@ -30,6 +33,9 @@ from app.tests.utils.utils import get_project from app.tests.utils.test_data import create_test_config +VALIDATOR_CONFIG_ID_1 = "00000000-0000-0000-0000-000000000001" +VALIDATOR_CONFIG_ID_2 = "00000000-0000-0000-0000-000000000002" + class TestStartJob: """Test cases for the start_job function.""" @@ -215,6 +221,15 @@ def test_handle_job_error_callback_failure_still_updates_job(self, db: Session): class TestExecuteJob: """Test suite for execute_job.""" + @pytest.fixture(autouse=True) + def mock_llm_call_crud(self): + with ( + patch("app.services.llm.jobs.create_llm_call") as mock_create_llm_call, + patch("app.services.llm.jobs.update_llm_call_response"), + ): + mock_create_llm_call.return_value = MagicMock(id=uuid4()) + yield + @pytest.fixture def job_for_execution(self, db: Session): job = JobCrud(session=db).create( @@ -749,7 +764,10 @@ def test_guardrails_sanitize_input_before_provider( unsafe_input = "My credit card is 4111 1111 1111 1111" - with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): mock_guardrails.return_value = { "success": True, "bypassed": False, @@ -758,6 +776,10 @@ def test_guardrails_sanitize_input_before_provider( "rephrase_needed": False, }, } + mock_fetch_configs.return_value = ( + [{"type": "pii_remover", "stage": "input"}], + [], + ) request_data = { "query": {"input": unsafe_input}, @@ -767,15 +789,16 @@ def test_guardrails_sanitize_input_before_provider( "provider": "openai-native", "type": "text", "params": {"model": "gpt-4"}, - } + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [], } }, - "input_guardrails": [{"type": "pii_remover"}], - "output_guardrails": [], "include_provider_raw_response": False, "callback_url": None, } - result = self._execute_job(job_for_execution, db, request_data) provider_query = env["provider"].execute.call_args[0][1] @@ -784,6 +807,52 @@ def test_guardrails_sanitize_input_before_provider( assert result["success"] + def test_guardrails_skip_input_validation_for_audio_input( + self, db, job_env, job_for_execution + ): + env = job_env + env["provider"].execute.return_value = (env["mock_llm_response"], None) + + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): + mock_fetch_configs.return_value = ( + [{"type": "pii_remover", "stage": "input"}], + [], + ) + + request_data = { + "query": { + "input": { + "type": "audio", + "content": { + "format": "base64", + "value": "UklGRiQAAABXQVZFZm10IA==", + "mime_type": "audio/wav", + }, + } + }, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [], + } + }, + } + result = self._execute_job(job_for_execution, db, request_data) + + assert result["success"] is True + env["provider"].execute.assert_called_once() + mock_guardrails.assert_not_called() + def test_guardrails_sanitize_output_after_provider( self, db, job_env, job_for_execution ): @@ -792,7 +861,10 @@ def test_guardrails_sanitize_output_after_provider( env["mock_llm_response"].response.output.content.value = "Aadhar no 123-45-6789" env["provider"].execute.return_value = (env["mock_llm_response"], None) - with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): mock_guardrails.return_value = { "success": True, "bypassed": False, @@ -801,6 +873,10 @@ def test_guardrails_sanitize_output_after_provider( "rephrase_needed": False, }, } + mock_fetch_configs.return_value = ( + [], + [{"type": "pii_remover", "stage": "output"}], + ) request_data = { "query": {"input": "hello"}, @@ -810,17 +886,63 @@ def test_guardrails_sanitize_output_after_provider( "provider": "openai-native", "type": "text", "params": {"model": "gpt-4"}, - } + }, + "input_guardrails": [], + "output_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_2} + ], } }, - "input_guardrails": [], - "output_guardrails": [{"type": "pii_remover"}], } - result = self._execute_job(job_for_execution, db, request_data) assert "REDACTED" in result["data"]["response"]["output"]["content"]["value"] + def test_guardrails_skip_output_validation_for_audio_output( + self, db, job_env, job_for_execution + ): + env = job_env + + env["mock_llm_response"].response.output = AudioOutput( + content=AudioContent( + value="UklGRiQAAABXQVZFZm10IA==", + mime_type="audio/wav", + ) + ) + env["provider"].execute.return_value = (env["mock_llm_response"], None) + + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): + mock_fetch_configs.return_value = ( + [], + [{"type": "safety_filter", "stage": "output"}], + ) + + request_data = { + "query": {"input": "hello"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + }, + "input_guardrails": [], + "output_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_2} + ], + } + }, + } + result = self._execute_job(job_for_execution, db, request_data) + + assert result["success"] is True + assert result["data"]["response"]["output"]["type"] == "audio" + env["provider"].execute.assert_called_once() + mock_guardrails.assert_not_called() + def test_guardrails_bypass_does_not_modify_input( self, db, job_env, job_for_execution ): @@ -830,7 +952,10 @@ def test_guardrails_bypass_does_not_modify_input( unsafe_input = "4111 1111 1111 1111" - with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): mock_guardrails.return_value = { "success": True, "bypassed": True, @@ -839,6 +964,10 @@ def test_guardrails_bypass_does_not_modify_input( "rephrase_needed": False, }, } + mock_fetch_configs.return_value = ( + [{"type": "pii_remover", "stage": "input"}], + [], + ) request_data = { "query": {"input": unsafe_input}, @@ -848,12 +977,14 @@ def test_guardrails_bypass_does_not_modify_input( "provider": "openai-native", "type": "text", "params": {"model": "gpt-4"}, - } + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [], } }, - "input_guardrails": [{"type": "pii_remover"}], } - self._execute_job(job_for_execution, db, request_data) provider_query = env["provider"].execute.call_args[0][1] @@ -864,11 +995,18 @@ def test_guardrails_validation_failure_blocks_job( ): env = job_env - with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): mock_guardrails.return_value = { "success": False, "error": "Unsafe content detected", } + mock_fetch_configs.return_value = ( + [{"type": "uli_slur_match", "stage": "input"}], + [], + ) request_data = { "query": {"input": "bad input"}, @@ -878,24 +1016,30 @@ def test_guardrails_validation_failure_blocks_job( "provider": "openai-native", "type": "text", "params": {"model": "gpt-4"}, - } + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [], } }, - "input_guardrails": [{"type": "uli_slur_match"}], } - result = self._execute_job(job_for_execution, db, request_data) assert not result["success"] assert "Unsafe content" in result["error"] env["provider"].execute.assert_not_called() - def test_guardrails_rephrase_needed_blocks_job( + def test_guardrails_rephrase_needed_allows_job_with_sanitized_input( self, db, job_env, job_for_execution ): env = job_env + env["provider"].execute.return_value = (env["mock_llm_response"], None) - with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + with ( + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + ): mock_guardrails.return_value = { "success": True, "bypassed": False, @@ -904,6 +1048,10 @@ def test_guardrails_rephrase_needed_blocks_job( "rephrase_needed": True, }, } + mock_fetch_configs.return_value = ( + [{"type": "policy", "stage": "input"}], + [], + ) request_data = { "query": {"input": "unsafe text"}, @@ -913,16 +1061,99 @@ def test_guardrails_rephrase_needed_blocks_job( "provider": "openai-native", "type": "text", "params": {"model": "gpt-4"}, - } + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [], } }, - "input_guardrails": [{"type": "policy"}], } + result = self._execute_job(job_for_execution, db, request_data) + + assert result["success"] is True + env["provider"].execute.assert_called_once() + provider_query = env["provider"].execute.call_args[0][1] + assert provider_query.input.content.value == "Rephrased text" + + def test_execute_job_fetches_validator_configs_from_blob_refs( + self, db, job_env, job_for_execution + ): + env = job_env + env["provider"].execute.return_value = (env["mock_llm_response"], None) + with patch( + "app.services.llm.jobs.list_validators_config" + ) as mock_fetch_configs: + mock_fetch_configs.return_value = ([], []) + + request_data = { + "query": {"input": "hello"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_2} + ], + } + }, + } result = self._execute_job(job_for_execution, db, request_data) - assert not result["success"] - env["provider"].execute.assert_not_called() + assert result["success"] + mock_fetch_configs.assert_called_once() + _, kwargs = mock_fetch_configs.call_args + input_validator_configs = kwargs["input_validator_configs"] + output_validator_configs = kwargs["output_validator_configs"] + assert [v.validator_config_id for v in input_validator_configs] == [ + UUID(VALIDATOR_CONFIG_ID_1) + ] + assert [v.validator_config_id for v in output_validator_configs] == [ + UUID(VALIDATOR_CONFIG_ID_2) + ] + + def test_execute_job_continues_when_no_validator_configs_resolved( + self, db, job_env, job_for_execution + ): + env = job_env + env["provider"].execute.return_value = (env["mock_llm_response"], None) + + with ( + patch("app.services.llm.jobs.list_validators_config") as mock_fetch_configs, + patch("app.services.llm.jobs.run_guardrails_validation") as mock_guardrails, + ): + mock_fetch_configs.return_value = ([], []) + + request_data = { + "query": {"input": "hello"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + }, + "input_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_1} + ], + "output_guardrails": [ + {"validator_config_id": VALIDATOR_CONFIG_ID_2} + ], + } + }, + } + result = self._execute_job(job_for_execution, db, request_data) + + assert result["success"] is True + env["provider"].execute.assert_called_once() + mock_guardrails.assert_not_called() class TestResolveConfigBlob: @@ -955,6 +1186,36 @@ def test_resolve_config_blob_success(self, db: Session): assert resolved_blob.completion.params["model"] == "gpt-4" assert resolved_blob.completion.params["temperature"] == 0.8 + def test_resolve_config_blob_keeps_validator_refs(self, db: Session): + project = get_project(db) + config_blob = ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ), + input_guardrails=[{"validator_config_id": VALIDATOR_CONFIG_ID_1}], + output_guardrails=[{"validator_config_id": VALIDATOR_CONFIG_ID_2}], + ) + config = create_test_config(db, project_id=project.id, config_blob=config_blob) + db.commit() + + config_crud = ConfigVersionCrud( + session=db, project_id=project.id, config_id=config.id + ) + llm_call_config = LLMCallConfig(id=str(config.id), version=1) + + resolved_blob, error = resolve_config_blob(config_crud, llm_call_config) + + assert error is None + assert resolved_blob is not None + assert [v.model_dump() for v in (resolved_blob.input_guardrails or [])] == [ + {"validator_config_id": UUID(VALIDATOR_CONFIG_ID_1)} + ] + assert [v.model_dump() for v in (resolved_blob.output_guardrails or [])] == [ + {"validator_config_id": UUID(VALIDATOR_CONFIG_ID_2)} + ] + def test_resolve_config_blob_version_not_found(self, db: Session): """Test resolve_config_blob when version doesn't exist.""" project = get_project(db)