Skip to content

Commit 07d9898

Browse files
Make the edit prediction status bar menu work correctly when using sweep (zed-industries#43203)
Release Notes: - N/A --------- Co-authored-by: Ben Kunkle <ben@zed.dev>
1 parent 8bbd101 commit 07d9898

File tree

3 files changed

+80
-35
lines changed

3 files changed

+80
-35
lines changed

crates/edit_prediction_button/src/edit_prediction_button.rs

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ impl Render for EditPredictionButton {
8383

8484
let all_language_settings = all_language_settings(None, cx);
8585

86-
match &all_language_settings.edit_predictions.provider {
87-
EditPredictionProvider::None => div().hidden(),
88-
86+
match all_language_settings.edit_predictions.provider {
8987
EditPredictionProvider::Copilot => {
9088
let Some(copilot) = Copilot::global(cx) else {
9189
return div().hidden();
@@ -302,23 +300,23 @@ impl Render for EditPredictionButton {
302300
.with_handle(self.popover_menu_handle.clone()),
303301
)
304302
}
305-
EditPredictionProvider::Experimental(provider_name) => {
306-
if *provider_name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
307-
&& cx.has_flag::<SweepFeatureFlag>()
308-
{
309-
div().child(Icon::new(IconName::SweepAi))
310-
} else {
311-
div()
312-
}
313-
}
314-
315-
EditPredictionProvider::Zed => {
303+
provider @ (EditPredictionProvider::Experimental(
304+
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
305+
)
306+
| EditPredictionProvider::Zed) => {
316307
let enabled = self.editor_enabled.unwrap_or(true);
317308

318-
let zeta_icon = if enabled {
319-
IconName::ZedPredict
320-
} else {
321-
IconName::ZedPredictDisabled
309+
let is_sweep = matches!(
310+
provider,
311+
EditPredictionProvider::Experimental(
312+
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
313+
)
314+
);
315+
316+
let zeta_icon = match (is_sweep, enabled) {
317+
(true, _) => IconName::SweepAi,
318+
(false, true) => IconName::ZedPredict,
319+
(false, false) => IconName::ZedPredictDisabled,
322320
};
323321

324322
if zeta::should_show_upsell_modal() {
@@ -402,8 +400,10 @@ impl Render for EditPredictionButton {
402400

403401
let mut popover_menu = PopoverMenu::new("zeta")
404402
.menu(move |window, cx| {
405-
this.update(cx, |this, cx| this.build_zeta_context_menu(window, cx))
406-
.ok()
403+
this.update(cx, |this, cx| {
404+
this.build_zeta_context_menu(provider, window, cx)
405+
})
406+
.ok()
407407
})
408408
.anchor(Corner::BottomRight)
409409
.with_handle(self.popover_menu_handle.clone());
@@ -429,6 +429,10 @@ impl Render for EditPredictionButton {
429429

430430
div().child(popover_menu.into_any_element())
431431
}
432+
433+
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
434+
div().hidden()
435+
}
432436
}
433437
}
434438
}
@@ -487,6 +491,12 @@ impl EditPredictionButton {
487491
providers.push(EditPredictionProvider::Codestral);
488492
}
489493

494+
if cx.has_flag::<SweepFeatureFlag>() {
495+
providers.push(EditPredictionProvider::Experimental(
496+
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
497+
));
498+
}
499+
490500
providers
491501
}
492502

@@ -498,6 +508,11 @@ impl EditPredictionButton {
498508
) -> ContextMenu {
499509
let available_providers = self.get_available_providers(cx);
500510

511+
const ZED_AI_CALLOUT: &str =
512+
"Zed's edit prediction is powered by Zeta, an open-source, dataset mode.";
513+
const USE_SWEEP_API_TOKEN_CALLOUT: &str =
514+
"Set the SWEEP_API_TOKEN environment variable to use Sweep";
515+
501516
let other_providers: Vec<_> = available_providers
502517
.into_iter()
503518
.filter(|p| *p != current_provider && *p != EditPredictionProvider::None)
@@ -514,11 +529,8 @@ impl EditPredictionButton {
514529
ContextMenuEntry::new("Zed AI")
515530
.documentation_aside(
516531
DocumentationSide::Left,
517-
DocumentationEdge::Top,
518-
|_| {
519-
Label::new("Zed's edit prediction is powered by Zeta, an open-source, dataset mode.")
520-
.into_any_element()
521-
},
532+
DocumentationEdge::Bottom,
533+
|_| Label::new(ZED_AI_CALLOUT).into_any_element(),
522534
)
523535
.handler(move |_, cx| {
524536
set_completion_provider(fs.clone(), cx, provider);
@@ -539,7 +551,29 @@ impl EditPredictionButton {
539551
set_completion_provider(fs.clone(), cx, provider);
540552
})
541553
}
542-
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => continue,
554+
EditPredictionProvider::Experimental(
555+
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
556+
) => {
557+
let has_api_token = zeta2::Zeta::try_global(cx)
558+
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
559+
560+
let entry = ContextMenuEntry::new("Sweep")
561+
.when(!has_api_token, |this| {
562+
this.disabled(true).documentation_aside(
563+
DocumentationSide::Left,
564+
DocumentationEdge::Bottom,
565+
|_| Label::new(USE_SWEEP_API_TOKEN_CALLOUT).into_any_element(),
566+
)
567+
})
568+
.handler(move |_, cx| {
569+
set_completion_provider(fs.clone(), cx, provider);
570+
});
571+
572+
menu.item(entry)
573+
}
574+
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
575+
continue;
576+
}
543577
};
544578
}
545579
}
@@ -909,6 +943,7 @@ impl EditPredictionButton {
909943

910944
fn build_zeta_context_menu(
911945
&self,
946+
provider: EditPredictionProvider,
912947
window: &mut Window,
913948
cx: &mut Context<Self>,
914949
) -> Entity<ContextMenu> {
@@ -996,7 +1031,7 @@ impl EditPredictionButton {
9961031
}
9971032

9981033
let menu = self.build_language_settings_menu(menu, window, cx);
999-
let menu = self.add_provider_switching_section(menu, EditPredictionProvider::Zed, cx);
1034+
let menu = self.add_provider_switching_section(menu, provider, cx);
10001035

10011036
menu
10021037
})

crates/zed/src/zed/edit_prediction_registry.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ fn assign_edit_prediction_provider(
204204
editor.set_edit_prediction_provider(Some(provider), window, cx);
205205
}
206206
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
207+
let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
208+
207209
if let Some(project) = editor.project() {
208210
let mut worktree = None;
209211
if let Some(buffer) = &singleton_buffer
@@ -217,7 +219,6 @@ fn assign_edit_prediction_provider(
217219
&& name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
218220
&& cx.has_flag::<SweepFeatureFlag>()
219221
{
220-
let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
221222
let provider = cx.new(|cx| {
222223
zeta2::ZetaEditPredictionProvider::new(
223224
project.clone(),

crates/zeta2/src/zeta2.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -402,20 +402,21 @@ impl Zeta {
402402
#[cfg(feature = "eval-support")]
403403
eval_cache: None,
404404
edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
405-
sweep_api_token: None,
405+
sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
406+
.context("No SWEEP_AI_TOKEN environment variable set")
407+
.log_err(),
406408
sweep_ai_debug_info: sweep_ai::debug_info(cx),
407409
}
408410
}
409411

410412
pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
411-
if model == ZetaEditPredictionModel::Sweep {
412-
self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
413-
.context("No SWEEP_AI_TOKEN environment variable set")
414-
.log_err();
415-
}
416413
self.edit_prediction_model = model;
417414
}
418415

416+
pub fn has_sweep_api_token(&self) -> bool {
417+
self.sweep_api_token.is_some()
418+
}
419+
419420
#[cfg(feature = "eval-support")]
420421
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
421422
self.eval_cache = Some(cache);
@@ -472,7 +473,11 @@ impl Zeta {
472473
}
473474

474475
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
475-
self.user_store.read(cx).edit_prediction_usage()
476+
if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
477+
self.user_store.read(cx).edit_prediction_usage()
478+
} else {
479+
None
480+
}
476481
}
477482

478483
pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
@@ -659,6 +664,10 @@ impl Zeta {
659664
}
660665

661666
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
667+
if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
668+
return;
669+
}
670+
662671
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
663672
return;
664673
};

0 commit comments

Comments
 (0)