feat(*): add RAG support

This commit is contained in:
h
2026-01-25 16:44:59 +01:00
parent 5b1f50a6f6
commit a992e3f0c2
20 changed files with 1412 additions and 17 deletions

View File

@@ -0,0 +1,252 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { RAG } from '@convex-dev/rag';
import { v } from 'convex/values';
import { api, components } from './_generated/api';
import { action, mutation, query } from './_generated/server';
function createRagInstance(apiKey: string) {
const google = createGoogleGenerativeAI({ apiKey });
return new RAG(components.rag, {
textEmbeddingModel: google.embedding('text-embedding-004'),
embeddingDimension: 768
});
}
function buildNamespace(userId: string, dbName: string): string {
return `user_${userId}/${dbName}`;
}
export const createDatabase = mutation({
args: { userId: v.id('users'), name: v.string() },
returns: v.id('ragDatabases'),
handler: async (ctx, args) => {
const existing = await ctx.db
.query('ragDatabases')
.withIndex('by_user_id_and_name', (q) => q.eq('userId', args.userId).eq('name', args.name))
.unique();
if (existing) {
return existing._id;
}
return await ctx.db.insert('ragDatabases', {
userId: args.userId,
name: args.name,
createdAt: Date.now()
});
}
});
export const getDatabase = query({
args: { userId: v.id('users'), name: v.string() },
returns: v.union(
v.object({
_id: v.id('ragDatabases'),
_creationTime: v.number(),
userId: v.id('users'),
name: v.string(),
createdAt: v.number()
}),
v.null()
),
handler: async (ctx, args) => {
return await ctx.db
.query('ragDatabases')
.withIndex('by_user_id_and_name', (q) => q.eq('userId', args.userId).eq('name', args.name))
.unique();
}
});
export const getDatabaseById = query({
args: { ragDatabaseId: v.id('ragDatabases') },
returns: v.union(
v.object({
_id: v.id('ragDatabases'),
_creationTime: v.number(),
userId: v.id('users'),
name: v.string(),
createdAt: v.number()
}),
v.null()
),
handler: async (ctx, args) => {
return await ctx.db.get(args.ragDatabaseId);
}
});
export const listDatabases = query({
args: { userId: v.id('users') },
returns: v.array(
v.object({
_id: v.id('ragDatabases'),
_creationTime: v.number(),
userId: v.id('users'),
name: v.string(),
createdAt: v.number()
})
),
handler: async (ctx, args) => {
return await ctx.db
.query('ragDatabases')
.withIndex('by_user_id', (q) => q.eq('userId', args.userId))
.collect();
}
});
export const deleteDatabase = action({
args: { userId: v.id('users'), name: v.string(), apiKey: v.string() },
returns: v.boolean(),
handler: async (ctx, args) => {
const db = await ctx.runQuery(api.rag.getDatabase, {
userId: args.userId,
name: args.name
});
if (!db) {
return false;
}
const connections = await ctx.runQuery(api.ragConnections.getByRagDatabaseId, {
ragDatabaseId: db._id
});
for (const conn of connections) {
await ctx.runMutation(api.ragConnections.deleteConnection, {
connectionId: conn._id
});
}
await ctx.runMutation(api.rag.deleteDatabaseRecord, {
ragDatabaseId: db._id
});
return true;
}
});
export const deleteDatabaseRecord = mutation({
args: { ragDatabaseId: v.id('ragDatabases') },
returns: v.null(),
handler: async (ctx, args) => {
await ctx.db.delete(args.ragDatabaseId);
return null;
}
});
export const addContent = action({
args: {
userId: v.id('users'),
ragDatabaseId: v.id('ragDatabases'),
apiKey: v.string(),
text: v.string(),
key: v.optional(v.string())
},
returns: v.null(),
handler: async (ctx, args) => {
const db = await ctx.runQuery(api.rag.getDatabaseById, {
ragDatabaseId: args.ragDatabaseId
});
if (!db) {
throw new Error('RAG database not found');
}
const rag = createRagInstance(args.apiKey);
const namespace = buildNamespace(args.userId, db.name);
await rag.add(ctx, {
namespace,
text: args.text,
key: args.key
});
return null;
}
});
export const search = action({
args: {
userId: v.id('users'),
dbName: v.string(),
apiKey: v.string(),
query: v.string(),
limit: v.optional(v.number())
},
returns: v.object({
text: v.string(),
results: v.array(
v.object({
text: v.string(),
score: v.number()
})
)
}),
handler: async (ctx, args) => {
const rag = createRagInstance(args.apiKey);
const namespace = buildNamespace(args.userId, args.dbName);
const { results, text } = await rag.search(ctx, {
namespace,
query: args.query,
limit: args.limit ?? 5
});
return {
text: text ?? '',
results: results.map((r) => ({
text: r.content.map((c) => c.text).join('\n'),
score: r.score
}))
};
}
});
export const searchMultiple = action({
args: {
userId: v.id('users'),
dbNames: v.array(v.string()),
apiKey: v.string(),
query: v.string(),
limit: v.optional(v.number())
},
returns: v.object({
text: v.string(),
results: v.array(
v.object({
text: v.string(),
score: v.number(),
dbName: v.string()
})
)
}),
handler: async (ctx, args) => {
const rag = createRagInstance(args.apiKey);
const allResults: Array<{ text: string; score: number; dbName: string }> = [];
for (const dbName of args.dbNames) {
const namespace = buildNamespace(args.userId, dbName);
const { results } = await rag.search(ctx, {
namespace,
query: args.query,
limit: args.limit ?? 5
});
for (const r of results) {
allResults.push({
text: r.content.map((c) => c.text).join('\n'),
score: r.score,
dbName
});
}
}
allResults.sort((a, b) => b.score - a.score);
const topResults = allResults.slice(0, args.limit ?? 5);
const combinedText = topResults.map((r) => r.text).join('\n\n---\n\n');
return {
text: combinedText,
results: topResults
};
}
});