Skip to content

Commit

Permalink
core: Add ruff rules for comprehensions (C4) (#26829)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Sep 25, 2024
1 parent 7e5a9c3 commit 3a1b925
Show file tree
Hide file tree
Showing 34 changed files with 259 additions and 265 deletions.
6 changes: 3 additions & 3 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def _config_with_context(
)
}
deps_by_key = {
key: set(
key: {
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
)
}
for key, group in grouped_by_key.items()
}

Expand Down Expand Up @@ -198,7 +198,7 @@ async def ainvoke(
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return {key: value for key, value in zip(self.key, values)}
return dict(zip(self.key, values))
else:
return await configurable[self.ids[0]]()

Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def _get_ls_params(
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
param_string = str(sorted(params.items()))
# This code is not super efficient as it goes back and forth between
# json and dict.
serialized_repr = self._serialized
Expand All @@ -561,7 +561,7 @@ def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> st
else:
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
return str(sorted(params.items()))

def generate(
self,
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_prompts(
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
Expand Down Expand Up @@ -202,7 +202,7 @@ async def aget_prompts(
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
Expand Down
6 changes: 3 additions & 3 deletions libs/core/langchain_core/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def __init__(
Defaults to None.
"""
self.secrets_from_env = secrets_from_env
self.secrets_map = secrets_map or dict()
self.secrets_map = secrets_map or {}
# By default, only support langchain, but user can pass in additional namespaces
self.valid_namespaces = (
[*DEFAULT_NAMESPACES, *valid_namespaces]
if valid_namespaces
else DEFAULT_NAMESPACES
)
self.additional_import_mappings = additional_import_mappings or dict()
self.additional_import_mappings = additional_import_mappings or {}
self.import_mappings = (
{
**ALL_SERIALIZABLE_MAPPINGS,
Expand Down Expand Up @@ -146,7 +146,7 @@ def __call__(self, value: dict[str, Any]) -> Any:

# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
kwargs = value.get("kwargs", {})
return cls(**kwargs)

return value
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/load/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def lc_secrets(self) -> dict[str, str]:
For example,
{"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
return {}

@property
def lc_attributes(self) -> dict:
Expand Down Expand Up @@ -188,7 +188,7 @@ def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.is_lc_serializable():
return self.to_json_not_implemented()

secrets = dict()
secrets = {}
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {}
for k, v in self:
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_format_instructions(self) -> str:
return "Return a JSON object."
else:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()}
schema = dict(self._get_schema(self.pydantic_object).items())

# Remove extraneous fields.
reduced_schema = schema
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_format_instructions(self) -> str:
The format instructions for the JSON output.
"""
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.model_json_schema().items()}
schema = dict(self.pydantic_object.model_json_schema().items())

# Remove extraneous fields.
reduced_schema = schema
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/outputs/llm_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def flatten(self) -> list[LLMResult]:
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
llm_output["token_usage"] = {}
else:
llm_output = None
llm_results.append(
Expand Down
10 changes: 5 additions & 5 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,11 +1007,11 @@ def __init__(
input_vars.update(_message.input_variables)

kwargs = {
**dict(
input_variables=sorted(input_vars),
optional_variables=sorted(optional_variables),
partial_variables=partial_vars,
),
**{
"input_variables": sorted(input_vars),
"optional_variables": sorted(optional_variables),
"partial_variables": partial_vars,
},
**kwargs,
}
cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/prompts/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs:
kwargs["input_variables"] = []

overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
overlap = set(kwargs["input_variables"]) & {"url", "path", "detail"}
if overlap:
raise ValueError(
"input_variables for the image template cannot contain"
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __add__(self, other: Any) -> PromptTemplate:
template = self.template + other.template
# If any do not want to validate, then don't
validate_template = self.validate_template and other.validate_template
partial_variables = {k: v for k, v in self.partial_variables.items()}
partial_variables = dict(self.partial_variables.items())
for k, v in other.partial_variables.items():
if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.")
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3778,7 +3778,7 @@ async def _ainvoke_step(
for key, step in steps.items()
)
)
output = {key: value for key, value in zip(steps, results)}
output = dict(zip(steps, results))
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/runnables/fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def batch(
]

to_return: dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)}
run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
Expand Down Expand Up @@ -388,7 +388,7 @@ async def abatch(
)

to_return = {}
run_again = {i: input for i, input in enumerate(inputs)}
run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/runnables/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_lc_namespace(cls) -> list[str]:

@property
def _kwargs_retrying(self) -> dict[str, Any]:
kwargs: dict[str, Any] = dict()
kwargs: dict[str, Any] = {}

if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/sys_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]:
from importlib import metadata

sub_deps = set()
_underscored_packages = set(pkg.replace("-", "_") for pkg in packages)
_underscored_packages = {pkg.replace("-", "_") for pkg in packages}

for pkg in packages:
try:
Expand All @@ -33,7 +33,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]:
return sorted(sub_deps, key=lambda x: x.lower())


def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None:
def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
"""Print information about the environment for debugging purposes.
Args:
Expand Down
10 changes: 4 additions & 6 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def _get_all_basemodel_annotations(
) and name not in fields:
continue
annotations[name] = param.annotation
orig_bases: tuple = getattr(cls, "__orig_bases__", tuple())
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
# cls has subscript: cls = FooBar[int]
else:
annotations = _get_all_basemodel_annotations(
Expand Down Expand Up @@ -1007,11 +1007,9 @@ def _get_all_basemodel_annotations(
# parent_origin = Baz,
# generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str}
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple())
generic_map = {
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
}
for field in getattr(parent_origin, "__annotations__", dict()):
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
generic_map = dict(zip(generic_type_vars, get_args(parent)))
for field in getattr(parent_origin, "__annotations__", {}):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
)
Expand Down
4 changes: 1 addition & 3 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def _convert_any_typed_dicts_to_pydantic(
new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited
)
field_kwargs = {
k: v for k, v in zip(("default", "description"), annotated_args[1:])
}
field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
if (field_desc := field_kwargs.get("description")) and not isinstance(
field_desc, str
):
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "B", "E", "F", "I", "N", "T201", "UP",]
select = [ "B", "C4", "E", "F", "I", "N", "T201", "UP",]
ignore = [ "UP007",]

[tool.coverage.run]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ def test_lazy_load() -> None:
expected.append(
Document(example.inputs["first"]["second"].upper(), metadata=metadata)
)
actual = [doc for doc in loader.lazy_load()]
actual = list(loader.lazy_load())
assert expected == actual
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def test_generic_fake_chat_model_stream() -> None:
]
assert len({chunk.id for chunk in chunks}) == 1

chunks = [chunk for chunk in model.stream("meow")]
chunks = list(model.stream("meow"))
assert chunks == [
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
Expand Down
32 changes: 16 additions & 16 deletions libs/core/tests/unit_tests/indexing/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def test_index_simple_delete_full(
):
indexing_result = index(loader, record_manager, vector_store, cleanup="full")

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"mutated document 1", "This is another document."}

assert indexing_result == {
Expand Down Expand Up @@ -267,11 +267,11 @@ async def test_aindex_simple_delete_full(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"mutated document 1", "This is another document."}

# Attempt to index again verify that nothing changes
Expand Down Expand Up @@ -558,11 +558,11 @@ def test_incremental_delete(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"This is another document.", "This is a test document."}

# Attempt to index again verify that nothing changes
Expand Down Expand Up @@ -617,11 +617,11 @@ def test_incremental_delete(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {
"mutated document 1",
"mutated document 2",
Expand Down Expand Up @@ -685,11 +685,11 @@ def test_incremental_indexing_with_batch_size(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"1", "2", "3", "4"}


Expand Down Expand Up @@ -735,11 +735,11 @@ def test_incremental_delete_with_batch_size(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"1", "2", "3", "4"}

# Attempt to index again verify that nothing changes
Expand Down Expand Up @@ -880,11 +880,11 @@ async def test_aincremental_delete(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"This is another document.", "This is a test document."}

# Attempt to index again verify that nothing changes
Expand Down Expand Up @@ -939,11 +939,11 @@ async def test_aincremental_delete(
"num_updated": 0,
}

doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {
"mutated document 1",
"mutated document 2",
Expand Down
Loading

0 comments on commit 3a1b925

Please sign in to comment.