Skip to content

Commit

Permalink
Merge pull request #919 from Gerrnperl/fix-asset-lib
Browse files Browse the repository at this point in the history
Fix Issues Related to Asset Library
  • Loading branch information
JiepengTan authored Sep 29, 2024
2 parents 316c261 + 31e462b commit f69d0c9
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 13 deletions.
2 changes: 1 addition & 1 deletion spx-backend/internal/aigc/aigc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func NewAigcClient(endpoint string) *AigcClient {
return &AigcClient{
endpoint: endpoint,
client: &http.Client{
Timeout: 20 * time.Second,
Timeout: 60 * time.Second,
},
}
}
Expand Down
37 changes: 37 additions & 0 deletions spx-backend/internal/model/milvus.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,38 @@ func SearchByVector(ctx context.Context, cli client.Client, collectionName strin
return assetNames, nil
}

func ExistsMilvusAsset(ctx context.Context, cli client.Client, assetID string) bool {
logger := log.GetReqLogger(ctx)

if cli == nil || assetID == "" {
logger.Printf("Invalid input: %v, %v", cli, assetID)
return false
}

opt := client.SearchQueryOptionFunc(func(option *client.SearchQueryOption) {
option.Limit = 3
option.Offset = 0
option.ConsistencyLevel = entity.ClStrong
option.IgnoreGrowing = false
})

// Search for the asset ID in the collection
_, err := cli.Query(
ctx,
"asset",
[]string{},
"asset_id == '"+assetID+"'",
[]string{"asset_id"},
opt,
)
if err != nil {
logger.Printf("Failed to search: %v", err)
return false
}

return true
}

// Add an asset
func AddMilvusAsset(ctx context.Context, cli client.Client, asset *MilvusAsset) error {
logger := log.GetReqLogger(ctx)
Expand All @@ -92,6 +124,11 @@ func AddMilvusAsset(ctx context.Context, cli client.Client, asset *MilvusAsset)
return nil
}

if ExistsMilvusAsset(ctx, cli, asset.AssetID) {
logger.Printf("Asset %s already exists in Milvus", asset.AssetName)
return nil
}

vector := asset.Vector

columns := []entity.Column{
Expand Down
15 changes: 14 additions & 1 deletion spx-backend/internal/model/user_asset.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,20 @@ const TableUserAsset = "user_asset"
// AddUserAsset adds an asset.
func AddUserAsset(ctx context.Context, db *gorm.DB, p *UserAsset) error {
logger := log.GetReqLogger(ctx)
result := db.Create(p)

// check if the asset already exists
var count int64
result := db.Model(&UserAsset{}).Where("asset_id = ? AND relation_type = ? AND owner = ?", p.AssetID, p.RelationType, p.Owner).Count(&count)
if result.Error != nil {
logger.Printf("failed to check if asset exists: %v", result.Error)
return result.Error
}

if count > 0 {
return nil
}

result = db.Create(p)
if result.Error != nil {
logger.Printf("failed to add asset: %v", result.Error)
return result.Error
Expand Down
8 changes: 8 additions & 0 deletions spx-backend/internal/model/user_asset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func TestAddUserAsset(t *testing.T) {
}), &gorm.Config{})
require.NoError(t, err)

mock.ExpectQuery("SELECT count(*) FROM `user_assets` WHERE asset_id = ? AND relation_type = ? AND owner = ?").
WithArgs(1, "owned", "user1").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))

mock.ExpectBegin()
mock.ExpectExec("INSERT INTO `user_assets` (`owner`,`asset_id`,`relation_type`,`relation_timestamp`) VALUES (?,?,?,?)").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).
Expand Down Expand Up @@ -55,6 +59,10 @@ func TestAddUserAsset(t *testing.T) {
SkipInitializeWithVersion: true}), &gorm.Config{})
require.NoError(t, err)

mock.ExpectQuery("SELECT count(*) FROM `user_assets` WHERE asset_id = ? AND relation_type = ? AND owner = ?").
WithArgs(1, "owned", "user1").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))

mock.ExpectBegin()
mock.ExpectExec("INSERT INTO `user_assets` (`owner`,`asset_id`,`relation_type`,`relation_timestamp`) VALUES (?,?,?,?)").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).
Expand Down
131 changes: 131 additions & 0 deletions spx-backend/loadToMilvus/loadToMilvus.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// This script is used to load all assets to Milvus.
//
// The script reads the assets from the database
// and then calls the AIGC service to get the embeddings of the assets
// and inserts them into the Milvus database.
package main

import (
"context"
"database/sql"
"errors"
_ "image/png"
"io/fs"
"net/http"
"os"

_ "github.com/go-sql-driver/mysql"
"github.com/goplus/builder/spx-backend/internal/aigc"
"github.com/goplus/builder/spx-backend/internal/log"
"github.com/joho/godotenv"
milvus "github.com/milvus-io/milvus-sdk-go/v2/client"
_ "github.com/qiniu/go-cdk-driver/kodoblob"
qiniuLog "github.com/qiniu/x/log"

"github.com/goplus/builder/spx-backend/internal/controller"
"github.com/goplus/builder/spx-backend/internal/model"
)

var (
ErrNotExist = errors.New("not exist")
ErrUnauthorized = errors.New("unauthorized")
ErrForbidden = errors.New("forbidden")
)

func Load() (err error) {
logger := log.GetLogger()
ctx := context.Background()

if err := godotenv.Load(); err != nil && !errors.Is(err, fs.ErrNotExist) {
logger.Printf("failed to load env: %v", err)
return err
}

dsn := mustEnv(logger, "GOP_SPX_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
logger.Printf("failed to connect sql: %v", err)
return err
}

aigcClient := aigc.NewAigcClient(mustEnv(logger, "AIGC_ENDPOINT"))

var milvusClient milvus.Client
if os.Getenv("ENV") != "test" && os.Getenv("MILVUS_ADDRESS") != "disabled" {
milvusClient, err = milvus.NewClient(ctx, milvus.Config{
Address: os.Getenv("MILVUS_ADDRESS"),
})
if err != nil {
logger.Printf("failed to create milvus client: %v,%v", err, os.Getenv("MILVUS_ADDRESS"))
return err
}
}

// load all assets from the database
assets, err := LoadAssets(ctx, db)
if err != nil {
logger.Printf("Failed to load assets: %v", err)
return err
}

// for each asset, call embedding service to get the embedding
for i, asset := range assets {
logger.Printf("Processing asset %d/%d: %s", i+1, len(assets), asset.DisplayName)

// check if the asset id is already in the milvus
if model.ExistsMilvusAsset(ctx, milvusClient, asset.ID) {
logger.Printf("Asset %s already exists in Milvus", asset.DisplayName)
continue
}

var embeddingResult controller.GetEmbeddingResult
err = aigcClient.Call(ctx, http.MethodPost, "/embedding", &controller.GetEmbeddingParams{
Prompt: asset.DisplayName,
CallbackUrl: "",
}, &embeddingResult)

if err != nil {
logger.Printf("failed to call: %v", err)
return err
}

// insert the embedding into the milvus
model.AddMilvusAsset(ctx, milvusClient, &model.MilvusAsset{
AssetID: asset.ID,
AssetName: asset.DisplayName,
Vector: embeddingResult.Embedding,
})
}

return nil
}

// LoadAssets loads all assets from the database.
func LoadAssets(ctx context.Context, db *sql.DB) ([]model.Asset, error) {
logger := log.GetReqLogger(ctx)

assets, err := model.ListAssets(ctx, db, model.Pagination{
Index: 1,
Size: 65535,
}, nil, nil, nil)
if err != nil {
logger.Printf("ListAssets failed: %v", err)
return nil, err
}
return assets.Data, nil
}

// mustEnv gets the environment variable value or exits the program.
func mustEnv(logger *qiniuLog.Logger, key string) string {
value := os.Getenv(key)
if value == "" {
logger.Fatalf("Missing required environment variable: %s", key)
}
return value
}

func main() {
if err := Load(); err != nil {
os.Exit(1)
}
}
4 changes: 2 additions & 2 deletions spx-gui/src/components/asset/animation/VideoRecorder.vue
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<Transition name="slide-fade" mode="out-in" appear>
<NSpin v-if="generatingAnimation" size="large" class="recorder-content loading">
<template #description>
<span style="color: white">
<span style="color: white; text-shadow: 0 0 4px #222;">
{{
generatingAnimationMessage ||
$t({
Expand Down Expand Up @@ -325,7 +325,7 @@ const generateAnimation = async () => {
generatingAnimation.value = false
return
}
generatingAnimation.value = false
emit('resolve', materialUrl)
} catch (error: any) {
errorMessage.error(
Expand Down
6 changes: 6 additions & 0 deletions spx-gui/src/components/asset/library/AIAssetItem.vue
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ const loadCloudFiles = async (cloudFiles?: AIGCFiles) => {
flex-direction: row;
align-items: center;
gap: 6px;
white-space: nowrap;
text-wrap: ellipsis;
overflow: hidden;
text-overflow: ellipsis;
width: 100%;
text-align: left;
}
.generating-text {
Expand Down
7 changes: 7 additions & 0 deletions spx-gui/src/components/asset/library/AssetItem.vue
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ const handleAddToProject = () => {
<style lang="scss" scoped>
.asset-item {
--flex-basis: calc(90% / var(--column-count, 5));
width: 0;
flex: 0 1 var(--flex-basis);
display: flex;
flex-direction: column;
Expand Down Expand Up @@ -210,6 +211,12 @@ const handleAddToProject = () => {
.asset-name {
margin-top: 8px;
white-space: nowrap;
text-wrap: ellipsis;
overflow: hidden;
text-overflow: ellipsis;
width: 100%;
text-align: left;
}
.asset-operations {
Expand Down
6 changes: 5 additions & 1 deletion spx-gui/src/components/asset/library/ai/AIPreviewModal.vue
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ const publicAsset = ref<AssetData | null>(null)
const searchCtx = useSearchCtx()
const renameAsset = useRenameAsset()
/**
* Get the public asset data from the asset data
* If the asset data is not exported, export it first
Expand All @@ -196,6 +198,8 @@ const exportAssetDataToPublic = async () => {
if (!props.asset[isContentReady]) {
throw new Error('Could not export an incomplete asset')
}
await renameAsset(props.asset, isFavorite.value, searchCtx.keyword)
// let addAssetParam = props.asset
let addAssetParam: AddAssetParams = {
...props.asset,
Expand All @@ -214,6 +218,7 @@ const exportAssetDataToPublic = async () => {
return publicAsset
}
const handleAddButton = async () => {
if (props.addToProjectPending) {
return
Expand All @@ -226,7 +231,6 @@ const handleAddButton = async () => {
emit('addToProject', publicAsset.value)
}
const renameAsset = useRenameAsset()
const handleRename = useMessageHandle(
async () => {
isFavorite.value = !isFavorite.value
Expand Down
Loading

0 comments on commit f69d0c9

Please sign in to comment.