Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Middleware auth, traces page efficiency #184

Merged
merged 4 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions app-server/src/db/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,31 @@ pub async fn count_traces(
date_range: &Option<DateRange>,
text_search_filter: Option<String>,
) -> Result<i64> {
let mut query = QueryBuilder::<Postgres>::new("WITH ");
add_traces_info_expression(&mut query, date_range, project_id)?;
let mut query = QueryBuilder::<Postgres>::new(
"WITH traces_info AS (
SELECT
id,
start_time,
end_time,
version,
release,
user_id,
session_id,
metadata,
project_id,
input_token_count,
output_token_count,
total_token_count,
input_cost,
output_cost,
cost,
success,
trace_type,
EXTRACT(EPOCH FROM (end_time - start_time)) as latency,
CASE WHEN success = true THEN 'Success' ELSE 'Failed' END status
FROM traces
WHERE start_time IS NOT NULL AND end_time IS NOT NULL AND trace_type = 'DEFAULT')",
);
query.push(
"
SELECT
Expand All @@ -471,6 +494,7 @@ pub async fn count_traces(
}
query.push(" WHERE project_id = ");
query.push_bind(project_id);
add_date_range_to_query(&mut query, date_range, "start_time", Some("end_time"))?;

add_filters_to_traces_query(&mut query, &filters);

Expand All @@ -483,22 +507,6 @@ pub async fn count_traces(
Ok(count)
}

/// `count_traces` with filters adds a lot of information to the query and joins on the events (in order to filter)
/// This function is a simpler version of `count_traces` that only counts the traces without any additional information
/// and is more efficient.
pub async fn count_all_traces_in_project(pool: &PgPool, project_id: Uuid) -> Result<i64> {
let count = sqlx::query_as::<_, TotalCount>(
"SELECT COUNT(*) as total_count
FROM traces
WHERE project_id = $1",
)
.bind(project_id)
.fetch_one(pool)
.await?;

Ok(count.total_count)
}

pub async fn get_single_trace(pool: &PgPool, id: Uuid) -> Result<Trace> {
let trace = sqlx::query_as::<_, Trace>(
"SELECT
Expand Down
11 changes: 3 additions & 8 deletions app-server/src/routes/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,9 @@ pub async fn get_traces(
)
.await
.unwrap_or(0) as u64;
let any_in_project = if total_count == 0 {
db::trace::count_all_traces_in_project(&db.pool, project_id)
.await
.unwrap_or(1)
> 0
} else {
true
};
// this is checked in the frontend, and we temporarily return true here,
// while we migrate other `PaginatedGet` queries to drizzle
let any_in_project = true;
(total_count, any_in_project)
});

Expand Down
1 change: 1 addition & 0 deletions app-server/src/traces/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ impl Span {
.get("SpanAttributes.LLM_PROMPTS.0.content")
.is_some()
{
// handling the LiteLLM auto-instrumentation
let input_messages = input_chat_messages_from_prompt_content(
&attributes,
"SpanAttributes.LLM_PROMPTS",
Expand Down
21 changes: 21 additions & 0 deletions frontend/app/api/auth/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { isCurrentUserMemberOfProject } from "@/lib/db/utils";
import { NextRequest, NextResponse } from "next/server";

export async function POST(req: NextRequest) {

// check if bearer token is valid
const token = req.headers.get('Authorization')?.split(' ')[1];
if (!token || token !== process.env.SHARED_SECRET_TOKEN) {
return NextResponse.json({ message: 'Unauthorized' }, { status: 401 });
}

const body = await req.json();

const { projectId } = body;

if (!await isCurrentUserMemberOfProject(projectId)) {
return NextResponse.json({ message: 'Unauthorized' }, { status: 401 });
}

return NextResponse.json({ message: 'Authorized' });
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@ import { db } from '@/lib/db/drizzle';
import { datasetDatapoints } from '@/lib/db/migrations/schema';
import { and, eq } from 'drizzle-orm';
import { z } from 'zod';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';


export async function POST(
req: Request,
{
params
}: { params: { projectId: string; datasetId: string; datapointId: string } }
): Promise<Response> {
const projectId = params.projectId;

if (!await isCurrentUserMemberOfProject(projectId)) {
return new Response('Unauthorized', { status: 401 });
}

const datasetId = params.datasetId;
const datapointId = params.datapointId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ import { authOptions } from '@/lib/auth';
import { fetcher } from '@/lib/utils';
import { NextRequest } from 'next/server';
import { db } from '@/lib/db/drizzle';
import { datasetDatapoints } from '@/lib/db/migrations/schema';
import { datasetDatapoints, datapointToSpan } from '@/lib/db/migrations/schema';
import { and, inArray, eq } from 'drizzle-orm';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';

import { z } from 'zod';

export async function GET(
req: NextRequest,
Expand All @@ -27,28 +28,64 @@ export async function GET(
);
}

const CreateDatapointsSchema = z.object({
datapoints: z.array(z.object({
data: z.unknown(),
target: z.any().optional(),
metadata: z.record(z.any()).optional(),
})),
sourceSpanId: z.string().optional(),
});

export async function POST(
req: Request,
{ params }: { params: { projectId: string; datasetId: string } }
): Promise<Response> {
const projectId = params.projectId;
const datasetId = params.datasetId;
const session = await getServerSession(authOptions);
const user = session!.user;



const body = await req.json();

return await fetcher(
`/projects/${projectId}/datasets/${datasetId}/datapoints`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${user.apiKey}`
},
body: JSON.stringify(body)
}
);
// Validate request body
const parseResult = CreateDatapointsSchema.required().safeParse(body);
if (!parseResult.success) {
return new Response(
JSON.stringify({
error: "Invalid request body",
details: parseResult.error.issues
}),
{ status: 400 }
);
}

const { datapoints, sourceSpanId } = parseResult.data;

const res = await db.insert(datasetDatapoints).values(
datapoints.map((datapoint) => ({
...datapoint,
data: datapoint.data,
createdAt: new Date().toUTCString(),
datasetId
}))
).returning();

if (sourceSpanId && res.length > 0) {
await db.insert(datapointToSpan).values(
res.map((datapoint) => ({
spanId: sourceSpanId,
datapointId: datapoint.id,
projectId,
}))
).returning();
}

if (res.length === 0) {
return new Response('Error creating datasetDatapoints', { status: 500 });
}

return new Response('datasetDatapoints created successfully', { status: 200 });
}

export async function DELETE(
Expand All @@ -58,9 +95,7 @@ export async function DELETE(
const projectId = params.projectId;
const datasetId = params.datasetId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const { searchParams } = new URL(req.url);
const datapointIds = searchParams.get('datapointIds')?.split(',');
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { db } from '@/lib/db/drizzle';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';

import { asc, and, eq } from 'drizzle-orm';
import { datasetDatapoints, datasets, evaluationResults, evaluations } from '@/lib/db/migrations/schema';
import { evaluationScores } from '@/lib/db/migrations/schema';
Expand All @@ -13,9 +13,7 @@ export async function GET(
params: { projectId: string; datasetId: string; };
}
): Promise<Response> {
if (!(await isCurrentUserMemberOfProject(params.projectId))) {
return Response.json({ error: 'Unauthorized' }, { status: 401 });
}


const projectId = params.projectId;
const datasetId = params.datasetId;
Expand Down
6 changes: 2 additions & 4 deletions frontend/app/api/projects/[projectId]/datasets/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { getServerSession } from 'next-auth';
import { authOptions } from '@/lib/auth';
import { fetcher } from '@/lib/utils';
import { datasets } from '@/lib/db/migrations/schema';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';

import { eq, inArray } from 'drizzle-orm';
import { and } from 'drizzle-orm';
import { db } from '@/lib/db/drizzle';
Expand Down Expand Up @@ -55,9 +55,7 @@ export async function DELETE(
): Promise<Response> {
const projectId = params.projectId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const { searchParams } = new URL(req.url);
const datasetIds = searchParams.get('datasetIds')?.split(',');
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { db } from '@/lib/db/drizzle';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';

import { asc, and, eq } from 'drizzle-orm';
import { evaluationResults, evaluations } from '@/lib/db/migrations/schema';
import { evaluationScores } from '@/lib/db/migrations/schema';
Expand All @@ -14,9 +14,7 @@ export async function GET(
params: { projectId: string; evaluationId: string; };
}
): Promise<Response> {
if (!(await isCurrentUserMemberOfProject(params.projectId))) {
return Response.json({ error: 'Unauthorized' }, { status: 401 });
}


const projectId = params.projectId;
const evaluationId = params.evaluationId;
Expand Down
8 changes: 2 additions & 6 deletions frontend/app/api/projects/[projectId]/evaluations/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ export async function GET(
): Promise<Response> {
const projectId = params.projectId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const result = await paginatedGet<any, Evaluation>({
table: evaluations,
Expand All @@ -29,9 +27,7 @@ export async function DELETE(
): Promise<Response> {
const projectId = params.projectId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const { searchParams } = new URL(req.url);
const evaluationIds = searchParams.get('evaluationIds')?.split(',');
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { db } from '@/lib/db/drizzle';
import { labelClassesForPath } from '@/lib/db/migrations/schema';
import { eq } from 'drizzle-orm';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';


export async function DELETE(
req: Request,
Expand All @@ -13,9 +13,7 @@ export async function DELETE(
const projectId = params.projectId;
const id = params.id;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const affectedRows = await db.delete(labelClassesForPath).where(eq(labelClassesForPath.id, id)).returning();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { getServerSession } from 'next-auth';
import { authOptions } from '@/lib/auth';
import { fetcher } from '@/lib/utils';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';

import { labelClasses } from '@/lib/db/migrations/schema';
import { db } from '@/lib/db/drizzle';
import { and, eq } from 'drizzle-orm';
Expand Down Expand Up @@ -34,9 +34,7 @@ export async function DELETE(
const projectId = params.projectId;
const labelClassId = params.labelClassId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const affectedRows = await db.delete(labelClasses).where(
and(
Expand Down
10 changes: 3 additions & 7 deletions frontend/app/api/projects/[projectId]/label-classes/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { db } from '@/lib/db/drizzle';
import { labelClasses } from '@/lib/db/migrations/schema';
import { eq, desc } from 'drizzle-orm';
import { isCurrentUserMemberOfProject } from '@/lib/db/utils';



export async function GET(
Expand All @@ -10,9 +10,7 @@ export async function GET(
): Promise<Response> {
const projectId = params.projectId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const res = await db
.select()
Expand All @@ -30,9 +28,7 @@ export async function POST(
): Promise<Response> {
const projectId = params.projectId;

if (!(await isCurrentUserMemberOfProject(projectId))) {
return new Response(JSON.stringify({ error: "User is not a member of the project" }), { status: 403 });
}


const body = await req.json();

Expand Down
Loading