import { createGoogleGenerativeAI } from '@ai-sdk/google'; import { RAG } from '@convex-dev/rag'; import { v } from 'convex/values'; import { api, components } from './_generated/api'; import { action, mutation, query } from './_generated/server'; function createRagInstance(apiKey: string) { const google = createGoogleGenerativeAI({ apiKey }); return new RAG(components.rag, { textEmbeddingModel: google.embedding('gemini-embedding-001'), embeddingDimension: 768 }); } function buildNamespace(userId: string, dbName: string): string { return `user_${userId}/${dbName}`; } export const createDatabase = mutation({ args: { userId: v.id('users'), name: v.string() }, returns: v.id('ragDatabases'), handler: async (ctx, args) => { const existing = await ctx.db .query('ragDatabases') .withIndex('by_user_id_and_name', (q) => q.eq('userId', args.userId).eq('name', args.name)) .unique(); if (existing) { return existing._id; } return await ctx.db.insert('ragDatabases', { userId: args.userId, name: args.name, createdAt: Date.now() }); } }); export const getDatabase = query({ args: { userId: v.id('users'), name: v.string() }, returns: v.union( v.object({ _id: v.id('ragDatabases'), _creationTime: v.number(), userId: v.id('users'), name: v.string(), createdAt: v.number() }), v.null() ), handler: async (ctx, args) => { return await ctx.db .query('ragDatabases') .withIndex('by_user_id_and_name', (q) => q.eq('userId', args.userId).eq('name', args.name)) .unique(); } }); export const getDatabaseById = query({ args: { ragDatabaseId: v.id('ragDatabases') }, returns: v.union( v.object({ _id: v.id('ragDatabases'), _creationTime: v.number(), userId: v.id('users'), name: v.string(), createdAt: v.number() }), v.null() ), handler: async (ctx, args) => { return await ctx.db.get(args.ragDatabaseId); } }); export const listDatabases = query({ args: { userId: v.id('users') }, returns: v.array( v.object({ _id: v.id('ragDatabases'), _creationTime: v.number(), userId: v.id('users'), name: v.string(), createdAt: v.number() }) ), handler: async (ctx, args) => { return await ctx.db .query('ragDatabases') .withIndex('by_user_id', (q) => q.eq('userId', args.userId)) .collect(); } }); export const deleteDatabase = action({ args: { userId: v.id('users'), name: v.string(), apiKey: v.string() }, returns: v.boolean(), handler: async (ctx, args) => { const db = await ctx.runQuery(api.rag.getDatabase, { userId: args.userId, name: args.name }); if (!db) { return false; } const connections = await ctx.runQuery(api.ragConnections.getByRagDatabaseId, { ragDatabaseId: db._id }); for (const conn of connections) { await ctx.runMutation(api.ragConnections.deleteConnection, { connectionId: conn._id }); } await ctx.runMutation(api.rag.deleteDatabaseRecord, { ragDatabaseId: db._id }); return true; } }); export const deleteDatabaseRecord = mutation({ args: { ragDatabaseId: v.id('ragDatabases') }, returns: v.null(), handler: async (ctx, args) => { await ctx.db.delete(args.ragDatabaseId); return null; } }); export const addContent = action({ args: { userId: v.id('users'), ragDatabaseId: v.id('ragDatabases'), apiKey: v.string(), text: v.string(), key: v.optional(v.string()) }, returns: v.null(), handler: async (ctx, args) => { const db = await ctx.runQuery(api.rag.getDatabaseById, { ragDatabaseId: args.ragDatabaseId }); if (!db) { throw new Error('RAG database not found'); } const rag = createRagInstance(args.apiKey); const namespace = buildNamespace(args.userId, db.name); await rag.add(ctx, { namespace, text: args.text, key: args.key }); return null; } }); export const search = action({ args: { userId: v.id('users'), dbName: v.string(), apiKey: v.string(), query: v.string(), limit: v.optional(v.number()) }, returns: v.object({ text: v.string(), results: v.array( v.object({ text: v.string(), score: v.number() }) ) }), handler: async (ctx, args) => { const rag = createRagInstance(args.apiKey); const namespace = buildNamespace(args.userId, args.dbName); const { results, text } = await rag.search(ctx, { namespace, query: args.query, limit: args.limit ?? 5 }); return { text: text ?? '', results: results.map((r) => ({ text: r.content.map((c) => c.text).join('\n'), score: r.score })) }; } }); export const searchMultiple = action({ args: { userId: v.id('users'), dbNames: v.array(v.string()), apiKey: v.string(), query: v.string(), limit: v.optional(v.number()) }, returns: v.object({ text: v.string(), results: v.array( v.object({ text: v.string(), score: v.number(), dbName: v.string() }) ) }), handler: async (ctx, args) => { const rag = createRagInstance(args.apiKey); const allResults: Array<{ text: string; score: number; dbName: string }> = []; for (const dbName of args.dbNames) { const namespace = buildNamespace(args.userId, dbName); const { results } = await rag.search(ctx, { namespace, query: args.query, limit: args.limit ?? 5 }); for (const r of results) { allResults.push({ text: r.content.map((c) => c.text).join('\n'), score: r.score, dbName }); } } allResults.sort((a, b) => b.score - a.score); const topResults = allResults.slice(0, args.limit ?? 5); const combinedText = topResults.map((r) => r.text).join('\n\n---\n\n'); return { text: combinedText, results: topResults }; } });