Skip to content

Commit b8081ad

Browse files
authored
Make it easy to point zeta2 at ollama (#42329)
I wanted to be able to work offline, so I made it a little bit more convenient to point zeta2 at ollama. * For zeta2, don't require that request ids be UUIDs * Add an env var `ZED_ZETA2_OLLAMA` that sets the edit prediction URL and model id to work w/ ollama. Release Notes: - N/A
1 parent 35c5815 commit b8081ad

File tree

4 files changed

+39
-25
lines changed

4 files changed

+39
-25
lines changed

crates/cloud_llm_client/src/cloud_llm_client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,13 @@ pub struct PredictEditsGitInfo {
183183

184184
#[derive(Debug, Clone, Serialize, Deserialize)]
185185
pub struct PredictEditsResponse {
186-
pub request_id: Uuid,
186+
pub request_id: String,
187187
pub output_excerpt: String,
188188
}
189189

190190
#[derive(Debug, Clone, Serialize, Deserialize)]
191191
pub struct AcceptEditPredictionBody {
192-
pub request_id: Uuid,
192+
pub request_id: String,
193193
}
194194

195195
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]

crates/zeta/src/zeta.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ impl Zeta {
652652
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
653653
.body(
654654
serde_json::to_string(&AcceptEditPredictionBody {
655-
request_id: request_id.0,
655+
request_id: request_id.0.to_string(),
656656
})?
657657
.into(),
658658
)?)
@@ -735,6 +735,8 @@ impl Zeta {
735735
return anyhow::Ok(None);
736736
};
737737

738+
let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?;
739+
738740
let edit_preview = edit_preview.await;
739741

740742
Ok(Some(EditPrediction {
@@ -2162,7 +2164,7 @@ mod tests {
21622164
.status(200)
21632165
.body(
21642166
serde_json::to_string(&PredictEditsResponse {
2165-
request_id: Uuid::new_v4(),
2167+
request_id: Uuid::new_v4().to_string(),
21662168
output_excerpt: completion_response.lock().clone(),
21672169
})
21682170
.unwrap()

crates/zeta2/src/prediction.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
use std::{ops::Range, sync::Arc};
22

3-
use gpui::{AsyncApp, Entity};
3+
use gpui::{AsyncApp, Entity, SharedString};
44
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
5-
use uuid::Uuid;
65

7-
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
8-
pub struct EditPredictionId(pub Uuid);
9-
10-
impl Into<Uuid> for EditPredictionId {
11-
fn into(self) -> Uuid {
12-
self.0
13-
}
14-
}
6+
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
7+
pub struct EditPredictionId(pub SharedString);
158

169
impl From<EditPredictionId> for gpui::ElementId {
1710
fn from(value: EditPredictionId) -> Self {
18-
gpui::ElementId::Uuid(value.0)
11+
gpui::ElementId::Name(value.0)
1912
}
2013
}
2114

@@ -149,7 +142,7 @@ mod tests {
149142
.await;
150143

151144
let prediction = EditPrediction {
152-
id: EditPredictionId(Uuid::new_v4()),
145+
id: EditPredictionId("prediction-1".into()),
153146
edits,
154147
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
155148
buffer: buffer.clone(),

crates/zeta2/src/zeta2.rs

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ use project::Project;
3030
use release_channel::AppVersion;
3131
use serde::de::DeserializeOwned;
3232
use std::collections::{VecDeque, hash_map};
33-
use uuid::Uuid;
3433

34+
use std::env;
3535
use std::ops::Range;
3636
use std::path::Path;
3737
use std::str::FromStr as _;
@@ -88,8 +88,24 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
8888
buffer_change_grouping_interval: Duration::from_secs(1),
8989
};
9090

91-
static MODEL_ID: LazyLock<String> =
92-
LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()));
91+
static USE_OLLAMA: LazyLock<bool> =
92+
LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
93+
static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
94+
env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
95+
"qwen3-coder:30b".to_string()
96+
} else {
97+
"yqvev8r3".to_string()
98+
})
99+
});
100+
static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
101+
env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
102+
if *USE_OLLAMA {
103+
Some("http://localhost:11434/v1/chat/completions".into())
104+
} else {
105+
None
106+
}
107+
})
108+
});
93109

94110
pub struct Zeta2FeatureFlag;
95111

@@ -567,13 +583,13 @@ impl Zeta {
567583
let Some(prediction) = project_state.current_prediction.take() else {
568584
return;
569585
};
570-
let request_id = prediction.prediction.id.into();
586+
let request_id = prediction.prediction.id.to_string();
571587

572588
let client = self.client.clone();
573589
let llm_token = self.llm_token.clone();
574590
let app_version = AppVersion::global(cx);
575591
cx.spawn(async move |this, cx| {
576-
let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
592+
let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
577593
http_client::Url::parse(&predict_edits_url)?
578594
} else {
579595
client
@@ -585,7 +601,10 @@ impl Zeta {
585601
.background_spawn(Self::send_api_request::<()>(
586602
move |builder| {
587603
let req = builder.uri(url.as_ref()).body(
588-
serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
604+
serde_json::to_string(&AcceptEditPredictionBody {
605+
request_id: request_id.clone(),
606+
})?
607+
.into(),
589608
);
590609
Ok(req?)
591610
},
@@ -875,7 +894,7 @@ impl Zeta {
875894
None
876895
};
877896

878-
if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
897+
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
879898
if let Some(debug_response_tx) = debug_response_tx {
880899
debug_response_tx
881900
.send((Err("Request skipped".to_string()), TimeDelta::zero()))
@@ -923,7 +942,7 @@ impl Zeta {
923942
}
924943

925944
let (res, usage) = response?;
926-
let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
945+
let request_id = EditPredictionId(res.id.clone().into());
927946
let Some(output_text) = text_from_response(res) else {
928947
return Ok((None, usage))
929948
};
@@ -980,7 +999,7 @@ impl Zeta {
980999
app_version: SemanticVersion,
9811000
request: open_ai::Request,
9821001
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
983-
let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
1002+
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
9841003
http_client::Url::parse(&predict_edits_url)?
9851004
} else {
9861005
client

0 commit comments

Comments
 (0)