Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/codestral/src/codestral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
Self::api_key(cx).is_some()
}

fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
self.pending_request.is_some()
}

Expand Down
2 changes: 1 addition & 1 deletion crates/copilot/src/copilot_completion_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
false
}

fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
self.pending_refresh.is_some() && self.completions.is_empty()
}

Expand Down
4 changes: 2 additions & 2 deletions crates/edit_prediction/src/edit_prediction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub trait EditPredictionProvider: 'static + Sized {
cursor_position: language::Anchor,
cx: &App,
) -> bool;
fn is_refreshing(&self) -> bool;
fn is_refreshing(&self, cx: &App) -> bool;
fn refresh(
&mut self,
buffer: Entity<Buffer>,
Expand Down Expand Up @@ -200,7 +200,7 @@ where
}

fn is_refreshing(&self, cx: &App) -> bool {
self.read(cx).is_refreshing()
self.read(cx).is_refreshing(cx)
}

fn refresh(
Expand Down
4 changes: 2 additions & 2 deletions crates/editor/src/edit_prediction_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
true
}

fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &gpui::App) -> bool {
false
}

Expand Down Expand Up @@ -542,7 +542,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
true
}

fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &gpui::App) -> bool {
false
}

Expand Down
2 changes: 1 addition & 1 deletion crates/supermaven/src/supermaven_completion_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
self.supermaven.read(cx).is_enabled()
}

fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
self.pending_refresh.is_some() && self.completion_id.is_none()
}

Expand Down
1 change: 1 addition & 0 deletions crates/util/src/rel_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ impl PartialEq<str> for RelPath {
}
}

#[derive(Default)]
pub struct RelPathComponents<'a>(&'a str);

pub struct RelPathAncestors<'a>(Option<&'a str>);
Expand Down
2 changes: 1 addition & 1 deletion crates/zeta/src/zeta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
) -> bool {
true
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
!self.pending_completions.is_empty()
}

Expand Down
3 changes: 2 additions & 1 deletion crates/zeta2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
open_ai.workspace = true
pretty_assertions.workspace = true
project.workspace = true
release_channel.workspace = true
serde.workspace = true
Expand All @@ -44,7 +46,6 @@ util.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
pretty_assertions.workspace = true

[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
Expand Down
89 changes: 12 additions & 77 deletions crates/zeta2/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
use std::{
cmp,
sync::Arc,
time::{Duration, Instant},
};
use std::{cmp, sync::Arc, time::Duration};

use arrayvec::ArrayVec;
use client::{Client, UserStore};
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
use gpui::{App, Entity, Task, prelude::*};
use gpui::{App, Entity, prelude::*};
use language::ToPoint as _;
use project::Project;
use util::ResultExt as _;

use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};

pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
last_request_timestamp: Instant,
project: Entity<Project>,
}

Expand All @@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider {
project: Entity<Project>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
cx: &mut App,
cx: &mut Context<Self>,
) -> Self {
let zeta = Zeta::global(client, user_store, cx);
zeta.update(cx, |zeta, cx| {
zeta.register_project(&project, cx);
});

cx.observe(&zeta, |_this, _zeta, cx| {
cx.notify();
})
.detach();

Self {
zeta,
next_pending_prediction_id: 0,
pending_predictions: ArrayVec::new(),
last_request_timestamp: Instant::now(),
project: project,
zeta,
}
}
}

struct PendingPrediction {
id: usize,
_task: Task<()>,
}

impl EditPredictionProvider for ZetaEditPredictionProvider {
fn name() -> &'static str {
"zed-predict2"
Expand Down Expand Up @@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}

fn is_refreshing(&self) -> bool {
!self.pending_predictions.is_empty()
fn is_refreshing(&self, cx: &App) -> bool {
self.zeta.read(cx).is_refreshing(&self.project)
}

fn refresh(
Expand All @@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {

self.zeta.update(cx, |zeta, cx| {
zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});

let pending_prediction_id = self.next_pending_prediction_id;
self.next_pending_prediction_id += 1;
let last_request_timestamp = self.last_request_timestamp;

let project = self.project.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
{
cx.background_executor().timer(timeout).await;
}

let refresh_task = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
})
});

if let Some(refresh_task) = refresh_task.ok() {
refresh_task.await.log_err();
}

this.update(cx, |this, cx| {
if this.pending_predictions[0].id == pending_prediction_id {
this.pending_predictions.remove(0);
} else {
this.pending_predictions.clear();
}

cx.notify();
})
.ok();
});

// We always maintain at most two pending predictions. When we already
// have two, we replace the newest one.
if self.pending_predictions.len() <= 1 {
self.pending_predictions.push(PendingPrediction {
id: pending_prediction_id,
_task: task,
});
} else if self.pending_predictions.len() == 2 {
self.pending_predictions.pop();
self.pending_predictions.push(PendingPrediction {
id: pending_prediction_id,
_task: task,
});
}

cx.notify();
}

fn cycle(
Expand All @@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
self.zeta.update(cx, |zeta, cx| {
zeta.accept_current_prediction(&self.project, cx);
});
self.pending_predictions.clear();
}

fn discard(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
});
self.pending_predictions.clear();
}

fn suggest(
Expand Down
Loading
Loading