From 49fc1b51279013d052a9f7e77752129be0927bf0 Mon Sep 17 00:00:00 2001 From: Brendan Kellam Date: Tue, 3 Mar 2026 20:00:20 -0800 Subject: [PATCH 1/5] feat(web): add MCP HTTP endpoint with Streamable HTTP transport (SOU-263) - Add /api/mcp route with WebStandardStreamableHTTPServerTransport supporting SSE and JSON responses - Add Bearer token auth support to getAuthenticatedUser for programmatic MCP clients - Add session ownership validation to prevent session hijacking (per MCP security spec) - Extract chat utils to utils.server.ts to fix 'use server' module boundary violations - Add ask_codebase, search_code, read_file, list_repos, list_commits, list_tree, list_language_models MCP tools - Add withAuthV2 tests for Bearer token authentication flow Co-Authored-By: Claude Sonnet 4.6 --- packages/web/package.json | 1 + .../[domain]/askgh/[owner]/[repo]/page.tsx | 12 +- .../web/src/app/[domain]/browse/layout.tsx | 2 +- .../web/src/app/[domain]/chat/[id]/page.tsx | 3 +- packages/web/src/app/[domain]/chat/page.tsx | 3 +- packages/web/src/app/[domain]/search/page.tsx | 2 +- .../app/api/(server)/chat/blocking/route.ts | 229 +------- .../web/src/app/api/(server)/chat/route.ts | 150 +----- .../web/src/app/api/(server)/mcp/route.ts | 135 +++++ .../web/src/app/api/(server)/models/route.ts | 2 +- packages/web/src/features/chat/actions.ts | 489 +---------------- packages/web/src/features/chat/agent.ts | 138 ++++- packages/web/src/features/chat/types.ts | 33 +- .../web/src/features/chat/utils.server.ts | 445 +++++++++++++++ packages/web/src/features/chat/utils.ts | 7 +- packages/web/src/features/mcp/askCodebase.ts | 198 +++++++ packages/web/src/features/mcp/server.ts | 507 ++++++++++++++++++ packages/web/src/features/mcp/types.ts | 17 + packages/web/src/features/mcp/utils.ts | 61 +++ .../web/src/features/searchAssist/actions.ts | 6 +- packages/web/src/withAuthV2.test.ts | 185 +++++++ packages/web/src/withAuthV2.ts | 20 + yarn.lock | 3 +- 23 files changed, 1787 insertions(+), 861 deletions(-) create mode 100644 packages/web/src/app/api/(server)/mcp/route.ts create mode 100644 packages/web/src/features/chat/utils.server.ts create mode 100644 packages/web/src/features/mcp/askCodebase.ts create mode 100644 packages/web/src/features/mcp/server.ts create mode 100644 packages/web/src/features/mcp/types.ts create mode 100644 packages/web/src/features/mcp/utils.ts diff --git a/packages/web/package.json b/packages/web/package.json index e0b0e2488..cf9b57746 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -59,6 +59,7 @@ "@hookform/resolvers": "^3.9.0", "@iconify/react": "^5.1.0", "@iizukak/codemirror-lang-wgsl": "^0.3.0", + "@modelcontextprotocol/sdk": "^1.27.1", "@openrouter/ai-sdk-provider": "^2.2.3", "@opentelemetry/api-logs": "^0.203.0", "@opentelemetry/instrumentation": "^0.203.0", diff --git a/packages/web/src/app/[domain]/askgh/[owner]/[repo]/page.tsx b/packages/web/src/app/[domain]/askgh/[owner]/[repo]/page.tsx index e7de26d95..d957599a4 100644 --- a/packages/web/src/app/[domain]/askgh/[owner]/[repo]/page.tsx +++ b/packages/web/src/app/[domain]/askgh/[owner]/[repo]/page.tsx @@ -1,5 +1,5 @@ import { addGithubRepo } from "@/features/workerApi/actions"; -import { isServiceError, unwrapServiceError } from "@/lib/utils"; +import { isServiceError } from "@/lib/utils"; import { ServiceErrorException } from "@/lib/serviceError"; import { prisma } from "@/prisma"; import { SINGLE_TENANT_ORG_ID } from "@/lib/constants"; @@ -7,7 +7,7 @@ import { getRepoInfo } from "./api"; import { CustomSlateEditor } from "@/features/chat/customSlateEditor"; import { RepoIndexedGuard } from "./components/repoIndexedGuard"; import { LandingPage } from "./components/landingPage"; -import { getConfiguredLanguageModelsInfo } from "@/features/chat/actions"; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { auth } from "@/auth"; interface PageProps { @@ -45,8 +45,12 @@ export default async function GitHubRepoPage(props: PageProps) { return response.repoId; })(); - const repoInfo = await unwrapServiceError(getRepoInfo(repoId)); - const languageModels = await unwrapServiceError(getConfiguredLanguageModelsInfo()); + const repoInfo = await getRepoInfo(repoId) + const languageModels = await getConfiguredLanguageModelsInfo() + + if (isServiceError(repoInfo)) { + throw new ServiceErrorException(repoInfo); + } return ( diff --git a/packages/web/src/app/[domain]/browse/layout.tsx b/packages/web/src/app/[domain]/browse/layout.tsx index b5b7d1374..3ea59b9d5 100644 --- a/packages/web/src/app/[domain]/browse/layout.tsx +++ b/packages/web/src/app/[domain]/browse/layout.tsx @@ -1,6 +1,6 @@ import { auth } from "@/auth"; import { LayoutClient } from "./layoutClient"; -import { getConfiguredLanguageModelsInfo } from "@/features/chat/actions"; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; interface LayoutProps { children: React.ReactNode; diff --git a/packages/web/src/app/[domain]/chat/[id]/page.tsx b/packages/web/src/app/[domain]/chat/[id]/page.tsx index 67a057926..d1c57fe5a 100644 --- a/packages/web/src/app/[domain]/chat/[id]/page.tsx +++ b/packages/web/src/app/[domain]/chat/[id]/page.tsx @@ -1,5 +1,6 @@ import { getRepos, getSearchContexts } from '@/actions'; -import { getUserChatHistory, getConfiguredLanguageModelsInfo, getChatInfo, claimAnonymousChats, getSharedWithUsersForChat } from '@/features/chat/actions'; +import { getUserChatHistory, getChatInfo, claimAnonymousChats, getSharedWithUsersForChat } from '@/features/chat/actions'; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { ServiceErrorException } from '@/lib/serviceError'; import { isServiceError } from '@/lib/utils'; import { ChatThreadPanel } from './components/chatThreadPanel'; diff --git a/packages/web/src/app/[domain]/chat/page.tsx b/packages/web/src/app/[domain]/chat/page.tsx index dd5124cbb..75cdc220a 100644 --- a/packages/web/src/app/[domain]/chat/page.tsx +++ b/packages/web/src/app/[domain]/chat/page.tsx @@ -1,6 +1,7 @@ import { getRepos, getReposStats, getSearchContexts } from "@/actions"; import { SourcebotLogo } from "@/app/components/sourcebotLogo"; -import { getConfiguredLanguageModelsInfo, getUserChatHistory } from "@/features/chat/actions"; +import { getUserChatHistory } from "@/features/chat/actions"; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { CustomSlateEditor } from "@/features/chat/customSlateEditor"; import { ServiceErrorException } from "@/lib/serviceError"; import { isServiceError, measure } from "@/lib/utils"; diff --git a/packages/web/src/app/[domain]/search/page.tsx b/packages/web/src/app/[domain]/search/page.tsx index a0667f430..b6b41ce22 100644 --- a/packages/web/src/app/[domain]/search/page.tsx +++ b/packages/web/src/app/[domain]/search/page.tsx @@ -2,7 +2,7 @@ import { env } from "@sourcebot/shared"; import { SearchLandingPage } from "./components/searchLandingPage"; import { SearchResultsPage } from "./components/searchResultsPage"; import { auth } from "@/auth"; -import { getConfiguredLanguageModelsInfo } from "@/features/chat/actions"; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; interface SearchPageProps { params: Promise<{ domain: string }>; diff --git a/packages/web/src/app/api/(server)/chat/blocking/route.ts b/packages/web/src/app/api/(server)/chat/blocking/route.ts index 4e887cf0f..5aaa04d1b 100644 --- a/packages/web/src/app/api/(server)/chat/blocking/route.ts +++ b/packages/web/src/app/api/(server)/chat/blocking/route.ts @@ -1,23 +1,11 @@ -import { sew } from "@/actions"; -import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, _updateChatMessages, _generateChatNameFromMessage } from "@/features/chat/actions"; -import { LanguageModelInfo, languageModelInfoSchema, SBChatMessage, SearchScope } from "@/features/chat/types"; -import { convertLLMOutputToPortableMarkdown, getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/features/chat/utils"; -import { ErrorCode } from "@/lib/errorCodes"; -import { requestBodySchemaValidationError, ServiceError, ServiceErrorException, serviceErrorResponse } from "@/lib/serviceError"; +import { askCodebase } from "@/features/mcp/askCodebase"; +import { languageModelInfoSchema } from "@/features/chat/types"; +import { apiHandler } from "@/lib/apiHandler"; +import { requestBodySchemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { isServiceError } from "@/lib/utils"; -import { withOptionalAuthV2 } from "@/withAuthV2"; -import { ChatVisibility, Prisma } from "@sourcebot/db"; -import { createLogger, env } from "@sourcebot/shared"; -import { randomUUID } from "crypto"; -import { StatusCodes } from "http-status-codes"; +import { ChatVisibility } from "@sourcebot/db"; import { NextRequest, NextResponse } from "next/server"; import { z } from "zod"; -import { createMessageStream } from "../route"; -import { InferUIMessageChunk, UITools, UIDataTypes, UIMessage } from "ai"; -import { apiHandler } from "@/lib/apiHandler"; -import { captureEvent } from "@/lib/posthog"; - -const logger = createLogger('chat-blocking-api'); /** * Request schema for the blocking chat API. @@ -40,22 +28,12 @@ const blockingChatRequestSchema = z.object({ .describe("The visibility of the chat session. If not provided, defaults to PRIVATE for authenticated users and PUBLIC for anonymous users. Set to PUBLIC to make the chat viewable by anyone with the link. Note: Anonymous users cannot create PRIVATE chats; any PRIVATE request from an unauthenticated user will be ignored and set to PUBLIC."), }); -/** - * Response schema for the blocking chat API. - */ -interface BlockingChatResponse { - answer: string; - chatId: string; - chatUrl: string; - languageModel: LanguageModelInfo; -} - /** * POST /api/chat/blocking - * + * * A blocking (non-streaming) chat endpoint designed for MCP and other integrations. * Creates a chat session, runs the agent to completion, and returns the final answer. - * + * * The chat session is persisted to the database, allowing users to view the full * conversation (including tool calls and reasoning) in the web UI. */ @@ -67,190 +45,7 @@ export const POST = apiHandler(async (request: NextRequest) => { return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); } - const { query, repos = [], languageModel: requestedLanguageModel, visibility: requestedVisibility } = parsed.data; - - const response: BlockingChatResponse | ServiceError = await sew(() => - withOptionalAuthV2(async ({ org, user, prisma }) => { - // Get all configured language models - const configuredModels = await _getConfiguredLanguageModelsFull(); - if (configuredModels.length === 0) { - return { - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: "No language models are configured. Please configure at least one language model. See: https://docs.sourcebot.dev/docs/configuration/language-model-providers", - } satisfies ServiceError; - } - - // Use the requested language model if provided, otherwise default to the first configured model - let languageModelConfig = configuredModels[0]; - if (requestedLanguageModel) { - const matchingModel = configuredModels.find( - (m) => getLanguageModelKey(m) === getLanguageModelKey(requestedLanguageModel as LanguageModelInfo) - ); - if (!matchingModel) { - return { - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: `Language model '${requestedLanguageModel.provider}/${requestedLanguageModel.model}' is not configured.`, - } satisfies ServiceError; - } - languageModelConfig = matchingModel; - } - - const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig); - const modelName = languageModelConfig.displayName ?? languageModelConfig.model; - - // Determine visibility: anonymous users cannot create private chats (they would be inaccessible) - // Only use requested visibility if user is authenticated, otherwise always use PUBLIC - const chatVisibility = (requestedVisibility && user) - ? requestedVisibility - : (user ? ChatVisibility.PRIVATE : ChatVisibility.PUBLIC); - - // Create a new chat session - const chat = await prisma.chat.create({ - data: { - orgId: org.id, - createdById: user?.id, - visibility: chatVisibility, - messages: [] as unknown as Prisma.InputJsonValue, - }, - }); - - await captureEvent('wa_chat_thread_created', { - chatId: chat.id, - isAnonymous: !user, - }); - - // Run the agent to completion - logger.debug(`Starting blocking agent for chat ${chat.id}`, { - chatId: chat.id, - query: query.substring(0, 100), - model: modelName, - }); - - // Create the initial user message - const userMessage: SBChatMessage = { - id: randomUUID(), - role: 'user', - parts: [{ type: 'text', text: query }], - }; - - const selectedRepos = (await Promise.all(repos.map(async (repo) => { - const repoDB = await prisma.repo.findFirst({ - where: { - name: repo, - }, - }); - - if (!repoDB) { - throw new ServiceErrorException({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: `Repository '${repo}' not found.`, - }) - } - - return { - type: 'repo', - value: repoDB.name, - name: repoDB.displayName ?? repoDB.name.split('/').pop() ?? repoDB.name, - codeHostType: repoDB.external_codeHostType, - } satisfies SearchScope; - }))); - - // We'll capture the final messages and usage from the stream - let finalMessages: SBChatMessage[] = []; - - await captureEvent('wa_chat_message_sent', { - chatId: chat.id, - messageCount: 1, - selectedReposCount: selectedRepos.length, - ...(env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? { - selectedRepos: selectedRepos.map(r => r.value) - } : {}), - }); - - const stream = await createMessageStream({ - chatId: chat.id, - messages: [userMessage], - metadata: { - selectedSearchScopes: selectedRepos, - }, - selectedRepos: selectedRepos.map(r => r.value), - model, - modelName, - modelProviderOptions: providerOptions, - onFinish: async ({ messages }) => { - finalMessages = messages; - }, - onError: (error) => { - if (error instanceof ServiceErrorException) { - throw error; - } - - const message = error instanceof Error ? error.message : String(error); - throw new ServiceErrorException({ - statusCode: StatusCodes.INTERNAL_SERVER_ERROR, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message, - }); - }, - }) - - const [_, name] = await Promise.all([ - // Consume the stream fully to trigger onFinish - blockStreamUntilFinish(stream), - // Generate and update the chat name - _generateChatNameFromMessage({ - message: query, - languageModelConfig, - }) - ]); - - // Persist the messages to the chat - await _updateChatMessages({ chatId: chat.id, messages: finalMessages, prisma }); - - // Update the chat name - await prisma.chat.update({ - where: { - id: chat.id, - orgId: org.id, - }, - data: { - name: name, - }, - }); - - // Extract the answer text from the assistant message - const assistantMessage = finalMessages.find(m => m.role === 'assistant'); - const answerPart = assistantMessage - ? getAnswerPartFromAssistantMessage(assistantMessage, false) - : undefined; - const answerText = answerPart?.text ?? ''; - - // Build the base URL and chat URL - const baseUrl = env.AUTH_URL; - - // Convert to portable markdown (replaces @file: references with markdown links) - const portableAnswer = convertLLMOutputToPortableMarkdown(answerText, baseUrl); - const chatUrl = `${baseUrl}/${org.domain}/chat/${chat.id}`; - - logger.debug(`Completed blocking agent for chat ${chat.id}`, { - chatId: chat.id, - }); - - return { - answer: portableAnswer, - chatId: chat.id, - chatUrl, - languageModel: { - provider: languageModelConfig.provider, - model: languageModelConfig.model, - displayName: languageModelConfig.displayName, - }, - } satisfies BlockingChatResponse; - }) - ); + const response = await askCodebase(parsed.data); if (isServiceError(response)) { return serviceErrorResponse(response); @@ -258,11 +53,3 @@ export const POST = apiHandler(async (request: NextRequest) => { return NextResponse.json(response); }); - -const blockStreamUntilFinish = async >(stream: ReadableStream>) => { - const reader = stream.getReader(); - while (true as const) { - const { done } = await reader.read(); - if (done) break; - } -} \ No newline at end of file diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/chat/route.ts index 1264a57a1..6385b4106 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/chat/route.ts @@ -1,28 +1,19 @@ import { sew } from "@/actions"; -import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, _updateChatMessages, _isOwnerOfChat } from "@/features/chat/actions"; -import { createAgentStream } from "@/features/chat/agent"; -import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SBChatMessageMetadata } from "@/features/chat/types"; -import { getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/features/chat/utils"; +import { createMessageStream } from "@/features/chat/agent"; +import { additionalChatRequestParamsSchema } from "@/features/chat/types"; +import { getLanguageModelKey } from "@/features/chat/utils"; +import { getAISDKLanguageModelAndOptions, getConfiguredLanguageModels, isOwnerOfChat, updateChatMessages } from "@/features/chat/utils.server"; import { apiHandler } from "@/lib/apiHandler"; import { ErrorCode } from "@/lib/errorCodes"; +import { captureEvent } from "@/lib/posthog"; import { notFound, requestBodySchemaValidationError, ServiceError, serviceErrorResponse } from "@/lib/serviceError"; import { isServiceError } from "@/lib/utils"; import { withOptionalAuthV2 } from "@/withAuthV2"; -import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; import * as Sentry from "@sentry/nextjs"; import { createLogger, env } from "@sourcebot/shared"; -import { captureEvent } from "@/lib/posthog"; import { - createUIMessageStream, - createUIMessageStreamResponse, - JSONValue, - ModelMessage, - StreamTextResult, - UIMessageStreamOnFinishCallback, - UIMessageStreamOptions, - UIMessageStreamWriter + createUIMessageStreamResponse } from "ai"; -import { randomUUID } from "crypto"; import { StatusCodes } from "http-status-codes"; import { NextRequest } from "next/server"; import { z } from "zod"; @@ -46,7 +37,7 @@ export const POST = apiHandler(async (req: NextRequest) => { // @note: a bit of type massaging is required here since the // zod schema does not enum on `model` or `provider`. // @see: chat/types.ts - const languageModel = _languageModel as LanguageModelInfo; + const languageModel = _languageModel; const response = await sew(() => withOptionalAuthV2(async ({ org, user, prisma }) => { @@ -63,7 +54,7 @@ export const POST = apiHandler(async (req: NextRequest) => { } // Check ownership - only the owner can send messages - const isOwner = await _isOwnerOfChat(chat, user); + const isOwner = await isOwnerOfChat(chat, user); if (!isOwner) { return { statusCode: StatusCodes.FORBIDDEN, @@ -75,7 +66,7 @@ export const POST = apiHandler(async (req: NextRequest) => { // From the language model ID, attempt to find the // corresponding config in `config.json`. const languageModelConfig = - (await _getConfiguredLanguageModelsFull()) + (await getConfiguredLanguageModels()) .find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel)); if (!languageModelConfig) { @@ -86,7 +77,7 @@ export const POST = apiHandler(async (req: NextRequest) => { } satisfies ServiceError; } - const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig); + const { model, providerOptions } = await getAISDKLanguageModelAndOptions(languageModelConfig); const expandedRepos = (await Promise.all(selectedSearchScopes.map(async (scope) => { if (scope.type === 'repo') return [scope.value]; @@ -118,7 +109,7 @@ export const POST = apiHandler(async (req: NextRequest) => { modelName: languageModelConfig.displayName ?? languageModelConfig.model, modelProviderOptions: providerOptions, onFinish: async ({ messages }) => { - await _updateChatMessages({ chatId: id, messages, prisma }); + await updateChatMessages({ chatId: id, messages, prisma }); }, onError: (error: unknown) => { logger.error(error); @@ -152,122 +143,3 @@ export const POST = apiHandler(async (req: NextRequest) => { return response; }); - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -const mergeStreamAsync = async (stream: StreamTextResult, writer: UIMessageStreamWriter, options: UIMessageStreamOptions = {}) => { - await new Promise((resolve) => writer.merge(stream.toUIMessageStream({ - ...options, - onFinish: async () => { - resolve(); - } - }))); -} - -interface CreateMessageStreamResponseProps { - chatId: string; - messages: SBChatMessage[]; - selectedRepos: string[]; - model: AISDKLanguageModelV3; - modelName: string; - onFinish: UIMessageStreamOnFinishCallback; - onError: (error: unknown) => string; - modelProviderOptions?: Record>; - metadata?: Partial; -} - -export const createMessageStream = async ({ - chatId, - messages, - metadata, - selectedRepos, - model, - modelName, - modelProviderOptions, - onFinish, - onError, -}: CreateMessageStreamResponseProps) => { - const latestMessage = messages[messages.length - 1]; - const sources = latestMessage.parts - .filter((part) => part.type === 'data-source') - .map((part) => part.data); - - const traceId = randomUUID(); - - // Extract user messages and assistant answers. - // We will use this as the context we carry between messages. - const messageHistory = - messages.map((message): ModelMessage | undefined => { - if (message.role === 'user') { - return { - role: 'user', - content: message.parts[0].type === 'text' ? message.parts[0].text : '', - }; - } - - if (message.role === 'assistant') { - const answerPart = getAnswerPartFromAssistantMessage(message, false); - if (answerPart) { - return { - role: 'assistant', - content: [answerPart] - } - } - } - }).filter(message => message !== undefined); - - const stream = createUIMessageStream({ - execute: async ({ writer }) => { - writer.write({ - type: 'start', - }); - - const startTime = new Date(); - - const researchStream = await createAgentStream({ - model, - providerOptions: modelProviderOptions, - inputMessages: messageHistory, - inputSources: sources, - selectedRepos, - onWriteSource: (source) => { - writer.write({ - type: 'data-source', - data: source, - }); - }, - traceId, - chatId, - }); - - await mergeStreamAsync(researchStream, writer, { - sendReasoning: true, - sendStart: false, - sendFinish: false, - }); - - const totalUsage = await researchStream.totalUsage; - - writer.write({ - type: 'message-metadata', - messageMetadata: { - totalTokens: totalUsage.totalTokens, - totalInputTokens: totalUsage.inputTokens, - totalOutputTokens: totalUsage.outputTokens, - totalResponseTimeMs: new Date().getTime() - startTime.getTime(), - modelName, - traceId, - ...metadata, - } - }); - - writer.write({ - type: 'finish', - }); - }, - onError, - originalMessages: messages, - onFinish, - }); - - return stream; -}; diff --git a/packages/web/src/app/api/(server)/mcp/route.ts b/packages/web/src/app/api/(server)/mcp/route.ts new file mode 100644 index 000000000..c387282a2 --- /dev/null +++ b/packages/web/src/app/api/(server)/mcp/route.ts @@ -0,0 +1,135 @@ +'use server'; + +import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/webStandardStreamableHttp.js'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { createMcpServer } from '@/features/mcp/server'; +import { withOptionalAuthV2 } from '@/withAuthV2'; +import { isServiceError } from '@/lib/utils'; +import { serviceErrorResponse, ServiceError } from '@/lib/serviceError'; +import { ErrorCode } from '@/lib/errorCodes'; +import { StatusCodes } from 'http-status-codes'; +import { NextRequest } from 'next/server'; +import { sew } from '@/actions'; + +// @see: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#session-management +interface McpSession { + server: McpServer; + transport: WebStandardStreamableHTTPServerTransport; + ownerId: string | null; // null for anonymous sessions +} + +const MCP_SESSION_ID_HEADER = 'MCP-Session-Id'; + +// Module-level session store. Persists across requests within the same Node.js process. +// Suitable for containerized/single-instance deployments. +const sessions = new Map(); + +export async function POST(request: NextRequest) { + const response = await sew(() => + withOptionalAuthV2(async ({ user }) => { + const ownerId = user?.id ?? null; + const sessionId = request.headers.get(MCP_SESSION_ID_HEADER); + + // Return existing session if available + if (sessionId && sessions.has(sessionId)) { + const session = sessions.get(sessionId)!; + if (session.ownerId !== ownerId) { + return { + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: 'Session does not belong to the authenticated user.', + } satisfies ServiceError; + } + return session.transport.handleRequest(request); + } + + // Create a new session + const transport = new WebStandardStreamableHTTPServerTransport({ + sessionIdGenerator: () => crypto.randomUUID(), + onsessioninitialized: (newSessionId) => { + sessions.set(newSessionId, { server: mcpServer, transport, ownerId }); + }, + onsessionclosed: (closedSessionId) => { + sessions.delete(closedSessionId); + }, + }); + + const mcpServer = createMcpServer(); + await mcpServer.connect(transport); + + return transport.handleRequest(request); + }) + ); + + if (isServiceError(response)) { + return serviceErrorResponse(response); + } + + return response; +} + +export async function DELETE(request: NextRequest) { + const result = await sew(() => + withOptionalAuthV2(async ({ user }) => { + const ownerId = user?.id ?? null; + const sessionId = request.headers.get(MCP_SESSION_ID_HEADER); + if (!sessionId || !sessions.has(sessionId)) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.NOT_FOUND, + message: 'Session not found.', + } satisfies ServiceError; + } + + const session = sessions.get(sessionId)!; + if (session.ownerId !== ownerId) { + return { + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: 'Session does not belong to the authenticated user.', + } satisfies ServiceError; + } + + return session.transport.handleRequest(request); + }) + ); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return result; +} + +export async function GET(request: NextRequest) { + const result = await sew(() => + withOptionalAuthV2(async ({ user }) => { + const ownerId = user?.id ?? null; + const sessionId = request.headers.get(MCP_SESSION_ID_HEADER); + if (!sessionId || !sessions.has(sessionId)) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.NOT_FOUND, + message: 'Session not found.', + } satisfies ServiceError; + } + + const session = sessions.get(sessionId)!; + if (session.ownerId !== ownerId) { + return { + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: 'Session does not belong to the authenticated user.', + } satisfies ServiceError; + } + + return session.transport.handleRequest(request); + }) + ); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return result; +} diff --git a/packages/web/src/app/api/(server)/models/route.ts b/packages/web/src/app/api/(server)/models/route.ts index 0970ab07a..1668ed846 100644 --- a/packages/web/src/app/api/(server)/models/route.ts +++ b/packages/web/src/app/api/(server)/models/route.ts @@ -1,6 +1,6 @@ import { sew } from "@/actions"; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { apiHandler } from "@/lib/apiHandler"; -import { getConfiguredLanguageModelsInfo } from "@/features/chat/actions"; import { serviceErrorResponse } from "@/lib/serviceError"; import { isServiceError } from "@/lib/utils"; import { withOptionalAuthV2 } from "@/withAuthV2"; diff --git a/packages/web/src/features/chat/actions.ts b/packages/web/src/features/chat/actions.ts index 6d646f6fb..f608a0d34 100644 --- a/packages/web/src/features/chat/actions.ts +++ b/packages/web/src/features/chat/actions.ts @@ -2,134 +2,19 @@ import { sew } from "@/actions"; import { getAuditService } from "@/ee/features/audit/factory"; +import { getAnonymousId, getOrCreateAnonymousId } from "@/lib/anonymousId"; import { ErrorCode } from "@/lib/errorCodes"; +import { captureEvent } from "@/lib/posthog"; import { notFound, ServiceError } from "@/lib/serviceError"; -import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; -import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; -import { createAzure } from '@ai-sdk/azure'; -import { createDeepSeek } from '@ai-sdk/deepseek'; -import { createGoogleGenerativeAI } from '@ai-sdk/google'; -import { createVertex } from '@ai-sdk/google-vertex'; -import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic'; -import { createMistral } from '@ai-sdk/mistral'; -import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai"; -import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; -import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; -import { createXai } from '@ai-sdk/xai'; -import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; -import { createOpenRouter } from '@openrouter/ai-sdk-provider'; -import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared"; +import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2"; import { ChatVisibility, Prisma } from "@sourcebot/db"; -import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type"; -import { Token } from "@sourcebot/schemas/v3/shared.type"; -import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai"; -import { loadConfig } from "@sourcebot/shared"; -import fs from 'fs'; +import { env } from "@sourcebot/shared"; import { StatusCodes } from "http-status-codes"; -import path from 'path'; -import { LanguageModelInfo, SBChatMessage } from "./types"; -import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2"; -import { getAnonymousId, getOrCreateAnonymousId } from "@/lib/anonymousId"; -import { Chat, PrismaClient, User } from "@sourcebot/db"; -import { captureEvent } from "@/lib/posthog"; -import { withTracing } from "@posthog/ai"; -import { createPostHogClient, tryGetPostHogDistinctId } from "@/lib/posthog"; +import { SBChatMessage } from "./types"; +import { generateChatNameFromMessage, getConfiguredLanguageModels, isChatSharedWithUser, isOwnerOfChat } from "./utils.server"; -const logger = createLogger('chat-actions'); const auditService = getAuditService(); -/** - * Checks if the current user (authenticated or anonymous) is the owner of a chat. - */ -export const _isOwnerOfChat = async (chat: Chat, user: User | undefined): Promise => { - // Authenticated user owns the chat - if (user && chat.createdById === user.id) { - return true; - } - - // Only check the anonymous cookie for unclaimed chats (createdById === null). - // Once a chat has been claimed by an authenticated user, the anonymous path - // must not grant access — even if the same browser still holds the original cookie. - if (!chat.createdById && chat.anonymousCreatorId) { - const anonymousId = await getAnonymousId(); - if (anonymousId && chat.anonymousCreatorId === anonymousId) { - return true; - } - } - - return false; -}; - - -/** - * Checks if a user has been explicitly shared access to a chat. - */ -export const _hasSharedAccess = async ({ prisma, chatId, userId }: { prisma: PrismaClient, chatId: string, userId: string | undefined }): Promise => { - if (!userId) { - return false; - } - - const share = await prisma.chatAccess.findUnique({ - where: { - chatId_userId: { - chatId, - userId, - }, - }, - }); - - return share !== null; -}; - -export const _updateChatMessages = async ({ chatId, messages, prisma }: { chatId: string, messages: SBChatMessage[], prisma: PrismaClient }) => { - await prisma.chat.update({ - where: { - id: chatId, - }, - data: { - messages: messages as unknown as Prisma.InputJsonValue, - }, - }); - - if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) { - const chatDir = path.join(env.DATA_CACHE_DIR, 'chats'); - if (!fs.existsSync(chatDir)) { - fs.mkdirSync(chatDir, { recursive: true }); - } - - const chatFile = path.join(chatDir, `${chatId}.json`); - fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2)); - } -}; - - -export const _generateChatNameFromMessage = async ({ message, languageModelConfig }: { message: string, languageModelConfig: LanguageModel }) => { - const { model } = await _getAISDKLanguageModelAndOptions(languageModelConfig); - - const prompt = `Convert this question into a short topic title (max 50 characters). - -Rules: -- Do NOT include question words (what, where, how, why, when, which) -- Do NOT end with a question mark -- Capitalize the first letter of the title -- Focus on the subject/topic being discussed -- Make it sound like a file name or category - -Examples: -"Where is the authentication code?" → "Authentication Code" -"How to setup the database?" → "Database Setup" -"What are the API endpoints?" → "API Endpoints" - -User question: ${message}`; - - const result = await generateText({ - model, - prompt, - }); - - return result.text; -} - export const createChat = async () => sew(() => withOptionalAuthV2(async ({ org, user, prisma }) => { const isGuestUser = user === undefined; @@ -188,8 +73,8 @@ export const getChatInfo = async ({ chatId }: { chatId: string }) => sew(() => return notFound(); } - const isOwner = await _isOwnerOfChat(chat, user); - const isSharedWithUser = await _hasSharedAccess({ prisma, chatId, userId: user?.id }); + const isOwner = await isOwnerOfChat(chat, user); + const isSharedWithUser = await isChatSharedWithUser({ prisma, chatId, userId: user?.id }); // Private chats can only be viewed by the owner or users it's been shared with if (chat.visibility === ChatVisibility.PRIVATE && !isOwner && !isSharedWithUser) { @@ -206,34 +91,6 @@ export const getChatInfo = async ({ chatId }: { chatId: string }) => sew(() => }) ); -export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }) => sew(() => - withOptionalAuthV2(async ({ org, user, prisma }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); - - if (!chat) { - return notFound(); - } - - const isOwner = await _isOwnerOfChat(chat, user); - - // Only the owner can modify chat messages - if (!isOwner) { - return notFound(); - } - - await _updateChatMessages({ chatId, messages, prisma }); - - return { - success: true, - } - }) -); - export const getUserChatHistory = async () => sew(() => withAuthV2(async ({ org, user, prisma }) => { const chats = await prisma.chat.findMany({ @@ -268,7 +125,7 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s return notFound(); } - const isOwner = await _isOwnerOfChat(chat, user); + const isOwner = await isOwnerOfChat(chat, user); // Only the owner can rename chats if (!isOwner) { @@ -346,13 +203,13 @@ export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageMod return notFound(); } - const isOwner = await _isOwnerOfChat(chat, user); + const isOwner = await isOwnerOfChat(chat, user); if (!isOwner) { return notFound(); } const languageModelConfig = - (await _getConfiguredLanguageModelsFull()) + (await getConfiguredLanguageModels()) .find((model) => model.model === languageModelId); if (!languageModelConfig) { @@ -363,7 +220,7 @@ export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageMod } satisfies ServiceError; } - const name = await _generateChatNameFromMessage({ message, languageModelConfig }); + const name = await generateChatNameFromMessage({ message, languageModelConfig }); await prisma.chat.update({ where: { @@ -472,8 +329,8 @@ export const duplicateChat = async ({ chatId, newName }: { chatId: string, newNa } // Check if user can access the chat (owner, shared, or public) - const isOwner = await _isOwnerOfChat(originalChat, user); - const isSharedWithUser = await _hasSharedAccess({ prisma, chatId, userId: user?.id }); + const isOwner = await isOwnerOfChat(originalChat, user); + const isSharedWithUser = await isChatSharedWithUser({ prisma, chatId, userId: user?.id }); if (originalChat.visibility === ChatVisibility.PRIVATE && !isOwner && !isSharedWithUser) { return notFound(); } @@ -654,7 +511,7 @@ export const submitFeedback = async ({ } // When a chat is private, only the creator or shared users can submit feedback. - const isSharedWithUser = await _hasSharedAccess({ prisma, chatId, userId: user?.id }); + const isSharedWithUser = await isChatSharedWithUser({ prisma, chatId, userId: user?.id }); if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id && !isSharedWithUser) { return notFound(); } @@ -691,299 +548,6 @@ export const submitFeedback = async ({ }) ) -/** - * Returns the subset of information about the configured language models - * that we can safely send to the client. - */ -export const getConfiguredLanguageModelsInfo = async (): Promise => { - const models = await _getConfiguredLanguageModelsFull(); - return models.map((model): LanguageModelInfo => ({ - provider: model.provider, - model: model.model, - displayName: model.displayName, - })); -} - -/** - * Returns the full configuration of the language models. - * - * @warning Do NOT call this function from the client, - * or pass the result of calling this function to the client. - */ -export const _getConfiguredLanguageModelsFull = async (): Promise => { - try { - const config = await loadConfig(env.CONFIG_PATH); - return config.models ?? []; - } catch (error) { - logger.error('Failed to load language model configuration', error); - return []; - } -} - -export const _getAISDKLanguageModelAndOptions = async (config: LanguageModel): Promise<{ - model: AISDKLanguageModelV3, - providerOptions?: Record>, -}> => { - const { provider, model: modelId } = config; - - const { model: _model, providerOptions } = await (async (): Promise<{ - model: AISDKLanguageModelV3, - providerOptions?: Record>, - }> => { - switch (provider) { - case 'amazon-bedrock': { - const aws = createAmazonBedrock({ - baseURL: config.baseUrl, - region: config.region ?? env.AWS_REGION, - accessKeyId: config.accessKeyId - ? await getTokenFromConfig(config.accessKeyId) - : env.AWS_ACCESS_KEY_ID, - secretAccessKey: config.accessKeySecret - ? await getTokenFromConfig(config.accessKeySecret) - : env.AWS_SECRET_ACCESS_KEY, - sessionToken: config.sessionToken - ? await getTokenFromConfig(config.sessionToken) - : env.AWS_SESSION_TOKEN, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - // Fallback to the default Node.js credential provider chain if no credentials are provided. - // See: https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/Package/-aws-sdk-credential-providers/#fromnodeproviderchain - credentialProvider: !config.accessKeyId && !config.accessKeySecret && !config.sessionToken - ? fromNodeProviderChain() - : undefined, - }); - - return { - model: aws(modelId), - }; - } - case 'anthropic': { - const anthropic = createAnthropic({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.ANTHROPIC_API_KEY, - authToken: config.authToken - ? await getTokenFromConfig(config.authToken) - : env.ANTHROPIC_AUTH_TOKEN, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: anthropic(modelId), - providerOptions: { - anthropic: { - thinking: { - type: "enabled", - budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS, - } - } satisfies AnthropicProviderOptions, - }, - }; - } - case 'azure': { - const azure = createAzure({ - baseURL: config.baseUrl, - apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.AZURE_API_KEY, - apiVersion: config.apiVersion, - resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: azure(modelId), - }; - } - case 'deepseek': { - const deepseek = createDeepSeek({ - baseURL: config.baseUrl, - apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.DEEPSEEK_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: deepseek(modelId), - }; - } - case 'google-generative-ai': { - const google = createGoogleGenerativeAI({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.GOOGLE_GENERATIVE_AI_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: google(modelId), - }; - } - case 'google-vertex': { - const vertex = createVertex({ - project: config.project ?? env.GOOGLE_VERTEX_PROJECT, - location: config.region ?? env.GOOGLE_VERTEX_REGION, - ...(config.credentials ? { - googleAuthOptions: { - keyFilename: await getTokenFromConfig(config.credentials), - } - } : {}), - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: vertex(modelId), - providerOptions: { - vertex: { - thinkingConfig: { - thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS, - includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true', - } - } - }, - }; - } - case 'google-vertex-anthropic': { - const vertexAnthropic = createVertexAnthropic({ - project: config.project ?? env.GOOGLE_VERTEX_PROJECT, - location: config.region ?? env.GOOGLE_VERTEX_REGION, - ...(config.credentials ? { - googleAuthOptions: { - keyFilename: await getTokenFromConfig(config.credentials), - } - } : {}), - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: vertexAnthropic(modelId), - }; - } - case 'mistral': { - const mistral = createMistral({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.MISTRAL_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: mistral(modelId), - }; - } - case 'openai': { - const openai = createOpenAI({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.OPENAI_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: openai(modelId), - providerOptions: { - openai: { - reasoningEffort: config.reasoningEffort ?? 'medium', - } satisfies OpenAIResponsesProviderOptions, - }, - }; - } - case 'openai-compatible': { - const openai = createOpenAICompatible({ - baseURL: config.baseUrl, - name: config.displayName ?? modelId, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : undefined, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - queryParams: config.queryParams - ? await extractLanguageModelKeyValuePairs(config.queryParams) - : undefined, - }); - - const model = wrapLanguageModel({ - model: openai.chatModel(modelId), - middleware: [ - extractReasoningMiddleware({ - tagName: config.reasoningTag ?? 'think', - }), - ] - }); - - return { - model, - } - } - case 'openrouter': { - const openrouter = createOpenRouter({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.OPENROUTER_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: openrouter(modelId), - }; - } - case 'xai': { - const xai = createXai({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.XAI_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: xai(modelId), - }; - } - } - })(); - - const posthog = await createPostHogClient(); - const distinctId = await tryGetPostHogDistinctId(); - - // Only enable posthog LLM analytics for the ask GH experiment. - const model = env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? - withTracing(_model, posthog, { - posthogDistinctId: distinctId, - }) : - _model; - - return { - model, - providerOptions, - }; - -} - export const getAskGhLoginWallData = async () => sew(async () => { const isEnabled = env.EXPERIMENT_ASK_GH_ENABLED === 'true'; if (!isEnabled) { @@ -994,26 +558,3 @@ export const getAskGhLoginWallData = async () => sew(async () => { return { isEnabled: true as const, providers: getIdentityProviderMetadata() }; }); -const extractLanguageModelKeyValuePairs = async ( - pairs: { - [k: string]: string | Token; - } -): Promise> => { - const resolvedPairs: Record = {}; - - if (!pairs) { - return resolvedPairs; - } - - for (const [key, val] of Object.entries(pairs)) { - if (typeof val === "string") { - resolvedPairs[key] = val; - continue; - } - - const value = await getTokenFromConfig(val); - resolvedPairs[key] = value; - } - - return resolvedPairs; -} diff --git a/packages/web/src/features/chat/agent.ts b/packages/web/src/features/chat/agent.ts index ec2a30758..da9fb3baa 100644 --- a/packages/web/src/features/chat/agent.ts +++ b/packages/web/src/features/chat/agent.ts @@ -1,19 +1,147 @@ +import { SBChatMessage, SBChatMessageMetadata } from "@/features/chat/types"; +import { getAnswerPartFromAssistantMessage } from "@/features/chat/utils"; import { getFileSource } from '@/features/git'; -import { isServiceError } from "@/lib/utils"; import { captureEvent } from "@/lib/posthog"; +import { isServiceError } from "@/lib/utils"; +import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; import { ProviderOptions } from "@ai-sdk/provider-utils"; import { createLogger, env } from "@sourcebot/shared"; -import { LanguageModel, ModelMessage, StopCondition, streamText } from "ai"; +import { + createUIMessageStream, JSONValue, LanguageModel, ModelMessage, StopCondition, streamText, StreamTextResult, + UIMessageStreamOnFinishCallback, + UIMessageStreamOptions, + UIMessageStreamWriter +} from "ai"; +import { randomUUID } from "crypto"; +import _dedent from "dedent"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX, toolNames } from "./constants"; -import { createCodeSearchTool, findSymbolDefinitionsTool, findSymbolReferencesTool, listReposTool, listCommitsTool, readFilesTool } from "./tools"; +import { createCodeSearchTool, findSymbolDefinitionsTool, findSymbolReferencesTool, listCommitsTool, listReposTool, readFilesTool } from "./tools"; import { Source } from "./types"; import { addLineNumbers, fileReferenceToString } from "./utils"; -import _dedent from "dedent"; const dedent = _dedent.withOptions({ alignValues: true }); const logger = createLogger('chat-agent'); +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const mergeStreamAsync = async (stream: StreamTextResult, writer: UIMessageStreamWriter, options: UIMessageStreamOptions = {}) => { + await new Promise((resolve) => writer.merge(stream.toUIMessageStream({ + ...options, + onFinish: async () => { + resolve(); + } + }))); +} + +interface CreateMessageStreamResponseProps { + chatId: string; + messages: SBChatMessage[]; + selectedRepos: string[]; + model: AISDKLanguageModelV3; + modelName: string; + onFinish: UIMessageStreamOnFinishCallback; + onError: (error: unknown) => string; + modelProviderOptions?: Record>; + metadata?: Partial; +} + +export const createMessageStream = async ({ + chatId, + messages, + metadata, + selectedRepos, + model, + modelName, + modelProviderOptions, + onFinish, + onError, +}: CreateMessageStreamResponseProps) => { + const latestMessage = messages[messages.length - 1]; + const sources = latestMessage.parts + .filter((part) => part.type === 'data-source') + .map((part) => part.data); + + const traceId = randomUUID(); + + // Extract user messages and assistant answers. + // We will use this as the context we carry between messages. + const messageHistory = + messages.map((message): ModelMessage | undefined => { + if (message.role === 'user') { + return { + role: 'user', + content: message.parts[0].type === 'text' ? message.parts[0].text : '', + }; + } + + if (message.role === 'assistant') { + const answerPart = getAnswerPartFromAssistantMessage(message, false); + if (answerPart) { + return { + role: 'assistant', + content: [answerPart] + } + } + } + }).filter(message => message !== undefined); + + const stream = createUIMessageStream({ + execute: async ({ writer }) => { + writer.write({ + type: 'start', + }); + + const startTime = new Date(); + + const researchStream = await createAgentStream({ + model, + providerOptions: modelProviderOptions, + inputMessages: messageHistory, + inputSources: sources, + selectedRepos, + onWriteSource: (source) => { + writer.write({ + type: 'data-source', + data: source, + }); + }, + traceId, + chatId, + }); + + await mergeStreamAsync(researchStream, writer, { + sendReasoning: true, + sendStart: false, + sendFinish: false, + }); + + const totalUsage = await researchStream.totalUsage; + + writer.write({ + type: 'message-metadata', + messageMetadata: { + totalTokens: totalUsage.totalTokens, + totalInputTokens: totalUsage.inputTokens, + totalOutputTokens: totalUsage.outputTokens, + totalResponseTimeMs: new Date().getTime() - startTime.getTime(), + modelName, + traceId, + ...metadata, + } + }); + + writer.write({ + type: 'finish', + }); + }, + onError, + originalMessages: messages, + onFinish, + }); + + return stream; +}; + interface AgentOptions { model: LanguageModel; providerOptions?: ProviderOptions; @@ -25,7 +153,7 @@ interface AgentOptions { chatId: string; } -export const createAgentStream = async ({ +const createAgentStream = async ({ model, providerOptions, inputMessages, diff --git a/packages/web/src/features/chat/types.ts b/packages/web/src/features/chat/types.ts index fbf840538..9411f850f 100644 --- a/packages/web/src/features/chat/types.ts +++ b/packages/web/src/features/chat/types.ts @@ -168,15 +168,36 @@ export type SetChatStatePayload = { export type LanguageModelProvider = LanguageModel['provider']; -// This is a subset of information about a configured -// language model that we can safely send to the client. -// @note: ensure this is in sync with the LanguageModelInfo type. +export const languageModelProviders = [ + "amazon-bedrock", + "anthropic", + "azure", + "deepseek", + "google-generative-ai", + "google-vertex-anthropic", + "google-vertex", + "mistral", + "openai", + "openai-compatible", + "openrouter", + "xai", +] as const satisfies readonly LanguageModelProvider[]; + +// Type-check assertion that ensure the above array is up to date +// with the LanguageModelProvider type. +type _AssertAllProviders = LanguageModelProvider extends typeof languageModelProviders[number] ? true : never; +const _assertAllProviders: _AssertAllProviders = true; +void _assertAllProviders; + export const languageModelInfoSchema = z.object({ - provider: z.string(), - model: z.string(), - displayName: z.string().optional(), + provider: z.enum(languageModelProviders).describe("The model provider (e.g., 'anthropic', 'openai')"), + model: z.string().describe("The model ID"), + displayName: z.string().optional().describe("Optional display name for the model"), }); +/** + * Client safe subset of information about a language model. + */ export type LanguageModelInfo = { provider: LanguageModelProvider, model: LanguageModel['model'], diff --git a/packages/web/src/features/chat/utils.server.ts b/packages/web/src/features/chat/utils.server.ts new file mode 100644 index 000000000..43faa4abe --- /dev/null +++ b/packages/web/src/features/chat/utils.server.ts @@ -0,0 +1,445 @@ +import 'server-only'; + +import { getAnonymousId } from '@/lib/anonymousId'; +import { createPostHogClient, tryGetPostHogDistinctId } from "@/lib/posthog"; +import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; +import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; +import { createAzure } from '@ai-sdk/azure'; +import { createDeepSeek } from '@ai-sdk/deepseek'; +import { createGoogleGenerativeAI } from '@ai-sdk/google'; +import { createVertex } from '@ai-sdk/google-vertex'; +import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic'; +import { createMistral } from '@ai-sdk/mistral'; +import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai"; +import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; +import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; +import { createXai } from '@ai-sdk/xai'; +import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; +import { createOpenRouter } from '@openrouter/ai-sdk-provider'; +import { withTracing } from "@posthog/ai"; +import { Chat, Prisma, PrismaClient, User } from '@sourcebot/db'; +import { LanguageModel } from '@sourcebot/schemas/v3/languageModel.type'; +import { Token } from "@sourcebot/schemas/v3/shared.type"; +import { env, getTokenFromConfig, loadConfig } from '@sourcebot/shared'; +import { extractReasoningMiddleware, generateText, JSONValue, wrapLanguageModel } from "ai"; +import fs from 'fs'; +import path from 'path'; +import { LanguageModelInfo, SBChatMessage } from './types'; + +/** + * Checks if the current user (authenticated or anonymous) is the owner of a chat. + */ +export const isOwnerOfChat = async (chat: Chat, user: User | undefined): Promise => { + // Authenticated user owns the chat + if (user && chat.createdById === user.id) { + return true; + } + + // Only check the anonymous cookie for unclaimed chats (createdById === null). + // Once a chat has been claimed by an authenticated user, the anonymous path + // must not grant access — even if the same browser still holds the original cookie. + if (!chat.createdById && chat.anonymousCreatorId) { + const anonymousId = await getAnonymousId(); + if (anonymousId && chat.anonymousCreatorId === anonymousId) { + return true; + } + } + + return false; +}; + +/** + * Checks if a user has been explicitly shared access to a chat. + */ +export const isChatSharedWithUser = async ({ + prisma, chatId, userId, +}: { + prisma: PrismaClient; + chatId: string; + userId?: string; +}): Promise => { + if (!userId) { + return false; + } + + const share = await prisma.chatAccess.findUnique({ + where: { + chatId_userId: { + chatId, + userId, + }, + }, + }); + + return share !== null; +}; + +export const updateChatMessages = async ({ + prisma, chatId, messages, +}: { + prisma: PrismaClient; + chatId: string; + messages: SBChatMessage[]; +}) => { + await prisma.chat.update({ + where: { + id: chatId, + }, + data: { + messages: messages as unknown as Prisma.InputJsonValue, + }, + }); + + if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) { + const chatDir = path.join(env.DATA_CACHE_DIR, 'chats'); + if (!fs.existsSync(chatDir)) { + fs.mkdirSync(chatDir, { recursive: true }); + } + + const chatFile = path.join(chatDir, `${chatId}.json`); + fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2)); + } +}; +/** + * Returns the full configuration of the language models. + * + * @warning this can contain sensitive information like environment + * variable names and base URLs. When passing information to the client, + * use getConfiguredLanguageModelsInfo instead. + */ +export const getConfiguredLanguageModels = async (): Promise => { + try { + const config = await loadConfig(env.CONFIG_PATH); + return config.models ?? []; + } catch (error) { + console.error('Failed to load language model configuration', error); + return []; + } +}; + +/** + * Returns the subset of information about the configured language models + * that we can safely send to the client. + */ +export const getConfiguredLanguageModelsInfo = async () => { + const models = await getConfiguredLanguageModels(); + return models.map((model): LanguageModelInfo => ({ + provider: model.provider, + model: model.model, + displayName: model.displayName, + })); +}; + +export const generateChatNameFromMessage = async ({ message, languageModelConfig }: { message: string, languageModelConfig: LanguageModel }) => { + const { model } = await getAISDKLanguageModelAndOptions(languageModelConfig); + + const prompt = `Convert this question into a short topic title (max 50 characters). + +Rules: +- Do NOT include question words (what, where, how, why, when, which) +- Do NOT end with a question mark +- Capitalize the first letter of the title +- Focus on the subject/topic being discussed +- Make it sound like a file name or category + +Examples: +"Where is the authentication code?" → "Authentication Code" +"How to setup the database?" → "Database Setup" +"What are the API endpoints?" → "API Endpoints" + +User question: ${message}`; + + const result = await generateText({ + model, + prompt, + }); + + return result.text; +} + +export const getAISDKLanguageModelAndOptions = async (config: LanguageModel): Promise<{ + model: AISDKLanguageModelV3, + providerOptions?: Record>, +}> => { + const { provider, model: modelId } = config; + + const { model: _model, providerOptions } = await (async (): Promise<{ + model: AISDKLanguageModelV3, + providerOptions?: Record>, + }> => { + switch (provider) { + case 'amazon-bedrock': { + const aws = createAmazonBedrock({ + baseURL: config.baseUrl, + region: config.region ?? env.AWS_REGION, + accessKeyId: config.accessKeyId + ? await getTokenFromConfig(config.accessKeyId) + : env.AWS_ACCESS_KEY_ID, + secretAccessKey: config.accessKeySecret + ? await getTokenFromConfig(config.accessKeySecret) + : env.AWS_SECRET_ACCESS_KEY, + sessionToken: config.sessionToken + ? await getTokenFromConfig(config.sessionToken) + : env.AWS_SESSION_TOKEN, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + // Fallback to the default Node.js credential provider chain if no credentials are provided. + // See: https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/Package/-aws-sdk-credential-providers/#fromnodeproviderchain + credentialProvider: !config.accessKeyId && !config.accessKeySecret && !config.sessionToken + ? fromNodeProviderChain() + : undefined, + }); + + return { + model: aws(modelId), + }; + } + case 'anthropic': { + const anthropic = createAnthropic({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.ANTHROPIC_API_KEY, + authToken: config.authToken + ? await getTokenFromConfig(config.authToken) + : env.ANTHROPIC_AUTH_TOKEN, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: anthropic(modelId), + providerOptions: { + anthropic: { + thinking: { + type: "enabled", + budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS, + } + } satisfies AnthropicProviderOptions, + }, + }; + } + case 'azure': { + const azure = createAzure({ + baseURL: config.baseUrl, + apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.AZURE_API_KEY, + apiVersion: config.apiVersion, + resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: azure(modelId), + }; + } + case 'deepseek': { + const deepseek = createDeepSeek({ + baseURL: config.baseUrl, + apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.DEEPSEEK_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: deepseek(modelId), + }; + } + case 'google-generative-ai': { + const google = createGoogleGenerativeAI({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.GOOGLE_GENERATIVE_AI_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: google(modelId), + }; + } + case 'google-vertex': { + const vertex = createVertex({ + project: config.project ?? env.GOOGLE_VERTEX_PROJECT, + location: config.region ?? env.GOOGLE_VERTEX_REGION, + ...(config.credentials ? { + googleAuthOptions: { + keyFilename: await getTokenFromConfig(config.credentials), + } + } : {}), + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: vertex(modelId), + providerOptions: { + vertex: { + thinkingConfig: { + thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS, + includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true', + } + } + }, + }; + } + case 'google-vertex-anthropic': { + const vertexAnthropic = createVertexAnthropic({ + project: config.project ?? env.GOOGLE_VERTEX_PROJECT, + location: config.region ?? env.GOOGLE_VERTEX_REGION, + ...(config.credentials ? { + googleAuthOptions: { + keyFilename: await getTokenFromConfig(config.credentials), + } + } : {}), + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: vertexAnthropic(modelId), + }; + } + case 'mistral': { + const mistral = createMistral({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.MISTRAL_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: mistral(modelId), + }; + } + case 'openai': { + const openai = createOpenAI({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.OPENAI_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: openai(modelId), + providerOptions: { + openai: { + reasoningEffort: config.reasoningEffort ?? 'medium', + } satisfies OpenAIResponsesProviderOptions, + }, + }; + } + case 'openai-compatible': { + const openai = createOpenAICompatible({ + baseURL: config.baseUrl, + name: config.displayName ?? modelId, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : undefined, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + queryParams: config.queryParams + ? await extractLanguageModelKeyValuePairs(config.queryParams) + : undefined, + }); + + const model = wrapLanguageModel({ + model: openai.chatModel(modelId), + middleware: [ + extractReasoningMiddleware({ + tagName: config.reasoningTag ?? 'think', + }), + ] + }); + + return { + model, + } + } + case 'openrouter': { + const openrouter = createOpenRouter({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.OPENROUTER_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: openrouter(modelId), + }; + } + case 'xai': { + const xai = createXai({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.XAI_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: xai(modelId), + }; + } + } + })(); + + const posthog = await createPostHogClient(); + const distinctId = await tryGetPostHogDistinctId(); + + // Only enable posthog LLM analytics for the ask GH experiment. + const model = env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? + withTracing(_model, posthog, { + posthogDistinctId: distinctId, + }) : + _model; + + return { + model, + providerOptions, + }; +} + +const extractLanguageModelKeyValuePairs = async ( + pairs: { + [k: string]: string | Token; + } +): Promise> => { + const resolvedPairs: Record = {}; + + if (!pairs) { + return resolvedPairs; + } + + for (const [key, val] of Object.entries(pairs)) { + if (typeof val === "string") { + resolvedPairs[key] = val; + continue; + } + + const value = await getTokenFromConfig(val); + resolvedPairs[key] = value; + } + + return resolvedPairs; +}; \ No newline at end of file diff --git a/packages/web/src/features/chat/utils.ts b/packages/web/src/features/chat/utils.ts index ca412618e..f77325c4d 100644 --- a/packages/web/src/features/chat/utils.ts +++ b/packages/web/src/features/chat/utils.ts @@ -1,8 +1,8 @@ +import { BrowseHighlightRange, getBrowsePath } from "@/app/[domain]/browse/hooks/utils"; +import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants"; import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai"; import { Descendant, Editor, Point, Range, Transforms } from "slate"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants"; -import { getBrowsePath, BrowseHighlightRange } from "@/app/[domain]/browse/hooks/utils"; -import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants"; import { CustomEditor, CustomText, @@ -19,6 +19,7 @@ import { Source, } from "./types"; + export const insertMention = (editor: CustomEditor, data: MentionData, target?: Range | null) => { const mention: MentionElement = { type: 'mention', @@ -365,4 +366,4 @@ export const tryResolveFileReference = (reference: FileReference, sources: FileS (source) => source.repo.endsWith(reference.repo) && source.path.endsWith(reference.path) ); -} \ No newline at end of file +}; diff --git a/packages/web/src/features/mcp/askCodebase.ts b/packages/web/src/features/mcp/askCodebase.ts new file mode 100644 index 000000000..a6543e79f --- /dev/null +++ b/packages/web/src/features/mcp/askCodebase.ts @@ -0,0 +1,198 @@ +import { sew } from "@/actions"; +import { getConfiguredLanguageModels, getAISDKLanguageModelAndOptions, generateChatNameFromMessage, updateChatMessages } from "@/features/chat/utils.server"; +import { LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types"; +import { convertLLMOutputToPortableMarkdown, getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/features/chat/utils"; +import { ErrorCode } from "@/lib/errorCodes"; +import { ServiceError, ServiceErrorException } from "@/lib/serviceError"; +import { withOptionalAuthV2 } from "@/withAuthV2"; +import { ChatVisibility, Prisma } from "@sourcebot/db"; +import { createLogger, env } from "@sourcebot/shared"; +import { randomUUID } from "crypto"; +import { StatusCodes } from "http-status-codes"; +import { InferUIMessageChunk, UIDataTypes, UIMessage, UITools } from "ai"; +import { captureEvent } from "@/lib/posthog"; +import { createMessageStream } from "../chat/agent"; + +const logger = createLogger('ask-codebase-api'); + +export type AskCodebaseParams = { + query: string; + repos?: string[]; + languageModel?: LanguageModelInfo; + visibility?: ChatVisibility; +}; + +export type AskCodebaseResult = { + answer: string; + chatId: string; + chatUrl: string; + languageModel: LanguageModelInfo; +}; + +const blockStreamUntilFinish = async >( + stream: ReadableStream> +) => { + const reader = stream.getReader(); + while (true as const) { + const { done } = await reader.read(); + if (done) break; + } +}; + +export const askCodebase = (params: AskCodebaseParams): Promise => + sew(() => + withOptionalAuthV2(async ({ org, user, prisma }) => { + const { query, repos = [], languageModel: requestedLanguageModel, visibility: requestedVisibility } = params; + + const configuredModels = await getConfiguredLanguageModels(); + if (configuredModels.length === 0) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: "No language models are configured. Please configure at least one language model. See: https://docs.sourcebot.dev/docs/configuration/language-model-providers", + } satisfies ServiceError; + } + + let languageModelConfig = configuredModels[0]; + if (requestedLanguageModel) { + const matchingModel = configuredModels.find( + (m) => getLanguageModelKey(m) === getLanguageModelKey(requestedLanguageModel) + ); + if (!matchingModel) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: `Language model '${requestedLanguageModel.provider}/${requestedLanguageModel.model}' is not configured.`, + } satisfies ServiceError; + } + languageModelConfig = matchingModel; + } + + const { model, providerOptions } = await getAISDKLanguageModelAndOptions(languageModelConfig); + const modelName = languageModelConfig.displayName ?? languageModelConfig.model; + + const chatVisibility = (requestedVisibility && user) + ? requestedVisibility + : (user ? ChatVisibility.PRIVATE : ChatVisibility.PUBLIC); + + const chat = await prisma.chat.create({ + data: { + orgId: org.id, + createdById: user?.id, + visibility: chatVisibility, + messages: [] as unknown as Prisma.InputJsonValue, + }, + }); + + await captureEvent('wa_chat_thread_created', { + chatId: chat.id, + isAnonymous: !user, + }); + + logger.debug(`Starting blocking agent for chat ${chat.id}`, { + chatId: chat.id, + query: query.substring(0, 100), + model: modelName, + }); + + const userMessage: SBChatMessage = { + id: randomUUID(), + role: 'user', + parts: [{ type: 'text', text: query }], + }; + + const selectedRepos = (await Promise.all(repos.map(async (repo) => { + const repoDB = await prisma.repo.findFirst({ + where: { name: repo }, + }); + if (!repoDB) { + throw new ServiceErrorException({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: `Repository '${repo}' not found.`, + }); + } + return { + type: 'repo', + value: repoDB.name, + name: repoDB.displayName ?? repoDB.name.split('/').pop() ?? repoDB.name, + codeHostType: repoDB.external_codeHostType, + } satisfies SearchScope; + }))); + + let finalMessages: SBChatMessage[] = []; + + await captureEvent('wa_chat_message_sent', { + chatId: chat.id, + messageCount: 1, + selectedReposCount: selectedRepos.length, + ...(env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? { + selectedRepos: selectedRepos.map(r => r.value) + } : {}), + }); + + const stream = await createMessageStream({ + chatId: chat.id, + messages: [userMessage], + metadata: { + selectedSearchScopes: selectedRepos, + }, + selectedRepos: selectedRepos.map(r => r.value), + model, + modelName, + modelProviderOptions: providerOptions, + onFinish: async ({ messages }) => { + finalMessages = messages; + }, + onError: (error) => { + if (error instanceof ServiceErrorException) { + throw error; + } + const message = error instanceof Error ? error.message : String(error); + throw new ServiceErrorException({ + statusCode: StatusCodes.INTERNAL_SERVER_ERROR, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message, + }); + }, + }); + + const [, name] = await Promise.all([ + blockStreamUntilFinish(stream), + generateChatNameFromMessage({ + message: query, + languageModelConfig, + }) + ]); + + await updateChatMessages({ chatId: chat.id, messages: finalMessages, prisma }); + + await prisma.chat.update({ + where: { id: chat.id, orgId: org.id }, + data: { name }, + }); + + const assistantMessage = finalMessages.find(m => m.role === 'assistant'); + const answerPart = assistantMessage + ? getAnswerPartFromAssistantMessage(assistantMessage, false) + : undefined; + const answerText = answerPart?.text ?? ''; + + const baseUrl = env.AUTH_URL; + const portableAnswer = convertLLMOutputToPortableMarkdown(answerText, baseUrl); + const chatUrl = `${baseUrl}/${org.domain}/chat/${chat.id}`; + + logger.debug(`Completed blocking agent for chat ${chat.id}`, { chatId: chat.id }); + + return { + answer: portableAnswer, + chatId: chat.id, + chatUrl, + languageModel: { + provider: languageModelConfig.provider, + model: languageModelConfig.model, + displayName: languageModelConfig.displayName, + }, + } satisfies AskCodebaseResult; + }) + ); diff --git a/packages/web/src/features/mcp/server.ts b/packages/web/src/features/mcp/server.ts new file mode 100644 index 000000000..07a27ec31 --- /dev/null +++ b/packages/web/src/features/mcp/server.ts @@ -0,0 +1,507 @@ +import { listRepos } from '@/app/api/(server)/repos/listReposApi'; +import { getConfiguredLanguageModelsInfo } from "../chat/utils.server"; +import { askCodebase } from '@/features/mcp/askCodebase'; +import { + languageModelInfoSchema, +} from '@/features/chat/types'; +import { getFileSource, getTree, listCommits } from '@/features/git'; +import { search } from '@/features/search/searchApi'; +import { isServiceError } from '@/lib/utils'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { ChatVisibility } from '@sourcebot/db'; +import { SOURCEBOT_VERSION } from '@sourcebot/shared'; +import _dedent from 'dedent'; +import escapeStringRegexp from 'escape-string-regexp'; +import { z } from 'zod'; +import { + ListTreeEntry, + TextContent, +} from './types'; +import { buildTreeNodeIndex, joinTreePath, normalizeTreePath, sortTreeEntries } from './utils'; + +const dedent = _dedent.withOptions({ alignValues: true }); + +const DEFAULT_MINIMUM_TOKENS = 10000; +const DEFAULT_MATCHES = 10000; +const DEFAULT_CONTEXT_LINES = 5; + +const DEFAULT_TREE_DEPTH = 1; +const MAX_TREE_DEPTH = 10; +const DEFAULT_MAX_TREE_ENTRIES = 1000; +const MAX_MAX_TREE_ENTRIES = 10000; + +export function createMcpServer(): McpServer { + const server = new McpServer({ + name: 'sourcebot-mcp-server', + version: SOURCEBOT_VERSION, + }); + + server.registerTool( + "search_code", + { + description: dedent` + Searches for code that matches the provided search query as a substring by default, or as a regular expression if useRegex is true. Useful for exploring remote repositories by searching for exact symbols, functions, variables, or specific code patterns. To determine if a repository is indexed, use the \`list_repos\` tool. By default, searches are global and will search the default branch of all repositories. Searches can be scoped to specific repositories, languages, and branches. When referencing code outputted by this tool, always include the file's external URL as a link. This makes it easier for the user to view the file, even if they don't have it locally checked out.`, + inputSchema: { + query: z + .string() + .describe(`The search pattern to match against code contents. Do not escape quotes in your query.`) + .transform((val) => { + const escaped = val.replace(/\\/g, '\\\\').replace(/"/g, '\\"'); + return `"${escaped}"`; + }), + useRegex: z + .boolean() + .describe(`Whether to use regular expression matching. When false, substring matching is used. (default: false)`) + .optional(), + filterByRepos: z + .array(z.string()) + .describe(`Scope the search to the provided repositories.`) + .optional(), + filterByLanguages: z + .array(z.string()) + .describe(`Scope the search to the provided languages.`) + .optional(), + filterByFilepaths: z + .array(z.string()) + .describe(`Scope the search to the provided filepaths.`) + .optional(), + caseSensitive: z + .boolean() + .describe(`Whether the search should be case sensitive (default: false).`) + .optional(), + includeCodeSnippets: z + .boolean() + .describe(`Whether to include code snippets in the response. If false, only the file's URL, repository, and language will be returned. (default: false)`) + .optional(), + ref: z + .string() + .describe(`Commit SHA, branch or tag name to search on. If not provided, defaults to the default branch.`) + .optional(), + maxTokens: z + .number() + .describe(`The maximum number of tokens to return (default: ${DEFAULT_MINIMUM_TOKENS}).`) + .transform((val) => (val < DEFAULT_MINIMUM_TOKENS ? DEFAULT_MINIMUM_TOKENS : val)) + .optional(), + }, + }, + async ({ + query, + filterByRepos: repos = [], + filterByLanguages: languages = [], + filterByFilepaths: filepaths = [], + maxTokens = DEFAULT_MINIMUM_TOKENS, + includeCodeSnippets = false, + caseSensitive = false, + ref, + useRegex = false, + }: { + query: string; + useRegex?: boolean; + filterByRepos?: string[]; + filterByLanguages?: string[]; + filterByFilepaths?: string[]; + caseSensitive?: boolean; + includeCodeSnippets?: boolean; + ref?: string; + maxTokens?: number; + }) => { + if (repos.length > 0) { + query += ` (repo:${repos.map(id => escapeStringRegexp(id)).join(' or repo:')})`; + } + if (languages.length > 0) { + query += ` (lang:${languages.join(' or lang:')})`; + } + if (filepaths.length > 0) { + query += ` (file:${filepaths.map(fp => escapeStringRegexp(fp)).join(' or file:')})`; + } + if (ref) { + query += ` ( rev:${ref} )`; + } + + const response = await search({ + queryType: 'string', + query, + options: { + matches: DEFAULT_MATCHES, + contextLines: DEFAULT_CONTEXT_LINES, + isRegexEnabled: useRegex, + isCaseSensitivityEnabled: caseSensitive, + }, + }); + + if (isServiceError(response)) { + return { + content: [{ type: "text", text: `Search failed: ${response.message}` }], + }; + } + + if (response.files.length === 0) { + return { + content: [{ type: "text", text: `No results found for the query: ${query}` }], + }; + } + + const content: TextContent[] = []; + let totalTokens = 0; + let isResponseTruncated = false; + + for (const file of response.files) { + const numMatches = file.chunks.reduce((acc, chunk) => acc + chunk.matchRanges.length, 0); + let text = dedent` + file: ${file.webUrl} + num_matches: ${numMatches} + repo: ${file.repository} + language: ${file.language} + `; + + if (includeCodeSnippets) { + const snippets = file.chunks.map(chunk => `\`\`\`\n${chunk.content}\n\`\`\``).join('\n'); + text += `\n\n${snippets}`; + } + + const tokens = text.length / 4; + + if ((totalTokens + tokens) > maxTokens) { + const remainingTokens = maxTokens - totalTokens; + if (remainingTokens > 100) { + const maxLength = Math.floor(remainingTokens * 4); + content.push({ + type: "text", + text: text.substring(0, maxLength) + "\n\n...[content truncated due to token limit]", + }); + totalTokens += remainingTokens; + } + isResponseTruncated = true; + break; + } + + totalTokens += tokens; + content.push({ type: "text", text }); + } + + if (isResponseTruncated) { + content.push({ + type: "text", + text: `The response was truncated because the number of tokens exceeded the maximum limit of ${maxTokens}.`, + }); + } + + return { content }; + } + ); + + server.registerTool( + "list_commits", + { + description: dedent`Get a list of commits for a given repository.`, + inputSchema: z.object({ + repo: z.string().describe("The name of the repository to list commits for."), + query: z.string().describe("Search query to filter commits by message content (case-insensitive).").optional(), + since: z.string().describe("Show commits more recent than this date. Supports ISO 8601 or relative formats (e.g., '30 days ago').").optional(), + until: z.string().describe("Show commits older than this date. Supports ISO 8601 or relative formats (e.g., 'yesterday').").optional(), + author: z.string().describe("Filter commits by author name or email (case-insensitive).").optional(), + ref: z.string().describe("Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch.").optional(), + page: z.number().int().positive().describe("Page number for pagination (min 1). Default: 1").optional().default(1), + perPage: z.number().int().positive().max(100).describe("Results per page for pagination (min 1, max 100). Default: 50").optional().default(50), + }), + }, + async ({ repo, query, since, until, author, ref, page, perPage }) => { + const skip = (page - 1) * perPage; + const result = await listCommits({ + repo, + query, + since, + until, + author, + ref, + maxCount: perPage, + skip, + }); + + if (isServiceError(result)) { + return { + content: [{ type: "text", text: `Failed to list commits: ${result.message}` }], + }; + } + + return { content: [{ type: "text", text: JSON.stringify(result) }] }; + } + ); + + server.registerTool( + "list_repos", + { + description: dedent`Lists repositories in the organization with optional filtering and pagination.`, + inputSchema: z.object({ + query: z.string().describe("Filter repositories by name (case-insensitive)").optional(), + page: z.number().int().positive().describe("Page number for pagination (min 1). Default: 1").optional().default(1), + perPage: z.number().int().positive().max(100).describe("Results per page for pagination (min 1, max 100). Default: 30").optional().default(30), + sort: z.enum(['name', 'pushed']).describe("Sort repositories by 'name' or 'pushed' (most recent commit). Default: 'name'").optional().default('name'), + direction: z.enum(['asc', 'desc']).describe("Sort direction: 'asc' or 'desc'. Default: 'asc'").optional().default('asc'), + }) + }, + async ({ query, page, perPage, sort, direction }) => { + const result = await listRepos({ query, page, perPage, sort, direction }); + + if (isServiceError(result)) { + return { + content: [{ type: "text", text: `Failed to list repositories: ${result.message}` }], + }; + } + + return { + content: [{ + type: "text", + text: JSON.stringify({ + repos: result.data.map((repo) => ({ + name: repo.repoName, + url: repo.webUrl, + pushedAt: repo.pushedAt, + defaultBranch: repo.defaultBranch, + isFork: repo.isFork, + isArchived: repo.isArchived, + })), + totalCount: result.totalCount, + }), + }], + }; + } + ); + + server.registerTool( + "read_file", + { + description: dedent`Reads the source code for a given file.`, + inputSchema: { + repo: z.string().describe("The repository name."), + path: z.string().describe("The path to the file."), + ref: z.string().optional().describe("Commit SHA, branch or tag name to fetch the source code for. If not provided, uses the default branch of the repository."), + }, + }, + async ({ repo, path, ref }) => { + const response = await getFileSource({ repo, path, ref }); + + if (isServiceError(response)) { + return { + content: [{ type: "text", text: `Failed to read file: ${response.message}` }], + }; + } + + return { + content: [{ + type: "text", + text: JSON.stringify({ + source: response.source, + language: response.language, + path: response.path, + url: response.webUrl, + }), + }], + }; + } + ); + + server.registerTool( + "list_tree", + { + description: dedent` + Lists files and directories from a repository path. This can be used as a repo tree tool or directory listing tool. + Returns a flat list of entries with path metadata and depth relative to the requested path. + `, + inputSchema: { + repo: z.string().describe("The name of the repository to list files from."), + path: z.string().describe("Directory path (relative to repo root). If omitted, the repo root is used.").optional().default(''), + ref: z.string().describe("Commit SHA, branch or tag name to list files from. If not provided, uses the default branch.").optional().default('HEAD'), + depth: z.number().int().positive().max(MAX_TREE_DEPTH).describe(`How many directory levels to traverse below \`path\` (min 1, max ${MAX_TREE_DEPTH}, default ${DEFAULT_TREE_DEPTH}).`).optional().default(DEFAULT_TREE_DEPTH), + includeFiles: z.boolean().describe("Whether to include files in the output (default: true).").optional().default(true), + includeDirectories: z.boolean().describe("Whether to include directories in the output (default: true).").optional().default(true), + maxEntries: z.number().int().positive().max(MAX_MAX_TREE_ENTRIES).describe(`Maximum number of entries to return (min 1, max ${MAX_MAX_TREE_ENTRIES}, default ${DEFAULT_MAX_TREE_ENTRIES}).`).optional().default(DEFAULT_MAX_TREE_ENTRIES), + }, + }, + async ({ + repo, + path = '', + ref = 'HEAD', + depth = DEFAULT_TREE_DEPTH, + includeFiles = true, + includeDirectories = true, + maxEntries = DEFAULT_MAX_TREE_ENTRIES, + }: { + repo: string; + path?: string; + ref?: string; + depth?: number; + includeFiles?: boolean; + includeDirectories?: boolean; + maxEntries?: number; + }) => { + const normalizedPath = normalizeTreePath(path); + const normalizedDepth = Math.min(depth, MAX_TREE_DEPTH); + const normalizedMaxEntries = Math.min(maxEntries, MAX_MAX_TREE_ENTRIES); + + if (!includeFiles && !includeDirectories) { + return { + content: [{ + type: "text", + text: JSON.stringify({ + repo, ref, path: normalizedPath, + entries: [] as ListTreeEntry[], + totalReturned: 0, + truncated: false, + }), + }], + }; + } + + const queue: Array<{ path: string; depth: number }> = [{ path: normalizedPath, depth: 0 }]; + const queuedPaths = new Set([normalizedPath]); + const seenEntries = new Set(); + const entries: ListTreeEntry[] = []; + let truncated = false; + let treeError: string | null = null; + + while (queue.length > 0 && !truncated) { + const currentDepth = queue[0]!.depth; + const currentLevelPaths: string[] = []; + + while (queue.length > 0 && queue[0]!.depth === currentDepth) { + currentLevelPaths.push(queue.shift()!.path); + } + + const treeResult = await getTree({ + repoName: repo, + revisionName: ref, + paths: currentLevelPaths.filter(Boolean), + }); + + if (isServiceError(treeResult)) { + treeError = treeResult.message; + break; + } + + const treeNodeIndex = buildTreeNodeIndex(treeResult.tree); + + for (const currentPath of currentLevelPaths) { + const currentNode = currentPath === '' ? treeResult.tree : treeNodeIndex.get(currentPath); + if (!currentNode || currentNode.type !== 'tree') continue; + + for (const child of currentNode.children) { + if (child.type !== 'tree' && child.type !== 'blob') continue; + + const childPath = joinTreePath(currentPath, child.name); + const childDepth = currentDepth + 1; + + if (child.type === 'tree' && childDepth < normalizedDepth && !queuedPaths.has(childPath)) { + queue.push({ path: childPath, depth: childDepth }); + queuedPaths.add(childPath); + } + + if ((child.type === 'blob' && !includeFiles) || (child.type === 'tree' && !includeDirectories)) { + continue; + } + + const key = `${child.type}:${childPath}`; + if (seenEntries.has(key)) continue; + seenEntries.add(key); + + if (entries.length >= normalizedMaxEntries) { + truncated = true; + break; + } + + entries.push({ + type: child.type as 'tree' | 'blob', + path: childPath, + name: child.name, + parentPath: currentPath, + depth: childDepth, + }); + } + + if (truncated) break; + } + } + + if (treeError) { + return { + content: [{ type: "text", text: `Failed to list tree: ${treeError}` }], + }; + } + + const sortedEntries = sortTreeEntries(entries); + return { + content: [{ + type: "text", + text: JSON.stringify({ + repo, ref, path: normalizedPath, + entries: sortedEntries, + totalReturned: sortedEntries.length, + truncated, + }), + }], + }; + } + ); + + server.registerTool( + "list_language_models", + { + description: dedent`Lists the available language models configured on the Sourcebot instance. Use this to discover which models can be specified when calling ask_codebase.`, + }, + async () => { + const models = await getConfiguredLanguageModelsInfo(); + return { content: [{ type: "text", text: JSON.stringify(models) }] }; + } + ); + + server.registerTool( + "ask_codebase", + { + description: dedent` + Ask a natural language question about the codebase. This tool uses an AI agent to autonomously search code, read files, and find symbol references/definitions to answer your question. + + The agent will: + - Analyze your question and determine what context it needs + - Search the codebase using multiple strategies (code search, symbol lookup, file reading) + - Synthesize findings into a comprehensive answer with code references + + Returns a detailed answer in markdown format with code references, plus a link to view the full research session (including all tool calls and reasoning) in the Sourcebot web UI. + + When using this in shared environments (e.g., Slack), you can set the visibility parameter to 'PUBLIC' to ensure everyone can access the chat link. + + This is a blocking operation that may take 30-60+ seconds for complex questions as the agent researches the codebase. + `, + inputSchema: z.object({ + query: z.string().describe("The query to ask about the codebase."), + repos: z.array(z.string()).optional().describe("The repositories accessible to the agent. If not provided, all repositories are accessible."), + languageModel: languageModelInfoSchema.optional().describe("The language model to use. If not provided, defaults to the first model in the config."), + visibility: z.enum(['PRIVATE', 'PUBLIC']).optional().describe("The visibility of the chat session. Defaults to PRIVATE for authenticated users."), + }), + }, + async (request) => { + const result = await askCodebase({ + query: request.query, + repos: request.repos, + languageModel: request.languageModel, + visibility: request.visibility as ChatVisibility | undefined, + }); + + if (isServiceError(result)) { + return { + content: [{ type: "text", text: `Failed to ask codebase: ${result.message}` }], + }; + } + + const formattedResponse = dedent` + ${result.answer} + + --- + **View full research session:** ${result.chatUrl} + **Model used:** ${result.languageModel.model} + `; + return { content: [{ type: "text", text: formattedResponse }] }; + } + ); + + return server; +} diff --git a/packages/web/src/features/mcp/types.ts b/packages/web/src/features/mcp/types.ts new file mode 100644 index 000000000..af60fd648 --- /dev/null +++ b/packages/web/src/features/mcp/types.ts @@ -0,0 +1,17 @@ + +export type TextContent = { type: "text", text: string }; + +export type ListTreeEntry = { + type: 'tree' | 'blob'; + path: string; + name: string; + parentPath: string; + depth: number; +}; + +export type ListTreeApiNode = { + type: 'tree' | 'blob'; + path: string; + name: string; + children: ListTreeApiNode[]; +}; \ No newline at end of file diff --git a/packages/web/src/features/mcp/utils.ts b/packages/web/src/features/mcp/utils.ts new file mode 100644 index 000000000..96ef5d568 --- /dev/null +++ b/packages/web/src/features/mcp/utils.ts @@ -0,0 +1,61 @@ +import { FileTreeNode } from "../git"; +import { ServiceError } from "@/lib/serviceError"; +import { ListTreeEntry } from "./types"; + +export const isServiceError = (data: unknown): data is ServiceError => { + return typeof data === 'object' && + data !== null && + 'statusCode' in data && + 'errorCode' in data && + 'message' in data; +} + +export class ServiceErrorException extends Error { + constructor(public readonly serviceError: ServiceError) { + super(JSON.stringify(serviceError)); + } +} + +export const normalizeTreePath = (path: string): string => { + const withoutLeading = path.replace(/^\/+/, ''); + return withoutLeading.replace(/\/+$/, ''); +} + +export const joinTreePath = (parentPath: string, name: string): string => { + if (!parentPath) { + return name; + } + return `${parentPath}/${name}`; +} + +export const buildTreeNodeIndex = (root: FileTreeNode): Map => { + const nodeIndex = new Map(); + + const visit = (node: FileTreeNode, currentPath: string) => { + nodeIndex.set(currentPath, node); + for (const child of node.children) { + visit(child, joinTreePath(currentPath, child.name)); + } + }; + + visit(root, ''); + return nodeIndex; +} + +export const sortTreeEntries = (entries: ListTreeEntry[]): ListTreeEntry[] => { + const collator = new Intl.Collator(undefined, { sensitivity: 'base' }); + + return [...entries].sort((a, b) => { + const parentCompare = collator.compare(a.parentPath, b.parentPath); + if (parentCompare !== 0) return parentCompare; + + if (a.type !== b.type) { + return a.type === 'tree' ? -1 : 1; + } + + const nameCompare = collator.compare(a.name, b.name); + if (nameCompare !== 0) return nameCompare; + + return collator.compare(a.path, b.path); + }); +} diff --git a/packages/web/src/features/searchAssist/actions.ts b/packages/web/src/features/searchAssist/actions.ts index d389306db..6db31a290 100644 --- a/packages/web/src/features/searchAssist/actions.ts +++ b/packages/web/src/features/searchAssist/actions.ts @@ -1,7 +1,7 @@ 'use server'; import { sew } from "@/actions"; -import { _getAISDKLanguageModelAndOptions, _getConfiguredLanguageModelsFull } from "@/features/chat/actions"; +import { getConfiguredLanguageModels, getAISDKLanguageModelAndOptions } from "../chat/utils.server"; import { ErrorCode } from "@/lib/errorCodes"; import { ServiceError } from "@/lib/serviceError"; import { withOptionalAuthV2 } from "@/withAuthV2"; @@ -26,7 +26,7 @@ ${SEARCH_SYNTAX_DESCRIPTION} export const translateSearchQuery = async ({ prompt }: { prompt: string }) => sew(() => withOptionalAuthV2(async () => { - const models = await _getConfiguredLanguageModelsFull(); + const models = await getConfiguredLanguageModels(); if (models.length === 0) { return { @@ -36,7 +36,7 @@ export const translateSearchQuery = async ({ prompt }: { prompt: string }) => se } satisfies ServiceError; } - const { model } = await _getAISDKLanguageModelAndOptions(models[0]); + const { model } = await getAISDKLanguageModelAndOptions(models[0]); const { object } = await generateObject({ model, diff --git a/packages/web/src/withAuthV2.test.ts b/packages/web/src/withAuthV2.test.ts index 1b9360057..16e923ab1 100644 --- a/packages/web/src/withAuthV2.test.ts +++ b/packages/web/src/withAuthV2.test.ts @@ -108,6 +108,51 @@ describe('getAuthenticatedUser', () => { }); }); + test('should return a user object if a valid Bearer token is present', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).not.toBeUndefined(); + expect(user?.id).toBe(userId); + expect(prisma.apiKey.update).toHaveBeenCalledWith({ + where: { + hash: 'apikey', + }, + data: { + lastUsedAt: expect.any(Date), + }, + }); + }); + + test('should return undefined if a Bearer token is present but the API key does not exist', async () => { + prisma.apiKey.findUnique.mockResolvedValue(null); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); + + test('should return undefined if a Bearer token is present but the user is not found', async () => { + prisma.user.findUnique.mockResolvedValue(null); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: 'test-user-id', + }); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); + test('should return undefined if no session or api key is present', async () => { const user = await getAuthenticatedUser(); expect(user).toBeUndefined(); @@ -385,6 +430,76 @@ describe('withAuthV2', () => { expect(result).toEqual(undefined); }); + test('should call the callback with the auth context object if a valid Bearer token is present and the user is a member of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid Bearer token is present and the user is a member of the organization with OWNER role', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER + }); + expect(result).toEqual(undefined); + }); + test('should return a service error if the user is a member of the organization but does not have a valid session', async () => { const userId = 'test-user-id'; prisma.user.findUnique.mockResolvedValue({ @@ -582,6 +697,76 @@ describe('withOptionalAuthV2', () => { expect(result).toEqual(undefined); }); + test('should call the callback with the auth context object if a valid Bearer token is present and the user is a member of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid Bearer token is present and the user is a member of the organization with OWNER role', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'Authorization': 'Bearer sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER + }); + expect(result).toEqual(undefined); + }); + test('should return a service error if the user is a member of the organization but does not have a valid session', async () => { const userId = 'test-user-id'; prisma.user.findUnique.mockResolvedValue({ diff --git a/packages/web/src/withAuthV2.ts b/packages/web/src/withAuthV2.ts index 88cb763b1..4a500a6e3 100644 --- a/packages/web/src/withAuthV2.ts +++ b/packages/web/src/withAuthV2.ts @@ -115,6 +115,26 @@ export const getAuthenticatedUser = async () => { return user ?? undefined; } + // If not, check for a Bearer token in the Authorization header. + const authorizationHeader = (await headers()).get("Authorization") ?? undefined; + if (authorizationHeader?.startsWith("Bearer ")) { + const bearerToken = authorizationHeader.slice(7); + const apiKey = await getVerifiedApiObject(bearerToken); + if (apiKey) { + const user = await __unsafePrisma.user.findUnique({ + where: { id: apiKey.createdById }, + include: { accounts: true }, + }); + if (user) { + await __unsafePrisma.apiKey.update({ + where: { hash: apiKey.hash }, + data: { lastUsedAt: new Date() }, + }); + return user; + } + } + } + // If not, check if we have a valid API key. const apiKeyString = (await headers()).get("X-Sourcebot-Api-Key") ?? undefined; if (apiKeyString) { diff --git a/yarn.lock b/yarn.lock index fd6fda477..7c2a2b6e1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3818,7 +3818,7 @@ __metadata: languageName: node linkType: hard -"@modelcontextprotocol/sdk@npm:^1.26.0": +"@modelcontextprotocol/sdk@npm:^1.26.0, @modelcontextprotocol/sdk@npm:^1.27.1": version: 1.27.1 resolution: "@modelcontextprotocol/sdk@npm:1.27.1" dependencies: @@ -8828,6 +8828,7 @@ __metadata: "@hookform/resolvers": "npm:^3.9.0" "@iconify/react": "npm:^5.1.0" "@iizukak/codemirror-lang-wgsl": "npm:^0.3.0" + "@modelcontextprotocol/sdk": "npm:^1.27.1" "@openrouter/ai-sdk-provider": "npm:^2.2.3" "@opentelemetry/api-logs": "npm:^0.203.0" "@opentelemetry/instrumentation": "npm:^0.203.0" From f2b5d1c85597e245c15cf6b589bbc23a5fe0d768 Mon Sep 17 00:00:00 2001 From: Brendan Kellam Date: Tue, 3 Mar 2026 20:18:00 -0800 Subject: [PATCH 2/5] add method to apiHandler --- packages/web/src/app/api/(server)/mcp/route.ts | 15 +++++++-------- packages/web/src/lib/apiHandler.ts | 3 ++- packages/web/src/lib/posthogEvents.ts | 1 + 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/packages/web/src/app/api/(server)/mcp/route.ts b/packages/web/src/app/api/(server)/mcp/route.ts index c387282a2..04d7b88b0 100644 --- a/packages/web/src/app/api/(server)/mcp/route.ts +++ b/packages/web/src/app/api/(server)/mcp/route.ts @@ -1,5 +1,3 @@ -'use server'; - import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/webStandardStreamableHttp.js'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { createMcpServer } from '@/features/mcp/server'; @@ -10,6 +8,7 @@ import { ErrorCode } from '@/lib/errorCodes'; import { StatusCodes } from 'http-status-codes'; import { NextRequest } from 'next/server'; import { sew } from '@/actions'; +import { apiHandler } from '@/lib/apiHandler'; // @see: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#session-management interface McpSession { @@ -24,7 +23,7 @@ const MCP_SESSION_ID_HEADER = 'MCP-Session-Id'; // Suitable for containerized/single-instance deployments. const sessions = new Map(); -export async function POST(request: NextRequest) { +export const POST = apiHandler(async (request: NextRequest) => { const response = await sew(() => withOptionalAuthV2(async ({ user }) => { const ownerId = user?.id ?? null; @@ -66,9 +65,9 @@ export async function POST(request: NextRequest) { } return response; -} +}); -export async function DELETE(request: NextRequest) { +export const DELETE = apiHandler(async (request: NextRequest) => { const result = await sew(() => withOptionalAuthV2(async ({ user }) => { const ownerId = user?.id ?? null; @@ -99,9 +98,9 @@ export async function DELETE(request: NextRequest) { } return result; -} +}); -export async function GET(request: NextRequest) { +export const GET = apiHandler(async (request: NextRequest) => { const result = await sew(() => withOptionalAuthV2(async ({ user }) => { const ownerId = user?.id ?? null; @@ -132,4 +131,4 @@ export async function GET(request: NextRequest) { } return result; -} +}); diff --git a/packages/web/src/lib/apiHandler.ts b/packages/web/src/lib/apiHandler.ts index e41a3ae21..65c76f116 100644 --- a/packages/web/src/lib/apiHandler.ts +++ b/packages/web/src/lib/apiHandler.ts @@ -43,10 +43,11 @@ export function apiHandler( const wrappedHandler = async (request: NextRequest, ...rest: unknown[]) => { if (track) { const path = request.nextUrl.pathname; + const method = request.method; const source = request.headers.get('X-Sourcebot-Client-Source') ?? 'unknown'; // Fire and forget - don't await to avoid blocking the request - captureEvent('api_request', { path, source }).catch(() => { + captureEvent('api_request', { path, method, source }).catch(() => { // Silently ignore tracking errors }); } diff --git a/packages/web/src/lib/posthogEvents.ts b/packages/web/src/lib/posthogEvents.ts index 01f86522f..1d7cc6779 100644 --- a/packages/web/src/lib/posthogEvents.ts +++ b/packages/web/src/lib/posthogEvents.ts @@ -284,6 +284,7 @@ export type PosthogEventMap = { api_request: { path: string; source: string; + method: string; }, } export type PosthogEvent = keyof PosthogEventMap; \ No newline at end of file From 316b6851c6f5f6473a931f2877a0b40c21d54db8 Mon Sep 17 00:00:00 2001 From: Brendan Kellam Date: Tue, 3 Mar 2026 20:21:33 -0800 Subject: [PATCH 3/5] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f151f5cd..62055d270 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Added support a MCP streamable http transport hosted at `/api/mcp`. [#976](https://github.com/sourcebot-dev/sourcebot/pull/976) + ## [4.13.2] - 2026-03-02 ### Changed From 425b7bf6829202b59bf92335ea44b8d901f5bda0 Mon Sep 17 00:00:00 2001 From: Brendan Kellam Date: Tue, 3 Mar 2026 20:29:18 -0800 Subject: [PATCH 4/5] feedback --- packages/web/src/features/mcp/askCodebase.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/web/src/features/mcp/askCodebase.ts b/packages/web/src/features/mcp/askCodebase.ts index a6543e79f..1d8af5aea 100644 --- a/packages/web/src/features/mcp/askCodebase.ts +++ b/packages/web/src/features/mcp/askCodebase.ts @@ -103,7 +103,10 @@ export const askCodebase = (params: AskCodebaseParams): Promise { const repoDB = await prisma.repo.findFirst({ - where: { name: repo }, + where: { + name: repo, + orgId: org.id, + }, }); if (!repoDB) { throw new ServiceErrorException({ From 06c1906617add41c817303e7e629f37d3b5cff56 Mon Sep 17 00:00:00 2001 From: Brendan Kellam Date: Tue, 3 Mar 2026 20:36:14 -0800 Subject: [PATCH 5/5] feedback --- packages/web/src/app/api/(server)/mcp/route.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/web/src/app/api/(server)/mcp/route.ts b/packages/web/src/app/api/(server)/mcp/route.ts index 04d7b88b0..971808b48 100644 --- a/packages/web/src/app/api/(server)/mcp/route.ts +++ b/packages/web/src/app/api/(server)/mcp/route.ts @@ -48,7 +48,12 @@ export const POST = apiHandler(async (request: NextRequest) => { onsessioninitialized: (newSessionId) => { sessions.set(newSessionId, { server: mcpServer, transport, ownerId }); }, - onsessionclosed: (closedSessionId) => { + onsessionclosed: async (closedSessionId) => { + const session = sessions.get(closedSessionId); + if (session) { + await session.server.close(); + await session.transport.close(); + } sessions.delete(closedSessionId); }, });