diff --git a/xinference/model/embedding/custom.py b/xinference/model/embedding/custom.py index 0e86c61bc7..9ed106d330 100644 --- a/xinference/model/embedding/custom.py +++ b/xinference/model/embedding/custom.py @@ -26,6 +26,7 @@ class CustomEmbeddingModelSpec(EmbeddingModelSpec): + model_id: Optional[str] # type: ignore model_revision: Optional[str] # type: ignore model_uri: Optional[str] diff --git a/xinference/web/ui/src/scenes/register_model/index.js b/xinference/web/ui/src/scenes/register_model/index.js index ab9a6b1161..1a1a8a15b6 100644 --- a/xinference/web/ui/src/scenes/register_model/index.js +++ b/xinference/web/ui/src/scenes/register_model/index.js @@ -1,3 +1,4 @@ +import { TabContext, TabList, TabPanel } from '@mui/lab' import { Box, Checkbox, @@ -6,6 +7,7 @@ import { FormHelperText, Radio, RadioGroup, + Tab, } from '@mui/material' import Alert from '@mui/material/Alert' import AlertTitle from '@mui/material/AlertTitle' @@ -17,6 +19,7 @@ import { ApiContext } from '../../components/apiContext' import ErrorMessageSnackBar from '../../components/errorMessageSnackBar' import Title from '../../components/Title' import { useMode } from '../../theme' +import RegisterEmbeddingModel from './register_embedding' const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } const SUPPORTED_FEATURES = ['Generate', 'Chat'] @@ -44,6 +47,7 @@ const RegisterModel = () => { }) const [promptStyleLabel, setPromptStyleLabel] = useState('vicuna') const [promptStyles, setPromptStyles] = useState([]) + const [tabValue, setTabValue] = React.useState('1') // model name must be // 1. Starts with an alphanumeric character (a letter or a digit). @@ -233,262 +237,283 @@ const RegisterModel = () => { <ErrorMessageSnackBar /> - <Box padding="20px"></Box> - - {/* Base Information */} - <FormControl sx={styles.baseFormControl}> - <TextField - label="Model Name" - error={errorModelName} - defaultValue={formData.model_name} - size="small" - helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." - onChange={(event) => - setFormData({ ...formData, model_name: event.target.value }) - } - /> - <Box padding="15px"></Box> + <TabContext value={tabValue}> + <Box sx={{ borderBottom: 1, borderColor: 'divider' }}> + <TabList + value={tabValue} + onChange={(e, v) => { + setTabValue(v) + }} + aria-label="tabs" + > + <Tab label="Language Model" value="1" /> + <Tab label="Embedding Model" value="2" /> + </TabList> + </Box> + <TabPanel value="1" sx={{ padding: 0 }}> + <Box padding="20px"></Box> + {/* Base Information */} + <FormControl sx={styles.baseFormControl}> + <TextField + label="Model Name" + error={errorModelName} + defaultValue={formData.model_name} + size="small" + helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." + onChange={(event) => + setFormData({ ...formData, model_name: event.target.value }) + } + /> + <Box padding="15px"></Box> - <label - style={{ - paddingLeft: 5, - }} - > - Model Format - </label> + <label + style={{ + paddingLeft: 5, + }} + > + Model Format + </label> - <RadioGroup - value={modelFormat} - onChange={(e) => { - setModelFormat(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="pytorch" - control={<Radio />} - label="PyTorch" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggmlv3" - control={<Radio />} - label="GGML" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggufv2" - control={<Radio />} - label="GGUF" - /> - </Box> - </Box> - </RadioGroup> - <Box padding="15px"></Box> - - <TextField - error={errorContextLength} - label="Context Length" - value={formData.context_length} - size="small" - onChange={(event) => { - let value = event.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - // Update with the processed value - setFormData({ - ...formData, - context_length: Number(value), - }) - }} - /> - <Box padding="15px"></Box> + <RadioGroup + value={modelFormat} + onChange={(e) => { + setModelFormat(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="pytorch" + control={<Radio />} + label="PyTorch" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggmlv3" + control={<Radio />} + label="GGML" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggufv2" + control={<Radio />} + label="GGUF" + /> + </Box> + </Box> + </RadioGroup> + <Box padding="15px"></Box> - <TextField - label="Model Size in Billions" - size="small" - error={errorModelSize} - value={modelSize} - onChange={(e) => { - let value = e.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - setModelSize(Number(value)) - }} - /> - <Box padding="15px"></Box> + <TextField + error={errorContextLength} + label="Context Length" + value={formData.context_length} + size="small" + onChange={(event) => { + let value = event.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + // Update with the processed value + setFormData({ + ...formData, + context_length: Number(value), + }) + }} + /> + <Box padding="15px"></Box> - <TextField - label="Model Path" - size="small" - value={modelUri} - onChange={(e) => { - setModelUri(e.target.value) - }} - helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." - /> - <Box padding="15px"></Box> + <TextField + label="Model Size in Billions" + size="small" + error={errorModelSize} + value={modelSize} + onChange={(e) => { + let value = e.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + setModelSize(Number(value)) + }} + /> + <Box padding="15px"></Box> - <TextField - label="Model Description (Optional)" - error={errorModelDescription} - defaultValue={formData.model_description} - size="small" - onChange={(event) => - setFormData({ ...formData, model_description: event.target.value }) - } - /> - <Box padding="15px"></Box> + <TextField + label="Model Path" + size="small" + value={modelUri} + onChange={(e) => { + setModelUri(e.target.value) + }} + helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." + /> + <Box padding="15px"></Box> - <label - style={{ - paddingLeft: 5, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} - > - Model Languages - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_LANGUAGES.map((lang) => ( - <Box key={lang} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_lang.includes(lang)} - onChange={() => toggleLanguage(lang)} - name={lang} - sx={ - errorLanguage - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} - } - /> - } - label={SUPPORTED_LANGUAGES_DICT[lang]} - style={{ - paddingLeft: 10, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} - /> - </Box> - ))} - </Box> - <Box padding="15px"></Box> + <TextField + label="Model Description (Optional)" + error={errorModelDescription} + defaultValue={formData.model_description} + size="small" + onChange={(event) => + setFormData({ + ...formData, + model_description: event.target.value, + }) + } + /> + <Box padding="15px"></Box> - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Abilities - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_FEATURES.map((ability) => ( - <Box key={ability} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_ability.includes( - ability.toLowerCase() - )} - onChange={() => toggleAbility(ability.toLowerCase())} - name={ability} - sx={ - errorAbility - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} + <label + style={{ + paddingLeft: 5, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + > + Model Languages + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_LANGUAGES.map((lang) => ( + <Box key={lang} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_lang.includes(lang)} + onChange={() => toggleLanguage(lang)} + name={lang} + sx={ + errorLanguage + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> } + label={SUPPORTED_LANGUAGES_DICT[lang]} + style={{ + paddingLeft: 10, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} /> - } - label={ability} - style={{ - paddingLeft: 10, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - /> + </Box> + ))} </Box> - ))} - </Box> - <Box padding="15px"></Box> - </FormControl> + <Box padding="15px"></Box> - {formData.model_ability.includes('chat') && ( - <FormControl sx={styles.baseFormControl}> - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Prompt styles - </label> - <FormHelperText> - Select a prompt style that aligns with the training data of your - model. - </FormHelperText> - <RadioGroup - value={promptStyleLabel} - onChange={(e) => { - setPromptStyleLabel(e.target.value) - }} - > + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Abilities + </label> <Box sx={styles.checkboxWrapper}> - {promptStyles.map((p) => ( - <Box sx={{ marginLeft: '10px' }}> + {SUPPORTED_FEATURES.map((ability) => ( + <Box key={ability} sx={{ marginRight: '10px' }}> <FormControlLabel - value={p.name} - control={<Radio />} - label={p.name} + control={ + <Checkbox + checked={formData.model_ability.includes( + ability.toLowerCase() + )} + onChange={() => toggleAbility(ability.toLowerCase())} + name={ability} + sx={ + errorAbility + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> + } + label={ability} + style={{ + paddingLeft: 10, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} /> </Box> ))} </Box> - </RadioGroup> - </FormControl> - )} + <Box padding="15px"></Box> + </FormControl> - <Box width={'100%'}> - {successMsg !== '' && ( - <Alert severity="success"> - <AlertTitle>Success</AlertTitle> - {successMsg} - </Alert> - )} - <Button - variant="contained" - color="primary" - type="submit" - onClick={handleClick} - > - Register Model - </Button> - </Box> + {formData.model_ability.includes('chat') && ( + <FormControl sx={styles.baseFormControl}> + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Prompt styles + </label> + <FormHelperText> + Select a prompt style that aligns with the training data of your + model. + </FormHelperText> + <RadioGroup + value={promptStyleLabel} + onChange={(e) => { + setPromptStyleLabel(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {promptStyles.map((p) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={p.name} + control={<Radio />} + label={p.name} + /> + </Box> + ))} + </Box> + </RadioGroup> + </FormControl> + )} + + <Box width={'100%'}> + {successMsg !== '' && ( + <Alert severity="success"> + <AlertTitle>Success</AlertTitle> + {successMsg} + </Alert> + )} + <Button + variant="contained" + color="primary" + type="submit" + onClick={handleClick} + > + Register Model + </Button> + </Box> + </TabPanel> + <TabPanel value="2" sx={{ padding: 0 }}> + <RegisterEmbeddingModel /> + </TabPanel> + </TabContext> </Box> ) } diff --git a/xinference/web/ui/src/scenes/register_model/register_embedding.js b/xinference/web/ui/src/scenes/register_model/register_embedding.js new file mode 100644 index 0000000000..c54a8be699 --- /dev/null +++ b/xinference/web/ui/src/scenes/register_model/register_embedding.js @@ -0,0 +1,244 @@ +import { Box, Checkbox, FormControl, FormControlLabel } from '@mui/material' +import Alert from '@mui/material/Alert' +import AlertTitle from '@mui/material/AlertTitle' +import Button from '@mui/material/Button' +import TextField from '@mui/material/TextField' +import React, { useContext, useState } from 'react' + +import { ApiContext } from '../../components/apiContext' +import { useMode } from '../../theme' + +const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } +// Convert dictionary of supported languages into list +const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) + +const RegisterEmbeddingModel = () => { + const ERROR_COLOR = useMode() + const endPoint = useContext(ApiContext).endPoint + const { setErrorMsg } = useContext(ApiContext) + const [successMsg, setSuccessMsg] = useState('') + const [formData, setFormData] = useState({ + model_name: 'custom-embedding', + dimensions: 768, + max_tokens: 512, + language: ['en'], + model_uri: '/path/to/embedding-model', + }) + + // model name must be + // 1. Starts with an alphanumeric character (a letter or a digit). + // 2. Followed by any number of alphanumeric characters, underscores (_), or hyphens (-). + const errorModelName = !/^[A-Za-z0-9][A-Za-z0-9_-]*$/.test( + formData.model_name + ) + + const errorDimensions = formData.dimensions < 0 + const errorMaxTokens = formData.max_tokens < 0 + const errorLanguage = + formData.language === undefined || formData.language.length === 0 + + const handleClick = async () => { + const errorAny = + errorModelName || errorDimensions || errorMaxTokens || errorLanguage + + if (errorAny) { + setErrorMsg('Please fill in valid value for all fields') + return + } + + try { + const response = await fetch( + endPoint + '/v1/model_registrations/embedding', + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: JSON.stringify(formData), + persist: true, + }), + } + ) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }` + ) + } else { + setSuccessMsg( + 'Model has been registered successfully! Navigate to launch model page to proceed.' + ) + } + } catch (error) { + console.error('There was a problem with the fetch operation:', error) + setErrorMsg(error.message || 'An unexpected error occurred.') + } + } + + const toggleLanguage = (lang) => { + if (formData.language.includes(lang)) { + setFormData({ + ...formData, + language: formData.language.filter((l) => l !== lang), + }) + } else { + setFormData({ + ...formData, + language: [...formData.language, lang], + }) + } + } + + return ( + <React.Fragment> + <Box padding="20px"></Box> + {/* Base Information */} + <FormControl sx={styles.baseFormControl}> + <TextField + label="Model Name" + error={errorModelName} + defaultValue={formData.model_name} + size="small" + helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." + onChange={(event) => + setFormData({ ...formData, model_name: event.target.value }) + } + /> + <Box padding="15px"></Box> + + <TextField + error={errorDimensions} + label="Dimensions" + value={formData.dimensions} + size="small" + onChange={(event) => { + setFormData({ + ...formData, + dimensions: parseInt(event.target.value, 10), + }) + }} + /> + <Box padding="15px"></Box> + + <TextField + error={errorMaxTokens} + label="Max Tokens" + value={formData.max_tokens} + size="small" + onChange={(event) => { + setFormData({ + ...formData, + max_tokens: parseInt(event.target.value, 10), + }) + }} + /> + <Box padding="15px"></Box> + + <TextField + label="Model Path" + size="small" + value={formData.model_uri} + onChange={(e) => { + setFormData({ + ...formData, + model_uri: e.target.value, + }) + }} + helperText="Provide the model directory path." + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + > + Model Languages + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_LANGUAGES.map((lang) => ( + <Box key={lang} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.language.includes(lang)} + onChange={() => toggleLanguage(lang)} + name={lang} + sx={ + errorLanguage + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> + } + label={SUPPORTED_LANGUAGES_DICT[lang]} + style={{ + paddingLeft: 10, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + /> + </Box> + ))} + </Box> + <Box padding="15px"></Box> + </FormControl> + + <Box width={'100%'}> + {successMsg !== '' && ( + <Alert severity="success"> + <AlertTitle>Success</AlertTitle> + {successMsg} + </Alert> + )} + <Button + variant="contained" + color="primary" + type="submit" + onClick={handleClick} + > + Register Model + </Button> + </Box> + </React.Fragment> + ) +} + +export default RegisterEmbeddingModel + +const styles = { + baseFormControl: { + width: '100%', + margin: 'normal', + size: 'small', + }, + checkboxWrapper: { + display: 'flex', + flexWrap: 'wrap', + maxWidth: '80%', + }, + labelPaddingLeft: { + paddingLeft: 5, + }, + formControlLabelPaddingLeft: { + paddingLeft: 10, + }, + buttonBox: { + width: '100%', + margin: '20px', + }, + error: { + fontWeight: 'bold', + margin: '5px 0', + padding: '1px', + borderRadius: '5px', + }, +}