Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve User-Added Labels in Pull Requests #433

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,20 @@ def set_custom_labels(variables):
final_labels += f" - {k} ({v['description']})\n"
variables["custom_labels"] = final_labels
variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"


def get_user_labels(current_labels):
## Only keep labels that has been added by the user
if current_labels is None:
current_labels = []
user_labels = []
for label in current_labels:
if label in ['Bug fix', 'Tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other']:
continue
if get_settings().config.enable_custom_labels:
if label in get_settings().custom_labels:
continue
user_labels.append(label)
if user_labels:
get_logger().info(f"Keeping user labels: {user_labels}")
return user_labels
3 changes: 2 additions & 1 deletion pr_agent/settings/pr_description_prompts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ PR Type:
...
{%- if enable_custom_labels %}
PR Labels:
{{ custom_labels_examples }}
- ...
- ...
{%- endif %}
PR Description: |-
...
Expand Down
8 changes: 4 additions & 4 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
Expand Down Expand Up @@ -98,9 +98,9 @@ async def run(self):
self.git_provider.publish_description(pr_title, pr_body)
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
current_labels = self.git_provider.get_labels()
if current_labels is None:
current_labels = []
self.git_provider.publish_labels(pr_labels + current_labels)
user_labels = get_user_labels(current_labels)

self.git_provider.publish_labels(pr_labels + user_labels)
self.git_provider.remove_initial_comment()
except Exception as e:
get_logger().error(f"Error generating PR description {self.pr_id}: {e}")
Expand Down
7 changes: 3 additions & 4 deletions pr_agent/tools/pr_generate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
Expand Down Expand Up @@ -84,9 +84,8 @@ async def run(self):
get_logger().info(f"Pushing labels {self.pr_id}")

current_labels = self.git_provider.get_labels()
if current_labels is None:
current_labels = []
pr_labels = pr_labels + current_labels
user_labels = get_user_labels(current_labels)
pr_labels = pr_labels + user_labels

if self.git_provider.is_supported("get_labels"):
self.git_provider.publish_labels(pr_labels)
Expand Down
Loading