Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve created_by info for local to cloud trace #2232

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/promptflow/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import traceback
from datetime import datetime
from typing import Callable

from flask import request
from google.protobuf.json_format import MessageToJson
Expand All @@ -28,7 +29,15 @@
from promptflow._utils.thread_utils import ThreadWithContextVars


def trace_collector(logger: logging.Logger):
def trace_collector(get_created_by_info_with_cache: Callable, logger: logging.Logger):
huaiyan marked this conversation as resolved.
Show resolved Hide resolved
zhengfeiwang marked this conversation as resolved.
Show resolved Hide resolved
"""
huaiyan marked this conversation as resolved.
Show resolved Hide resolved
This function is target to be reused in other places, so pass in get_created_by_info_with_cache and logger to avoid
app related dependencies.
Args:
get_created_by_info_with_cache (Callable): A function that retrieves information about the creator of the trace.
logger (logging.Logger): The logger object used for logging.
"""
content_type = request.headers.get("Content-Type")
# binary protobuf encoding
if "application/x-protobuf" in content_type:
Expand All @@ -55,15 +64,17 @@ def trace_collector(logger: logging.Logger):
all_spans.append(span)

# Create a new thread to write trace to cosmosdb to avoid blocking the main thread
ThreadWithContextVars(target=_try_write_trace_to_cosmosdb, args=(all_spans, logger)).start()
ThreadWithContextVars(
target=_try_write_trace_to_cosmosdb, args=(all_spans, get_created_by_info_with_cache, logger)
).start()
return "Traces received", 200

# JSON protobuf encoding
elif "application/json" in content_type:
raise NotImplementedError


def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):
def _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache: Callable, logger: logging.Logger):
if not all_spans:
return
try:
Expand All @@ -78,31 +89,37 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()
from promptflow._sdk._service.app import CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE
from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
from promptflow.azure._storage.cosmosdb.summary import Summary

# Load span and summary clients first time may slow.
# So, we load 2 client in parallel for warm up.
span_thread = ThreadWithContextVars(
span_client_thread = ThreadWithContextVars(
target=get_client, args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
)
span_thread.start()
span_client_thread.start()

# Load created_by info first time may slow. So, we load it in parallel for warm up.
created_by_thread = ThreadWithContextVars(target=get_created_by_info_with_cache)
created_by_thread.start()

get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name)

span_thread.join()
span_client_thread.join()
created_by_thread.join()

created_by = get_created_by_info_with_cache()

for span in all_spans:
span_client = get_client(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
result = SpanCosmosDB(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE).persist(span_client)
result = SpanCosmosDB(span, created_by).persist(span_client)
# None means the span already exists, then we don't need to persist the summary also.
if result is not None:
line_summary_client = get_client(
CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name
)
Summary(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE, logger).persist(line_summary_client)
Summary(span, created_by, logger).persist(line_summary_client)
logger.info(
(
f"Finish writing trace to cosmosdb, total spans count: {len(all_spans)}."
Expand Down
71 changes: 39 additions & 32 deletions src/promptflow/promptflow/_sdk/_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ---------------------------------------------------------
import logging
import sys
import threading
import time
from datetime import datetime, timedelta
from logging.handlers import RotatingFileHandler
Expand Down Expand Up @@ -42,9 +43,6 @@ def heartbeat():
return jsonify(response)


CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE = {}


def create_app():
app = Flask(__name__)

Expand All @@ -54,7 +52,9 @@ def create_app():
CORS(app)

app.add_url_rule("/heartbeat", view_func=heartbeat)
app.add_url_rule("/v1/traces", view_func=lambda: trace_collector(app.logger), methods=["POST"])
app.add_url_rule(
"/v1/traces", view_func=lambda: trace_collector(get_created_by_info_with_cache, app.logger), methods=["POST"]
)
with app.app_context():
api_v1 = Blueprint("Prompt Flow Service", __name__, url_prefix="/v1.0")

Expand Down Expand Up @@ -86,33 +86,6 @@ def create_app():
# Set app logger to the only one RotatingFileHandler to avoid duplicate logs
app.logger.handlers = [handler]

def initialize_created_by_info():
from promptflow._sdk._configuration import Configuration
from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider

trace_provider = Configuration.get_instance().get_trace_provider()
if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None:
return
try:
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
user_object_id, user_tenant_id = decoded_token["oid"], decoded_token["tid"]
CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE.update(
{
"object_id": user_object_id,
"tenant_id": user_tenant_id,
}
)
except Exception as e:
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")

# Basic error handler
@api.errorhandler(Exception)
def handle_exception(e):
Expand Down Expand Up @@ -167,8 +140,42 @@ def monitor_request():
kill_exist_service(port)
break

initialize_created_by_info()
if not sys.executable.endswith("pfcli.exe"):
monitor_thread = ThreadWithContextVars(target=monitor_request, daemon=True)
monitor_thread.start()
return app, api


created_by_for_local_to_cloud_trace = {}
created_by_for_local_to_cloud_trace_lock = threading.Lock()


def get_created_by_info_with_cache():
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
with created_by_for_local_to_cloud_trace_lock:
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
try:
# The total time of collecting info is about 3s.
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
created_by_for_local_to_cloud_trace.update(
{
"object_id": decoded_token["oid"],
"tenant_id": decoded_token["tid"],
# Use appid as fallback for service principal scenario.
"name": decoded_token.get("name", decoded_token.get("appid", "")),
}
)
except Exception as e:
# This function is only target to be used in Flask app.
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")
return created_by_for_local_to_cloud_trace
Loading