diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index d2ac82cd990..be6e7d608c0 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -337,32 +337,17 @@ const addEndpoint = (m: Awaited>) => ({ }, }); -const hasInferenceAPI = async (m: Awaited>) => { - if (!isHuggingChat) { - return false; - } - - let r: Response; - try { - r = await fetch(`https://huggingface.co/api/models/${m.id}`); - } catch (e) { - console.log(e); - return false; - } - - if (!r.ok) { - logger.warn(`Failed to check if ${m.id} has inference API: ${r.statusText}`); - return false; - } - - const json = await r.json(); - - if (json.cardData.inference === false) { - return false; - } - - return true; -}; +const inferenceApiIds = isHuggingChat + ? await fetch( + "https://huggingface.co/api/models?pipeline_tag=text-generation&inference=warm&filter=conversational" + ) + .then((r) => r.json()) + .then((json) => json.map((r: { id: string }) => r.id)) + .catch((err) => { + logger.error(err, "Failed to fetch inference API ids"); + return []; + }) + : []; export const models = await Promise.all( modelsRaw.map((e) => @@ -370,7 +355,7 @@ export const models = await Promise.all( .then(addEndpoint) .then(async (m) => ({ ...m, - hasInferenceAPI: await hasInferenceAPI(m), + hasInferenceAPI: inferenceApiIds.includes(m.id ?? m.name), })) ) );