Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions dotnet/src/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,22 @@ await InvokeRpcAsync<object>(
"session.abort", [new SessionAbortRequest { SessionId = SessionId }], cancellationToken);
}

/// <summary>
/// Changes the model for this session.
/// The new model takes effect for the next message. Conversation history is preserved.
/// </summary>
/// <param name="model">Model ID to switch to (e.g., "gpt-4.1").</param>
/// <param name="cancellationToken">Optional cancellation token.</param>
/// <example>
/// <code>
/// await session.SetModelAsync("gpt-4.1");
/// </code>
/// </example>
public async Task SetModelAsync(string model, CancellationToken cancellationToken = default)
{
await Rpc.Model.SwitchToAsync(model, cancellationToken);
}

/// <summary>
/// Disposes the <see cref="CopilotSession"/> and releases all associated resources.
/// </summary>
Expand Down
15 changes: 15 additions & 0 deletions dotnet/test/SessionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,19 @@ public async Task Should_Create_Session_With_Custom_Config_Dir()
Assert.NotNull(assistantMessage);
Assert.Contains("2", assistantMessage!.Data.Content);
}

[Fact]
public async Task Should_Set_Model_On_Existing_Session()
{
var session = await CreateSessionAsync();

// Subscribe for the model change event before calling SetModelAsync
var modelChangedTask = TestHelper.GetNextEventOfTypeAsync<SessionModelChangeEvent>(session);

await session.SetModelAsync("gpt-4.1");

// Verify a model_change event was emitted with the new model
var modelChanged = await modelChangedTask;
Assert.Equal("gpt-4.1", modelChanged.Data.NewModel);
}
}
17 changes: 17 additions & 0 deletions go/internal/e2e/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@ func TestSessionRpc(t *testing.T) {
}
})

// session.model.switchTo is defined in schema but not yet implemented in CLI
t.Run("should call session.SetModel", func(t *testing.T) {
t.Skip("session.model.switchTo not yet implemented in CLI")

session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
Model: "claude-sonnet-4.5",
})
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}

if err := session.SetModel(t.Context(), "gpt-4.1"); err != nil {
t.Fatalf("SetModel returned error: %v", err)
}
})

t.Run("should get and set session mode", func(t *testing.T) {
session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{OnPermissionRequest: copilot.PermissionHandler.ApproveAll})
if err != nil {
Expand Down
17 changes: 17 additions & 0 deletions go/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,20 @@ func (s *Session) Abort(ctx context.Context) error {

return nil
}

// SetModel changes the model for this session.
// The new model takes effect for the next message. Conversation history is preserved.
//
// Example:
//
// if err := session.SetModel(context.Background(), "gpt-4.1"); err != nil {
// log.Printf("Failed to set model: %v", err)
// }
func (s *Session) SetModel(ctx context.Context, model string) error {
_, err := s.RPC.Model.SwitchTo(ctx, &rpc.SessionModelSwitchToParams{ModelID: model})
if err != nil {
return fmt.Errorf("failed to set model: %w", err)
}

return nil
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test coverage gap: The Go SDK is missing a test for SetModel(), while Node.js, Python, and .NET all have tests for this new method.

Consider adding a test in go/client_test.go (for unit testing with mocks) or go/internal/e2e/session_test.go (for E2E testing) to match the test coverage in the other SDKs.

Example pattern (based on Python/Node.js tests):

func TestSession_SetModel(t *testing.T) {
    // Mock the RPC call and verify session.model.switchTo is called
    // with correct sessionId and modelId parameters
}

This would ensure cross-SDK test parity and prevent regressions.

AI generated by SDK Consistency Review Agent for #621

15 changes: 15 additions & 0 deletions nodejs/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,19 @@ export class CopilotSession {
sessionId: this.sessionId,
});
}

/**
* Change the model for this session.
* The new model takes effect for the next message. Conversation history is preserved.
*
* @param model - Model ID to switch to
*
* @example
* ```typescript
* await session.setModel("gpt-4.1");
* ```
*/
async setModel(model: string): Promise<void> {
await this.rpc.model.switchTo({ modelId: model });
}
}
26 changes: 26 additions & 0 deletions nodejs/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ describe("CopilotClient", () => {
);
});

it("sends session.model.switchTo RPC with correct params", async () => {
const client = new CopilotClient();
await client.start();
onTestFinished(() => client.forceStop());

const session = await client.createSession({ onPermissionRequest: approveAll });

// Mock sendRequest to capture the call without hitting the runtime
const spy = vi
.spyOn((client as any).connection!, "sendRequest")
.mockImplementation(async (method: string, _params: any) => {
if (method === "session.model.switchTo") return {};
// Fall through for other methods (shouldn't be called)
throw new Error(`Unexpected method: ${method}`);
});

await session.setModel("gpt-4.1");

expect(spy).toHaveBeenCalledWith("session.model.switchTo", {
sessionId: session.sessionId,
modelId: "gpt-4.1",
});

spy.mockRestore();
});

describe("URL parsing", () => {
it("should parse port-only URL format", () => {
const client = new CopilotClient({
Expand Down
20 changes: 19 additions & 1 deletion python/copilot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Callable
from typing import Any, cast

from .generated.rpc import SessionRpc
from .generated.rpc import SessionModelSwitchToParams, SessionRpc
from .generated.session_events import SessionEvent, SessionEventType, session_event_from_dict
from .types import (
MessageOptions,
Expand Down Expand Up @@ -520,3 +520,21 @@ async def abort(self) -> None:
>>> await session.abort()
"""
await self._client.request("session.abort", {"sessionId": self.session_id})

async def set_model(self, model: str) -> None:
"""
Change the model for this session.

The new model takes effect for the next message. Conversation history
is preserved.

Args:
model: Model ID to switch to (e.g., "gpt-4.1", "claude-sonnet-4").

Raises:
Exception: If the session has been destroyed or the connection fails.

Example:
>>> await session.set_model("gpt-4.1")
"""
await self.rpc.model.switch_to(SessionModelSwitchToParams(model_id=model))
26 changes: 26 additions & 0 deletions python/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,29 @@ async def mock_request(method, params):
assert captured["session.resume"]["clientName"] == "my-app"
finally:
await client.force_stop()

@pytest.mark.asyncio
async def test_set_model_sends_correct_rpc(self):
client = CopilotClient({"cli_path": CLI_PATH})
await client.start()

try:
session = await client.create_session(
{"on_permission_request": PermissionHandler.approve_all}
)

captured = {}
original_request = client._client.request

async def mock_request(method, params):
captured[method] = params
if method == "session.model.switchTo":
return {}
return await original_request(method, params)

client._client.request = mock_request
await session.set_model("gpt-4.1")
assert captured["session.model.switchTo"]["sessionId"] == session.session_id
assert captured["session.model.switchTo"]["modelId"] == "gpt-4.1"
finally:
await client.force_stop()
Loading