diff --git a/src/agentscope/manager/_file.py b/src/agentscope/manager/_file.py index 8fe93b171..d49f4488b 100644 --- a/src/agentscope/manager/_file.py +++ b/src/agentscope/manager/_file.py @@ -34,13 +34,12 @@ def _get_text_embedding_record_hash( if isinstance(embedding_model, dict): # Format the dict to avoid duplicate keys embedding_model = json.dumps(embedding_model, sort_keys=True) - elif isinstance(embedding_model, str): - embedding_model_hash = _hash_string(embedding_model, hash_method) - else: + elif not isinstance(embedding_model, str): raise RuntimeError( f"The embedding model must be a string or a dict, got " f"{type(embedding_model)}.", ) + embedding_model_hash = _hash_string(embedding_model, hash_method) # Calculate the embedding id by hashing the hash codes of the # original data and the embedding model @@ -48,7 +47,6 @@ def _get_text_embedding_record_hash( original_data_hash + embedding_model_hash, hash_method, ) - return record_hash diff --git a/src/agentscope/rag/llama_index_knowledge.py b/src/agentscope/rag/llama_index_knowledge.py index 142f71068..b886825ff 100644 --- a/src/agentscope/rag/llama_index_knowledge.py +++ b/src/agentscope/rag/llama_index_knowledge.py @@ -203,8 +203,9 @@ def __init__( ) if persist_root is None: - persist_root = FileManager.get_instance().run_dir or "./" + persist_root = FileManager.get_instance().cache_dir or "./" self.persist_dir = os.path.join(persist_root, knowledge_id) + logger.info(f"** persist_dir: {self.persist_dir}") self.emb_model = emb_model self.overwrite_index = overwrite_index self.showprogress = showprogress diff --git a/src/agentscope/service/web/web_digest.py b/src/agentscope/service/web/web_digest.py index 5e39fdc4d..95864be8f 100644 --- a/src/agentscope/service/web/web_digest.py +++ b/src/agentscope/service/web/web_digest.py @@ -3,6 +3,8 @@ import json from urllib.parse import urlparse from typing import Optional, Callable, Sequence, Any +import socket +import ipaddress import requests from loguru import logger @@ -37,12 +39,46 @@ def is_valid_url(url: str) -> bool: return False # A ValueError indicates that the URL is not valid. +def is_internal_ip_address(url: str) -> bool: + """ + Check if a URL is to interal IP addresses + Args: + url (str): url to be checked + + Returns: + bool: True if url is not to interal IP addresses, + False otherwise + """ + parsed_url = urlparse(url) + hostname = parsed_url.hostname + if hostname is None: + # illegal hostname is ignore in this function + return False + + # Resolve the hostname to an IP address + ip = socket.gethostbyname(hostname) + # Check if it's localhost or within the loopback range + if ( + ip.startswith("127.") + or ip == "::1" + or ipaddress.ip_address(ip).is_private + ): + logger.warning( + f"Access to this URL {url} is " + f"restricted because it is private", + ) + return True + + return False + + def load_web( url: str, keep_raw: bool = True, html_selected_tags: Optional[Sequence[str]] = None, self_parse_func: Optional[Callable[[requests.Response], Any]] = None, timeout: int = 5, + exclude_internal_ips: bool = True, ) -> ServiceResponse: """Function for parsing and digesting the web page. @@ -62,6 +98,8 @@ def load_web( The result is stored with `self_define_func` key timeout (int): timeout parameter for requests. + exclude_internal_ips (bool): + whether prevent the function access internal_ips Returns: `ServiceResponse`: If successful, `ServiceResponse` object is returned @@ -87,6 +125,13 @@ def load_web( "selected_tags_text": xxxxx } """ + if exclude_internal_ips and is_internal_ip_address(url): + return ServiceResponse( + ServiceExecStatus.ERROR, + content=f"Access to this URL {url} is restricted " + f"because it is private", + ) + header = { "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6", "Cache-Control": "max-age=0", diff --git a/tests/web_digest_test.py b/tests/web_digest_test.py index 3b62b0e36..d08d0e489 100644 --- a/tests/web_digest_test.py +++ b/tests/web_digest_test.py @@ -58,7 +58,7 @@ def test_web_load(self, mock_get: MagicMock) -> None: mock_get.return_value = mock_response # set parameters - fake_url = "fake-url" + fake_url = "http://fake-url.com" results = load_web( url=fake_url, @@ -100,6 +100,12 @@ def format( expected_result, ) + def test_block_internal_ips(self) -> None: + """test whether can prevent internal_url successfully""" + internal_url = "http://localhost:8080/some/path" + response = load_web(internal_url) + self.assertEqual(ServiceExecStatus.ERROR, response.status) + # This allows the tests to be run from the command line if __name__ == "__main__":