Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ async def _prepare_run_context(
and not opts.get("store")
and not (getattr(self.client, "STORES_BY_DEFAULT", False) and opts.get("store") is not False)
):
self.context_providers.append(InMemoryHistoryProvider("memory"))
self.context_providers.append(InMemoryHistoryProvider())

session_context, chat_options = await self._prepare_session_and_messages(
session=session,
Expand Down
35 changes: 35 additions & 0 deletions python/packages/core/agent_framework/_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,43 @@ class InMemoryHistoryProvider(BaseHistoryProvider):

This is the default provider auto-added by the agent when no providers
are configured and ``conversation_id`` or ``store=True`` is set.

Attributes:
DEFAULT_SOURCE_ID: Default source_id used when none is provided ("in-memory").
"""

DEFAULT_SOURCE_ID = "in-memory"

def __init__(
self,
source_id: str | None = None,
*,
load_messages: bool = True,
store_inputs: bool = True,
store_context_messages: bool = False,
store_context_from: set[str] | None = None,
store_outputs: bool = True,
):
"""Initialize the in-memory history provider.

Args:
source_id: Unique identifier for this provider instance.
Defaults to DEFAULT_SOURCE_ID ("in-memory") if not provided.
load_messages: Whether to load messages before invocation.
store_inputs: Whether to store input messages.
store_context_messages: Whether to store context from other providers.
store_context_from: If set, only store context from these source_ids.
store_outputs: Whether to store response messages.
"""
super().__init__(
source_id=source_id or self.DEFAULT_SOURCE_ID,
load_messages=load_messages,
store_inputs=store_inputs,
store_context_messages=store_context_messages,
store_context_from=store_context_from,
store_outputs=store_outputs,
)

async def get_messages(
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
) -> list[Message]:
Expand Down
2 changes: 1 addition & 1 deletion python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(

resolved_context_providers = list(context_providers) if context_providers is not None else []
if not resolved_context_providers:
resolved_context_providers.append(InMemoryHistoryProvider("memory"))
resolved_context_providers.append(InMemoryHistoryProvider())

super().__init__(
id=id,
Expand Down
8 changes: 5 additions & 3 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ async def test_chat_client_agent_create_session(client: SupportsChatGetResponse)
async def test_chat_client_agent_prepare_session_and_messages(client: SupportsChatGetResponse) -> None:
from agent_framework._sessions import InMemoryHistoryProvider

agent = Agent(client=client, context_providers=[InMemoryHistoryProvider("memory")])
agent = Agent(client=client, context_providers=[InMemoryHistoryProvider()])
message = Message(role="user", text="Hello")
session = AgentSession()
session.state["memory"] = {"messages": [message]}
session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID] = {"messages": [message]}

session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
session=session,
Expand Down Expand Up @@ -261,6 +261,8 @@ async def test_chat_client_agent_update_session_id_streaming_does_not_use_respon


async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
from agent_framework._sessions import InMemoryHistoryProvider

agent = Agent(client=client)
session = agent.create_session()

Expand All @@ -269,7 +271,7 @@ async def test_chat_client_agent_update_session_messages(client: SupportsChatGet

assert session.service_session_id is None

chat_messages: list[Message] = session.state.get("memory", {}).get("messages", [])
chat_messages: list[Message] = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get("messages", [])

assert chat_messages is not None
assert len(chat_messages) == 2
Expand Down
13 changes: 9 additions & 4 deletions python/packages/core/tests/core/test_middleware_with_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
chat_middleware,
function_middleware,
)
from agent_framework._sessions import InMemoryHistoryProvider

from .conftest import MockBaseChatClient, MockChatClient

Expand Down Expand Up @@ -1416,8 +1417,10 @@ class SessionTrackingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture state before next() call
thread_messages = []
if context.session and context.session.state.get("memory"):
thread_messages = context.session.state.get("memory", {}).get("messages", [])
if context.session and context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID):
thread_messages = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get(
"messages", []
)

before_state = {
"before_next": True,
Expand All @@ -1432,8 +1435,10 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable

# Capture state after next() call
thread_messages_after = []
if context.session and context.session.state.get("memory"):
thread_messages_after = context.session.state.get("memory", {}).get("messages", [])
if context.session and context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID):
thread_messages_after = context.session.state.get(
InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}
).get("messages", [])

after_state = {
"before_next": False,
Expand Down
33 changes: 33 additions & 0 deletions python/packages/core/tests/core/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,36 @@ async def test_source_id_attribution(self) -> None:
ctx = SessionContext(session_id="s1", input_messages=[])
ctx.extend_messages("custom-source", [Message(role="user", contents=["test"])])
assert "custom-source" in ctx.context_messages

async def test_default_source_id(self) -> None:
"""Test that InMemoryHistoryProvider uses default source_id when none provided."""
provider = InMemoryHistoryProvider()
assert provider.source_id == InMemoryHistoryProvider.DEFAULT_SOURCE_ID
assert provider.source_id == "in-memory"

async def test_default_source_id_class_attribute(self) -> None:
"""Test that DEFAULT_SOURCE_ID is accessible as a class attribute."""
assert InMemoryHistoryProvider.DEFAULT_SOURCE_ID == "in-memory"
# Can be used to access session state
session = AgentSession()
session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID] = {"messages": []}
assert InMemoryHistoryProvider.DEFAULT_SOURCE_ID in session.state

async def test_default_source_id_works_with_session(self) -> None:
"""Test that default provider works with session state."""
from agent_framework import AgentResponse

provider = InMemoryHistoryProvider() # Use default source_id
session = AgentSession()

# First run: store messages
input_msg = Message(role="user", contents=["test"])
ctx = SessionContext(session_id="s1", input_messages=[input_msg])
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]

# Verify messages are stored under the default key
assert InMemoryHistoryProvider.DEFAULT_SOURCE_ID in session.state
assert "messages" in session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID]
assert len(session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID]["messages"]) == 2