Skip to content

Commit

Permalink
Add ruff rules for boolean trap (FBT)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Oct 14, 2024
1 parent 3e181b9 commit 7aeb375
Show file tree
Hide file tree
Showing 57 changed files with 158 additions and 101 deletions.
16 changes: 9 additions & 7 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_letter_from_version(version: str):


def build_version_notice(current_version: str, package_name: str) -> str:
latest_version = fetch_latest_version(package_name, langflow_is_pre_release(current_version))
latest_version = fetch_latest_version(package_name, include_prerelease=langflow_is_pre_release(current_version))
if latest_version and pkg_version.parse(current_version) < pkg_version.parse(latest_version):
release_type = "pre-release" if langflow_is_pre_release(latest_version) else "version"
return f"A new {release_type} of {package_name} is available: {latest_version}"
Expand All @@ -317,7 +317,7 @@ def generate_pip_command(package_names, is_pre_release):
return f"{base_command} {' '.join(package_names)} -U"


def stylize_text(text: str, to_style: str, is_prerelease: bool) -> str:
def stylize_text(text: str, to_style: str, *, is_prerelease: bool) -> str:
color = "#42a7f5" if is_prerelease else "#6e42f5"
# return "".join(f"[{color}]{char}[/]" for char in text)
styled_text = f"[{color}]{to_style}[/]"
Expand All @@ -337,7 +337,7 @@ def print_banner(host: str, port: int):
is_pre_release |= langflow_is_pre_release(langflow_version) # Update pre-release status

notice = build_version_notice(langflow_version, package_name)
notice = stylize_text(notice, package_name, is_pre_release)
notice = stylize_text(notice, package_name, is_prerelease=is_pre_release)
if notice:
notices.append(notice)
package_names.append(package_name)
Expand All @@ -350,7 +350,9 @@ def print_banner(host: str, port: int):
notices.append(f"Run '{pip_command}' to update.")

styled_notices = [f"[bold]{notice}[/bold]" for notice in notices if notice]
styled_package_name = stylize_text(package_name, package_name, any("pre-release" in notice for notice in notices))
styled_package_name = stylize_text(
package_name, package_name, is_prerelease=any("pre-release" in notice for notice in notices)
)

title = f"[bold]Welcome to :chains: {styled_package_name}[/bold]\n"
info_text = (
Expand Down Expand Up @@ -466,9 +468,9 @@ def copy_db():

@app.command()
def migration(
test: bool = typer.Option(True, help="Run migrations in test mode."),
fix: bool = typer.Option(
False,
test: bool = typer.Option(default=True, help="Run migrations in test mode."), # noqa: FBT001
fix: bool = typer.Option( # noqa: FBT001
default=False,
help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.",
),
):
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def retrieve_vertices_order(

@router.post("/build/{flow_id}/flow")
async def build_flow(
*,
background_tasks: BackgroundTasks,
flow_id: uuid.UUID,
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None,
Expand Down
6 changes: 6 additions & 0 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@

@router.get("/all", dependencies=[Depends(get_current_active_user)])
async def get_all(
*,
settings_service=Depends(get_settings_service),
cache_service: CacheService = Depends(dependency=get_cache_service),
force_refresh: bool = False,
Expand Down Expand Up @@ -105,6 +106,7 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
async def simple_run_flow(
flow: Flow,
input_request: SimplifiedAPIRequest,
*,
stream: bool = False,
api_key_user: User | None = None,
):
Expand Down Expand Up @@ -153,6 +155,7 @@ async def simple_run_flow(
async def simple_run_flow_task(
flow: Flow,
input_request: SimplifiedAPIRequest,
*,
stream: bool = False,
api_key_user: User | None = None,
):
Expand All @@ -173,6 +176,7 @@ async def simple_run_flow_task(

@router.post("/run/{flow_id_or_name}", response_model=RunResponse, response_model_exclude_none=True) # noqa: RUF100, FAST003
async def simplified_run_flow(
*,
background_tasks: BackgroundTasks,
flow: Annotated[FlowRead | None, Depends(get_flow_by_id_or_endpoint_name)],
input_request: SimplifiedAPIRequest | None = None,
Expand Down Expand Up @@ -372,6 +376,7 @@ async def webhook_run_flow(

@router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def experimental_run_flow(
*,
session: Annotated[Session, Depends(get_session)],
flow_id: UUID,
inputs: list[InputValueRequest] | None = None,
Expand Down Expand Up @@ -500,6 +505,7 @@ async def experimental_run_flow(
response_model=ProcessResponse,
)
async def process(
*,
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: list[dict] | dict | None = None,
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/api/v1/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def update_shared_component(

@router.get("/components/", response_model=ListComponentResponseModel)
async def get_components(
*,
component_id: Annotated[str | None, Query()] = None,
search: Annotated[str | None, Query()] = None,
private: Annotated[bool | None, Query()] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _validate_outputs(self):
msg = f"Method '{method_name}' must be defined."
raise ValueError(msg)

def get_agent_kwargs(self, flatten: bool = False) -> dict:
def get_agent_kwargs(self, *, flatten: bool = False) -> dict:
base = {
"handle_parsing_errors": self.handle_parsing_errors,
"verbose": self.verbose,
Expand Down
8 changes: 5 additions & 3 deletions src/backend/base/langflow/base/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def is_hidden(path: Path) -> bool:

def retrieve_file_paths(
path: str,
*,
load_hidden: bool,
recursive: bool,
depth: int,
Expand Down Expand Up @@ -74,7 +75,7 @@ def walk_level(directory: Path, max_depth: int):
return [str(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)]


def partition_file_to_data(file_path: str, silent_errors: bool) -> Data | None:
def partition_file_to_data(file_path: str, *, silent_errors: bool) -> Data | None:
# Use the partition function to load the file
from unstructured.partition.auto import partition

Expand Down Expand Up @@ -122,7 +123,7 @@ def parse_pdf_to_text(file_path: str) -> str:
return "\n\n".join([page.extract_text() for page in reader.pages])


def parse_text_file_to_data(file_path: str, silent_errors: bool) -> Data | None:
def parse_text_file_to_data(file_path: str, *, silent_errors: bool) -> Data | None:
try:
if file_path.endswith(".pdf"):
text = parse_pdf_to_text(file_path)
Expand Down Expand Up @@ -172,13 +173,14 @@ def parse_text_file_to_data(file_path: str, silent_errors: bool) -> Data | None:

def parallel_load_data(
file_paths: list[str],
*,
silent_errors: bool,
max_concurrency: int,
load_function: Callable = parse_text_file_to_data,
) -> list[Data | None]:
with futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
loaded_files = executor.map(
lambda file_path: load_function(file_path, silent_errors),
lambda file_path: load_function(file_path, silent_errors=silent_errors),
file_paths,
)
# loaded_files is an iterator, so we need to convert it to a list
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/flow_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def build_data_from_run_outputs(run_outputs: RunOutputs) -> list[Data]:
return data


def build_data_from_result_data(result_data: ResultData, get_final_results_only: bool = True) -> list[Data]:
def build_data_from_result_data(result_data: ResultData, *, get_final_results_only: bool = True) -> list[Data]:
"""
Build a list of data from the given ResultData.
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/base/io/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _stream_message(self, message: Message, message_id: str) -> str:

def build_with_data(
self,
*,
sender: str | None = "User",
sender_name: str | None = "User",
input_value: str | Data | Message | None = None,
Expand Down
7 changes: 5 additions & 2 deletions src/backend/base/langflow/base/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def text_response(self) -> Message:
stream = self.stream
system_message = self.system_message
output = self.build_model()
result = self.get_chat_result(output, stream, input_value, system_message)
result = self.get_chat_result(
runnable=output, stream=stream, input_value=input_value, system_message=system_message
)
self.status = result
return result

def get_result(self, runnable: LLM, stream: bool, input_value: str):
def get_result(self, *, runnable: LLM, stream: bool, input_value: str):
"""
Retrieves the result from the output of a Runnable object.
Expand Down Expand Up @@ -141,6 +143,7 @@ def build_status_message(self, message: AIMessage):

def get_chat_result(
self,
*,
runnable: LanguageModel,
stream: bool,
input_value: str | Message,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/prompts/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _check_input_variables(input_variables):
return fixed_variables


def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]:
def validate_prompt(prompt_template: str, *, silent_errors: bool = False) -> list[str]:
input_variables = extract_input_variables_from_prompt(prompt_template)

# Check if there are invalid characters in the input_variables
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/chains/RetrievalQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def invoke_chain(self) -> Message:

result = runnable.invoke({"query": self.input_value}, config={"callbacks": self.get_langchain_callbacks()})

source_docs = self.to_data(result.get("source_documents", []))
source_docs = self.to_data(result.get("source_documents", keys=[]))
result_str = str(result.get("result", ""))
if self.return_source_documents and len(source_docs):
references_str = self.create_references_from_data(source_docs)
Expand Down
8 changes: 5 additions & 3 deletions src/backend/base/langflow/components/data/Directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,19 @@ def load_directory(self) -> list[Data]:
use_multithreading = self.use_multithreading

resolved_path = self.resolve_path(path)
file_paths = retrieve_file_paths(resolved_path, load_hidden, recursive, depth, types)
file_paths = retrieve_file_paths(
resolved_path, load_hidden=load_hidden, recursive=recursive, depth=depth, types=types
)

if types:
file_paths = [fp for fp in file_paths if any(fp.endswith(ext) for ext in types)]

loaded_data = []

if use_multithreading:
loaded_data = parallel_load_data(file_paths, silent_errors, max_concurrency)
loaded_data = parallel_load_data(file_paths, silent_errors=silent_errors, max_concurrency=max_concurrency)
else:
loaded_data = [parse_text_file_to_data(file_path, silent_errors) for file_path in file_paths]
loaded_data = [parse_text_file_to_data(file_path, silent_errors=silent_errors) for file_path in file_paths]
loaded_data = list(filter(None, loaded_data))
self.status = loaded_data
return loaded_data # type: ignore[return-value]
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/data/File.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ def load_file(self) -> Data:
msg = f"Unsupported file type: {extension}"
raise ValueError(msg)

data = parse_text_file_to_data(resolved_path, silent_errors)
data = parse_text_file_to_data(resolved_path, silent_errors=silent_errors)
self.status = data or "No data"
return data or Data()
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/data/Gmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class GmailLoaderComponent(Component):
def load_emails(self) -> Data:
class CustomGMailLoader(GMailLoader):
def __init__(
self, creds: Any, n: int = 100, label_ids: list[str] | None = None, raise_error: bool = False
self, creds: Any, *, n: int = 100, label_ids: list[str] | None = None, raise_error: bool = False
) -> None:
super().__init__(creds, n, raise_error)
self.label_ids = label_ids if label_ids is not None else ["SENT"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExtractKeyFromDataComponent(CustomComponent):
},
}

def build(self, data: Data, keys: list[str], silent_error: bool = True) -> Data:
def build(self, data: Data, keys: list[str], *, silent_error: bool = True) -> Data:
"""
Extracts the keys from a data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class SelectivePassThroughComponent(Component):
Output(display_name="Passed Output", name="passed_output", method="pass_through"),
]

def evaluate_condition(self, input_value: str, comparison_value: str, operator: str, case_sensitive: bool) -> bool:
def evaluate_condition(
self, input_value: str, comparison_value: str, operator: str, *, case_sensitive: bool
) -> bool:
if not case_sensitive:
input_value = input_value.lower()
comparison_value = comparison_value.lower()
Expand All @@ -68,7 +70,7 @@ def pass_through(self) -> Text:
value_to_pass = self.value_to_pass
case_sensitive = self.case_sensitive

if self.evaluate_condition(input_value, comparison_value, operator, case_sensitive):
if self.evaluate_condition(input_value, comparison_value, operator, case_sensitive=case_sensitive):
self.status = value_to_pass
return value_to_pass
self.status = ""
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/components/deactivated/SubFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def build_config(self):
},
}

async def build(self, flow_name: str, get_final_results_only: bool = True, **kwargs) -> list[Data]:
async def build(self, flow_name: str, *, get_final_results_only: bool = True, **kwargs) -> list[Data]:
tweaks = {key: {"input_value": value} for key, value in kwargs.items()}
run_outputs: list[RunOutputs | None] = await self.run_flow(
tweaks=tweaks,
Expand All @@ -118,7 +118,7 @@ async def build(self, flow_name: str, get_final_results_only: bool = True, **kwa
if run_output is not None:
for output in run_output.outputs:
if output:
data.extend(build_data_from_result_data(output, get_final_results_only))
data.extend(build_data_from_result_data(output, get_final_results_only=get_final_results_only))

self.status = data
logger.debug(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def build(
app = FirecrawlApp(api_key=api_key)
crawl_result = app.crawl_url(
url,
{
params={
"crawlerOptions": crawler_options_dict,
"pageOptions": page_options_dict,
},
True,
int(timeout / 1000),
idempotency_key,
wait_until_done=True,
poll_interval=int(timeout / 1000),
idempotency_key=idempotency_key,
)

return Data(data={"results": crawl_result})
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConditionalRouterComponent(Component):
Output(display_name="False Route", name="false_result", method="false_response"),
]

def evaluate_condition(self, input_text: str, match_text: str, operator: str, case_sensitive: bool) -> bool:
def evaluate_condition(self, input_text: str, match_text: str, operator: str, *, case_sensitive: bool) -> bool:
if not case_sensitive:
input_text = input_text.lower()
match_text = match_text.lower()
Expand All @@ -65,15 +65,19 @@ def evaluate_condition(self, input_text: str, match_text: str, operator: str, ca
return False

def true_response(self) -> Message:
result = self.evaluate_condition(self.input_text, self.match_text, self.operator, self.case_sensitive)
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if result:
self.status = self.message
return self.message
self.stop("true_result")
return None # type: ignore[return-value]

def false_response(self) -> Message:
result = self.evaluate_condition(self.input_text, self.match_text, self.operator, self.case_sensitive)
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if not result:
self.status = self.message
return self.message
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/prototypes/Notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def build_config(self):
},
}

def build(self, name: str, data: Data | None = None, append: bool = False) -> Data:
def build(self, name: str, *, data: Data | None = None, append: bool = False) -> Data:
if data and not isinstance(data, Data):
if isinstance(data, str):
data = Data(text=data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def build(
self,
query: str,
database_url: str,
*,
include_columns: bool = False,
passthrough: bool = False,
add_error: bool = False,
Expand Down
Loading

0 comments on commit 7aeb375

Please sign in to comment.