diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5f5fefe..bd31a09 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: - name: Install package run: | - pip install ".[dev]" + pip install ".[tensorflow,dev]" - name: Lint with flake8 run: | @@ -45,7 +45,7 @@ jobs: - name: Upload coverage to Coveralls if: ${{ github.ref == 'refs/heads/main' }} run: | - coveralls --service=github + coveralls -i --service=github env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COVERALLS_FLAG_NAME: ${{ matrix.python-version }} @@ -74,6 +74,6 @@ jobs: - name: Finished if: ${{ github.ref == 'refs/heads/main' }} run: | - coveralls --service=github --finish + coveralls -i --service=github --finish env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index 8f87765..6ec3f44 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,11 @@ You can use `pip` to install the latest version of UniSim: pip install unisim ``` -By default, UniSim uses [Onnx](https://github.com/onnx/onnx) as the runtime. You can switch to using TensorFlow by setting the `BACKEND` environment variable (e.g. `os.environ["BACKEND"] = "tf"`). +By default, UniSim uses [Onnx](https://github.com/onnx/onnx) when running on CPU, and [TensorFlow](https://www.tensorflow.org/) for GPU acceleration. You can switch backends by setting the `BACKEND` environment variable (e.g. `os.environ["BACKEND"] = "tf"` or `"onnx"`). If you have a GPU, you can additionally install TensorFlow using: +2 +``` +pip install unisim[tensorflow] +``` ## Text UniSim (TextSim) diff --git a/notebooks/unisim_text_demo.ipynb b/notebooks/unisim_text_demo.ipynb index 8743a64..def2d6a 100644 --- a/notebooks/unisim_text_demo.ipynb +++ b/notebooks/unisim_text_demo.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Text UniSim Demo\n", + "## Text UniSim Demo -- Fuzzy Address Matching\n", "\n", "This demo showcases how to use Text UniSim (TextSim) for efficient fuzzy string matching, near-duplicate detection, and string similarity using a real-world entity matching dataset.\n", "\n", @@ -21,7 +21,7 @@ "output_type": "stream", "text": [ "INFO: Loaded backend\n", - "INFO: Using TF with GPU\n" + "INFO: Using ONNX with CPU\n" ] }, { @@ -88,8 +88,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.10/dist-packages/huggingface_hub/repocard.py:105: UserWarning: Repo card metadata block was not found. Setting CardData to empty.\n", - " warnings.warn(\"Repo card metadata block was not found. Setting CardData to empty.\")\n" + "Repo card metadata block was not found. Setting CardData to empty.\n" ] }, { @@ -150,21 +149,13 @@ "execution_count": 6, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer RandomNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", - " warnings.warn(\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n", "INFO: UniSim is storing a copy of the indexed data\n", - "INFO: If you are using large data corpus, consider disabling this behavior using store_data=False\n" + "INFO: If you are using large data corpus, consider disabling this behavior using store_data=False\n", + "INFO: Accelerator is not available, using CPU\n" ] } ], @@ -227,35 +218,35 @@ " 15 Romolo (415) 398-1359 15 Romolo Place, San ...\n", " 15 Romolo (415) 398-1359 15 Romolo Pl, San Fra...\n", " 1\n", - " 0.971069\n", + " 0.971120\n", " \n", " \n", " 1\n", " 456 Shanghai Cuisine 1261 69 Mott Street, New ...\n", " Shanghai Asian Manor (212) 766-6311 21 Mott St...\n", " 0\n", - " 0.779613\n", + " 0.779699\n", " \n", " \n", " 2\n", " 5A5 Steak Lounge (415) 989-2539 244 Jackson St...\n", " Delicious Dim Sum (415) 781-0721 752 Jackson S...\n", " 0\n", - " 0.687863\n", + " 0.687844\n", " \n", " \n", " 3\n", " 9th Street Pizza (213) 627-7798 231 E 9th St, ...\n", " Han Bat Sul Lung Tang (213) 383-9499 4163 W 5t...\n", " 0\n", - " 0.610497\n", + " 0.610611\n", " \n", " \n", " 4\n", " 9th Street Pizza (213) 627-7798 231 E 9th St, ...\n", " Jun Won Restaurant (213) 383-8855 3100 W 8th S...\n", " 0\n", - " 0.617363\n", + " 0.617366\n", " \n", " \n", "\n", @@ -270,11 +261,11 @@ "4 9th Street Pizza (213) 627-7798 231 E 9th St, ... \n", "\n", " text2 match_label similarity \n", - "0 15 Romolo (415) 398-1359 15 Romolo Pl, San Fra... 1 0.971069 \n", - "1 Shanghai Asian Manor (212) 766-6311 21 Mott St... 0 0.779613 \n", - "2 Delicious Dim Sum (415) 781-0721 752 Jackson S... 0 0.687863 \n", - "3 Han Bat Sul Lung Tang (213) 383-9499 4163 W 5t... 0 0.610497 \n", - "4 Jun Won Restaurant (213) 383-8855 3100 W 8th S... 0 0.617363 " + "0 15 Romolo (415) 398-1359 15 Romolo Pl, San Fra... 1 0.971120 \n", + "1 Shanghai Asian Manor (212) 766-6311 21 Mott St... 0 0.779699 \n", + "2 Delicious Dim Sum (415) 781-0721 752 Jackson S... 0 0.687844 \n", + "3 Han Bat Sul Lung Tang (213) 383-9499 4163 W 5t... 0 0.610611 \n", + "4 Jun Won Restaurant (213) 383-8855 3100 W 8th S... 0 0.617366 " ] }, "metadata": {}, @@ -352,21 +343,21 @@ " 0\n", " Shanghai asia manor (212)-766-6311 21 Mott street, New York, NY 94133\n", " Shanghai Asian Manor (212) 766-6311 21 Mott St, New York, NY 10013\n", - " 0.906984\n", + " 0.907003\n", " True\n", " \n", " \n", " 1\n", " Googleplex (650) 253-0000 1600 Amphitheatre Pkwy, Mountain View, CA 94043\n", " 15 Romolo (415) 398-1359 15 Romolo Pl, San Francisco, CA 94133\n", - " 0.466477\n", + " 0.466519\n", " False\n", " \n", " \n", " 2\n", " Sino-american books & arts (415) 421-3345 751 Jackson St, San Francisco, CA 94133\n", " Delicious Dim Sum (415) 781-0721 752 Jackson St, San Francisco, CA 94133\n", - " 0.746303\n", + " 0.746313\n", " False\n", " \n", " \n", @@ -385,9 +376,9 @@ "2 Delicious Dim Sum (415) 781-0721 752 Jackson St, San Francisco, CA 94133 \n", "\n", " similarity is_match \n", - "0 0.906984 True \n", - "1 0.466477 False \n", - "2 0.746303 False " + "0 0.907003 True \n", + "1 0.466519 False \n", + "2 0.746313 False " ] }, "metadata": {}, @@ -434,11 +425,11 @@ "output_type": "stream", "text": [ "Dataset examples:\n", - "Thai Noodles (608) 270-9527 5957 Mckee Rd, Fitchburg, WI\n", - "Baby Blues BBQ (415) 896-4250 3149 Mission Street, San Francisco, CA\n", - "Capriotti's Sandwich Shop (608) 255-2227 902 Regent St, Madison, WI\n", - "True Coffee Roasters (608) 277-1455 6250 Nesbitt Rd, Fitchburg, WI\n", - "Carmine's (212) 221-3800 200 W 44th Street, New York, NY\n" + "Batter & Berries (773) 248-7710 2748 N. Lincoln Avenue, Chicago, IL\n", + "Joey's Seafood & Grill (608) 829-0093 6604 Mineral Pt Rd, Madison, WI\n", + "Flaming Wok (608) 240-1085 4237 Lien Rd Ste H, Madison, WI\n", + "Shallots Bistro (847) 677-3463 7016 Carpenter Road, Skokie, IL\n", + "Yang Chow (213) 625-0811 819 N Broadway, Los Angeles, CA\n" ] } ], @@ -485,72 +476,72 @@ " \n", " \n", " 0\n", - " North and South Seafood & Smokehouse (608) 829-0093 6604 Mineral Point Rd, Madison, WI 53705\n", - " Joey's Seafood & Grill (608) 829-0093 6604 Mineral Pt Rd, Madison, WI\n", - " 0.759907\n", + " Hong Kong Buffet (608) 240-0762 2817 E Washington Ave, Madison, WI 53704\n", + " Brink Lounge (608) 661-8599 701 E Washington Ave, Madison, WI\n", + " 0.768695\n", " False\n", " \n", " \n", " 1\n", - " BAR Ama (213) 687-8002 118 W 4th St, Los Angeles, CA 90013\n", - " The Edison (213) 613-0000 108 W 2nd St, Los Angeles, CA\n", - " 0.792568\n", + " Ruay Thai (212) 545-7829 625 2nd Ave, New York, NY 10016\n", + " Ruay Thai Restaurant (212) 545-7829 625 Second Avenue, New York, NY, NY\n", + " 0.876181\n", " False\n", " \n", " \n", " 2\n", - " Casellula (212) 247-8137 401 W 52nd St, New York, NY 10019\n", - " Casellula Cheese and Wine Cafe (212) 247-8137 401 W 52nd Street, New York, NY\n", - " 0.850503\n", - " False\n", + " Kabul Restaurant (608) 256-6322 540 State St, Madison, WI 53703\n", + " Kabul Afghanistan Restaurant (608) 256-6322 540 State St, Madison, WI\n", + " 0.916992\n", + " True\n", " \n", " \n", " 3\n", " Saigon Sandwich (415) 474-5698 560 Larkin St, San Francisco, CA 94102\n", " Saigon Sandwich (415) 474-5698 560 Larkin Street, San Francisco, CA\n", - " 0.945792\n", + " 0.945782\n", " True\n", " \n", " \n", " 4\n", - " Via Lima (773) 348-4900 4024 N Lincoln Ave, Chicago, IL 60618\n", - " Via Lima (773) 348-4900 4024 N. Lincoln Avenue, Chicago, IL\n", - " 0.941480\n", - " True\n", + " Ricks Olde Gold (608) 257-7280 1314 Williamson St Frnt, Madison, WI 53703\n", + " Cafe Costa Rica (608) 256-9830 1133 Williamson St, Madison, WI\n", + " 0.719272\n", + " False\n", " \n", " \n", " 5\n", - " Toe Bang Cafe (213) 387-4905 3465 W 6th St, Ste 110, Los Angeles, CA 90020\n", - " The Escondite (213) 626-1800 410 Boyd St, Los Angeles, CA\n", - " 0.769392\n", + " Maharaja Restaurant (608) 246-8525 1707 Thierer Rd, Madison, WI 53704\n", + " Maharana (608) 246-8525 1707 Thierer Rd, Madison, WI\n", + " 0.886549\n", " False\n", " \n", " \n", " 6\n", - " Flaming Wok (608) 240-1085 4237 Lien Rd, Madison, WI 53704\n", - " Flaming Wok (608) 240-1085 4237 Lien Rd Ste H, Madison, WI\n", - " 0.946072\n", - " True\n", + " J J Fish (773) 533-1995 816 N Kedzie Ave, Chicago, IL 60651\n", + " El Cid (773) 395-0505 2645 N. Kedzie Avenue, Chicago, IL\n", + " 0.757653\n", + " False\n", " \n", " \n", " 7\n", - " Lazy Janes (608) 257-5263 1358 Williamson St, Madison, WI 53703\n", - " Willalby's Cafe (608) 259-9032 1351 Williamson St, Madison, WI\n", - " 0.755470\n", + " Clark St. Ale House (877) 637-7133 742 N Clark St, Chicago, IL 60654\n", + " Roka Akor (312) 477-7652 456 N. Clark Street, Chicago, IL\n", + " 0.754152\n", " False\n", " \n", " \n", " 8\n", - " Eno Vino Wine Bar & Bistro (608) 664-9565 601 Junction Rd, Madison, WI 53717\n", - " Eno Vino Wine Bar & Bistro (608) 664-9565 601 Junction Rd, Madison, WI\n", - " 0.982187\n", - " True\n", + " The Taco Bros (608) 422-5075 604 E University Ave, Madison, WI 53715\n", + " The Taco Shop (608) 250-8226 604 University Ave, Madison, WI\n", + " 0.828169\n", + " False\n", " \n", " \n", " 9\n", - " Hong Kong Station (608) 661-8288 1441 Regent St, Madison, WI 53711\n", - " Rising Sons (608) 661-4334 617 State St, Madison, WI\n", - " 0.781205\n", + " Au Lac DTLA (213) 617-2533 710 W 1st St, Los Angeles, CA 90012\n", + " The Little Door (323) 951-1210 8164 W 3rd St, Los Angeles, CA\n", + " 0.760763\n", " False\n", " \n", " \n", @@ -558,41 +549,41 @@ "" ], "text/plain": [ - " query \\\n", - "0 North and South Seafood & Smokehouse (608) 829-0093 6604 Mineral Point Rd, Madison, WI 53705 \n", - "1 BAR Ama (213) 687-8002 118 W 4th St, Los Angeles, CA 90013 \n", - "2 Casellula (212) 247-8137 401 W 52nd St, New York, NY 10019 \n", - "3 Saigon Sandwich (415) 474-5698 560 Larkin St, San Francisco, CA 94102 \n", - "4 Via Lima (773) 348-4900 4024 N Lincoln Ave, Chicago, IL 60618 \n", - "5 Toe Bang Cafe (213) 387-4905 3465 W 6th St, Ste 110, Los Angeles, CA 90020 \n", - "6 Flaming Wok (608) 240-1085 4237 Lien Rd, Madison, WI 53704 \n", - "7 Lazy Janes (608) 257-5263 1358 Williamson St, Madison, WI 53703 \n", - "8 Eno Vino Wine Bar & Bistro (608) 664-9565 601 Junction Rd, Madison, WI 53717 \n", - "9 Hong Kong Station (608) 661-8288 1441 Regent St, Madison, WI 53711 \n", + " query \\\n", + "0 Hong Kong Buffet (608) 240-0762 2817 E Washington Ave, Madison, WI 53704 \n", + "1 Ruay Thai (212) 545-7829 625 2nd Ave, New York, NY 10016 \n", + "2 Kabul Restaurant (608) 256-6322 540 State St, Madison, WI 53703 \n", + "3 Saigon Sandwich (415) 474-5698 560 Larkin St, San Francisco, CA 94102 \n", + "4 Ricks Olde Gold (608) 257-7280 1314 Williamson St Frnt, Madison, WI 53703 \n", + "5 Maharaja Restaurant (608) 246-8525 1707 Thierer Rd, Madison, WI 53704 \n", + "6 J J Fish (773) 533-1995 816 N Kedzie Ave, Chicago, IL 60651 \n", + "7 Clark St. Ale House (877) 637-7133 742 N Clark St, Chicago, IL 60654 \n", + "8 The Taco Bros (608) 422-5075 604 E University Ave, Madison, WI 53715 \n", + "9 Au Lac DTLA (213) 617-2533 710 W 1st St, Los Angeles, CA 90012 \n", "\n", - " target \\\n", - "0 Joey's Seafood & Grill (608) 829-0093 6604 Mineral Pt Rd, Madison, WI \n", - "1 The Edison (213) 613-0000 108 W 2nd St, Los Angeles, CA \n", - "2 Casellula Cheese and Wine Cafe (212) 247-8137 401 W 52nd Street, New York, NY \n", - "3 Saigon Sandwich (415) 474-5698 560 Larkin Street, San Francisco, CA \n", - "4 Via Lima (773) 348-4900 4024 N. Lincoln Avenue, Chicago, IL \n", - "5 The Escondite (213) 626-1800 410 Boyd St, Los Angeles, CA \n", - "6 Flaming Wok (608) 240-1085 4237 Lien Rd Ste H, Madison, WI \n", - "7 Willalby's Cafe (608) 259-9032 1351 Williamson St, Madison, WI \n", - "8 Eno Vino Wine Bar & Bistro (608) 664-9565 601 Junction Rd, Madison, WI \n", - "9 Rising Sons (608) 661-4334 617 State St, Madison, WI \n", + " target \\\n", + "0 Brink Lounge (608) 661-8599 701 E Washington Ave, Madison, WI \n", + "1 Ruay Thai Restaurant (212) 545-7829 625 Second Avenue, New York, NY, NY \n", + "2 Kabul Afghanistan Restaurant (608) 256-6322 540 State St, Madison, WI \n", + "3 Saigon Sandwich (415) 474-5698 560 Larkin Street, San Francisco, CA \n", + "4 Cafe Costa Rica (608) 256-9830 1133 Williamson St, Madison, WI \n", + "5 Maharana (608) 246-8525 1707 Thierer Rd, Madison, WI \n", + "6 El Cid (773) 395-0505 2645 N. Kedzie Avenue, Chicago, IL \n", + "7 Roka Akor (312) 477-7652 456 N. Clark Street, Chicago, IL \n", + "8 The Taco Shop (608) 250-8226 604 University Ave, Madison, WI \n", + "9 The Little Door (323) 951-1210 8164 W 3rd St, Los Angeles, CA \n", "\n", " similarity is_match \n", - "0 0.759907 False \n", - "1 0.792568 False \n", - "2 0.850503 False \n", - "3 0.945792 True \n", - "4 0.941480 True \n", - "5 0.769392 False \n", - "6 0.946072 True \n", - "7 0.755470 False \n", - "8 0.982187 True \n", - "9 0.781205 False " + "0 0.768695 False \n", + "1 0.876181 False \n", + "2 0.916992 True \n", + "3 0.945782 True \n", + "4 0.719272 False \n", + "5 0.886549 False \n", + "6 0.757653 False \n", + "7 0.754152 False \n", + "8 0.828169 False \n", + "9 0.760763 False " ] }, "metadata": {}, @@ -672,16 +663,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Query 0: \"North and South Seafood & Smokehouse (608) 829-0093 6604 Mineral Point Rd, Madison, WI 53705\"\n", + "Query 0: \"Hong Kong Buffet (608) 240-0762 2817 E Washington Ave, Madison, WI 53704\"\n", "Most similar matches:\n", "\n", " idx is_match similarity text\n", "----- ---------- ------------ ----------------------------------------------------------------\n", - " 19 False 0.76 Joey's Seafood & Grill (608) 829-0093 6604 Mineral Pt Rd, Madiso\n", - " 172 False 0.72 Noodles & Company (608) 829-0202 7050 Mineral Point Rd, Madison,\n", - " 34 False 0.7 Silly Yak Bakery (608) 833-5965 7866 Mineral Point Road, Madison\n", - " 52 False 0.66 Swagat Restaurant (608) 836-9399 707 N High Point Rd, Madison, W\n", - " 229 False 0.64 Pho Nam (608) 836-7040 610 Junction Rd Suite 109, Madison, WI\n" + " 179 False 0.77 Brink Lounge (608) 661-8599 701 E Washington Ave, Madison, WI\n", + " 289 False 0.76 Barriques Coffee (608) 268-6264 127 W Washington Ave, Madison, W\n", + " 231 False 0.76 Athens Gyros (608) 246-7733 1860 E Washington Ave, Madison, WI\n", + " 51 False 0.76 Taco Bell (608) 249-7312 4120 E Washington Ave, Madison, WI\n", + " 98 False 0.73 Einstein Bros Bagels (608) 242-9889 3904 E Washington Ave, Madis\n" ] } ], @@ -701,16 +692,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Query 3: \"Saigon Sandwich (415) 474-5698 560 Larkin St, San Francisco, CA 94102\"\n", + "Query 2: \"Kabul Restaurant (608) 256-6322 540 State St, Madison, WI 53703\"\n", "Most similar matches:\n", "\n", " idx is_match similarity text\n", "----- ---------- ------------ ----------------------------------------------------------------\n", - " 37 True 0.95 Saigon Sandwich (415) 474-5698 560 Larkin Street, San Francisco,\n", - " 228 False 0.72 Rin's (415) 821-4776 4301 24th Street, San Francisco, CA\n", - " 26 False 0.72 La Santaneca (415) 648-1034 3781 Mission Street, San Francisco,\n", - " 176 False 0.71 Pakwan (415) 255-2440 3182 16th Street, San Francisco, CA\n", - " 48 False 0.7 Rosamunde Sausage Grill (415) 437-6851 545 Haight Street, San Fr\n" + " 17 True 0.92 Kabul Afghanistan Restaurant (608) 256-6322 540 State St, Madiso\n", + " 66 False 0.79 State Street Brats (608) 255-5544 603 State St, Madison, WI\n", + " 35 False 0.79 Parthenon Gyros Restaurant (608) 251-6311 316 State St, Madison,\n", + " 109 False 0.78 Wasabi Japanese Restaurant (608) 255-5020 449 State St Ste 2G, M\n", + " 186 False 0.75 Mediterranean Cafe (608) 251-8510 625 State St, Madison, WI\n" ] } ], @@ -748,10 +739,10 @@ " idx is_match similarity text\n", "----- ---------- ------------ ----------------------------------------------------------------\n", " 304 True 0.92 Googleplex (650) 253-0000 1600 Amphitheatre Parkway, Mountain Vi\n", - " 135 False 0.52 KFC (608) 849-5004 600 W Main St, Waunakee, WI\n", - " 115 False 0.51 Gus's Diner (608) 318-0900 630 N Westmount Dr, Sun Prairie, WI\n", - " 168 False 0.5 Sweet Maple (415) 655-9169 2101 Sutter Street, San Francisco, CA\n", - " 229 False 0.49 Pho Nam (608) 836-7040 610 Junction Rd Suite 109, Madison, WI\n" + " 150 False 0.52 KFC (608) 849-5004 600 W Main St, Waunakee, WI\n", + " 299 False 0.51 Gus's Diner (608) 318-0900 630 N Westmount Dr, Sun Prairie, WI\n", + " 116 False 0.5 Sweet Maple (415) 655-9169 2101 Sutter Street, San Francisco, CA\n", + " 267 False 0.49 Pho Nam (608) 836-7040 610 Junction Rd Suite 109, Madison, WI\n" ] } ], @@ -770,7 +761,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.12 64-bit", + "display_name": "Python 3.10.6 64-bit", "language": "python", "name": "python3" }, @@ -784,7 +775,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.6" }, "orig_nbformat": 4, "vscode": { diff --git a/setup.py b/setup.py index c31af56..0fdddab 100644 --- a/setup.py +++ b/setup.py @@ -41,14 +41,14 @@ def get_version(rel_path): "tabulate", "numpy", "tqdm", - "onnx", "jaxtyping", - "onnxruntime-gpu", + "onnx", + "onnxruntime", "pandas", - "tensorflow>=2.11,<2.16", "usearch>=2.6.0", ], extras_require={ + "tensorflow": ["tensorflow>=2.11,<2.16"], "dev": [ "datasets", "mypy", diff --git a/unisim/__init__.py b/unisim/__init__.py index 73b9c69..6cf614f 100644 --- a/unisim/__init__.py +++ b/unisim/__init__.py @@ -4,5 +4,5 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. -__version__ = "0.0.2" +__version__ = "1.0.0" from .textsim import TextSim # noqa: F401 diff --git a/unisim/backend/load_backend.py b/unisim/backend/load_backend.py index 6f6301b..e212654 100644 --- a/unisim/backend/load_backend.py +++ b/unisim/backend/load_backend.py @@ -30,6 +30,18 @@ except ImportError: TF_AVAILABLE = False +# detect accelerator +if TF_AVAILABLE or get_backend() == BackendType.tf: + devices_types = [d.device_type for d in tf.config.list_physical_devices()] + + if "GPU" in devices_types: + set_accelerator(AcceleratorType.gpu) + else: + set_accelerator(AcceleratorType.cpu) + +else: + set_accelerator(AcceleratorType.cpu) + # choose backend if not set by user accel = get_accelerator() backend = get_backend() @@ -41,7 +53,7 @@ elif accel == AcceleratorType.cpu: # on CPU always onnx set_backend(BackendType.onnx) - elif TF_AVAILABLE: + elif TF_AVAILABLE and accel == AcceleratorType.gpu: # on GPU use TF by default set_backend(BackendType.tf) else: @@ -50,15 +62,10 @@ # post detection if get_backend() == BackendType.onnx: - import onnxruntime as rt - from .onnx import * # noqa: F403, F401 - # FIXME onnx accelerator type support - if rt.get_device() == "GPU": - set_accelerator(AcceleratorType.gpu) - else: - set_accelerator(AcceleratorType.cpu) + # FIXME(marinazh): onnx accelerator type support + set_accelerator(AcceleratorType.cpu) elif get_backend() == BackendType.tf: from .tf import * # type: ignore # noqa: F403, F401 diff --git a/unisim/backend/tf.py b/unisim/backend/tf.py index 936f678..503d1ed 100644 --- a/unisim/backend/tf.py +++ b/unisim/backend/tf.py @@ -12,7 +12,6 @@ from tensorflow import Tensor from tensorflow.keras import Model -# typing from ..types import BatchEmbeddings diff --git a/unisim/textsim.py b/unisim/textsim.py index 841742d..74705dc 100644 --- a/unisim/textsim.py +++ b/unisim/textsim.py @@ -4,6 +4,8 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. +from __future__ import annotations + from typing import Any, Dict, List, Sequence from pandas import DataFrame diff --git a/unisim/types.py b/unisim/types.py index 0f90086..cb23abb 100644 --- a/unisim/types.py +++ b/unisim/types.py @@ -4,15 +4,17 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. -from __future__ import annotations - -from typing import Union +from typing import TYPE_CHECKING, Union from jaxtyping import Float32 -from numpy import ndarray -from tensorflow import Tensor -Array = Union[Tensor, ndarray] +if TYPE_CHECKING: + from numpy import ndarray + from tensorflow import Tensor + + Array = Union[Tensor, ndarray] +else: + from numpy import ndarray as Array # Embeddings Embedding = Float32[Array, "embedding"]