From 1532ed05766f4c59e8ab1ac1fe77e22946155cab Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 5 Nov 2024 11:35:41 +0800 Subject: [PATCH] Add file size and type validation in processUpload --- fs/process.go | 125 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 4 deletions(-) diff --git a/fs/process.go b/fs/process.go index 26bbe4b..bd0d7fe 100644 --- a/fs/process.go +++ b/fs/process.go @@ -6,6 +6,7 @@ import ( "path" "path/filepath" "sort" + "strconv" "strings" "time" @@ -493,14 +494,22 @@ func processUpload(process *process.Process) interface{} { exception.New(err.Error(), 500).Throw() } + // Validate the file size + if props.Has("maxFilesize") { + validateFileSize(stor, filename, props.Get("maxFilesize")) + } + // Cheek the file is exists - size, err := Size(stor, filename) + size, err := stor.Size(filename) if err != nil { exception.New(err.Error(), 500).Throw() } total := tmpfile.TotalSize() if int64(size) == total { + if props.Has("accept") { + validateAcceptType(stor, filename, props.Get("accept"), true) + } return filename } @@ -531,6 +540,11 @@ func processUpload(process *process.Process) interface{} { exception.New(err.Error(), 500).Throw() } + // Validate the file size + if props.Has("maxFilesize") { + validateFileSize(stor, filename, props.Get("maxFilesize")) + } + // Check if all chunks are uploaded. progress, err := uploadProgress(stor, tmpDir) if err != nil { @@ -541,6 +555,11 @@ func processUpload(process *process.Process) interface{} { if progress.Completed { defer stor.RemoveAll(tmpDir) + // Validate the file type + if props.Has("accept") { + validateAcceptType(stor, filename, props.Get("accept"), true) + } + // Get Files files, err := getChunkFiles(stor, tmpDir, true) if err != nil { @@ -566,12 +585,12 @@ func processUpload(process *process.Process) interface{} { // Validate the file type if props.Has("accept") { - fmt.Println("Validate the file type") + validateAcceptType(stor, tmpfile.TempFile, props.Get("accept"), true) } // Validate the file size - if props.Has("maxSize") { - fmt.Println("Validate the file size") + if props.Has("maxFilesize") { + validateFileSize(stor, tmpfile.TempFile, props.Get("maxFilesize")) } // For normal upload. @@ -606,6 +625,104 @@ func processDownload(process *process.Process) interface{} { } } +func validateAcceptType(stor FileSystem, file string, accept interface{}, checkMime bool) { + + // Check the file type + acceptstr, ok := accept.(string) + if !ok { + exception.New("the accept type is invalid", 400).Throw() + } + + ext := filepath.Ext(file) + acceptList := strings.Split(acceptstr, ",") + for _, accept := range acceptList { + accept = strings.TrimSpace(accept) + + // Check the file extension + if strings.HasPrefix(accept, ".") { + if ext == accept { + return + } + } + + // Check the file mime type + if !checkMime { + continue + } + + mime, err := MimeType(stor, file) + if err != nil { + exception.New(err.Error(), 500).Throw() + } + + if mime == accept { + return + } + + // accept = image/*, video/*, audio/*, application/* ... + if strings.HasSuffix(accept, "/*") { + if strings.HasPrefix(mime, strings.TrimRight(accept, "/*")) { + return + } + } + } + + // Remove the file + stor.Remove(file) + exception.New("File type should be %v", 415, acceptstr).Throw() +} + +func validateFileSize(stor FileSystem, file string, hmSize interface{}) { + // Get the file size + maxFilesize, err := parseFileSize(hmSize) + if err != nil { + defer stor.Remove(file) + exception.New(err.Error(), 500).Throw() + } + + // Get the file size + size, err := stor.Size(file) + if err != nil { + defer stor.Remove(file) + exception.New(err.Error(), 500).Throw() + } + + if size > int(maxFilesize) { + defer stor.Remove(file) + exception.New("File size too large, max size is %v", 413, hmSize).Throw() + } +} + +func parseFileSize(hmSize interface{}) (int64, error) { + if hmSize == nil { + return 1024 * 1024, nil // 1MB + } + + switch v := hmSize.(type) { + case int: + return int64(v), nil + case int64: + return v, nil + case string: + unit := strings.ToUpper(v[len(v)-1:]) + size, err := strconv.ParseInt(v[:len(v)-1], 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid chunk size: %v", err) + } + + switch unit { + case "M": + return size * 1024 * 1024, nil // MB + case "K": + return size * 1024, nil // KB + default: + return size, nil // bytes + } + default: + return 0, fmt.Errorf("invalid type for hmSize") + } +} + func uploadProgress(stor FileSystem, path string) (types.UploadProgress, error) { files, err := getChunkFiles(stor, path, true) if err != nil {