Skip to content

Commit

Permalink
ref: Add ruff rules for boolean trap (FBT) (langflow-ai#4126)
Browse files Browse the repository at this point in the history
Add ruff rules for boolean trap (FBT)
  • Loading branch information
cbornet authored and smatiolids committed Oct 15, 2024
1 parent eb34484 commit 787beea
Show file tree
Hide file tree
Showing 56 changed files with 158 additions and 102 deletions.
16 changes: 9 additions & 7 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def get_letter_from_version(version: str):


def build_version_notice(current_version: str, package_name: str) -> str:
latest_version = fetch_latest_version(package_name, langflow_is_pre_release(current_version))
latest_version = fetch_latest_version(package_name, include_prerelease=langflow_is_pre_release(current_version))
if latest_version and pkg_version.parse(current_version) < pkg_version.parse(latest_version):
release_type = "pre-release" if langflow_is_pre_release(latest_version) else "version"
return f"A new {release_type} of {package_name} is available: {latest_version}"
Expand All @@ -302,7 +302,7 @@ def generate_pip_command(package_names, is_pre_release):
return f"{base_command} {' '.join(package_names)} -U"


def stylize_text(text: str, to_style: str, is_prerelease: bool) -> str:
def stylize_text(text: str, to_style: str, *, is_prerelease: bool) -> str:
color = "#42a7f5" if is_prerelease else "#6e42f5"
# return "".join(f"[{color}]{char}[/]" for char in text)
styled_text = f"[{color}]{to_style}[/]"
Expand All @@ -322,7 +322,7 @@ def print_banner(host: str, port: int):
is_pre_release |= langflow_is_pre_release(langflow_version) # Update pre-release status

notice = build_version_notice(langflow_version, package_name)
notice = stylize_text(notice, package_name, is_pre_release)
notice = stylize_text(notice, package_name, is_prerelease=is_pre_release)
if notice:
notices.append(notice)
package_names.append(package_name)
Expand All @@ -335,7 +335,9 @@ def print_banner(host: str, port: int):
notices.append(f"Run '{pip_command}' to update.")

styled_notices = [f"[bold]{notice}[/bold]" for notice in notices if notice]
styled_package_name = stylize_text(package_name, package_name, any("pre-release" in notice for notice in notices))
styled_package_name = stylize_text(
package_name, package_name, is_prerelease=any("pre-release" in notice for notice in notices)
)

title = f"[bold]Welcome to :chains: {styled_package_name}[/bold]\n"
info_text = (
Expand Down Expand Up @@ -445,9 +447,9 @@ def copy_db():

@app.command()
def migration(
test: bool = typer.Option(True, help="Run migrations in test mode."),
fix: bool = typer.Option(
False,
test: bool = typer.Option(default=True, help="Run migrations in test mode."), # noqa: FBT001
fix: bool = typer.Option( # noqa: FBT001
default=False,
help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.",
),
):
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def retrieve_vertices_order(

@router.post("/build/{flow_id}/flow")
async def build_flow(
*,
background_tasks: BackgroundTasks,
flow_id: uuid.UUID,
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None,
Expand Down
5 changes: 5 additions & 0 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

@router.get("/all", dependencies=[Depends(get_current_active_user)])
async def get_all(
*,
settings_service=Depends(get_settings_service),
):
from langflow.interface.types import get_and_cache_all_types_dict
Expand Down Expand Up @@ -96,6 +97,7 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
async def simple_run_flow(
flow: Flow,
input_request: SimplifiedAPIRequest,
*,
stream: bool = False,
api_key_user: User | None = None,
):
Expand Down Expand Up @@ -144,6 +146,7 @@ async def simple_run_flow(
async def simple_run_flow_task(
flow: Flow,
input_request: SimplifiedAPIRequest,
*,
stream: bool = False,
api_key_user: User | None = None,
):
Expand All @@ -162,6 +165,7 @@ async def simple_run_flow_task(

@router.post("/run/{flow_id_or_name}", response_model=RunResponse, response_model_exclude_none=True) # noqa: RUF100, FAST003
async def simplified_run_flow(
*,
background_tasks: BackgroundTasks,
flow: Annotated[FlowRead | None, Depends(get_flow_by_id_or_endpoint_name)],
input_request: SimplifiedAPIRequest | None = None,
Expand Down Expand Up @@ -361,6 +365,7 @@ async def webhook_run_flow(

@router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def experimental_run_flow(
*,
session: Annotated[Session, Depends(get_session)],
flow_id: UUID,
inputs: list[InputValueRequest] | None = None,
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/api/v1/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def update_shared_component(

@router.get("/components/", response_model=ListComponentResponseModel)
async def get_components(
*,
component_id: Annotated[str | None, Query()] = None,
search: Annotated[str | None, Query()] = None,
private: Annotated[bool | None, Query()] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _validate_outputs(self):
msg = f"Method '{method_name}' must be defined."
raise ValueError(msg)

def get_agent_kwargs(self, flatten: bool = False) -> dict:
def get_agent_kwargs(self, *, flatten: bool = False) -> dict:
base = {
"handle_parsing_errors": self.handle_parsing_errors,
"verbose": self.verbose,
Expand Down
8 changes: 5 additions & 3 deletions src/backend/base/langflow/base/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def is_hidden(path: Path) -> bool:

def retrieve_file_paths(
path: str,
*,
load_hidden: bool,
recursive: bool,
depth: int,
Expand Down Expand Up @@ -74,7 +75,7 @@ def walk_level(directory: Path, max_depth: int):
return [str(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)]


def partition_file_to_data(file_path: str, silent_errors: bool) -> Data | None:
def partition_file_to_data(file_path: str, *, silent_errors: bool) -> Data | None:
# Use the partition function to load the file
from unstructured.partition.auto import partition

Expand Down Expand Up @@ -122,7 +123,7 @@ def parse_pdf_to_text(file_path: str) -> str:
return "\n\n".join([page.extract_text() for page in reader.pages])


def parse_text_file_to_data(file_path: str, silent_errors: bool) -> Data | None:
def parse_text_file_to_data(file_path: str, *, silent_errors: bool) -> Data | None:
try:
if file_path.endswith(".pdf"):
text = parse_pdf_to_text(file_path)
Expand Down Expand Up @@ -172,13 +173,14 @@ def parse_text_file_to_data(file_path: str, silent_errors: bool) -> Data | None:

def parallel_load_data(
file_paths: list[str],
*,
silent_errors: bool,
max_concurrency: int,
load_function: Callable = parse_text_file_to_data,
) -> list[Data | None]:
with futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
loaded_files = executor.map(
lambda file_path: load_function(file_path, silent_errors),
lambda file_path: load_function(file_path, silent_errors=silent_errors),
file_paths,
)
# loaded_files is an iterator, so we need to convert it to a list
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/base/io/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _stream_message(self, message: Message, message_id: str) -> str:

def build_with_data(
self,
*,
sender: str | None = "User",
sender_name: str | None = "User",
input_value: str | Data | Message | None = None,
Expand Down
7 changes: 5 additions & 2 deletions src/backend/base/langflow/base/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def text_response(self) -> Message:
stream = self.stream
system_message = self.system_message
output = self.build_model()
result = self.get_chat_result(output, stream, input_value, system_message)
result = self.get_chat_result(
runnable=output, stream=stream, input_value=input_value, system_message=system_message
)
self.status = result
return result

def get_result(self, runnable: LLM, stream: bool, input_value: str):
def get_result(self, *, runnable: LLM, stream: bool, input_value: str):
"""Retrieves the result from the output of a Runnable object.
Args:
Expand Down Expand Up @@ -139,6 +141,7 @@ def build_status_message(self, message: AIMessage):

def get_chat_result(
self,
*,
runnable: LanguageModel,
stream: bool,
input_value: str | Message,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/prompts/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _check_input_variables(input_variables):
return fixed_variables


def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]:
def validate_prompt(prompt_template: str, *, silent_errors: bool = False) -> list[str]:
input_variables = extract_input_variables_from_prompt(prompt_template)

# Check if there are invalid characters in the input_variables
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/chains/RetrievalQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def invoke_chain(self) -> Message:

result = runnable.invoke({"query": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})

source_docs = self.to_data(result.get("source_documents", []))
source_docs = self.to_data(result.get("source_documents", keys=[]))
result_str = str(result.get("result", ""))
if self.return_source_documents and len(source_docs):
references_str = self.create_references_from_data(source_docs)
Expand Down
8 changes: 5 additions & 3 deletions src/backend/base/langflow/components/data/Directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,19 @@ def load_directory(self) -> list[Data]:
use_multithreading = self.use_multithreading

resolved_path = self.resolve_path(path)
file_paths = retrieve_file_paths(resolved_path, load_hidden, recursive, depth, types)
file_paths = retrieve_file_paths(
resolved_path, load_hidden=load_hidden, recursive=recursive, depth=depth, types=types
)

if types:
file_paths = [fp for fp in file_paths if any(fp.endswith(ext) for ext in types)]

loaded_data = []

if use_multithreading:
loaded_data = parallel_load_data(file_paths, silent_errors, max_concurrency)
loaded_data = parallel_load_data(file_paths, silent_errors=silent_errors, max_concurrency=max_concurrency)
else:
loaded_data = [parse_text_file_to_data(file_path, silent_errors) for file_path in file_paths]
loaded_data = [parse_text_file_to_data(file_path, silent_errors=silent_errors) for file_path in file_paths]
loaded_data = list(filter(None, loaded_data))
self.status = loaded_data
return loaded_data # type: ignore[return-value]
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/data/File.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ def load_file(self) -> Data:
msg = f"Unsupported file type: {extension}"
raise ValueError(msg)

data = parse_text_file_to_data(resolved_path, silent_errors)
data = parse_text_file_to_data(resolved_path, silent_errors=silent_errors)
self.status = data or "No data"
return data or Data()
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/data/Gmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class GmailLoaderComponent(Component):
def load_emails(self) -> Data:
class CustomGMailLoader(GMailLoader):
def __init__(
self, creds: Any, n: int = 100, label_ids: list[str] | None = None, raise_error: bool = False
self, creds: Any, *, n: int = 100, label_ids: list[str] | None = None, raise_error: bool = False
) -> None:
super().__init__(creds, n, raise_error)
self.label_ids = label_ids if label_ids is not None else ["SENT"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExtractKeyFromDataComponent(CustomComponent):
},
}

def build(self, data: Data, keys: list[str], silent_error: bool = True) -> Data:
def build(self, data: Data, keys: list[str], *, silent_error: bool = True) -> Data:
"""Extracts the keys from a data.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class SelectivePassThroughComponent(Component):
Output(display_name="Passed Output", name="passed_output", method="pass_through"),
]

def evaluate_condition(self, input_value: str, comparison_value: str, operator: str, case_sensitive: bool) -> bool:
def evaluate_condition(
self, input_value: str, comparison_value: str, operator: str, *, case_sensitive: bool
) -> bool:
if not case_sensitive:
input_value = input_value.lower()
comparison_value = comparison_value.lower()
Expand All @@ -68,7 +70,7 @@ def pass_through(self) -> Text:
value_to_pass = self.value_to_pass
case_sensitive = self.case_sensitive

if self.evaluate_condition(input_value, comparison_value, operator, case_sensitive):
if self.evaluate_condition(input_value, comparison_value, operator, case_sensitive=case_sensitive):
self.status = value_to_pass
return value_to_pass
self.status = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def build(
app = FirecrawlApp(api_key=api_key)
crawl_result = app.crawl_url(
url,
{
params={
"crawlerOptions": crawler_options_dict,
"pageOptions": page_options_dict,
},
True,
int(timeout / 1000),
idempotency_key,
wait_until_done=True,
poll_interval=int(timeout / 1000),
idempotency_key=idempotency_key,
)

return Data(data={"results": crawl_result})
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConditionalRouterComponent(Component):
Output(display_name="False Route", name="false_result", method="false_response"),
]

def evaluate_condition(self, input_text: str, match_text: str, operator: str, case_sensitive: bool) -> bool:
def evaluate_condition(self, input_text: str, match_text: str, operator: str, *, case_sensitive: bool) -> bool:
if not case_sensitive:
input_text = input_text.lower()
match_text = match_text.lower()
Expand All @@ -65,15 +65,19 @@ def evaluate_condition(self, input_text: str, match_text: str, operator: str, ca
return False

def true_response(self) -> Message:
result = self.evaluate_condition(self.input_text, self.match_text, self.operator, self.case_sensitive)
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if result:
self.status = self.message
return self.message
self.stop("true_result")
return None # type: ignore[return-value]

def false_response(self) -> Message:
result = self.evaluate_condition(self.input_text, self.match_text, self.operator, self.case_sensitive)
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if not result:
self.status = self.message
return self.message
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/prototypes/Notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def build_config(self):
},
}

def build(self, name: str, data: Data | None = None, append: bool = False) -> Data:
def build(self, name: str, *, data: Data | None = None, append: bool = False) -> Data:
if data and not isinstance(data, Data):
if isinstance(data, str):
data = Data(text=data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def build(
self,
query: str,
database_url: str,
*,
include_columns: bool = False,
passthrough: bool = False,
add_error: bool = False,
Expand Down
17 changes: 10 additions & 7 deletions src/backend/base/langflow/components/tools/TavilyAISearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,19 @@ class TavilySearchSchema(BaseModel):
search_depth: str = Field("basic", description="The depth of the search.")
topic: str = Field("general", description="The category of the search.")
max_results: int = Field(5, description="The maximum number of search results to return.")
include_images: bool = Field(False, description="Include a list of query-related images in the response.")
include_answer: bool = Field(False, description="Include a short answer to original query.")
include_images: bool = Field(
default=False, description="Include a list of query-related images in the response."
)
include_answer: bool = Field(default=False, description="Include a short answer to original query.")

def run_model(self) -> list[Data]:
return self._tavily_search(
self.query,
self.search_depth,
self.topic,
self.max_results,
self.include_images,
self.include_answer,
search_depth=self.search_depth,
topic=self.topic,
max_results=self.max_results,
include_images=self.include_images,
include_answer=self.include_answer,
)

def build_tool(self) -> Tool:
Expand All @@ -102,6 +104,7 @@ def build_tool(self) -> Tool:
def _tavily_search(
self,
query: str,
*,
search_depth: str = "basic",
topic: str = "general",
max_results: int = 5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def to_frontend_node(self):
field_config = self.get_template_config(self)
frontend_node = ComponentFrontendNode.from_inputs(**field_config)
for key in self._inputs:
frontend_node.set_field_load_from_db_in_template(key, False)
frontend_node.set_field_load_from_db_in_template(key, value=False)
self._map_parameters_on_frontend_node(frontend_node)

frontend_node_dict = frontend_node.to_dict(keep_name=False)
Expand Down
Loading

0 comments on commit 787beea

Please sign in to comment.