diff --git a/xinference/model/embedding/custom.py b/xinference/model/embedding/custom.py
index 0e86c61bc7..9ed106d330 100644
--- a/xinference/model/embedding/custom.py
+++ b/xinference/model/embedding/custom.py
@@ -26,6 +26,7 @@
class CustomEmbeddingModelSpec(EmbeddingModelSpec):
+ model_id: Optional[str] # type: ignore
model_revision: Optional[str] # type: ignore
model_uri: Optional[str]
diff --git a/xinference/web/ui/src/scenes/register_model/index.js b/xinference/web/ui/src/scenes/register_model/index.js
index ab9a6b1161..1a1a8a15b6 100644
--- a/xinference/web/ui/src/scenes/register_model/index.js
+++ b/xinference/web/ui/src/scenes/register_model/index.js
@@ -1,3 +1,4 @@
+import { TabContext, TabList, TabPanel } from '@mui/lab'
import {
Box,
Checkbox,
@@ -6,6 +7,7 @@ import {
FormHelperText,
Radio,
RadioGroup,
+ Tab,
} from '@mui/material'
import Alert from '@mui/material/Alert'
import AlertTitle from '@mui/material/AlertTitle'
@@ -17,6 +19,7 @@ import { ApiContext } from '../../components/apiContext'
import ErrorMessageSnackBar from '../../components/errorMessageSnackBar'
import Title from '../../components/Title'
import { useMode } from '../../theme'
+import RegisterEmbeddingModel from './register_embedding'
const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' }
const SUPPORTED_FEATURES = ['Generate', 'Chat']
@@ -44,6 +47,7 @@ const RegisterModel = () => {
})
const [promptStyleLabel, setPromptStyleLabel] = useState('vicuna')
const [promptStyles, setPromptStyles] = useState([])
+ const [tabValue, setTabValue] = React.useState('1')
// model name must be
// 1. Starts with an alphanumeric character (a letter or a digit).
@@ -233,262 +237,283 @@ const RegisterModel = () => {
-
-
- {/* Base Information */}
-
-
- setFormData({ ...formData, model_name: event.target.value })
- }
- />
-
+
+
+ {
+ setTabValue(v)
+ }}
+ aria-label="tabs"
+ >
+
+
+
+
+
+
+ {/* Base Information */}
+
+
+ setFormData({ ...formData, model_name: event.target.value })
+ }
+ />
+
-
+
- {
- setModelFormat(e.target.value)
- }}
- >
-
-
- }
- label="PyTorch"
- />
-
-
- }
- label="GGML"
- />
-
-
- }
- label="GGUF"
- />
-
-
-
-
-
- {
- let value = event.target.value
- // Remove leading zeros
- if (/^0+/.test(value)) {
- value = value.replace(/^0+/, '') || '0'
- }
- // Ensure it's a positive integer, if not set it to the minimum
- if (!/^\d+$/.test(value) || parseInt(value) < 0) {
- value = '0'
- }
- // Update with the processed value
- setFormData({
- ...formData,
- context_length: Number(value),
- })
- }}
- />
-
+ {
+ setModelFormat(e.target.value)
+ }}
+ >
+
+
+ }
+ label="PyTorch"
+ />
+
+
+ }
+ label="GGML"
+ />
+
+
+ }
+ label="GGUF"
+ />
+
+
+
+
- {
- let value = e.target.value
- // Remove leading zeros
- if (/^0+/.test(value)) {
- value = value.replace(/^0+/, '') || '0'
- }
- // Ensure it's a positive integer, if not set it to the minimum
- if (!/^\d+$/.test(value) || parseInt(value) < 0) {
- value = '0'
- }
- setModelSize(Number(value))
- }}
- />
-
+ {
+ let value = event.target.value
+ // Remove leading zeros
+ if (/^0+/.test(value)) {
+ value = value.replace(/^0+/, '') || '0'
+ }
+ // Ensure it's a positive integer, if not set it to the minimum
+ if (!/^\d+$/.test(value) || parseInt(value) < 0) {
+ value = '0'
+ }
+ // Update with the processed value
+ setFormData({
+ ...formData,
+ context_length: Number(value),
+ })
+ }}
+ />
+
- {
- setModelUri(e.target.value)
- }}
- helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path."
- />
-
+ {
+ let value = e.target.value
+ // Remove leading zeros
+ if (/^0+/.test(value)) {
+ value = value.replace(/^0+/, '') || '0'
+ }
+ // Ensure it's a positive integer, if not set it to the minimum
+ if (!/^\d+$/.test(value) || parseInt(value) < 0) {
+ value = '0'
+ }
+ setModelSize(Number(value))
+ }}
+ />
+
-
- setFormData({ ...formData, model_description: event.target.value })
- }
- />
-
+ {
+ setModelUri(e.target.value)
+ }}
+ helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path."
+ />
+
-
-
- {SUPPORTED_LANGUAGES.map((lang) => (
-
- toggleLanguage(lang)}
- name={lang}
- sx={
- errorLanguage
- ? {
- 'color': ERROR_COLOR,
- '&.Mui-checked': {
- color: ERROR_COLOR,
- },
- }
- : {}
- }
- />
- }
- label={SUPPORTED_LANGUAGES_DICT[lang]}
- style={{
- paddingLeft: 10,
- color: errorLanguage ? ERROR_COLOR : 'inherit',
- }}
- />
-
- ))}
-
-
+
+ setFormData({
+ ...formData,
+ model_description: event.target.value,
+ })
+ }
+ />
+
-
-
- {SUPPORTED_FEATURES.map((ability) => (
-
- toggleAbility(ability.toLowerCase())}
- name={ability}
- sx={
- errorAbility
- ? {
- 'color': ERROR_COLOR,
- '&.Mui-checked': {
- color: ERROR_COLOR,
- },
- }
- : {}
+
+
+ {SUPPORTED_LANGUAGES.map((lang) => (
+
+ toggleLanguage(lang)}
+ name={lang}
+ sx={
+ errorLanguage
+ ? {
+ 'color': ERROR_COLOR,
+ '&.Mui-checked': {
+ color: ERROR_COLOR,
+ },
+ }
+ : {}
+ }
+ />
}
+ label={SUPPORTED_LANGUAGES_DICT[lang]}
+ style={{
+ paddingLeft: 10,
+ color: errorLanguage ? ERROR_COLOR : 'inherit',
+ }}
/>
- }
- label={ability}
- style={{
- paddingLeft: 10,
- color: errorAbility ? ERROR_COLOR : 'inherit',
- }}
- />
+
+ ))}
- ))}
-
-
-
+
- {formData.model_ability.includes('chat') && (
-
-
-
- Select a prompt style that aligns with the training data of your
- model.
-
- {
- setPromptStyleLabel(e.target.value)
- }}
- >
+
- {promptStyles.map((p) => (
-
+ {SUPPORTED_FEATURES.map((ability) => (
+
}
- label={p.name}
+ control={
+ toggleAbility(ability.toLowerCase())}
+ name={ability}
+ sx={
+ errorAbility
+ ? {
+ 'color': ERROR_COLOR,
+ '&.Mui-checked': {
+ color: ERROR_COLOR,
+ },
+ }
+ : {}
+ }
+ />
+ }
+ label={ability}
+ style={{
+ paddingLeft: 10,
+ color: errorAbility ? ERROR_COLOR : 'inherit',
+ }}
/>
))}
-
-
- )}
+
+
-
- {successMsg !== '' && (
-
- Success
- {successMsg}
-
- )}
-
-
+ {formData.model_ability.includes('chat') && (
+
+
+
+ Select a prompt style that aligns with the training data of your
+ model.
+
+ {
+ setPromptStyleLabel(e.target.value)
+ }}
+ >
+
+ {promptStyles.map((p) => (
+
+ }
+ label={p.name}
+ />
+
+ ))}
+
+
+
+ )}
+
+
+ {successMsg !== '' && (
+
+ Success
+ {successMsg}
+
+ )}
+
+
+
+
+
+
+
)
}
diff --git a/xinference/web/ui/src/scenes/register_model/register_embedding.js b/xinference/web/ui/src/scenes/register_model/register_embedding.js
new file mode 100644
index 0000000000..c54a8be699
--- /dev/null
+++ b/xinference/web/ui/src/scenes/register_model/register_embedding.js
@@ -0,0 +1,244 @@
+import { Box, Checkbox, FormControl, FormControlLabel } from '@mui/material'
+import Alert from '@mui/material/Alert'
+import AlertTitle from '@mui/material/AlertTitle'
+import Button from '@mui/material/Button'
+import TextField from '@mui/material/TextField'
+import React, { useContext, useState } from 'react'
+
+import { ApiContext } from '../../components/apiContext'
+import { useMode } from '../../theme'
+
+const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' }
+// Convert dictionary of supported languages into list
+const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT)
+
+const RegisterEmbeddingModel = () => {
+ const ERROR_COLOR = useMode()
+ const endPoint = useContext(ApiContext).endPoint
+ const { setErrorMsg } = useContext(ApiContext)
+ const [successMsg, setSuccessMsg] = useState('')
+ const [formData, setFormData] = useState({
+ model_name: 'custom-embedding',
+ dimensions: 768,
+ max_tokens: 512,
+ language: ['en'],
+ model_uri: '/path/to/embedding-model',
+ })
+
+ // model name must be
+ // 1. Starts with an alphanumeric character (a letter or a digit).
+ // 2. Followed by any number of alphanumeric characters, underscores (_), or hyphens (-).
+ const errorModelName = !/^[A-Za-z0-9][A-Za-z0-9_-]*$/.test(
+ formData.model_name
+ )
+
+ const errorDimensions = formData.dimensions < 0
+ const errorMaxTokens = formData.max_tokens < 0
+ const errorLanguage =
+ formData.language === undefined || formData.language.length === 0
+
+ const handleClick = async () => {
+ const errorAny =
+ errorModelName || errorDimensions || errorMaxTokens || errorLanguage
+
+ if (errorAny) {
+ setErrorMsg('Please fill in valid value for all fields')
+ return
+ }
+
+ try {
+ const response = await fetch(
+ endPoint + '/v1/model_registrations/embedding',
+ {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model: JSON.stringify(formData),
+ persist: true,
+ }),
+ }
+ )
+ if (!response.ok) {
+ const errorData = await response.json() // Assuming the server returns error details in JSON format
+ setErrorMsg(
+ `Server error: ${response.status} - ${
+ errorData.detail || 'Unknown error'
+ }`
+ )
+ } else {
+ setSuccessMsg(
+ 'Model has been registered successfully! Navigate to launch model page to proceed.'
+ )
+ }
+ } catch (error) {
+ console.error('There was a problem with the fetch operation:', error)
+ setErrorMsg(error.message || 'An unexpected error occurred.')
+ }
+ }
+
+ const toggleLanguage = (lang) => {
+ if (formData.language.includes(lang)) {
+ setFormData({
+ ...formData,
+ language: formData.language.filter((l) => l !== lang),
+ })
+ } else {
+ setFormData({
+ ...formData,
+ language: [...formData.language, lang],
+ })
+ }
+ }
+
+ return (
+
+
+ {/* Base Information */}
+
+
+ setFormData({ ...formData, model_name: event.target.value })
+ }
+ />
+
+
+ {
+ setFormData({
+ ...formData,
+ dimensions: parseInt(event.target.value, 10),
+ })
+ }}
+ />
+
+
+ {
+ setFormData({
+ ...formData,
+ max_tokens: parseInt(event.target.value, 10),
+ })
+ }}
+ />
+
+
+ {
+ setFormData({
+ ...formData,
+ model_uri: e.target.value,
+ })
+ }}
+ helperText="Provide the model directory path."
+ />
+
+
+
+
+ {SUPPORTED_LANGUAGES.map((lang) => (
+
+ toggleLanguage(lang)}
+ name={lang}
+ sx={
+ errorLanguage
+ ? {
+ 'color': ERROR_COLOR,
+ '&.Mui-checked': {
+ color: ERROR_COLOR,
+ },
+ }
+ : {}
+ }
+ />
+ }
+ label={SUPPORTED_LANGUAGES_DICT[lang]}
+ style={{
+ paddingLeft: 10,
+ color: errorLanguage ? ERROR_COLOR : 'inherit',
+ }}
+ />
+
+ ))}
+
+
+
+
+
+ {successMsg !== '' && (
+
+ Success
+ {successMsg}
+
+ )}
+
+
+
+ )
+}
+
+export default RegisterEmbeddingModel
+
+const styles = {
+ baseFormControl: {
+ width: '100%',
+ margin: 'normal',
+ size: 'small',
+ },
+ checkboxWrapper: {
+ display: 'flex',
+ flexWrap: 'wrap',
+ maxWidth: '80%',
+ },
+ labelPaddingLeft: {
+ paddingLeft: 5,
+ },
+ formControlLabelPaddingLeft: {
+ paddingLeft: 10,
+ },
+ buttonBox: {
+ width: '100%',
+ margin: '20px',
+ },
+ error: {
+ fontWeight: 'bold',
+ margin: '5px 0',
+ padding: '1px',
+ borderRadius: '5px',
+ },
+}