feat(*): add RAG support
This commit is contained in:
252
frontend/src/lib/convex/rag.ts
Normal file
252
frontend/src/lib/convex/rag.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { RAG } from '@convex-dev/rag';
|
||||
import { v } from 'convex/values';
|
||||
import { api, components } from './_generated/api';
|
||||
import { action, mutation, query } from './_generated/server';
|
||||
|
||||
function createRagInstance(apiKey: string) {
|
||||
const google = createGoogleGenerativeAI({ apiKey });
|
||||
return new RAG(components.rag, {
|
||||
textEmbeddingModel: google.embedding('text-embedding-004'),
|
||||
embeddingDimension: 768
|
||||
});
|
||||
}
|
||||
|
||||
function buildNamespace(userId: string, dbName: string): string {
|
||||
return `user_${userId}/${dbName}`;
|
||||
}
|
||||
|
||||
export const createDatabase = mutation({
|
||||
args: { userId: v.id('users'), name: v.string() },
|
||||
returns: v.id('ragDatabases'),
|
||||
handler: async (ctx, args) => {
|
||||
const existing = await ctx.db
|
||||
.query('ragDatabases')
|
||||
.withIndex('by_user_id_and_name', (q) => q.eq('userId', args.userId).eq('name', args.name))
|
||||
.unique();
|
||||
|
||||
if (existing) {
|
||||
return existing._id;
|
||||
}
|
||||
|
||||
return await ctx.db.insert('ragDatabases', {
|
||||
userId: args.userId,
|
||||
name: args.name,
|
||||
createdAt: Date.now()
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
export const getDatabase = query({
|
||||
args: { userId: v.id('users'), name: v.string() },
|
||||
returns: v.union(
|
||||
v.object({
|
||||
_id: v.id('ragDatabases'),
|
||||
_creationTime: v.number(),
|
||||
userId: v.id('users'),
|
||||
name: v.string(),
|
||||
createdAt: v.number()
|
||||
}),
|
||||
v.null()
|
||||
),
|
||||
handler: async (ctx, args) => {
|
||||
return await ctx.db
|
||||
.query('ragDatabases')
|
||||
.withIndex('by_user_id_and_name', (q) => q.eq('userId', args.userId).eq('name', args.name))
|
||||
.unique();
|
||||
}
|
||||
});
|
||||
|
||||
export const getDatabaseById = query({
|
||||
args: { ragDatabaseId: v.id('ragDatabases') },
|
||||
returns: v.union(
|
||||
v.object({
|
||||
_id: v.id('ragDatabases'),
|
||||
_creationTime: v.number(),
|
||||
userId: v.id('users'),
|
||||
name: v.string(),
|
||||
createdAt: v.number()
|
||||
}),
|
||||
v.null()
|
||||
),
|
||||
handler: async (ctx, args) => {
|
||||
return await ctx.db.get(args.ragDatabaseId);
|
||||
}
|
||||
});
|
||||
|
||||
export const listDatabases = query({
|
||||
args: { userId: v.id('users') },
|
||||
returns: v.array(
|
||||
v.object({
|
||||
_id: v.id('ragDatabases'),
|
||||
_creationTime: v.number(),
|
||||
userId: v.id('users'),
|
||||
name: v.string(),
|
||||
createdAt: v.number()
|
||||
})
|
||||
),
|
||||
handler: async (ctx, args) => {
|
||||
return await ctx.db
|
||||
.query('ragDatabases')
|
||||
.withIndex('by_user_id', (q) => q.eq('userId', args.userId))
|
||||
.collect();
|
||||
}
|
||||
});
|
||||
|
||||
export const deleteDatabase = action({
|
||||
args: { userId: v.id('users'), name: v.string(), apiKey: v.string() },
|
||||
returns: v.boolean(),
|
||||
handler: async (ctx, args) => {
|
||||
const db = await ctx.runQuery(api.rag.getDatabase, {
|
||||
userId: args.userId,
|
||||
name: args.name
|
||||
});
|
||||
|
||||
if (!db) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const connections = await ctx.runQuery(api.ragConnections.getByRagDatabaseId, {
|
||||
ragDatabaseId: db._id
|
||||
});
|
||||
|
||||
for (const conn of connections) {
|
||||
await ctx.runMutation(api.ragConnections.deleteConnection, {
|
||||
connectionId: conn._id
|
||||
});
|
||||
}
|
||||
|
||||
await ctx.runMutation(api.rag.deleteDatabaseRecord, {
|
||||
ragDatabaseId: db._id
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
});
|
||||
|
||||
export const deleteDatabaseRecord = mutation({
|
||||
args: { ragDatabaseId: v.id('ragDatabases') },
|
||||
returns: v.null(),
|
||||
handler: async (ctx, args) => {
|
||||
await ctx.db.delete(args.ragDatabaseId);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
|
||||
export const addContent = action({
|
||||
args: {
|
||||
userId: v.id('users'),
|
||||
ragDatabaseId: v.id('ragDatabases'),
|
||||
apiKey: v.string(),
|
||||
text: v.string(),
|
||||
key: v.optional(v.string())
|
||||
},
|
||||
returns: v.null(),
|
||||
handler: async (ctx, args) => {
|
||||
const db = await ctx.runQuery(api.rag.getDatabaseById, {
|
||||
ragDatabaseId: args.ragDatabaseId
|
||||
});
|
||||
|
||||
if (!db) {
|
||||
throw new Error('RAG database not found');
|
||||
}
|
||||
|
||||
const rag = createRagInstance(args.apiKey);
|
||||
const namespace = buildNamespace(args.userId, db.name);
|
||||
|
||||
await rag.add(ctx, {
|
||||
namespace,
|
||||
text: args.text,
|
||||
key: args.key
|
||||
});
|
||||
|
||||
return null;
|
||||
}
|
||||
});
|
||||
|
||||
export const search = action({
|
||||
args: {
|
||||
userId: v.id('users'),
|
||||
dbName: v.string(),
|
||||
apiKey: v.string(),
|
||||
query: v.string(),
|
||||
limit: v.optional(v.number())
|
||||
},
|
||||
returns: v.object({
|
||||
text: v.string(),
|
||||
results: v.array(
|
||||
v.object({
|
||||
text: v.string(),
|
||||
score: v.number()
|
||||
})
|
||||
)
|
||||
}),
|
||||
handler: async (ctx, args) => {
|
||||
const rag = createRagInstance(args.apiKey);
|
||||
const namespace = buildNamespace(args.userId, args.dbName);
|
||||
|
||||
const { results, text } = await rag.search(ctx, {
|
||||
namespace,
|
||||
query: args.query,
|
||||
limit: args.limit ?? 5
|
||||
});
|
||||
|
||||
return {
|
||||
text: text ?? '',
|
||||
results: results.map((r) => ({
|
||||
text: r.content.map((c) => c.text).join('\n'),
|
||||
score: r.score
|
||||
}))
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
export const searchMultiple = action({
|
||||
args: {
|
||||
userId: v.id('users'),
|
||||
dbNames: v.array(v.string()),
|
||||
apiKey: v.string(),
|
||||
query: v.string(),
|
||||
limit: v.optional(v.number())
|
||||
},
|
||||
returns: v.object({
|
||||
text: v.string(),
|
||||
results: v.array(
|
||||
v.object({
|
||||
text: v.string(),
|
||||
score: v.number(),
|
||||
dbName: v.string()
|
||||
})
|
||||
)
|
||||
}),
|
||||
handler: async (ctx, args) => {
|
||||
const rag = createRagInstance(args.apiKey);
|
||||
const allResults: Array<{ text: string; score: number; dbName: string }> = [];
|
||||
|
||||
for (const dbName of args.dbNames) {
|
||||
const namespace = buildNamespace(args.userId, dbName);
|
||||
const { results } = await rag.search(ctx, {
|
||||
namespace,
|
||||
query: args.query,
|
||||
limit: args.limit ?? 5
|
||||
});
|
||||
|
||||
for (const r of results) {
|
||||
allResults.push({
|
||||
text: r.content.map((c) => c.text).join('\n'),
|
||||
score: r.score,
|
||||
dbName
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
allResults.sort((a, b) => b.score - a.score);
|
||||
const topResults = allResults.slice(0, args.limit ?? 5);
|
||||
const combinedText = topResults.map((r) => r.text).join('\n\n---\n\n');
|
||||
|
||||
return {
|
||||
text: combinedText,
|
||||
results: topResults
|
||||
};
|
||||
}
|
||||
});
|
||||
Reference in New Issue
Block a user