diff --git a/base/atomiccounter/atomiccounter_test.go b/base/atomicx/atomicx_test.go similarity index 77% rename from base/atomiccounter/atomiccounter_test.go rename to base/atomicx/atomicx_test.go index 327700d2eb..b8f4b89223 100644 --- a/base/atomiccounter/atomiccounter_test.go +++ b/base/atomicx/atomicx_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package atomiccounter +package atomicx import ( "testing" @@ -26,3 +26,13 @@ func TestCounter(t *testing.T) { c.Set(2) assert.Equal(t, int64(2), c.Value()) } + +func TestMax(t *testing.T) { + a := int32(10) + MaxInt32(&a, 5) + assert.Equal(t, a, int32(10)) + MaxInt32(&a, 10) + assert.Equal(t, a, int32(10)) + MaxInt32(&a, 11) + assert.Equal(t, a, int32(11)) +} diff --git a/base/atomiccounter/atomiccounter.go b/base/atomicx/counter.go similarity index 92% rename from base/atomiccounter/atomiccounter.go rename to base/atomicx/counter.go index 9846df1c74..e8aa6e0562 100644 --- a/base/atomiccounter/atomiccounter.go +++ b/base/atomicx/counter.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package atomiccounter implements a basic atomic int64 counter. -package atomiccounter +// Package atomicx implements misc atomic functions. +package atomicx import ( "sync/atomic" diff --git a/base/atomicx/max.go b/base/atomicx/max.go new file mode 100644 index 0000000000..763e7c6e16 --- /dev/null +++ b/base/atomicx/max.go @@ -0,0 +1,15 @@ +// Copyright (c) 2018, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicx + +import "sync/atomic" + +// MaxInt32 performs an atomic Max operation: a = max(a, b) +func MaxInt32(a *int32, b int32) { + old := atomic.LoadInt32(a) + for old < b && !atomic.CompareAndSwapInt32(a, old, b) { + old = atomic.LoadInt32(a) + } +} diff --git a/base/exec/stdio.go b/base/exec/stdio.go index f871bd6284..3c04bebaa4 100644 --- a/base/exec/stdio.go +++ b/base/exec/stdio.go @@ -56,9 +56,13 @@ func (st *StdIO) Set(o *StdIO) *StdIO { func (st *StdIO) SetToOS() *StdIO { cur := &StdIO{} cur.SetFromOS() + if sif, ok := st.In.(*os.File); ok { + os.Stdin = sif + } else { + fmt.Printf("In is not an *os.File: %#v\n", st.In) + } os.Stdout = st.Out.(*os.File) os.Stderr = st.Err.(*os.File) - os.Stdin = st.In.(*os.File) return cur } @@ -98,13 +102,10 @@ func IsPipe(rw any) bool { if rw == nil { return false } - w, ok := rw.(io.Writer) + _, ok := rw.(io.Writer) if !ok { return false } - if w == os.Stdout { - return false - } of, ok := rw.(*os.File) if !ok { return false @@ -247,6 +248,9 @@ func (st *StdIOState) PopToStart() { for len(st.InStack) > st.InStart { st.PopIn() } + for len(st.PipeIn) > 0 { + CloseReader(st.PipeIn.Pop()) + } } // ErrIsInOut returns true if the given Err writer is also present diff --git a/base/exec/stdio_test.go b/base/exec/stdio_test.go index 3b8f041799..764ea1174d 100644 --- a/base/exec/stdio_test.go +++ b/base/exec/stdio_test.go @@ -18,14 +18,14 @@ func TestStdIO(t *testing.T) { assert.Equal(t, os.Stdout, st.Out) assert.Equal(t, os.Stderr, st.Err) assert.Equal(t, os.Stdin, st.In) - assert.Equal(t, false, st.OutIsPipe()) + // assert.Equal(t, false, st.OutIsPipe()) obuf := &bytes.Buffer{} ibuf := &bytes.Buffer{} var ss StdIOState ss.SetFromOS() ss.StackStart() - assert.Equal(t, false, ss.OutIsPipe()) + // assert.Equal(t, false, ss.OutIsPipe()) ss.PushOut(obuf) assert.NotEqual(t, os.Stdout, ss.Out) diff --git a/base/fileinfo/fileinfo.go b/base/fileinfo/fileinfo.go index 28caffa376..0ef6bf11f4 100644 --- a/base/fileinfo/fileinfo.go +++ b/base/fileinfo/fileinfo.go @@ -75,6 +75,10 @@ type FileInfo struct { //types:add // version control system status, when enabled VCS vcs.FileStatus `table:"-"` + // Generated indicates that the file is generated and should not be edited. + // For Go files, this regex: `^// Code generated .* DO NOT EDIT\.$` is used. + Generated bool `table:"-"` + // full path to file, including name; for file functions Path string `table:"-"` } @@ -143,6 +147,7 @@ func (fi *FileInfo) SetMimeInfo() error { } fi.Cat = UnknownCategory fi.Known = Unknown + fi.Generated = IsGeneratedFile(fi.Path) fi.Kind = "" mtyp, _, err := MimeFromFile(fi.Path) if err != nil { diff --git a/base/fileinfo/mimetype.go b/base/fileinfo/mimetype.go index d44bcc819f..f0529b8916 100644 --- a/base/fileinfo/mimetype.go +++ b/base/fileinfo/mimetype.go @@ -7,7 +7,9 @@ package fileinfo import ( "fmt" "mime" + "os" "path/filepath" + "regexp" "strings" "github.com/h2non/filetype" @@ -99,6 +101,18 @@ func MimeFromFile(fname string) (mtype, ext string, err error) { return "", ext, fmt.Errorf("fileinfo.MimeFromFile could not find mime type for ext: %v file: %v", ext, fn) } +var generatedRe = regexp.MustCompile(`^// Code generated .* DO NOT EDIT`) + +func IsGeneratedFile(fname string) bool { + file, err := os.Open(fname) + if err != nil { + return false + } + head := make([]byte, 2048) + file.Read(head) + return generatedRe.Match(head) +} + // todo: use this to check against mime types! // MimeToKindMapInit makes sure the MimeToKindMap is initialized from @@ -316,7 +330,7 @@ var StandardMimes = []MimeType{ {"text/x-forth", []string{".frt"}, Code, Forth}, // note: ".fs" conflicts with fsharp {"text/x-fortran", []string{".f", ".F"}, Code, Fortran}, {"text/x-fsharp", []string{".fs", ".fsi"}, Code, FSharp}, - {"text/x-gosrc", []string{".go", ".mod", ".work", ".cosh"}, Code, Go}, + {"text/x-gosrc", []string{".go", ".mod", ".work", ".goal"}, Code, Go}, {"text/x-haskell", []string{".hs", ".lhs"}, Code, Haskell}, {"text/x-literate-haskell", nil, Code, Haskell}, // todo: not sure if same or not diff --git a/base/fileinfo/typegen.go b/base/fileinfo/typegen.go index 0b89bbbf00..e8f4015e30 100644 --- a/base/fileinfo/typegen.go +++ b/base/fileinfo/typegen.go @@ -6,4 +6,4 @@ import ( "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/base/fileinfo.FileInfo", IDName: "file-info", Doc: "FileInfo represents the information about a given file / directory,\nincluding icon, mimetype, etc", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Duplicate", Doc: "Duplicate creates a copy of given file -- only works for regular files, not\ndirectories.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"string", "error"}}, {Name: "Delete", Doc: "Delete moves the file to the trash / recycling bin.\nOn mobile and web, it deletes it directly.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"error"}}, {Name: "Rename", Doc: "Rename renames (moves) this file to given new path name.\nUpdates the FileInfo setting to the new name, although it might\nbe out of scope if it moved into a new path", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"path"}, Returns: []string{"newpath", "err"}}}, Fields: []types.Field{{Name: "Ic", Doc: "icon for file"}, {Name: "Name", Doc: "name of the file, without any path"}, {Name: "Size", Doc: "size of the file"}, {Name: "Kind", Doc: "type of file / directory; shorter, more user-friendly\nversion of mime type, based on category"}, {Name: "Mime", Doc: "full official mime type of the contents"}, {Name: "Cat", Doc: "functional category of the file, based on mime data etc"}, {Name: "Known", Doc: "known file type"}, {Name: "Mode", Doc: "file mode bits"}, {Name: "ModTime", Doc: "time that contents (only) were last modified"}, {Name: "VCS", Doc: "version control system status, when enabled"}, {Name: "Path", Doc: "full path to file, including name; for file functions"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/base/fileinfo.FileInfo", IDName: "file-info", Doc: "FileInfo represents the information about a given file / directory,\nincluding icon, mimetype, etc", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Duplicate", Doc: "Duplicate creates a copy of given file -- only works for regular files, not\ndirectories.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"string", "error"}}, {Name: "Delete", Doc: "Delete moves the file to the trash / recycling bin.\nOn mobile and web, it deletes it directly.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"error"}}, {Name: "Rename", Doc: "Rename renames (moves) this file to given new path name.\nUpdates the FileInfo setting to the new name, although it might\nbe out of scope if it moved into a new path", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"path"}, Returns: []string{"newpath", "err"}}}, Fields: []types.Field{{Name: "Ic", Doc: "icon for file"}, {Name: "Name", Doc: "name of the file, without any path"}, {Name: "Size", Doc: "size of the file"}, {Name: "Kind", Doc: "type of file / directory; shorter, more user-friendly\nversion of mime type, based on category"}, {Name: "Mime", Doc: "full official mime type of the contents"}, {Name: "Cat", Doc: "functional category of the file, based on mime data etc"}, {Name: "Known", Doc: "known file type"}, {Name: "Mode", Doc: "file mode bits"}, {Name: "ModTime", Doc: "time that contents (only) were last modified"}, {Name: "VCS", Doc: "version control system status, when enabled"}, {Name: "Generated", Doc: "Generated indicates that the file is generated and should not be edited.\nFor Go files, this regex: `^// Code generated .* DO NOT EDIT\\.$` is used."}, {Name: "Path", Doc: "full path to file, including name; for file functions"}}}) diff --git a/base/fsx/fsx.go b/base/fsx/fsx.go index b757ee5f05..9fb0d865fa 100644 --- a/base/fsx/fsx.go +++ b/base/fsx/fsx.go @@ -17,6 +17,10 @@ import ( "time" ) +// Filename is used to open a file picker dialog when used as an argument +// type in a function, or as a field value. +type Filename string + // GoSrcDir tries to locate dir in GOPATH/src/ or GOROOT/src/pkg/ and returns its // full path. GOPATH may contain a list of paths. From Robin Elkind github.com/mewkiz/pkg. func GoSrcDir(dir string) (absDir string, err error) { diff --git a/base/generate/generate.go b/base/generate/generate.go index 4a41bd7e3a..c2446d2723 100644 --- a/base/generate/generate.go +++ b/base/generate/generate.go @@ -54,16 +54,28 @@ func PrintHeader(w io.Writer, pkg string, imports ...string) { } } -// Inspect goes through all of the files in the given package -// and calls the given function on each node in files that -// are not generated. The bool return value from the given function +// ExcludeFile returns true if the given file is on the exclude list. +func ExcludeFile(pkg *packages.Package, file *ast.File, exclude ...string) bool { + fpos := pkg.Fset.Position(file.FileStart) + _, fname := filepath.Split(fpos.Filename) + for _, ex := range exclude { + if fname == ex { + return true + } + } + return false +} + +// Inspect goes through all of the files in the given package, +// except those listed in the exclude list, and calls the given +// function on each node. The bool return value from the given function // indicates whether to continue traversing down the AST tree // of that node and look at its children. If a non-nil error value // is returned by the given function, the traversal of the tree is // stopped and the error value is returned. -func Inspect(pkg *packages.Package, f func(n ast.Node) (bool, error)) error { +func Inspect(pkg *packages.Package, f func(n ast.Node) (bool, error), exclude ...string) error { for _, file := range pkg.Syntax { - if ast.IsGenerated(file) { + if ExcludeFile(pkg, file, exclude...) { continue } var terr error diff --git a/base/iox/imagex/testing.go b/base/iox/imagex/testing.go index 1e2f3a91c5..28cec1db34 100644 --- a/base/iox/imagex/testing.go +++ b/base/iox/imagex/testing.go @@ -12,6 +12,8 @@ import ( "os" "path/filepath" "strings" + + "cogentcore.org/core/base/num" ) // TestingT is an interface wrapper around *testing.T @@ -56,6 +58,25 @@ func CompareColors(cc, ic color.RGBA, tol int) bool { return true } +// DiffImage returns the difference between two images, +// with pixels having the abs of the difference between pixels. +func DiffImage(a, b image.Image) image.Image { + ab := a.Bounds() + di := image.NewRGBA(ab) + for y := ab.Min.Y; y < ab.Max.Y; y++ { + for x := ab.Min.X; x < ab.Max.X; x++ { + cc := color.RGBAModel.Convert(a.At(x, y)).(color.RGBA) + ic := color.RGBAModel.Convert(b.At(x, y)).(color.RGBA) + r := uint8(num.Abs(int(cc.R) - int(ic.R))) + g := uint8(num.Abs(int(cc.G) - int(ic.G))) + b := uint8(num.Abs(int(cc.B) - int(ic.B))) + c := color.RGBA{r, g, b, 255} + di.Set(x, y, c) + } + } + return di +} + // Assert asserts that the given image is equivalent // to the image stored at the given filename in the testdata directory, // with ".png" added to the filename if there is no extension @@ -77,6 +98,7 @@ func Assert(t TestingT, img image.Image, filename string) { ext := filepath.Ext(filename) failFilename := strings.TrimSuffix(filename, ext) + ".fail" + ext + diffFilename := strings.TrimSuffix(filename, ext) + ".diff" + ext if UpdateTestImages { err := Save(img, filename) @@ -87,6 +109,7 @@ func Assert(t TestingT, img image.Image, filename string) { if err != nil { t.Errorf("AssertImage: error removing old fail image: %v", err) } + os.RemoveAll(diffFilename) return } @@ -133,10 +156,15 @@ func Assert(t TestingT, img image.Image, filename string) { if err != nil { t.Errorf("AssertImage: error saving fail image: %v", err) } + err = Save(DiffImage(img, fimg), diffFilename) + if err != nil { + t.Errorf("AssertImage: error saving diff image: %v", err) + } } else { err := os.RemoveAll(failFilename) if err != nil { t.Errorf("AssertImage: error removing old fail image: %v", err) } + os.RemoveAll(diffFilename) } } diff --git a/base/keylist/README.md b/base/keylist/README.md new file mode 100644 index 0000000000..c89691916a --- /dev/null +++ b/base/keylist/README.md @@ -0,0 +1,6 @@ +# keylist + +keylist implements an ordered list (slice) of items (Values), with a map from a Key (e.g., names) to indexes, to support fast lookup by name. There is also a Keys slice. + +This is a different implementation of the [ordmap](../ordmap) package, and has the advantage of direct slice access to the values, instead of having to go through the KeyValue tuple struct in ordmap. + diff --git a/base/keylist/keylist.go b/base/keylist/keylist.go new file mode 100644 index 0000000000..0f41d23e0f --- /dev/null +++ b/base/keylist/keylist.go @@ -0,0 +1,229 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package keylist implements an ordered list (slice) of items, +with a map from a key (e.g., names) to indexes, +to support fast lookup by name. +This is a different implementation of the [ordmap] package, +that has separate slices for Values and Keys, instead of +using a tuple list of both. The awkwardness of value access +through the tuple is the major problem with ordmap. +*/ +package keylist + +import ( + "fmt" + "slices" +) + +// TODO: probably want to consolidate ordmap and keylist: https://github.com/cogentcore/core/issues/1224 + +// List implements an ordered list (slice) of Values, +// with a map from a key (e.g., names) to indexes, +// to support fast lookup by name. +type List[K comparable, V any] struct { //types:add + // List is the ordered slice of items. + Values []V + + // Keys is the ordered list of keys, in same order as [List.Values] + Keys []K + + // indexes is the key-to-index mapping. + indexes map[K]int +} + +// New returns a new [List]. The zero value +// is usable without initialization, so this is +// just a simple standard convenience method. +func New[K comparable, V any]() *List[K, V] { + return &List[K, V]{} +} + +func (kl *List[K, V]) makeIndexes() { + kl.indexes = make(map[K]int) +} + +// initIndexes ensures that the index map exists. +func (kl *List[K, V]) initIndexes() { + if kl.indexes == nil { + kl.makeIndexes() + } +} + +// Reset resets the list, removing any existing elements. +func (kl *List[K, V]) Reset() { + kl.Values = nil + kl.Keys = nil + kl.makeIndexes() +} + +// Set sets given key to given value, adding to the end of the list +// if not already present, and otherwise replacing with this new value. +// This is the same semantics as a Go map. +// See [List.Add] for version that only adds and does not replace. +func (kl *List[K, V]) Set(key K, val V) { + kl.initIndexes() + if idx, ok := kl.indexes[key]; ok { + kl.Values[idx] = val + kl.Keys[idx] = key + return + } + kl.indexes[key] = len(kl.Values) + kl.Values = append(kl.Values, val) + kl.Keys = append(kl.Keys, key) +} + +// Add adds an item to the list with given key, +// An error is returned if the key is already on the list. +// See [List.Set] for a method that automatically replaces. +func (kl *List[K, V]) Add(key K, val V) error { + kl.initIndexes() + if _, ok := kl.indexes[key]; ok { + return fmt.Errorf("keylist.Add: key %v is already on the list", key) + } + kl.indexes[key] = len(kl.Values) + kl.Values = append(kl.Values, val) + kl.Keys = append(kl.Keys, key) + return nil +} + +// Insert inserts the given value with the given key at the given index. +// This is relatively slow because it needs regenerate the keys list. +// It panics if the key already exists because the behavior is undefined +// in that situation. +func (kl *List[K, V]) Insert(idx int, key K, val V) { + if _, has := kl.indexes[key]; has { + panic("keylist.Add: key is already on the list") + } + + kl.Keys = slices.Insert(kl.Keys, idx, key) + kl.Values = slices.Insert(kl.Values, idx, val) + kl.makeIndexes() + for i, k := range kl.Keys { + kl.indexes[k] = i + } +} + +// At returns the value corresponding to the given key, +// with a zero value returned for a missing key. See [List.AtTry] +// for one that returns a bool for missing keys. +// For index-based access, use [List.Values] or [List.Keys] slices directly. +func (kl *List[K, V]) At(key K) V { + idx, ok := kl.indexes[key] + if ok { + return kl.Values[idx] + } + var zv V + return zv +} + +// AtTry returns the value corresponding to the given key, +// with false returned for a missing key, in case the zero value +// is not diagnostic. +func (kl *List[K, V]) AtTry(key K) (V, bool) { + idx, ok := kl.indexes[key] + if ok { + return kl.Values[idx], true + } + var zv V + return zv, false +} + +// IndexIsValid returns an error if the given index is invalid. +func (kl *List[K, V]) IndexIsValid(idx int) error { + if idx >= len(kl.Values) || idx < 0 { + return fmt.Errorf("keylist.List: IndexIsValid: index %d is out of range of a list of length %d", idx, len(kl.Values)) + } + return nil +} + +// IndexByKey returns the index of the given key, with a -1 for missing key. +func (kl *List[K, V]) IndexByKey(key K) int { + idx, ok := kl.indexes[key] + if !ok { + return -1 + } + return idx +} + +// Len returns the number of items in the list. +func (kl *List[K, V]) Len() int { + if kl == nil { + return 0 + } + return len(kl.Values) +} + +// DeleteByIndex deletes item(s) within the index range [i:j]. +// This is relatively slow because it needs to regenerate the +// index map. +func (kl *List[K, V]) DeleteByIndex(i, j int) { + ndel := j - i + if ndel <= 0 { + panic("index range is <= 0") + } + kl.Keys = slices.Delete(kl.Keys, i, j) + kl.Values = slices.Delete(kl.Values, i, j) + kl.makeIndexes() + for i, k := range kl.Keys { + kl.indexes[k] = i + } + +} + +// DeleteByKey deletes the item with the given key, +// returning false if it does not find it. +// This is relatively slow because it needs to regenerate the +// index map. +func (kl *List[K, V]) DeleteByKey(key K) bool { + idx, ok := kl.indexes[key] + if !ok { + return false + } + kl.DeleteByIndex(idx, idx+1) + return true +} + +// RenameIndex renames the item at given index to new key. +func (kl *List[K, V]) RenameIndex(i int, key K) { + old := kl.Keys[i] + delete(kl.indexes, old) + kl.Keys[i] = key + kl.indexes[key] = i +} + +// Copy copies all of the entries from the given key list +// into this list. It keeps existing entries in this +// list unless they also exist in the given list, in which case +// they are overwritten. Use [List.Reset] first to get an exact copy. +func (kl *List[K, V]) Copy(from *List[K, V]) { + for i, v := range from.Values { + kl.Set(kl.Keys[i], v) + } +} + +// String returns a string representation of the list. +func (kl *List[K, V]) String() string { + sv := "{" + for i, v := range kl.Values { + sv += fmt.Sprintf("%v", kl.Keys[i]) + ": " + fmt.Sprintf("%v", v) + ", " + } + sv += "}" + return sv +} + +/* +// GoString returns the list as Go code. +func (kl *List[K, V]) GoString() string { + var zk K + var zv V + res := fmt.Sprintf("ordlist.Make([]ordlist.KeyVal[%T, %T]{\n", zk, zv) + for _, kv := range kl.Order { + res += fmt.Sprintf("{%#v, %#v},\n", kv.Key, kv.Value) + } + res += "})" + return res +} +*/ diff --git a/base/keylist/keylist_test.go b/base/keylist/keylist_test.go new file mode 100644 index 0000000000..b88876ecf0 --- /dev/null +++ b/base/keylist/keylist_test.go @@ -0,0 +1,39 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package keylist + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeyList(t *testing.T) { + kl := New[string, int]() + kl.Add("key0", 0) + kl.Add("key1", 1) + kl.Add("key2", 2) + + assert.Equal(t, 1, kl.At("key1")) + assert.Equal(t, 2, kl.IndexByKey("key2")) + + assert.Equal(t, 1, kl.Values[1]) + + assert.Equal(t, 3, kl.Len()) + + kl.DeleteByIndex(1, 2) + assert.Equal(t, 2, kl.Values[1]) + assert.Equal(t, 1, kl.IndexByKey("key2")) + + kl.Insert(0, "new0", 3) + assert.Equal(t, 3, kl.Values[0]) + assert.Equal(t, 0, kl.Values[1]) + assert.Equal(t, 2, kl.IndexByKey("key2")) + + // nm := Make([]KeyValue[string, int]{{"one", 1}, {"two", 2}, {"three", 3}}) + // assert.Equal(t, 3, nm.Values[2]) + // assert.Equal(t, 2, nm.Values[1]) + // assert.Equal(t, 3, nm.At("three")) +} diff --git a/base/metadata/metadata.go b/base/metadata/metadata.go index 1fbc1852e2..a2f964ed79 100644 --- a/base/metadata/metadata.go +++ b/base/metadata/metadata.go @@ -2,15 +2,27 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package metadata provides a map of named any elements +// with generic support for type-safe Get and nil-safe Set. +// Metadata keys often function as optional fields in a struct, +// and therefore a CamelCase naming convention is typical. +// Provides default support for "Name", "Doc", "File" standard keys. package metadata import ( "fmt" "maps" + + "cogentcore.org/core/base/errors" ) // Data is metadata as a map of named any elements // with generic support for type-safe Get and nil-safe Set. +// Metadata keys often function as optional fields in a struct, +// and therefore a CamelCase naming convention is typical. +// Provides default support for "Name" and "Doc" standard keys. +// In general it is good practice to provide access functions +// that establish standard key names, to avoid issues with typos. type Data map[string]any func (md *Data) init() { @@ -26,8 +38,8 @@ func (md *Data) Set(key string, value any) { (*md)[key] = value } -// Get gets metadata value of given type. -// returns error if not present or item is a different type. +// Get gets metadata value of given type from given Data. +// Returns error if not present or item is a different type. func Get[T any](md Data, key string) (T, error) { var z T x, ok := md[key] @@ -41,14 +53,73 @@ func Get[T any](md Data, key string) (T, error) { return v, nil } -// Copy does a shallow copy of metadata from source. +// CopyFrom does a shallow copy of metadata from source. // Any pointer-based values will still point to the same // underlying data as the source, but the two maps remain // distinct. It uses [maps.Copy]. -func (md *Data) Copy(src Data) { +func (md *Data) CopyFrom(src Data) { if src == nil { return } md.init() maps.Copy(*md, src) } + +//////// Metadataer + +// Metadataer is an interface for a type that returns associated +// metadata.Data using a Metadata() method. To be able to set metadata, +// the method should be defined with a pointer receiver. +type Metadataer interface { + Metadata() *Data +} + +// GetData gets the Data from given object, if it implements the +// Metadata() method. Returns nil if it does not. +// Must pass a pointer to the object. +func GetData(obj any) *Data { + if md, ok := obj.(Metadataer); ok { + return md.Metadata() + } + return nil +} + +// GetFrom gets metadata value of given type from given object, +// if it implements the Metadata() method. +// Must pass a pointer to the object. +// Returns error if not present or item is a different type. +func GetFrom[T any](obj any, key string) (T, error) { + md := GetData(obj) + if md == nil { + var zv T + return zv, errors.New("metadata not available for given object type") + } + return Get[T](*md, key) +} + +// SetTo sets metadata value on given object, if it implements +// the Metadata() method. Returns error if no Metadata on object. +// Must pass a pointer to the object. +func SetTo(obj any, key string, value any) error { + md := GetData(obj) + if md == nil { + return errors.Log(errors.New("metadata not available for given object type")) + } + md.Set(key, value) + return nil +} + +// CopyFrom copies metadata from source +// Must pass a pointer to the object. +func CopyFrom(to, src any) *Data { + tod := GetData(to) + if tod == nil { + return nil + } + srcd := GetData(src) + if srcd == nil { + return tod + } + tod.CopyFrom(*srcd) + return tod +} diff --git a/base/metadata/metadata_test.go b/base/metadata/metadata_test.go new file mode 100644 index 0000000000..43b2735020 --- /dev/null +++ b/base/metadata/metadata_test.go @@ -0,0 +1,32 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metadata + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type mytest struct { + Meta Data +} + +func (mt *mytest) Metadata() *Data { + return &mt.Meta +} + +func TestMetadata(t *testing.T) { + mt := &mytest{} + + SetName(mt, "test") + assert.Equal(t, "test", Name(mt)) + + SetDoc(mt, "this is good") + assert.Equal(t, "this is good", Doc(mt)) + + SetFilename(mt, "path/me.go") + assert.Equal(t, "path/me.go", Filename(mt)) +} diff --git a/base/metadata/std.go b/base/metadata/std.go new file mode 100644 index 0000000000..74eca83052 --- /dev/null +++ b/base/metadata/std.go @@ -0,0 +1,51 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metadata + +import "os" + +// SetName sets the "Name" standard key. +func SetName(obj any, name string) { + SetTo(obj, "Name", name) +} + +// Name returns the "Name" standard key value (empty if not set). +func Name(obj any) string { + nm, _ := GetFrom[string](obj, "Name") + return nm +} + +// SetDoc sets the "Doc" standard key. +func SetDoc(obj any, doc string) { + SetTo(obj, "Doc", doc) +} + +// Doc returns the "Doc" standard key value (empty if not set). +func Doc(obj any) string { + doc, _ := GetFrom[string](obj, "Doc") + return doc +} + +// SetFile sets the "File" standard key for *os.File. +func SetFile(obj any, file *os.File) { + SetTo(obj, "File", file) +} + +// File returns the "File" standard key value (nil if not set). +func File(obj any) *os.File { + doc, _ := GetFrom[*os.File](obj, "File") + return doc +} + +// SetFilename sets the "Filename" standard key. +func SetFilename(obj any, file string) { + SetTo(obj, "Filename", file) +} + +// Filename returns the "Filename" standard key value (empty if not set). +func Filename(obj any) string { + doc, _ := GetFrom[string](obj, "Filename") + return doc +} diff --git a/base/randx/dists_test.go b/base/randx/dists_test.go index 059dd031d8..7bfb2759af 100644 --- a/base/randx/dists_test.go +++ b/base/randx/dists_test.go @@ -8,14 +8,14 @@ import ( "math" "testing" - "cogentcore.org/core/base/errors" "cogentcore.org/core/tensor/stats/stats" "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" ) func TestGaussianGen(t *testing.T) { nsamp := int(1e6) - dt := &table.Table{} + dt := table.New() dt.AddFloat32Column("Val") dt.SetNumRows(nsamp) @@ -25,18 +25,15 @@ func TestGaussianGen(t *testing.T) { for i := 0; i < nsamp; i++ { vl := GaussianGen(mean, sig) - dt.SetFloat("Val", i, vl) + dt.Column("Val").SetFloatRow(vl, i, 0) } - ix := table.NewIndexView(dt) - desc := stats.DescAll(ix) + dir, _ := tensorfs.NewDir("Desc") + stats.DescribeTableAll(dir, dt) + desc := tensorfs.DirTable(dir, nil) + // fmt.Println(desc.Columns.Keys) - meanRow := errors.Log1(desc.RowsByString("Stat", "Mean", table.Equals, table.UseCase))[0] - stdRow := errors.Log1(desc.RowsByString("Stat", "Std", table.Equals, table.UseCase))[0] - // minRow := errors.Log1(desc.RowsByString("Stat", "Min", table.Equals, table.UseCase))[0] - // maxRow := errors.Log1(desc.RowsByString("Stat", "Max", table.Equals, table.UseCase))[0] - - actMean := desc.Float("Val", meanRow) - actStd := desc.Float("Val", stdRow) + actMean := desc.Column("Val/Mean").FloatRow(0, 0) + actStd := desc.Column("Val/Std").FloatRow(0, 0) if math.Abs(actMean-mean) > tol { t.Errorf("Gaussian: mean %g\t out of tolerance vs target: %g\n", actMean, mean) @@ -44,14 +41,11 @@ func TestGaussianGen(t *testing.T) { if math.Abs(actStd-sig) > tol { t.Errorf("Gaussian: stdev %g\t out of tolerance vs target: %g\n", actStd, sig) } - // b := bytes.NewBuffer(nil) - // desc.WriteCSV(b, table.Tab, table.Headers) - // fmt.Printf("%s\n", string(b.Bytes())) } func TestBinomialGen(t *testing.T) { nsamp := int(1e6) - dt := &table.Table{} + dt := table.New() dt.AddFloat32Column("Val") dt.SetNumRows(nsamp) @@ -61,21 +55,15 @@ func TestBinomialGen(t *testing.T) { for i := 0; i < nsamp; i++ { vl := BinomialGen(n, p) - dt.SetFloat("Val", i, vl) - } - ix := table.NewIndexView(dt) - desc := stats.DescAll(ix) - - meanRow := errors.Log1(desc.RowsByString("Stat", "Mean", table.Equals, table.UseCase))[0] - stdRow := errors.Log1(desc.RowsByString("Stat", "Std", table.Equals, table.UseCase))[0] - minRow := errors.Log1(desc.RowsByString("Stat", "Min", table.Equals, table.UseCase))[0] - maxRow := errors.Log1(desc.RowsByString("Stat", "Max", table.Equals, table.UseCase))[0] - - actMean := desc.Float("Val", meanRow) - actStd := desc.Float("Val", stdRow) - actMin := desc.Float("Val", minRow) - actMax := desc.Float("Val", maxRow) - + dt.Column("Val").SetFloat(vl, i) + } + dir, _ := tensorfs.NewDir("Desc") + stats.DescribeTableAll(dir, dt) + desc := tensorfs.DirTable(dir, nil) + actMean := desc.Column("Val/Mean").FloatRow(0, 0) + actStd := desc.Column("Val/Std").FloatRow(0, 0) + actMin := desc.Column("Val/Min").FloatRow(0, 0) + actMax := desc.Column("Val/Max").FloatRow(0, 0) mean := n * p if math.Abs(actMean-mean) > tol { t.Errorf("Binomial: mean %g\t out of tolerance vs target: %g\n", actMean, mean) @@ -90,14 +78,11 @@ func TestBinomialGen(t *testing.T) { if actMax < 0 { t.Errorf("Binomial: max %g\t should not be > 1\n", actMax) } - // b := bytes.NewBuffer(nil) - // desc.WriteCSV(b, table.Tab, table.Headers) - // fmt.Printf("%s\n", string(b.Bytes())) } func TestPoissonGen(t *testing.T) { nsamp := int(1e6) - dt := &table.Table{} + dt := table.New() dt.AddFloat32Column("Val") dt.SetNumRows(nsamp) @@ -106,20 +91,15 @@ func TestPoissonGen(t *testing.T) { for i := 0; i < nsamp; i++ { vl := PoissonGen(lambda) - dt.SetFloat("Val", i, vl) + dt.Column("Val").SetFloatRow(vl, i, 0) } - ix := table.NewIndexView(dt) - desc := stats.DescAll(ix) - - meanRow := errors.Log1(desc.RowsByString("Stat", "Mean", table.Equals, table.UseCase))[0] - stdRow := errors.Log1(desc.RowsByString("Stat", "Std", table.Equals, table.UseCase))[0] - minRow := errors.Log1(desc.RowsByString("Stat", "Min", table.Equals, table.UseCase))[0] - // maxRow := errors.Log1(desc.RowsByString("Stat", "Max", table.Equals, table.UseCase))[0] - - actMean := desc.Float("Val", meanRow) - actStd := desc.Float("Val", stdRow) - actMin := desc.Float("Val", minRow) - // actMax := desc.Float("Val", maxRow) + dir, _ := tensorfs.NewDir("Desc") + stats.DescribeTableAll(dir, dt) + desc := tensorfs.DirTable(dir, nil) + actMean := desc.Column("Val/Mean").FloatRow(0, 0) + actStd := desc.Column("Val/Std").FloatRow(0, 0) + actMin := desc.Column("Val/Min").FloatRow(0, 0) + // actMax := desc.Column("Val/Max").FloatRow(0, 0) mean := lambda if math.Abs(actMean-mean) > tol { @@ -135,14 +115,11 @@ func TestPoissonGen(t *testing.T) { // if actMax < 0 { // t.Errorf("Poisson: max %g\t should not be > 1\n", actMax) // } - // b := bytes.NewBuffer(nil) - // desc.WriteCSV(b, table.Tab, table.Headers) - // fmt.Printf("%s\n", string(b.Bytes())) } func TestGammaGen(t *testing.T) { nsamp := int(1e6) - dt := &table.Table{} + dt := table.New() dt.AddFloat32Column("Val") dt.SetNumRows(nsamp) @@ -152,17 +129,15 @@ func TestGammaGen(t *testing.T) { for i := 0; i < nsamp; i++ { vl := GammaGen(alpha, beta) - dt.SetFloat("Val", i, vl) - } - ix := table.NewIndexView(dt) - desc := stats.DescAll(ix) - - meanRow := errors.Log1(desc.RowsByString("Stat", "Mean", table.Equals, table.UseCase))[0] - stdRow := errors.Log1(desc.RowsByString("Stat", "Std", table.Equals, table.UseCase))[0] - - actMean := desc.Float("Val", meanRow) - actStd := desc.Float("Val", stdRow) - + dt.Column("Val").SetFloatRow(vl, i, 0) + } + dir, _ := tensorfs.NewDir("Desc") + stats.DescribeTableAll(dir, dt) + desc := tensorfs.DirTable(dir, nil) + actMean := desc.Column("Val/Mean").FloatRow(0, 0) + actStd := desc.Column("Val/Std").FloatRow(0, 0) + // actMin := desc.Column("Val/Min").FloatRow(0, 0) + // actMax := desc.Column("Val/Max").FloatRow(0, 0) mean := alpha / beta if math.Abs(actMean-mean) > tol { t.Errorf("Gamma: mean %g\t out of tolerance vs target: %g\n", actMean, mean) @@ -171,14 +146,11 @@ func TestGammaGen(t *testing.T) { if math.Abs(actStd-sig) > tol { t.Errorf("Gamma: stdev %g\t out of tolerance vs target: %g\n", actStd, sig) } - // b := bytes.NewBuffer(nil) - // desc.WriteCSV(b, table.Tab, table.Headers) - // fmt.Printf("%s\n", string(b.Bytes())) } func TestBetaGen(t *testing.T) { nsamp := int(1e6) - dt := &table.Table{} + dt := table.New() dt.AddFloat32Column("Val") dt.SetNumRows(nsamp) @@ -188,17 +160,15 @@ func TestBetaGen(t *testing.T) { for i := 0; i < nsamp; i++ { vl := BetaGen(alpha, beta) - dt.SetFloat("Val", i, vl) - } - ix := table.NewIndexView(dt) - desc := stats.DescAll(ix) - - meanRow := errors.Log1(desc.RowsByString("Stat", "Mean", table.Equals, table.UseCase))[0] - stdRow := errors.Log1(desc.RowsByString("Stat", "Std", table.Equals, table.UseCase))[0] - - actMean := desc.Float("Val", meanRow) - actStd := desc.Float("Val", stdRow) - + dt.Column("Val").SetFloatRow(vl, i, 0) + } + dir, _ := tensorfs.NewDir("Desc") + stats.DescribeTableAll(dir, dt) + desc := tensorfs.DirTable(dir, nil) + actMean := desc.Column("Val/Mean").FloatRow(0, 0) + actStd := desc.Column("Val/Std").FloatRow(0, 0) + // actMin := desc.Column("Val/Min").FloatRow(0, 0) + // actMax := desc.Column("Val/Max").FloatRow(0, 0) mean := alpha / (alpha + beta) if math.Abs(actMean-mean) > tol { t.Errorf("Beta: mean %g\t out of tolerance vs target: %g\n", actMean, mean) @@ -208,7 +178,4 @@ func TestBetaGen(t *testing.T) { if math.Abs(actStd-sig) > tol { t.Errorf("Beta: stdev %g\t out of tolerance vs target: %g\n", actStd, sig) } - // b := bytes.NewBuffer(nil) - // desc.WriteCSV(b, table.Tab, table.Headers) - // fmt.Printf("%s\n", string(b.Bytes())) } diff --git a/base/reflectx/pointers_test.go b/base/reflectx/pointers_test.go index 0238307a57..d4cff26ea0 100644 --- a/base/reflectx/pointers_test.go +++ b/base/reflectx/pointers_test.go @@ -244,14 +244,14 @@ func InitPointerTest() { pt.Mbr2 = 2 } -func FieldValue(obj any, fld reflect.StructField) reflect.Value { +func fieldValue(obj any, fld reflect.StructField) reflect.Value { ov := reflect.ValueOf(obj) f := unsafe.Pointer(ov.Pointer() + fld.Offset) nw := reflect.NewAt(fld.Type, f) return nw } -func SubFieldValue(obj any, fld reflect.StructField, sub reflect.StructField) reflect.Value { +func subFieldValue(obj any, fld reflect.StructField, sub reflect.StructField) reflect.Value { ov := reflect.ValueOf(obj) f := unsafe.Pointer(ov.Pointer() + fld.Offset + sub.Offset) nw := reflect.NewAt(sub.Type, f) @@ -263,7 +263,7 @@ func TestNewAt(t *testing.T) { InitPointerTest() typ := reflect.TypeOf(pt) fld, _ := typ.FieldByName("Mbr2") - vf := FieldValue(&pt, fld) + vf := fieldValue(&pt, fld) // fmt.Printf("Fld: %v Typ: %v vf: %v vfi: %v vfT: %v vfp: %v canaddr: %v canset: %v caninterface: %v\n", fld.Name, vf.Type().String(), vf.String(), vf.Interface(), vf.Interface(), vf.Interface(), vf.CanAddr(), vf.CanSet(), vf.CanInterface()) @@ -274,7 +274,7 @@ func TestNewAt(t *testing.T) { } fld, _ = typ.FieldByName("Mbr1") - vf = FieldValue(&pt, fld) + vf = fieldValue(&pt, fld) // fmt.Printf("Fld: %v Typ: %v vf: %v vfi: %v vfT: %v vfp: %v canaddr: %v canset: %v caninterface: %v\n", fld.Name, vf.Type().String(), vf.String(), vf.Interface(), vf.Interface(), vf.Interface(), vf.CanAddr(), vf.CanSet(), vf.CanInterface()) diff --git a/base/reflectx/structs.go b/base/reflectx/structs.go index e647b944bd..4a9f8af015 100644 --- a/base/reflectx/structs.go +++ b/base/reflectx/structs.go @@ -6,6 +6,7 @@ package reflectx import ( "fmt" + "log" "log/slog" "reflect" "strconv" @@ -229,3 +230,110 @@ func StructTags(tags reflect.StructTag) map[string]string { func StringJSON(v any) string { return string(errors.Log1(jsonx.WriteBytesIndent(v))) } + +// FieldValue returns the [reflect.Value] of given field within given struct value, +// where the field can be a path with . separators, for fields within struct fields. +func FieldValue(s reflect.Value, fieldPath string) (reflect.Value, error) { + sv := UnderlyingPointer(s) + var zv reflect.Value + if sv.Elem().Kind() != reflect.Struct { + return zv, errors.New("reflectx.FieldValue: kind is not struct") + } + fps := strings.Split(fieldPath, ".") + fv := sv.Elem().FieldByName(fps[0]) + if fv == zv { + return zv, errors.New("reflectx.FieldValue: field name not found: " + fps[0]) + } + if len(fps) == 1 { + return fv, nil + } + return FieldValue(fv, strings.Join(fps[1:], ".")) +} + +// CopyFields copies the named fields from src struct into dest struct. +// Fields can be paths with . separators for sub-fields of fields. +func CopyFields(dest, src any, fields ...string) error { + dsv := UnderlyingPointer(reflect.ValueOf(dest)) + if dsv.Elem().Kind() != reflect.Struct { + return errors.New("reflectx.CopyFields: destination kind is not struct") + } + ssv := UnderlyingPointer(reflect.ValueOf(src)) + if ssv.Elem().Kind() != reflect.Struct { + return errors.New("reflectx.CopyFields: source kind is not struct") + } + var errs []error + for _, f := range fields { + dfv, err := FieldValue(dsv, f) + if err != nil { + errs = append(errs, err) + continue + } + sfv, err := FieldValue(ssv, f) + if err != nil { + errs = append(errs, err) + continue + } + err = SetRobust(PointerValue(dfv).Interface(), sfv.Interface()) + if err != nil { + errs = append(errs, err) + continue + } + } + return errors.Join(errs...) +} + +// FieldAtPath parses a path to a field within the given struct, +// using . delimted field names, and returns the [reflect.Value] for +// the field. Returns an error if not found. +func FieldAtPath(val reflect.Value, path string) (reflect.Value, error) { + npv := NonPointerValue(val) + if npv.Kind() != reflect.Struct { + if !npv.IsValid() { + err := fmt.Errorf("FieldAtPath: struct is nil, for path: %q", path) + return npv, err + } + err := fmt.Errorf("FieldAtPath: object is not a struct: %q kind: %q, for path: %q", npv.String(), npv.Kind(), path) + return npv, err + } + paths := strings.Split(path, ".") + fnm := paths[0] + fld := npv.FieldByName(fnm) + if !fld.IsValid() { + err := fmt.Errorf("FieldAtPath: could not find Field named: %q in struct: %q kind: %q, path: %v", fnm, npv.String(), npv.Kind(), path) + return fld, err + } + if len(paths) == 1 { + return fld.Addr(), nil + } + return FieldAtPath(fld.Addr(), strings.Join(paths[1:], ".")) +} + +// SetFieldsFromMap sets given map[string]any values to fields of given object, +// where the map keys are field paths (with . delimiters for sub-field paths). +// The value can be any appropriate type that applies to the given field. +// It prints a message if a parameter fails to be set, and returns an error. +func SetFieldsFromMap(obj any, vals map[string]any) error { + objv := reflect.ValueOf(obj) + npv := NonPointerValue(objv) + if npv.Kind() == reflect.Map { + err := CopyMapRobust(obj, vals) + if err != nil { + log.Println(err) + return err + } + } + var errs []error + for k, v := range vals { + fld, err := FieldAtPath(objv, k) + if err != nil { + errs = append(errs, err) + } + err = SetRobust(fld.Interface(), v) + if err != nil { + err = fmt.Errorf("SetFieldsFromMap: was not able to apply value: %v to field: %s", v, k) + log.Println(err) + errs = append(errs, err) + } + } + return errors.Join(errs...) +} diff --git a/base/reflectx/structs_test.go b/base/reflectx/structs_test.go index e42b07f175..a6e64a0537 100644 --- a/base/reflectx/structs_test.go +++ b/base/reflectx/structs_test.go @@ -5,8 +5,12 @@ package reflectx import ( + "image" "reflect" "testing" + + "cogentcore.org/core/colors" + "github.com/stretchr/testify/assert" ) type person struct { @@ -53,3 +57,64 @@ func TestNonDefaultFields(t *testing.T) { t.Errorf("expected\n%v\n\tbut got\n%v", want, have) } } + +type imgfield struct { + Mycolor image.Image +} + +func TestCopyFields(t *testing.T) { + sp := &person{ + Name: "Go Gopher", + Age: 23, + ProgrammingLanguage: "Go", + FavoriteFruit: "Peach", + Data: "abcdef", + Pet: pet{ + Name: "Pet Gopher", + Type: "Dog", + Age: 7, + }, + } + dp := &person{} + CopyFields(dp, sp, "Name", "Pet.Age") + assert.Equal(t, sp.Name, dp.Name) + assert.Equal(t, sp.Pet.Age, dp.Pet.Age) + + sif := &imgfield{ + Mycolor: colors.Uniform(colors.Black), + } + dif := &imgfield{} + CopyFields(dif, sif, "Mycolor") + assert.Equal(t, sif.Mycolor, dif.Mycolor) +} + +func TestFieldAtPath(t *testing.T) { + sp := &person{ + Name: "Go Gopher", + Age: 23, + ProgrammingLanguage: "Go", + FavoriteFruit: "Peach", + Data: "abcdef", + Pet: pet{ + Name: "Pet Gopher", + Type: "Dog", + Age: 7, + }, + } + spv := reflect.ValueOf(sp) + fv, err := FieldAtPath(spv, "Pet.Age") + assert.NoError(t, err) + assert.Equal(t, 7, fv.Elem().Interface()) + fv, err = FieldAtPath(spv, "Pet.Name") + assert.NoError(t, err) + assert.Equal(t, "Pet Gopher", fv.Elem().Interface()) + fv, err = FieldAtPath(spv, "Pet.Ages") + assert.Error(t, err) + fv, err = FieldAtPath(spv, "Pets.Age") + assert.Error(t, err) + + err = SetFieldsFromMap(sp, map[string]any{"Pet.Age": 8, "Data": "ddd"}) + assert.NoError(t, err) + assert.Equal(t, 8, sp.Pet.Age) + assert.Equal(t, "ddd", sp.Data) +} diff --git a/base/reflectx/values.go b/base/reflectx/values.go index ebadbbbef6..459ced7c8c 100644 --- a/base/reflectx/values.go +++ b/base/reflectx/values.go @@ -47,6 +47,18 @@ func KindIsNumber(vk reflect.Kind) bool { return vk >= reflect.Int && vk <= reflect.Complex128 } +// KindIsInt returns whether the given [reflect.Kind] is an int +// type such as int, int32 etc. +func KindIsInt(vk reflect.Kind) bool { + return vk >= reflect.Int && vk <= reflect.Uintptr +} + +// KindIsFloat returns whether the given [reflect.Kind] is a +// float32 or float64. +func KindIsFloat(vk reflect.Kind) bool { + return vk >= reflect.Float32 && vk <= reflect.Float64 +} + // ToBool robustly converts to a bool any basic elemental type // (including pointers to such) using a big type switch organized // for greatest efficiency. It tries the [bools.Booler] @@ -939,7 +951,13 @@ func ToStringPrec(v any, prec int) string { // set to be fully equivalent to the source slice. func SetRobust(to, from any) error { rto := reflect.ValueOf(to) - pto := UnderlyingPointer(rto) + if IsNil(rto) { + return fmt.Errorf("got nil destination value") + } + pto := rto + if !(pto.Kind() == reflect.Pointer && pto.Elem().Kind() == reflect.Pointer) { + pto = UnderlyingPointer(rto) + } if IsNil(pto) { return fmt.Errorf("got nil destination value") } @@ -951,6 +969,15 @@ func SetRobust(to, from any) error { return fmt.Errorf("destination value cannot be set; it must be a variable or field, not a const or tmp or other value that cannot be set (value: %v of type %T)", pto, pto) } + // images should not be copied per content: just set the pointer! + // otherwise the original images (esp colors!) are altered. + if img, ok := to.(*image.Image); ok { + if fimg, ok := from.(image.Image); ok { + *img = fimg + return nil + } + } + // first we do the generic AssignableTo case if rto.Kind() == reflect.Pointer { fv := reflect.ValueOf(from) diff --git a/base/reflectx/values_test.go b/base/reflectx/values_test.go index 6a26f7bfdb..b959bf9ccf 100644 --- a/base/reflectx/values_test.go +++ b/base/reflectx/values_test.go @@ -146,6 +146,13 @@ func TestPointerSetRobust(t *testing.T) { t.Errorf(err.Error()) } assert.Equal(t, aptr, bptr) + + aptr = nil // also must work if dest pointer is nil + err = SetRobust(&aptr, bptr) + if err != nil { + t.Errorf(err.Error()) + } + assert.Equal(t, aptr, bptr) } func BenchmarkFloatToFloat(b *testing.B) { diff --git a/base/stringsx/stringsx.go b/base/stringsx/stringsx.go index 79dc0941ca..84450ffad8 100644 --- a/base/stringsx/stringsx.go +++ b/base/stringsx/stringsx.go @@ -8,6 +8,7 @@ package stringsx import ( "bytes" + "slices" "strings" ) @@ -88,3 +89,18 @@ func InsertFirstUnique(strs *[]string, str string, max int) { (*strs)[0] = str } } + +// DedupeList removes duplicates from given string list, +// preserving the order. +func DedupeList(strs []string) []string { + n := len(strs) + for i := n - 1; i >= 0; i-- { + p := strs[i] + for j, s := range strs { + if p == s && i != j { + strs = slices.Delete(strs, i, i+1) + } + } + } + return strs +} diff --git a/cli/cli.go b/cli/cli.go index 992245faf4..11f71aa3b6 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -40,6 +40,9 @@ func Run[T any, C CmdOrFunc[T]](opts *Options, cfg T, cmds ...C) error { } return err } + if len(cmds) == 1 { // one command is always the root + cs[0].Root = true + } cmd, err := config(opts, cfg, cs...) if err != nil { if opts.Fatal { diff --git a/cmd/core/cmd/build.go b/cmd/core/cmd/build.go index 614fe40a8a..848173b862 100644 --- a/cmd/core/cmd/build.go +++ b/cmd/core/cmd/build.go @@ -78,9 +78,21 @@ func buildDesktop(c *config.Config, platform config.Platform) error { } } ldflags += " " + config.LinkerFlags(c) + + inCmd := false + fi, err := os.Stat("cmd") + if err == nil && fi.IsDir() { + os.Chdir("cmd") + inCmd = true + } + args = append(args, "-ldflags", ldflags, "-o", filepath.Join(c.Build.Output, output)) - err := xc.Run("go", args...) + err = xc.Run("go", args...) + if inCmd { + os.Rename(output, filepath.Join("..", output)) + os.Chdir("../") + } if err != nil { return fmt.Errorf("error building for platform %s/%s: %w", platform.OS, platform.Arch, err) } diff --git a/core/bars.go b/core/bars.go index afa2d491b6..3910d93a19 100644 --- a/core/bars.go +++ b/core/bars.go @@ -114,8 +114,7 @@ func (sc *Scene) addDefaultBars() { } } -////////////////////////////////////////////////////////////// -// Scene wrappers +//////// Scene wrappers // AddTopBar adds the given function for configuring a control bar // at the top of the window diff --git a/core/events.go b/core/events.go index 19528e9e17..905124cb64 100644 --- a/core/events.go +++ b/core/events.go @@ -474,6 +474,10 @@ func (em *Events) handlePosEvent(e events.Event) { if !sentMulti { em.lastDoubleClickWidget = nil em.lastClickWidget = up + if em.focus != up { + em.focusClear() // always any other focus before the click is processed + // this causes textfields etc to apply their changes. + } up.AsWidget().Send(events.Click, e) } case events.Right: // note: automatically gets Control+Left diff --git a/core/filepicker.go b/core/filepicker.go index f3c1406b0b..69a48f2a16 100644 --- a/core/filepicker.go +++ b/core/filepicker.go @@ -21,6 +21,7 @@ import ( "cogentcore.org/core/base/elide" "cogentcore.org/core/base/errors" "cogentcore.org/core/base/fileinfo" + "cogentcore.org/core/base/fsx" "cogentcore.org/core/colors" "cogentcore.org/core/cursors" "cogentcore.org/core/events" @@ -705,7 +706,7 @@ func (fp *FilePicker) editRecentPaths() { // Filename is used to specify an file path. // It results in a [FileButton] [Value]. -type Filename string +type Filename = fsx.Filename // FileButton represents a [Filename] value with a button // that opens a [FilePicker]. diff --git a/core/form.go b/core/form.go index 9785dbfd56..fb4a2c0038 100644 --- a/core/form.go +++ b/core/form.go @@ -30,6 +30,11 @@ type Form struct { // Inline is whether to display the form in one line. Inline bool + // Modified optionally highlights and tracks fields that have been modified + // through an OnChange event. If present, it replaces the default value highlighting + // and resetting logic. Ignored if nil. + Modified map[string]bool + // structFields are the fields of the current struct. structFields []*structField @@ -179,13 +184,20 @@ func (fm *Form) Init() { // (see https://github.com/cogentcore/core/issues/1098). doc, _ := types.GetDoc(f.value, f.parent, f.field, label) w.SetTooltip(doc) - if hasDef { - w.SetTooltip("(Default: " + def + ") " + w.Tooltip) + if hasDef || fm.Modified != nil { + if hasDef { + w.SetTooltip("(Default: " + def + ") " + w.Tooltip) + } var isDef bool w.Styler(func(s *styles.Style) { f := fm.structFields[i] - isDef = reflectx.ValueIsDefault(f.value, def) dcr := "(Double click to reset to default) " + if fm.Modified != nil { + isDef = !fm.Modified[f.path] + dcr = "(Double click to mark as non-modified) " + } else { + isDef = reflectx.ValueIsDefault(f.value, def) + } if !isDef { s.Color = colors.Scheme.Primary.Base s.Cursor = cursors.Poof @@ -202,13 +214,20 @@ func (fm *Form) Init() { return } e.SetHandled() - err := reflectx.SetFromDefaultTag(f.value, def) + var err error + if fm.Modified != nil { + fm.Modified[f.path] = false + } else { + err = reflectx.SetFromDefaultTag(f.value, def) + } if err != nil { ErrorSnackbar(w, err, "Error setting default value") } else { w.Update() valueWidget.AsWidget().Update() - valueWidget.AsWidget().SendChange(e) + if fm.Modified == nil { + valueWidget.AsWidget().SendChange(e) + } } }) } @@ -243,8 +262,11 @@ func (fm *Form) Init() { }) if !fm.IsReadOnly() && !readOnlyTag { wb.OnChange(func(e events.Event) { + if fm.Modified != nil { + fm.Modified[f.path] = true + } fm.SendChange(e) - if hasDef { + if hasDef || fm.Modified != nil { labelWidget.Update() } if fm.isShouldDisplayer { diff --git a/core/form_test.go b/core/form_test.go index 91942c95e8..280d2a03e3 100644 --- a/core/form_test.go +++ b/core/form_test.go @@ -7,6 +7,7 @@ package core import ( "testing" + "cogentcore.org/core/colors" "cogentcore.org/core/events" "cogentcore.org/core/styles" "cogentcore.org/core/styles/abilities" @@ -73,3 +74,30 @@ func TestFormStyle(t *testing.T) { NewForm(b).SetStruct(s) b.AssertRender(t, "form/style") } + +type giveUpParams struct { + ProbThr float32 + MinGiveUpSum float32 + Utility float32 + Timing float32 + Progress float32 + MinUtility float32 + ProgressRateTau float32 + ProgressRateDt float32 +} + +type addFields struct { + GiveUp giveUpParams `display:"add-fields"` +} + +func TestFormAddFields(t *testing.T) { + AppearanceSettings.Spacing = 30 + b := NewBody() + b.Styler(func(s *styles.Style) { + s.Min.X.Ch(100) + }) + NewForm(b).SetStruct(&addFields{}).Styler(func(s *styles.Style) { + s.Background = colors.Scheme.SurfaceContainerLow + }) + b.AssertRender(t, "form/addfields") +} diff --git a/core/layout.go b/core/layout.go index e8d938bd4a..6a0fdaad82 100644 --- a/core/layout.go +++ b/core/layout.go @@ -140,8 +140,8 @@ type Layouter interface { SetScrollParams(d math32.Dims, sb *Slider) } -// AsFrame returns the given value as a value of type [Frame] if the type -// of the given value embeds [Frame], or nil otherwise. +// AsFrame returns the given value as a [Frame] if it has +// an AsFrame() method, or nil otherwise. func AsFrame(n tree.Node) *Frame { if t, ok := n.(Layouter); ok { return t.AsFrame() @@ -149,7 +149,6 @@ func AsFrame(n tree.Node) *Frame { return nil } -// AsFrame satisfies the [Layouter] interface. func (t *Frame) AsFrame() *Frame { return t } @@ -295,13 +294,13 @@ func (ls *geomState) contentRangeDim(d math32.Dims) (cmin, cmax float32) { // totalRect returns Pos.Total -- Size.Actual.Total // as an image.Rectangle, e.g., for bounding box func (ls *geomState) totalRect() image.Rectangle { - return math32.RectFromPosSizeMax(ls.Pos.Total, ls.Size.Actual.Total) + return math32.RectFromPosSizeMax(ls.Pos.Total, ls.Size.Alloc.Total) } // contentRect returns Pos.Content, Size.Actual.Content // as an image.Rectangle, e.g., for bounding box. func (ls *geomState) contentRect() image.Rectangle { - return math32.RectFromPosSizeMax(ls.Pos.Content, ls.Size.Actual.Content) + return math32.RectFromPosSizeMax(ls.Pos.Content, ls.Size.Alloc.Content) } // ScrollOffset computes the net scrolling offset as a function of diff --git a/core/list.go b/core/list.go index 707d9f7f5e..4f5b6f1399 100644 --- a/core/list.go +++ b/core/list.go @@ -90,7 +90,7 @@ type Lister interface { // SliceIndex returns the logical slice index: si = i + StartIndex, // the actual value index vi into the slice value (typically = si), // which can be different if there is an index indirection as in - // tensorcore table.IndexView), and a bool that is true if the + // tensorcore table.Table), and a bool that is true if the // index is beyond the available data and is thus invisible, // given the row index provided. SliceIndex(i int) (si, vi int, invis bool) @@ -817,7 +817,7 @@ func (lb *ListBase) MakeToolbar(p *tree.Plan) { }) } -//////////////////////////////////////////////////////////// +//////// // Row access methods // NOTE: row = physical GUI display row, idx = slice index // not the same! @@ -1062,8 +1062,7 @@ func (lb *ListBase) movePageUpEvent(selMode events.SelectModes) int { return nidx } -////////////////////////////////////////////////////////// -// Selection: user operates on the index labels +//////// Selection: user operates on the index labels // updateSelectRow updates the selection for the given row func (lb *ListBase) updateSelectRow(row int, selMode events.SelectModes) { @@ -1251,8 +1250,7 @@ func (lb *ListBase) unselectIndexEvent(idx int) { } } -/////////////////////////////////////////////////// -// Copy / Cut / Paste +//////// Copy / Cut / Paste // mimeDataIndex adds mimedata for given idx: an application/json of the struct func (lb *ListBase) mimeDataIndex(md *mimedata.Mimes, idx int) { @@ -1438,8 +1436,7 @@ func (lb *ListBase) duplicate() int { //types:add return pasteAt } -////////////////////////////////////////////////////////////////////////////// -// Drag-n-Drop +//////// Drag-n-Drop // selectRowIfNone selects the row the mouse is on if there // are no currently selected items. Returns false if no valid mouse row. diff --git a/core/render.go b/core/render.go index 5eec22c456..dff2d2bbaa 100644 --- a/core/render.go +++ b/core/render.go @@ -160,8 +160,7 @@ func (wb *WidgetBase) doNeedsRender() { }) } -////////////////////////////////////////////////////////////////// -// Scene +//////// Scene var sceneShowIters = 2 @@ -274,8 +273,7 @@ func (sc *Scene) contentSize(initSz image.Point) image.Point { return psz.ToPointFloor() } -////////////////////////////////////////////////////////////////// -// Widget local rendering +//////// Widget local rendering // PushBounds pushes our bounding box bounds onto the bounds stack // if they are non-empty. This automatically limits our drawing to diff --git a/core/renderwindow.go b/core/renderwindow.go index 4f9665f1e7..6069f916a7 100644 --- a/core/renderwindow.go +++ b/core/renderwindow.go @@ -284,6 +284,8 @@ func (w *renderWindow) resized() { if DebugSettings.WindowEventTrace { fmt.Printf("Win: %v skipped same-size Resized: %v\n", w.name, curRg) } + rc.logicalDPI = w.logicalDPI() + w.mains.resize(rg) // no-op if everyone below is good // still need to apply style even if size is same for _, kv := range w.mains.stack.Order { st := kv.Value diff --git a/core/scene.go b/core/scene.go index 49fa2e6a0a..52a72359ce 100644 --- a/core/scene.go +++ b/core/scene.go @@ -210,6 +210,9 @@ func (sc *Scene) Init() { currentRenderWindow.SetStageTitle(st.Title) }) sc.Updater(func() { + if TheApp.Platform() == system.Offscreen { + return + } // At the scene level, we reset the shortcuts and add our context menu // shortcuts every time. This clears the way for buttons to add their // shortcuts in their own Updaters. We must get the shortcuts every time diff --git a/core/settings.go b/core/settings.go index f46726786b..ccd0db4450 100644 --- a/core/settings.go +++ b/core/settings.go @@ -645,8 +645,7 @@ type EditorSettings struct { //types:add DepthColor bool `default:"true"` } -////////////////////////////////////////////////////////////////// -// FavoritePaths +//////// FavoritePaths // favoritePathItem represents one item in a favorite path list, for display of // favorites. Is an ordered list instead of a map because user can organize @@ -696,8 +695,7 @@ var defaultPaths = favoritePaths{ {icons.Computer, "root", "/"}, } -////////////////////////////////////////////////////////////////// -// FilePaths +//////// FilePaths // FilePaths represents a set of file paths. type FilePaths []string @@ -746,8 +744,7 @@ func openRecentPaths() { } } -////////////////////////////////////////////////////////////////// -// DebugSettings +//////// DebugSettings // DebugSettings are the currently active debugging settings var DebugSettings = &DebugSettingsData{ diff --git a/core/tabs.go b/core/tabs.go index a29530e40a..1aba2856ac 100644 --- a/core/tabs.go +++ b/core/tabs.go @@ -20,6 +20,33 @@ import ( "cogentcore.org/core/tree" ) +// Tabber is an interface for getting the parent Tabs of tab buttons. +// It exposes the main Tabs interface so other packages can build on that +// to provide an augmented Tabs API. +type Tabber interface { + + // AsCoreTabs returns the underlying Tabs implementation. + AsCoreTabs() *Tabs + + // CurrentTab returns currently selected tab and its index; returns nil if none. + CurrentTab() (Widget, int) + + // TabByName returns a tab with the given name, nil if not found. + TabByName(name string) *Frame + + // SelectTabIndex selects the tab at the given index, returning it or nil. + // This is the final tab selection path. + SelectTabIndex(idx int) *Frame + + // SelectTabByName selects the tab by widget name, returning it. + // The widget name is the original full tab label, prior to any eliding. + SelectTabByName(name string) *Frame + + // RecycleTab returns a tab with the given name, first by looking for an existing one, + // and if not found, making a new one. It returns the frame for the tab. + RecycleTab(name string) *Frame +} + // Tabs divide widgets into logical groups and give users the ability // to freely navigate between them using tab buttons. type Tabs struct { @@ -101,6 +128,8 @@ func (tt TabTypes) isColumn() bool { return tt == NavigationDrawer } +func (ts *Tabs) AsCoreTabs() *Tabs { return ts } + func (ts *Tabs) Init() { ts.Frame.Init() ts.maxChars = 16 @@ -544,5 +573,8 @@ func (tb *Tab) Init() { // tabs returns the parent [Tabs] of this [Tab]. func (tb *Tab) tabs() *Tabs { - return tb.Parent.AsTree().Parent.(*Tabs) + if tbr, ok := tb.Parent.AsTree().Parent.(Tabber); ok { + return tbr.AsCoreTabs() + } + return nil } diff --git a/core/textfield.go b/core/textfield.go index c5f4c77486..525f721d28 100644 --- a/core/textfield.go +++ b/core/textfield.go @@ -585,8 +585,7 @@ func (tf *TextField) WidgetTooltip(pos image.Point) (string, image.Point) { return tf.error.Error(), tf.DefaultTooltipPos() } -////////////////////////////////////////////////////////////////////////////////////////// -// Cursor Navigation +//////// Cursor Navigation // cursorForward moves the cursor forward func (tf *TextField) cursorForward(steps int) { @@ -856,8 +855,7 @@ func (tf *TextField) cursorKill() { tf.cursorDelete(steps) } -/////////////////////////////////////////////////////////////////////////////// -// Selection +//////// Selection // clearSelected resets both the global selected flag and any current selection func (tf *TextField) clearSelected() { @@ -1098,8 +1096,7 @@ func (tf *TextField) contextMenu(m *Scene) { } } -/////////////////////////////////////////////////////////////////////////////// -// Undo +//////// Undo // textFieldUndoRecord holds one undo record type textFieldUndoRecord struct { @@ -1190,8 +1187,7 @@ func (tf *TextField) redo() { } } -/////////////////////////////////////////////////////////////////////////////// -// Complete +//////// Complete // SetCompleter sets completion functions so that completions will // automatically be offered as the user types. @@ -1241,8 +1237,7 @@ func (tf *TextField) completeText(s string) { tf.editDone() } -/////////////////////////////////////////////////////////////////////////////// -// Rendering +//////// Rendering // hasWordWrap returns true if the layout is multi-line word wrapping func (tf *TextField) hasWordWrap() bool { @@ -1461,7 +1456,7 @@ func (tf *TextField) autoScroll() { availSz := sz.Actual.Content.Sub(icsz) tf.configTextSize(availSz) n := len(tf.editText) - tf.cursorPos = math32.ClampInt(tf.cursorPos, 0, n) + tf.cursorPos = math32.Clamp(tf.cursorPos, 0, n) if tf.hasWordWrap() { // does not scroll tf.startPos = 0 diff --git a/core/tree.go b/core/tree.go index 5b06a3f728..d60096d3da 100644 --- a/core/tree.go +++ b/core/tree.go @@ -60,8 +60,8 @@ type Treer interface { //types:add DropDeleteSource(e events.Event) } -// AsTree returns the given value as a value of type [Tree] if the type -// of the given value embeds [Tree], or nil otherwise. +// AsTree returns the given value as a [Tree] if it has +// an AsCoreTree() method, or nil otherwise. func AsTree(n tree.Node) *Tree { if t, ok := n.(Treer); ok { return t.AsCoreTree() @@ -123,7 +123,7 @@ type Tree struct { // with each child tree node when it is initialized. It is only // called with the root node itself in [Tree.SetTreeInit], so you // should typically call that instead of setting this directly. - TreeInit func(tr *Tree) `set:"-"` + TreeInit func(tr *Tree) `set:"-" json:"-" xml:"-"` // Indent is the amount to indent children relative to this node. // It should be set in a Styler like all other style properties. @@ -151,8 +151,8 @@ type Tree struct { // our alloc includes all of our children, but we only draw us. widgetSize math32.Vector2 - // root is the cached root of the tree. It is automatically set. - root *Tree + // Root is the cached root of the tree. It is automatically set. + Root Treer `copier:"-" json:"-" xml:"-" edit:"-" set:"-"` // SelectedNodes holds the currently selected nodes. // It is only set on the root node. See [Tree.GetSelectedNodes] @@ -186,7 +186,9 @@ func (tr *Tree) rootSetViewIndex() int { tvn := AsTree(cw) if tvn != nil { tvn.viewIndex = idx - tvn.root = tr + if tvn.Root == nil { + tvn.Root = tr + } idx++ } return tree.Continue @@ -480,15 +482,18 @@ func (tr *Tree) OnAdd() { tr.WidgetBase.OnAdd() tr.Text = tr.Name if ptv := AsTree(tr.Parent); ptv != nil { - tr.root = ptv.root + tr.Root = ptv.Root tr.IconOpen = ptv.IconOpen tr.IconClosed = ptv.IconClosed tr.IconLeaf = ptv.IconLeaf } else { - tr.root = tr + if tr.Root == nil { + tr.Root = tr + } } - if tr.root.TreeInit != nil { - tr.root.TreeInit(tr) + troot := tr.Root.AsCoreTree() + if troot.TreeInit != nil { + troot.TreeInit(tr) } } @@ -507,10 +512,10 @@ func (tr *Tree) SetTreeInit(v func(tr *Tree)) *Tree { // which is what controls the functional inactivity of the tree // if individual nodes are ReadOnly that only affects display typically. func (tr *Tree) rootIsReadOnly() bool { - if tr.root == nil { + if tr.Root == nil { return true } - return tr.root.IsReadOnly() + return tr.Root.AsCoreTree().IsReadOnly() } func (tr *Tree) Style() { @@ -549,8 +554,8 @@ func (tr *Tree) SizeUp() { tr.widgetSize = tr.Geom.Size.Actual.Total h := tr.widgetSize.Y w := tr.widgetSize.X - if tr.root.This == tr.This { // do it every time on root - tr.root.rootSetViewIndex() + if tr.IsRoot() { // do it every time on root + tr.rootSetViewIndex() } if !tr.Closed { @@ -583,11 +588,11 @@ func (tr *Tree) SizeDown(iter int) bool { } func (tr *Tree) Position() { - rn := tr.root - if rn == nil { + if tr.Root == nil { slog.Error("core.Tree: RootView is nil", "in node:", tr) return } + rn := tr.Root.AsCoreTree() tr.setBranchState() sz := &tr.Geom.Size sz.Actual.Total.X = rn.Geom.Size.Actual.Total.X - (tr.Geom.Pos.Total.X - rn.Geom.Pos.Total.X) @@ -658,26 +663,26 @@ func (tr *Tree) RenderWidget() { } } -////////////////////////////////////////////////////////////////////////////// -// Selection +//////// Selection // GetSelectedNodes returns a slice of the currently selected // Trees within the entire tree, using a list maintained // by the root node. func (tr *Tree) GetSelectedNodes() []Treer { - if tr.root == nil { + if tr.Root == nil { return nil } - if len(tr.root.SelectedNodes) == 0 { - return tr.root.SelectedNodes + rn := tr.Root.AsCoreTree() + if len(rn.SelectedNodes) == 0 { + return rn.SelectedNodes } - return tr.root.SelectedNodes + return rn.SelectedNodes } // SetSelectedNodes updates the selected nodes on the root node to the given list. func (tr *Tree) SetSelectedNodes(sl []Treer) { - if tr.root != nil { - tr.root.SelectedNodes = sl + if tr.Root != nil { + tr.Root.AsCoreTree().SelectedNodes = sl } } @@ -743,7 +748,7 @@ func (tr *Tree) SelectAll() { return } tr.UnselectAll() - nn := tr.root + nn := tr.Root.AsCoreTree() nn.Select() for nn != nil { nn = nn.moveDown(events.SelectQuiet) @@ -830,26 +835,27 @@ func (tr *Tree) selectUpdate(mode events.SelectModes) bool { // sendSelectEvent sends an [events.Select] event on both this node and the root node. func (tr *Tree) sendSelectEvent(original ...events.Event) { - if tr.This != tr.root.This { + if !tr.IsRoot() { tr.Send(events.Select, original...) } - tr.root.Send(events.Select, original...) + tr.Root.AsCoreTree().Send(events.Select, original...) } // sendChangeEvent sends an [events.Change] event on both this node and the root node. func (tr *Tree) sendChangeEvent(original ...events.Event) { - if tr.This != tr.root.This { + if !tr.IsRoot() { tr.SendChange(original...) } - tr.root.SendChange(original...) + tr.Root.AsCoreTree().SendChange(original...) } // sendChangeEventReSync sends an [events.Change] event on the RootView node. // If SyncNode != nil, it also does a re-sync from root. func (tr *Tree) sendChangeEventReSync(original ...events.Event) { tr.sendChangeEvent(original...) - if tr.root.SyncNode != nil { - tr.root.Resync() + rn := tr.Root.AsCoreTree() + if rn.SyncNode != nil { + rn.Resync() } } @@ -873,8 +879,7 @@ func (tr *Tree) UnselectEvent() { } } -////////////////////////////////////////////////////////////////////////////// -// Moving +//////// Moving // moveDown moves the selection down to next element in the tree, // using given select mode (from keyboard modifiers). @@ -916,7 +921,7 @@ func (tr *Tree) moveDownSibling(selMode events.SelectModes) *Tree { if tr.Parent == nil { return nil } - if tr == tr.root { + if tr == tr.Root { return nil } myidx := tr.IndexInParent() @@ -936,7 +941,7 @@ func (tr *Tree) moveDownSibling(selMode events.SelectModes) *Tree { // using given select mode (from keyboard modifiers). // Returns newly selected node func (tr *Tree) moveUp(selMode events.SelectModes) *Tree { - if tr.Parent == nil || tr == tr.root { + if tr.Parent == nil || tr == tr.Root { return nil } myidx := tr.IndexInParent() @@ -1038,7 +1043,7 @@ func (tr *Tree) movePageDownEvent(selMode events.SelectModes) *Tree { // moveToLastChild moves to the last child under me, using given select mode // (from keyboard modifiers) func (tr *Tree) moveToLastChild(selMode events.SelectModes) *Tree { - if tr.Parent == nil || tr == tr.root { + if tr.Parent == nil || tr == tr.Root { return nil } if !tr.Closed && tr.HasChildren() { @@ -1054,11 +1059,12 @@ func (tr *Tree) moveToLastChild(selMode events.SelectModes) *Tree { // using given select mode (from keyboard modifiers) // and emits select event for newly selected item func (tr *Tree) moveHomeEvent(selMode events.SelectModes) *Tree { - tr.root.selectUpdate(selMode) - tr.root.SetFocusQuiet() - tr.root.ScrollToThis() - tr.root.sendSelectEvent() - return tr.root + rn := tr.Root.AsCoreTree() + rn.selectUpdate(selMode) + rn.SetFocusQuiet() + rn.ScrollToThis() + rn.sendSelectEvent() + return rn } // moveEndEvent moves the selection to the very last node in the tree, @@ -1195,8 +1201,7 @@ func (tr *Tree) OpenParents() { tr.NeedsLayout() } -///////////////////////////////////////////////////////////// -// Modifying Source Tree +//////// Modifying Source Tree func (tr *Tree) ContextMenuPos(e events.Event) (pos image.Point) { if e != nil { @@ -1249,18 +1254,17 @@ func (tr *Tree) contextMenu(m *Scene) { // IsRoot returns true if given node is the root of the tree, // creating an error snackbar if it is and action is non-empty. -func (tr *Tree) IsRoot(action string) bool { - if tr.This == tr.root.This { - if action != "" { - MessageSnackbar(tr, fmt.Sprintf("Cannot %v the root of the tree", action)) +func (tr *Tree) IsRoot(action ...string) bool { + if tr.This == tr.Root.AsCoreTree().This { + if len(action) == 1 { + MessageSnackbar(tr, fmt.Sprintf("Cannot %v the root of the tree", action[0])) } return true } return false } -//////////////////////////////////////////////////////////// -// Copy / Cut / Paste +//////// Copy / Cut / Paste // MimeData adds mimedata for this node: a text/plain of the Path. func (tr *Tree) MimeData(md *mimedata.Mimes) { @@ -1268,7 +1272,7 @@ func (tr *Tree) MimeData(md *mimedata.Mimes) { tr.mimeDataSync(md) return } - *md = append(*md, mimedata.NewTextData(tr.PathFrom(tr.root))) + *md = append(*md, mimedata.NewTextData(tr.PathFrom(tr.Root.AsCoreTree()))) var buf bytes.Buffer err := jsonx.Write(tr.This, &buf) if err == nil { @@ -1326,13 +1330,13 @@ func (tr *Tree) Cut() { //types:add } tr.Copy() sels := tr.GetSelectedNodes() - root := tr.root + rn := tr.Root.AsCoreTree() tr.UnselectAll() for _, sn := range sels { sn.AsTree().Delete() } - root.Update() - root.sendChangeEvent() + rn.Update() + rn.sendChangeEvent() } // Paste pastes clipboard at given node. @@ -1370,7 +1374,7 @@ func (tr *Tree) makePasteMenu(m *Scene, md mimedata.Mimes, fun func()) { fun() } }) - if !tr.IsRoot("") && tr.root.This != tr.This { + if !tr.IsRoot() { NewButton(m).SetText("Insert Before").OnClick(func(e events.Event) { tr.pasteBefore(md, events.DropCopy) if fun != nil { @@ -1459,10 +1463,10 @@ func (tr *Tree) pasteAt(md mimedata.Mimes, mod events.DropMods, rel int, actNm s parent.InsertChild(ns, myidx+i) nwb := AsWidget(ns) ntv := AsTree(ns) - ntv.root = tr.root + ntv.Root = tr.Root nwb.setScene(tr.Scene) nwb.Update() // incl children - npath := ns.AsTree().PathFrom(tr.root) + npath := ns.AsTree().PathFrom(tr.Root) if mod == events.DropMove && npath == orgpath { // we will be nuked immediately after drag ns.AsTree().SetName(ns.AsTree().Name + treeTempMovedTag) // special keyword :) } @@ -1490,7 +1494,7 @@ func (tr *Tree) pasteChildren(md mimedata.Mimes, mod events.DropMods) { tr.AddChild(ns) nwb := AsWidget(ns) ntv := AsTree(ns) - ntv.root = tr.root + ntv.Root = tr.Root nwb.setScene(tr.Scene) } tr.Update() @@ -1498,8 +1502,7 @@ func (tr *Tree) pasteChildren(md mimedata.Mimes, mod events.DropMods) { tr.sendChangeEvent() } -////////////////////////////////////////////////////////////////////////////// -// Drag-n-Drop +//////// Drag-n-Drop // dragStart starts a drag-n-drop on this node -- it includes any other // selected nodes as well, each as additional records in mimedata. @@ -1568,17 +1571,17 @@ func (tr *Tree) DropDeleteSource(e events.Event) { return } md := de.Data.(mimedata.Mimes) - root := tr.root + rn := tr.Root.AsCoreTree() for _, d := range md { if d.Type != fileinfo.TextPlain { // link continue } path := string(d.Data) - sn := root.FindPath(path) + sn := rn.FindPath(path) if sn != nil { sn.AsTree().Delete() } - sn = root.FindPath(path + treeTempMovedTag) + sn = rn.FindPath(path + treeTempMovedTag) if sn != nil { psplt := strings.Split(path, "/") orgnm := psplt[len(psplt)-1] diff --git a/core/treesync.go b/core/treesync.go index 1cda830ea7..6eebb22df1 100644 --- a/core/treesync.go +++ b/core/treesync.go @@ -62,13 +62,13 @@ func (tr *Tree) Resync() { func (tr *Tree) syncToSrc(tvIndex *int, init bool, depth int) { sn := tr.SyncNode // root must keep the same name for continuity with surrounding context - if tr != tr.root { + if tr != tr.Root { nm := "tv_" + sn.AsTree().Name tr.SetName(nm) } tr.viewIndex = *tvIndex *tvIndex++ - if init && depth >= tr.root.OpenDepth { + if init && depth >= tr.Root.AsCoreTree().OpenDepth { tr.SetClosed(true) } skids := sn.AsTree().Children @@ -377,7 +377,7 @@ func (tr *Tree) inspectNode() { //types:add // mimeDataSync adds mimedata for this node: a text/plain of the Path, // and an application/json of the sync node. func (tr *Tree) mimeDataSync(md *mimedata.Mimes) { - sroot := tr.root.SyncNode + sroot := tr.Root.AsCoreTree().SyncNode src := tr.SyncNode *md = append(*md, mimedata.NewTextData(src.AsTree().PathFrom(sroot))) var buf bytes.Buffer @@ -435,7 +435,7 @@ func (tr *Tree) pasteAtSync(md mimedata.Mimes, mod events.DropMods, rel int, act return } myidx += rel - sroot := tr.root.SyncNode + sroot := tr.Root.AsCoreTree().SyncNode sz := len(sl) var seln tree.Node for i, ns := range sl { @@ -488,7 +488,7 @@ func (tr *Tree) cutSync() { // dropDeleteSourceSync handles delete source event for DropMove case, for Sync func (tr *Tree) dropDeleteSourceSync(de *events.DragDrop) { md := de.Data.(mimedata.Mimes) - sroot := tr.root.SyncNode + sroot := tr.Root.AsCoreTree().SyncNode for _, d := range md { if d.Type != fileinfo.TextPlain { // link continue diff --git a/core/typegen.go b/core/typegen.go index 0d3a07e588..0ffc6c3b35 100644 --- a/core/typegen.go +++ b/core/typegen.go @@ -237,7 +237,7 @@ func (t *FileButton) SetFilename(v string) *FileButton { t.Filename = v; return // Extensions are the target file extensions for the file picker. func (t *FileButton) SetExtensions(v string) *FileButton { t.Extensions = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/core.Form", IDName: "form", Doc: "Form represents a struct with rows of field names and editable values.", Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "Struct", Doc: "Struct is the pointer to the struct that we are viewing."}, {Name: "Inline", Doc: "Inline is whether to display the form in one line."}, {Name: "structFields", Doc: "structFields are the fields of the current struct."}, {Name: "isShouldDisplayer", Doc: "isShouldDisplayer is whether the struct implements [ShouldDisplayer], which results\nin additional updating being done at certain points."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/core.Form", IDName: "form", Doc: "Form represents a struct with rows of field names and editable values.", Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "Struct", Doc: "Struct is the pointer to the struct that we are viewing."}, {Name: "Inline", Doc: "Inline is whether to display the form in one line."}, {Name: "Modified", Doc: "Modified optionally highlights and tracks fields that have been modified\nthrough an OnChange event. If present, it replaces the default value highlighting\nand resetting logic. Ignored if nil."}, {Name: "structFields", Doc: "structFields are the fields of the current struct."}, {Name: "isShouldDisplayer", Doc: "isShouldDisplayer is whether the struct implements [ShouldDisplayer], which results\nin additional updating being done at certain points."}}}) // NewForm returns a new [Form] with the given optional parent: // Form represents a struct with rows of field names and editable values. @@ -251,6 +251,12 @@ func (t *Form) SetStruct(v any) *Form { t.Struct = v; return t } // Inline is whether to display the form in one line. func (t *Form) SetInline(v bool) *Form { t.Inline = v; return t } +// SetModified sets the [Form.Modified]: +// Modified optionally highlights and tracks fields that have been modified +// through an OnChange event. If present, it replaces the default value highlighting +// and resetting logic. Ignored if nil. +func (t *Form) SetModified(v map[string]bool) *Form { t.Modified = v; return t } + var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/core.Frame", IDName: "frame", Doc: "Frame is the primary node type responsible for organizing the sizes\nand positions of child widgets. It also renders the standard box model.\nAll collections of widgets should generally be contained within a [Frame];\notherwise, the parent widget must take over responsibility for positioning.\nFrames automatically can add scrollbars depending on the [styles.Style.Overflow].\n\nFor a [styles.Grid] frame, the [styles.Style.Columns] property should\ngenerally be set to the desired number of columns, from which the number of rows\nis computed; otherwise, it uses the square root of number of\nelements.", Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "StackTop", Doc: "StackTop, for a [styles.Stacked] frame, is the index of the node to use\nas the top of the stack. Only the node at this index is rendered; if it is\nnot a valid index, nothing is rendered."}, {Name: "LayoutStackTopOnly", Doc: "LayoutStackTopOnly is whether to only layout the top widget\n(specified by [Frame.StackTop]) for a [styles.Stacked] frame.\nThis is appropriate for widgets such as [Tabs], which do a full\nredraw on stack changes, but not for widgets such as [Switch]es\nwhich don't."}, {Name: "layout", Doc: "layout contains implementation state info for doing layout"}, {Name: "HasScroll", Doc: "HasScroll is whether scrollbars exist for each dimension."}, {Name: "scrolls", Doc: "scrolls are the scroll bars, which are fully managed as needed."}, {Name: "focusName", Doc: "accumulated name to search for when keys are typed"}, {Name: "focusNameTime", Doc: "time of last focus name event; for timeout"}, {Name: "focusNameLast", Doc: "last element focused on; used as a starting point if name is the same"}}}) // NewFrame returns a new [Frame] with the given optional parent: @@ -1177,7 +1183,7 @@ func NewToolbar(parent ...tree.Node) *Toolbar { return tree.New[Toolbar](parent. var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/core.Treer", IDName: "treer", Doc: "Treer is an interface for [Tree] types\nproviding access to the base [Tree] and\noverridable method hooks for actions taken on the [Tree],\nincluding OnOpen, OnClose, etc.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "AsCoreTree", Doc: "AsTree returns the base [Tree] for this node.", Returns: []string{"Tree"}}, {Name: "CanOpen", Doc: "CanOpen returns true if the node is able to open.\nBy default it checks HasChildren(), but could check other properties\nto perform lazy building of the tree.", Returns: []string{"bool"}}, {Name: "OnOpen", Doc: "OnOpen is called when a node is opened.\nThe base version does nothing."}, {Name: "OnClose", Doc: "OnClose is called when a node is closed\nThe base version does nothing."}, {Name: "MimeData", Args: []string{"md"}}, {Name: "Cut"}, {Name: "Copy"}, {Name: "Paste"}, {Name: "DragDrop", Args: []string{"e"}}, {Name: "DropDeleteSource", Args: []string{"e"}}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/core.Tree", IDName: "tree", Doc: "Tree provides a graphical representation of a tree structure,\nproviding full navigation and manipulation abilities.\n\nIt does not handle layout by itself, so if you want it to scroll\nseparately from the rest of the surrounding context, you must\nplace it in a [Frame].\n\nIf the [Tree.SyncNode] field is non-nil, typically via the\n[Tree.SyncTree] method, then the Tree mirrors another\ntree structure, and tree editing functions apply to\nthe source tree first, and then to the Tree by sync.\n\nOtherwise, data can be directly encoded in a Tree\nderived type, to represent any kind of tree structure\nand associated data.\n\nStandard [events.Event]s are sent to any listeners, including\n[events.Select], [events.Change], and [events.DoubleClick].\nThe selected nodes are in the root [Tree.SelectedNodes] list.", Methods: []types.Method{{Name: "OpenAll", Doc: "OpenAll opens the node and all of its sub-nodes.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "CloseAll", Doc: "CloseAll closes the node and all of its sub-nodes.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Copy", Doc: "Copy copies the tree to the clipboard.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Cut", Doc: "Cut copies to [system.Clipboard] and deletes selected items.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Paste", Doc: "Paste pastes clipboard at given node.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "InsertAfter", Doc: "InsertAfter inserts a new node in the tree\nafter this node, at the same (sibling) level,\nprompting for the type of node to insert.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "InsertBefore", Doc: "InsertBefore inserts a new node in the tree\nbefore this node, at the same (sibling) level,\nprompting for the type of node to insert\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "AddChildNode", Doc: "AddChildNode adds a new child node to this one in the tree,\nprompting the user for the type of node to add\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "DeleteNode", Doc: "DeleteNode deletes the tree node or sync node corresponding\nto this view node in the sync tree.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Duplicate", Doc: "Duplicate duplicates the sync node corresponding to this view node in\nthe tree, and inserts the duplicate after this node (as a new sibling).\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "editNode", Doc: "editNode pulls up a [Form] dialog for the node.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "inspectNode", Doc: "inspectNode pulls up a new Inspector window on the node.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "SyncNode", Doc: "SyncNode, if non-nil, is the [tree.Node] that this widget is\nviewing in the tree (the source). It should be set using\n[Tree.SyncTree]."}, {Name: "Text", Doc: "Text is the text to display for the tree item label, which automatically\ndefaults to the [tree.Node.Name] of the tree node. It has no effect\nif [Tree.SyncNode] is non-nil."}, {Name: "Icon", Doc: "Icon is an optional icon displayed to the the left of the text label."}, {Name: "IconOpen", Doc: "IconOpen is the icon to use for an open (expanded) branch;\nit defaults to [icons.KeyboardArrowDown]."}, {Name: "IconClosed", Doc: "IconClosed is the icon to use for a closed (collapsed) branch;\nit defaults to [icons.KeyboardArrowRight]."}, {Name: "IconLeaf", Doc: "IconLeaf is the icon to use for a terminal node branch that has no children;\nit defaults to [icons.Blank]."}, {Name: "TreeInit", Doc: "TreeInit is a function that can be set on the root node that is called\nwith each child tree node when it is initialized. It is only\ncalled with the root node itself in [Tree.SetTreeInit], so you\nshould typically call that instead of setting this directly."}, {Name: "Indent", Doc: "Indent is the amount to indent children relative to this node.\nIt should be set in a Styler like all other style properties."}, {Name: "OpenDepth", Doc: "OpenDepth is the depth for nodes be initialized as open (default 4).\nNodes beyond this depth will be initialized as closed."}, {Name: "Closed", Doc: "Closed is whether this tree node is currently toggled closed\n(children not visible)."}, {Name: "SelectMode", Doc: "SelectMode, when set on the root node, determines whether keyboard movements should update selection."}, {Name: "viewIndex", Doc: "linear index of this node within the entire tree.\nupdated on full rebuilds and may sometimes be off,\nbut close enough for expected uses"}, {Name: "widgetSize", Doc: "size of just this node widget.\nour alloc includes all of our children, but we only draw us."}, {Name: "root", Doc: "root is the cached root of the tree. It is automatically set."}, {Name: "SelectedNodes", Doc: "SelectedNodes holds the currently selected nodes.\nIt is only set on the root node. See [Tree.GetSelectedNodes]\nfor a version that also works on non-root nodes."}, {Name: "actStateLayer", Doc: "actStateLayer is the actual state layer of the tree, which\nshould be used when rendering it and its parts (but not its children).\nthe reason that it exists is so that the children of the tree\n(other trees) do not inherit its stateful background color, as\nthat does not look good."}, {Name: "inOpen", Doc: "inOpen is set in the Open method to prevent recursive opening for lazy-open nodes."}, {Name: "Branch", Doc: "Branch is the branch widget that is used to open and close the tree node."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/core.Tree", IDName: "tree", Doc: "Tree provides a graphical representation of a tree structure,\nproviding full navigation and manipulation abilities.\n\nIt does not handle layout by itself, so if you want it to scroll\nseparately from the rest of the surrounding context, you must\nplace it in a [Frame].\n\nIf the [Tree.SyncNode] field is non-nil, typically via the\n[Tree.SyncTree] method, then the Tree mirrors another\ntree structure, and tree editing functions apply to\nthe source tree first, and then to the Tree by sync.\n\nOtherwise, data can be directly encoded in a Tree\nderived type, to represent any kind of tree structure\nand associated data.\n\nStandard [events.Event]s are sent to any listeners, including\n[events.Select], [events.Change], and [events.DoubleClick].\nThe selected nodes are in the root [Tree.SelectedNodes] list.", Methods: []types.Method{{Name: "OpenAll", Doc: "OpenAll opens the node and all of its sub-nodes.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "CloseAll", Doc: "CloseAll closes the node and all of its sub-nodes.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Copy", Doc: "Copy copies the tree to the clipboard.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Cut", Doc: "Cut copies to [system.Clipboard] and deletes selected items.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Paste", Doc: "Paste pastes clipboard at given node.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "InsertAfter", Doc: "InsertAfter inserts a new node in the tree\nafter this node, at the same (sibling) level,\nprompting for the type of node to insert.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "InsertBefore", Doc: "InsertBefore inserts a new node in the tree\nbefore this node, at the same (sibling) level,\nprompting for the type of node to insert\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "AddChildNode", Doc: "AddChildNode adds a new child node to this one in the tree,\nprompting the user for the type of node to add\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "DeleteNode", Doc: "DeleteNode deletes the tree node or sync node corresponding\nto this view node in the sync tree.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Duplicate", Doc: "Duplicate duplicates the sync node corresponding to this view node in\nthe tree, and inserts the duplicate after this node (as a new sibling).\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "editNode", Doc: "editNode pulls up a [Form] dialog for the node.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "inspectNode", Doc: "inspectNode pulls up a new Inspector window on the node.\nIf SyncNode is set, operates on Sync Tree.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "SyncNode", Doc: "SyncNode, if non-nil, is the [tree.Node] that this widget is\nviewing in the tree (the source). It should be set using\n[Tree.SyncTree]."}, {Name: "Text", Doc: "Text is the text to display for the tree item label, which automatically\ndefaults to the [tree.Node.Name] of the tree node. It has no effect\nif [Tree.SyncNode] is non-nil."}, {Name: "Icon", Doc: "Icon is an optional icon displayed to the the left of the text label."}, {Name: "IconOpen", Doc: "IconOpen is the icon to use for an open (expanded) branch;\nit defaults to [icons.KeyboardArrowDown]."}, {Name: "IconClosed", Doc: "IconClosed is the icon to use for a closed (collapsed) branch;\nit defaults to [icons.KeyboardArrowRight]."}, {Name: "IconLeaf", Doc: "IconLeaf is the icon to use for a terminal node branch that has no children;\nit defaults to [icons.Blank]."}, {Name: "TreeInit", Doc: "TreeInit is a function that can be set on the root node that is called\nwith each child tree node when it is initialized. It is only\ncalled with the root node itself in [Tree.SetTreeInit], so you\nshould typically call that instead of setting this directly."}, {Name: "Indent", Doc: "Indent is the amount to indent children relative to this node.\nIt should be set in a Styler like all other style properties."}, {Name: "OpenDepth", Doc: "OpenDepth is the depth for nodes be initialized as open (default 4).\nNodes beyond this depth will be initialized as closed."}, {Name: "Closed", Doc: "Closed is whether this tree node is currently toggled closed\n(children not visible)."}, {Name: "SelectMode", Doc: "SelectMode, when set on the root node, determines whether keyboard movements should update selection."}, {Name: "viewIndex", Doc: "linear index of this node within the entire tree.\nupdated on full rebuilds and may sometimes be off,\nbut close enough for expected uses"}, {Name: "widgetSize", Doc: "size of just this node widget.\nour alloc includes all of our children, but we only draw us."}, {Name: "Root", Doc: "Root is the cached root of the tree. It is automatically set."}, {Name: "SelectedNodes", Doc: "SelectedNodes holds the currently selected nodes.\nIt is only set on the root node. See [Tree.GetSelectedNodes]\nfor a version that also works on non-root nodes."}, {Name: "actStateLayer", Doc: "actStateLayer is the actual state layer of the tree, which\nshould be used when rendering it and its parts (but not its children).\nthe reason that it exists is so that the children of the tree\n(other trees) do not inherit its stateful background color, as\nthat does not look good."}, {Name: "inOpen", Doc: "inOpen is set in the Open method to prevent recursive opening for lazy-open nodes."}, {Name: "Branch", Doc: "Branch is the branch widget that is used to open and close the tree node."}}}) // NewTree returns a new [Tree] with the given optional parent: // Tree provides a graphical representation of a tree structure, diff --git a/enums/enumgen/config.go b/enums/enumgen/config.go index 142dcc5c39..2afee384f2 100644 --- a/enums/enumgen/config.go +++ b/enums/enumgen/config.go @@ -54,4 +54,7 @@ type Config struct { //types:add // whether to allow enums to extend other enums; this should be on in almost all circumstances, // but can be turned off for specific enum types that extend non-enum types Extend bool `default:"true"` + + // generate gosl:start and gosl:end tags around generated N values. + Gosl bool } diff --git a/enums/enumgen/enumgen.go b/enums/enumgen/enumgen.go index 69d71fcc0e..5b394afc15 100644 --- a/enums/enumgen/enumgen.go +++ b/enums/enumgen/enumgen.go @@ -10,8 +10,11 @@ package enumgen import ( "fmt" + "log/slog" + "cogentcore.org/core/base/errors" "cogentcore.org/core/base/generate" + "cogentcore.org/core/base/logx" "golang.org/x/tools/go/packages" ) @@ -43,9 +46,16 @@ func ParsePackages(cfg *Config) ([]*packages.Package, error) { func Generate(cfg *Config) error { //types:add pkgs, err := ParsePackages(cfg) if err != nil { + if logx.UserLevel <= slog.LevelInfo { + errors.Log(err) + } return err } - return GeneratePkgs(cfg, pkgs) + err = GeneratePkgs(cfg, pkgs) + if logx.UserLevel <= slog.LevelInfo { + errors.Log(err) + } + return err } // GeneratePkgs generates enum methods using diff --git a/enums/enumgen/generator.go b/enums/enumgen/generator.go index 08babc821c..26405ba79a 100644 --- a/enums/enumgen/generator.go +++ b/enums/enumgen/generator.go @@ -70,7 +70,7 @@ func (g *Generator) PrintHeader() { // or enums:bitflag. It stores the resulting types in [Generator.Types]. func (g *Generator) FindEnumTypes() error { g.Types = []*Type{} - return generate.Inspect(g.Pkg, g.InspectForType) + return generate.Inspect(g.Pkg, g.InspectForType, "enumgen.go", "typegen.go") } // AllowedEnumTypes are the types that can be used for enums @@ -159,7 +159,7 @@ func (g *Generator) Generate() (bool, error) { for _, typ := range g.Types { values := make([]Value, 0, 100) for _, file := range g.Pkg.Syntax { - if ast.IsGenerated(file) { + if generate.ExcludeFile(g.Pkg, file, "enumgen.go", "typegen.go") { continue } var terr error diff --git a/enums/enumgen/methods.go b/enums/enumgen/methods.go index b7cc783130..9c572b0737 100644 --- a/enums/enumgen/methods.go +++ b/enums/enumgen/methods.go @@ -49,6 +49,13 @@ var NConstantTmpl = template.Must(template.New("StringNConstant").Parse( const {{.Name}}N {{.Name}} = {{.MaxValueP1}} `)) +var NConstantTmplGosl = template.Must(template.New("StringNConstant").Parse( + `//gosl:start +//{{.Name}}N is the highest valid value for type {{.Name}}, plus one. +const {{.Name}}N {{.Name}} = {{.MaxValueP1}} +//gosl:end +`)) + var SetStringMethodTmpl = template.Must(template.New("SetStringMethod").Parse( `// SetString sets the {{.Name}} value from its string representation, // and returns an error if the string is invalid. @@ -111,7 +118,11 @@ func (g *Generator) BuildBasicMethods(values []Value, typ *Type) { typ.MaxValueP1 = max + 1 - g.ExecTmpl(NConstantTmpl, typ) + if g.Config.Gosl { + g.ExecTmpl(NConstantTmplGosl, typ) + } else { + g.ExecTmpl(NConstantTmpl, typ) + } // Print the map between name and value g.PrintValueMap(values, typ) diff --git a/enums/enumgen/testdata/enumgen.go b/enums/enumgen/testdata/enumgen.go index e02efad438..4826cd43fd 100644 --- a/enums/enumgen/testdata/enumgen.go +++ b/enums/enumgen/testdata/enumgen.go @@ -1,4 +1,4 @@ -// Code generated by "enumgen.test -test.testlogfile=/var/folders/x1/r8shprmj7j71zbw3qvgl9dqc0000gq/T/go-build1829688390/b649/testlog.txt -test.paniconexit0 -test.timeout=20s"; DO NOT EDIT. +// Code generated by "enumgen.test -test.paniconexit0 -test.timeout=10m0s -test.v=true"; DO NOT EDIT. package testdata diff --git a/enums/enumgen/testdata/enumgen.golden b/enums/enumgen/testdata/enumgen.golden index df0be10917..4826cd43fd 100644 --- a/enums/enumgen/testdata/enumgen.golden +++ b/enums/enumgen/testdata/enumgen.golden @@ -1,4 +1,4 @@ -// Code generated by "enumgen.test -test.paniconexit0 -test.timeout=10m0s"; DO NOT EDIT. +// Code generated by "enumgen.test -test.paniconexit0 -test.timeout=10m0s -test.v=true"; DO NOT EDIT. package testdata diff --git a/enums/enumgen/typegen.go b/enums/enumgen/typegen.go index eae66358d7..cc2bd1862d 100644 --- a/enums/enumgen/typegen.go +++ b/enums/enumgen/typegen.go @@ -6,6 +6,4 @@ import ( "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/enums/enumgen.Config", IDName: "config", Doc: "Config contains the configuration information\nused by enumgen", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Dir", Doc: "the source directory to run enumgen on (can be set to multiple through paths like ./...)"}, {Name: "Output", Doc: "the output file location relative to the package on which enumgen is being called"}, {Name: "Transform", Doc: "if specified, the enum item transformation method (upper, lower, snake, SNAKE, kebab, KEBAB,\ncamel, lower-camel, title, sentence, first, first-upper, or first-lower)"}, {Name: "TrimPrefix", Doc: "if specified, a comma-separated list of prefixes to trim from each item"}, {Name: "AddPrefix", Doc: "if specified, the prefix to add to each item"}, {Name: "LineComment", Doc: "whether to use line comment text as printed text when present"}, {Name: "AcceptLower", Doc: "whether to accept lowercase versions of enum names in SetString"}, {Name: "IsValid", Doc: "whether to generate a method returning whether a value is\na valid option for its enum type; this must also be set for\nany base enum type being extended"}, {Name: "Text", Doc: "whether to generate text marshaling methods"}, {Name: "SQL", Doc: "whether to generate methods that implement the SQL Scanner and Valuer interfaces"}, {Name: "GQL", Doc: "whether to generate GraphQL marshaling methods for gqlgen"}, {Name: "Extend", Doc: "whether to allow enums to extend other enums; this should be on in almost all circumstances,\nbut can be turned off for specific enum types that extend non-enum types"}}}) - -var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/enums/enumgen.Generate", Doc: "Generate generates enum methods, using the\nconfiguration information, loading the packages from the\nconfiguration source directory, and writing the result\nto the configuration output file.\n\nIt is a simple entry point to enumgen that does all\nof the steps; for more specific functionality, create\na new [Generator] with [NewGenerator] and call methods on it.", Directives: []types.Directive{{Tool: "cli", Directive: "cmd", Args: []string{"-root"}}, {Tool: "types", Directive: "add"}}, Args: []string{"cfg"}, Returns: []string{"error"}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/enums/enumgen.Config", IDName: "config", Doc: "Config contains the configuration information\nused by enumgen", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Dir", Doc: "the source directory to run enumgen on (can be set to multiple through paths like ./...)"}, {Name: "Output", Doc: "the output file location relative to the package on which enumgen is being called"}, {Name: "Transform", Doc: "if specified, the enum item transformation method (upper, lower, snake, SNAKE, kebab, KEBAB,\ncamel, lower-camel, title, sentence, first, first-upper, or first-lower)"}, {Name: "TrimPrefix", Doc: "if specified, a comma-separated list of prefixes to trim from each item"}, {Name: "AddPrefix", Doc: "if specified, the prefix to add to each item"}, {Name: "LineComment", Doc: "whether to use line comment text as printed text when present"}, {Name: "AcceptLower", Doc: "whether to accept lowercase versions of enum names in SetString"}, {Name: "IsValid", Doc: "whether to generate a method returning whether a value is\na valid option for its enum type; this must also be set for\nany base enum type being extended"}, {Name: "Text", Doc: "whether to generate text marshaling methods"}, {Name: "SQL", Doc: "whether to generate methods that implement the SQL Scanner and Valuer interfaces"}, {Name: "GQL", Doc: "whether to generate GraphQL marshaling methods for gqlgen"}, {Name: "Extend", Doc: "whether to allow enums to extend other enums; this should be on in almost all circumstances,\nbut can be turned off for specific enum types that extend non-enum types"}, {Name: "Gosl", Doc: "generate gosl:start and gosl:end tags around generated N values."}}}) diff --git a/examples/demo/demo.go b/examples/demo/demo.go index dbab36d905..fd4f211ff5 100644 --- a/examples/demo/demo.go +++ b/examples/demo/demo.go @@ -43,6 +43,7 @@ func main() { makeStyles(ts) b.RunMainWindow() + // b.NewWindow().SetFullscreen(true).RunMain() } func home(ts *core.Tabs) { diff --git a/examples/plot/plot.go b/examples/plot/plot.go index 226f565e72..64ff53e62c 100644 --- a/examples/plot/plot.go +++ b/examples/plot/plot.go @@ -8,7 +8,9 @@ import ( "embed" "cogentcore.org/core/core" + "cogentcore.org/core/plot" "cogentcore.org/core/plot/plotcore" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/table" ) @@ -18,15 +20,19 @@ var tsv embed.FS func main() { b := core.NewBody("Plot Example") - epc := table.NewTable("epc") - epc.OpenFS(tsv, "ra25epoch.tsv", table.Tab) + epc := table.New("epc") + epc.OpenFS(tsv, "ra25epoch.tsv", tensor.Tab) + pst := func(s *plot.Style) { + s.Plot.Title = "RA25 Epoch Train" + } + perr := epc.Column("PctErr") + plot.SetStylersTo(perr, plot.Stylers{pst, func(s *plot.Style) { + s.On = true + s.Role = plot.Y + }}) pl := plotcore.NewPlotEditor(b) - pl.Options.Title = "RA25 Epoch Train" - pl.Options.XAxis = "Epoch" - pl.Options.Points = true pl.SetTable(epc) - pl.ColumnOptions("UnitErr").On = true b.AddTopBar(func(bar *core.Frame) { core.NewToolbar(bar).Maker(pl.MakeToolbar) }) diff --git a/filetree/copypaste.go b/filetree/copypaste.go index f75dd7ae6d..9c7244a74c 100644 --- a/filetree/copypaste.go +++ b/filetree/copypaste.go @@ -24,7 +24,7 @@ import ( // MimeData adds mimedata for this node: a text/plain of the Path, // text/plain of filename, and text/ func (fn *Node) MimeData(md *mimedata.Mimes) { - froot := fn.FileRoot + froot := fn.FileRoot() path := string(fn.Filepath) punq := fn.PathFrom(froot) // note: tree paths have . escaped -> \, *md = append(*md, mimedata.NewTextData(punq)) @@ -77,7 +77,7 @@ func (fn *Node) DragDrop(e events.Event) { // that is non-nil (otherwise just uses absolute path), and returns list of existing // and node for last one if exists. func (fn *Node) pasteCheckExisting(tfn *Node, md mimedata.Mimes, externalDrop bool) ([]string, *Node) { - froot := fn.FileRoot + froot := fn.FileRoot() tpath := "" if tfn != nil { tpath = string(tfn.Filepath) @@ -118,7 +118,7 @@ func (fn *Node) pasteCheckExisting(tfn *Node, md mimedata.Mimes, externalDrop bo // pasteCopyFiles copies files in given data into given target directory func (fn *Node) pasteCopyFiles(tdir *Node, md mimedata.Mimes, externalDrop bool) { - froot := fn.FileRoot + froot := fn.FileRoot() nf := len(md) if !externalDrop { nf /= 3 @@ -283,7 +283,7 @@ func (fn *Node) pasteFiles(md mimedata.Mimes, externalDrop bool, dropFinal func( // satisfies core.DragNDropper interface and can be overridden by subtypes func (fn *Node) DropDeleteSource(e events.Event) { de := e.(*events.DragDrop) - froot := fn.FileRoot + froot := fn.FileRoot() if froot == nil || fn.isExternal() { return } diff --git a/filetree/enumgen.go b/filetree/enumgen.go index 9d2068dcb7..f01aa7ef8b 100644 --- a/filetree/enumgen.go +++ b/filetree/enumgen.go @@ -13,7 +13,7 @@ const dirFlagsN dirFlags = 3 var _dirFlagsValueMap = map[string]dirFlags{`IsOpen`: 0, `SortByName`: 1, `SortByModTime`: 2} -var _dirFlagsDescMap = map[dirFlags]string{0: `dirIsOpen means directory is open -- else closed`, 1: `dirSortByName means sort the directory entries by name. this is mutex with other sorts -- keeping option open for non-binary sort choices.`, 2: `dirSortByModTime means sort the directory entries by modification time`} +var _dirFlagsDescMap = map[dirFlags]string{0: `dirIsOpen means directory is open -- else closed`, 1: `dirSortByName means sort the directory entries by name. this overrides SortByModTime default on Tree if set.`, 2: `dirSortByModTime means sort the directory entries by modification time.`} var _dirFlagsMap = map[dirFlags]string{0: `IsOpen`, 1: `SortByName`, 2: `SortByModTime`} diff --git a/filetree/file.go b/filetree/file.go index 813c731b36..dc6082888a 100644 --- a/filetree/file.go +++ b/filetree/file.go @@ -59,7 +59,7 @@ func (fn *Node) OpenFileDefault() error { // duplicateFiles makes a copy of selected files func (fn *Node) duplicateFiles() { //types:add - fn.FileRoot.NeedsLayout() + fn.FileRoot().NeedsLayout() fn.SelectedFunc(func(sn *Node) { sn.duplicateFile() }) @@ -92,7 +92,7 @@ func (fn *Node) deleteFiles() { //types:add // deleteFilesImpl does the actual deletion, no prompts func (fn *Node) deleteFilesImpl() { - fn.FileRoot.NeedsLayout() + fn.FileRoot().NeedsLayout() fn.SelectedFunc(func(sn *Node) { if !sn.Info.IsDir() { sn.deleteFile() @@ -100,7 +100,7 @@ func (fn *Node) deleteFilesImpl() { } var fns []string sn.Info.Filenames(&fns) - ft := sn.FileRoot + ft := sn.FileRoot() for _, filename := range fns { sn, ok := ft.FindFile(filename) if !ok { @@ -145,7 +145,7 @@ func (fn *Node) deleteFile() error { // renames any selected files func (fn *Node) RenameFiles() { //types:add - fn.FileRoot.NeedsLayout() + fn.FileRoot().NeedsLayout() fn.SelectedFunc(func(sn *Node) { fb := core.NewSoloFuncButton(sn).SetFunc(sn.RenameFile) fb.Args[0].SetValue(sn.Name) @@ -158,7 +158,7 @@ func (fn *Node) RenameFile(newpath string) error { //types:add if fn.isExternal() { return nil } - root := fn.FileRoot + root := fn.FileRoot() var err error fn.closeBuf() // invalid after this point orgpath := fn.Filepath @@ -167,8 +167,8 @@ func (fn *Node) RenameFile(newpath string) error { //types:add return err } if fn.IsDir() { - if fn.FileRoot.isDirOpen(orgpath) { - fn.FileRoot.setDirOpen(core.Filename(newpath)) + if fn.FileRoot().isDirOpen(orgpath) { + fn.FileRoot().setDirOpen(core.Filename(newpath)) } } repo, _ := fn.Repo() @@ -230,15 +230,15 @@ func (fn *Node) newFile(filename string, addToVCS bool) { //types:add return } if addToVCS { - nfn, ok := fn.FileRoot.FindFile(np) - if ok && nfn.This != fn.FileRoot.This && string(nfn.Filepath) == np { + nfn, ok := fn.FileRoot().FindFile(np) + if ok && !nfn.IsRoot() && string(nfn.Filepath) == np { // todo: this is where it is erroneously adding too many files to vcs! fmt.Println("Adding new file to VCS:", nfn.Filepath) core.MessageSnackbar(fn, "Adding new file to VCS: "+fsx.DirAndFile(string(nfn.Filepath))) nfn.AddToVCS() } } - fn.FileRoot.UpdatePath(np) + fn.FileRoot().UpdatePath(np) } // makes a new folder in the given selected directory @@ -267,7 +267,7 @@ func (fn *Node) newFolder(foldername string) { //types:add core.ErrorSnackbar(fn, err) return } - fn.FileRoot.UpdatePath(ppath) + fn.FileRoot().UpdatePath(ppath) } // copyFileToDir copies given file path into node that is a directory. @@ -280,11 +280,11 @@ func (fn *Node) copyFileToDir(filename string, perm os.FileMode) { sfn := filepath.Base(filename) tpath := filepath.Join(ppath, sfn) fileinfo.CopyFile(tpath, filename, perm) - fn.FileRoot.UpdatePath(ppath) - ofn, ok := fn.FileRoot.FindFile(filename) + fn.FileRoot().UpdatePath(ppath) + ofn, ok := fn.FileRoot().FindFile(filename) if ok && ofn.Info.VCS >= vcs.Stored { - nfn, ok := fn.FileRoot.FindFile(tpath) - if ok && nfn.This != fn.FileRoot.This { + nfn, ok := fn.FileRoot().FindFile(tpath) + if ok && !nfn.IsRoot() { if string(nfn.Filepath) != tpath { fmt.Printf("error: nfn.FPath != tpath; %q != %q, see bug #453\n", nfn.Filepath, tpath) } else { diff --git a/filetree/find.go b/filetree/find.go index b48b8a912f..7ecac4c304 100644 --- a/filetree/find.go +++ b/filetree/find.go @@ -62,7 +62,7 @@ func (fn *Node) FindFile(fnm string) (*Node, bool) { } } - if efn, err := fn.FileRoot.externalNodeByPath(fnm); err == nil { + if efn, err := fn.FileRoot().externalNodeByPath(fnm); err == nil { return efn, true } diff --git a/filetree/menu.go b/filetree/menu.go index 0f30d373b1..f3d5dd3a4d 100644 --- a/filetree/menu.go +++ b/filetree/menu.go @@ -26,7 +26,7 @@ func vcsLabelFunc(fn *Node, label string) string { } func (fn *Node) VCSContextMenu(m *core.Scene) { - if fn.FileRoot.FS != nil { + if fn.FileRoot().FS != nil { return } core.NewFuncButton(m).SetFunc(fn.addToVCSSelected).SetText(vcsLabelFunc(fn, "Add to VCS")).SetIcon(icons.Add). diff --git a/filetree/node.go b/filetree/node.go index 433ae767bf..5807d95999 100644 --- a/filetree/node.go +++ b/filetree/node.go @@ -53,9 +53,6 @@ type Node struct { //core:embedder // Buffer is the file buffer for editing this file. Buffer *texteditor.Buffer `edit:"-" set:"-" json:"-" xml:"-" copier:"-"` - // FileRoot is the root [Tree] of the tree, which has global state. - FileRoot *Tree `edit:"-" set:"-" json:"-" xml:"-" copier:"-"` - // DirRepo is the version control system repository for this directory, // only non-nil if this is the highest-level directory in the tree under vcs control. DirRepo vcs.Repo `edit:"-" set:"-" json:"-" xml:"-" copier:"-"` @@ -70,6 +67,11 @@ func (fn *Node) AsFileNode() *Node { return fn } +// FileRoot returns the Root node as a [Tree]. +func (fn *Node) FileRoot() *Tree { + return AsTree(fn.Root) +} + func (fn *Node) Init() { fn.Tree.Init() fn.IconOpen = icons.FolderOpen @@ -94,6 +96,9 @@ func (fn *Node) Init() { case status == vcs.Stored: s.Color = colors.Scheme.OnSurface } + if fn.Info.Generated { + s.Color = errors.Must1(gradient.FromString("#8080C0")) + } }) fn.On(events.KeyChord, func(e events.Event) { if core.DebugSettings.KeyEventTrace { @@ -189,14 +194,14 @@ func (fn *Node) Init() { return } if fn.Name == externalFilesName { - files := fn.FileRoot.externalFiles + files := fn.FileRoot().externalFiles for _, fi := range files { tree.AddNew(p, fi, func() Filer { - return tree.NewOfType(fn.FileRoot.FileNodeType).(Filer) + return tree.NewOfType(fn.FileRoot().FileNodeType).(Filer) }, func(wf Filer) { w := wf.AsFileNode() + w.Root = fn.Root w.NeedsLayout() - w.FileRoot = fn.FileRoot w.Filepath = core.Filename(fi) w.Info.Mode = os.ModeIrregular w.Info.VCS = vcs.Stored @@ -207,25 +212,25 @@ func (fn *Node) Init() { if !fn.IsDir() || fn.IsIrregular() { return } - if !((fn.FileRoot.inOpenAll && !fn.Info.IsHidden()) || fn.FileRoot.isDirOpen(fn.Filepath)) { + if !((fn.FileRoot().inOpenAll && !fn.Info.IsHidden()) || fn.FileRoot().isDirOpen(fn.Filepath)) { return } repo, _ := fn.Repo() files := fn.dirFileList() for _, fi := range files { fpath := filepath.Join(string(fn.Filepath), fi.Name()) - if fn.FileRoot.FilterFunc != nil && !fn.FileRoot.FilterFunc(fpath, fi) { + if fn.FileRoot().FilterFunc != nil && !fn.FileRoot().FilterFunc(fpath, fi) { continue } tree.AddNew(p, fi.Name(), func() Filer { - return tree.NewOfType(fn.FileRoot.FileNodeType).(Filer) + return tree.NewOfType(fn.FileRoot().FileNodeType).(Filer) }, func(wf Filer) { w := wf.AsFileNode() + w.Root = fn.Root w.NeedsLayout() - w.FileRoot = fn.FileRoot w.Filepath = core.Filename(fpath) w.This.(Filer).GetFileInfo() - if w.FileRoot.FS == nil { + if w.FileRoot().FS == nil { if w.IsDir() && repo == nil { w.detectVCSRepo(true) // update files } @@ -284,10 +289,10 @@ func (fn *Node) isAutoSave() bool { // RelativePath returns the relative path from root for this node func (fn *Node) RelativePath() string { - if fn.IsIrregular() || fn.FileRoot == nil { + if fn.IsIrregular() || fn.FileRoot() == nil { return fn.Name } - return fsx.RelativeFilePath(string(fn.Filepath), string(fn.FileRoot.Filepath)) + return fsx.RelativeFilePath(string(fn.Filepath), string(fn.FileRoot().Filepath)) } // dirFileList returns the list of files in this directory, @@ -297,14 +302,16 @@ func (fn *Node) dirFileList() []fs.FileInfo { var files []fs.FileInfo var dirs []fs.FileInfo // for DirsOnTop mode var di []fs.DirEntry - if fn.FileRoot.FS == nil { + isFS := false + if fn.FileRoot().FS == nil { di = errors.Log1(os.ReadDir(path)) } else { - di = errors.Log1(fs.ReadDir(fn.FileRoot.FS, path)) + isFS = true + di = errors.Log1(fs.ReadDir(fn.FileRoot().FS, path)) } for _, d := range di { info := errors.Log1(d.Info()) - if fn.FileRoot.DirsOnTop { + if fn.FileRoot().DirsOnTop { if d.IsDir() { dirs = append(dirs, info) } else { @@ -314,30 +321,35 @@ func (fn *Node) dirFileList() []fs.FileInfo { files = append(files, info) } } - doModSort := fn.FileRoot.SortByModTime + doModSort := fn.FileRoot().SortByModTime if doModSort { - doModSort = !fn.FileRoot.dirSortByName(core.Filename(path)) + doModSort = !fn.FileRoot().dirSortByName(core.Filename(path)) } else { - doModSort = fn.FileRoot.dirSortByModTime(core.Filename(path)) + doModSort = fn.FileRoot().dirSortByModTime(core.Filename(path)) } - if fn.FileRoot.DirsOnTop { + if fn.FileRoot().DirsOnTop { if doModSort { - sortByModTime(dirs) - sortByModTime(files) + sortByModTime(dirs, isFS) // note: FS = ascending, otherwise descending + sortByModTime(files, isFS) } files = append(dirs, files...) } else { if doModSort { - sortByModTime(files) + sortByModTime(files, isFS) } } return files } -func sortByModTime(files []fs.FileInfo) { +// sortByModTime sorts by _reverse_ mod time (newest first) +func sortByModTime(files []fs.FileInfo, ascending bool) { slices.SortFunc(files, func(a, b fs.FileInfo) int { - return a.ModTime().Compare(b.ModTime()) + if ascending { + return a.ModTime().Compare(b.ModTime()) + } else { + return b.ModTime().Compare(a.ModTime()) + } }) } @@ -371,7 +383,7 @@ func (fn *Node) InitFileInfo() error { return nil } var err error - if fn.FileRoot.FS == nil { // deal with symlinks + if fn.FileRoot().FS == nil { // deal with symlinks ls, err := os.Lstat(string(fn.Filepath)) if errors.Log(err) != nil { return err @@ -387,7 +399,7 @@ func (fn *Node) InitFileInfo() error { } err = fn.Info.InitFile(string(fn.Filepath)) } else { - err = fn.Info.InitFileFS(fn.FileRoot.FS, string(fn.Filepath)) + err = fn.Info.InitFileFS(fn.FileRoot().FS, string(fn.Filepath)) } if err != nil { emsg := fmt.Errorf("filetree.Node InitFileInfo Path %q: Error: %v", fn.Filepath, err) @@ -431,7 +443,7 @@ func (fn *Node) OnClose() { if !fn.IsDir() { return } - fn.FileRoot.setDirClosed(fn.Filepath) + fn.FileRoot().setDirClosed(fn.Filepath) } func (fn *Node) CanOpen() bool { @@ -443,7 +455,7 @@ func (fn *Node) openDir() { if !fn.IsDir() { return } - fn.FileRoot.setDirOpen(fn.Filepath) + fn.FileRoot().setDirOpen(fn.Filepath) fn.Update() } @@ -458,15 +470,15 @@ func (fn *Node) sortBys(modTime bool) { //types:add // sortBy determines how to sort the files in the directory -- default is alpha by name, // optionally can be sorted by modification time. func (fn *Node) sortBy(modTime bool) { - fn.FileRoot.setDirSortBy(fn.Filepath, modTime) + fn.FileRoot().setDirSortBy(fn.Filepath, modTime) fn.Update() } // openAll opens all directories under this one func (fn *Node) openAll() { //types:add - fn.FileRoot.inOpenAll = true // causes chaining of opening + fn.FileRoot().inOpenAll = true // causes chaining of opening fn.Tree.OpenAll() - fn.FileRoot.inOpenAll = false + fn.FileRoot().inOpenAll = false } // OpenBuf opens the file in its buffer if it is not already open. @@ -499,7 +511,7 @@ func (fn *Node) removeFromExterns() { //types:add if !sn.isExternal() { return } - sn.FileRoot.removeExternalFile(string(sn.Filepath)) + sn.FileRoot().removeExternalFile(string(sn.Filepath)) sn.closeBuf() sn.Delete() }) diff --git a/filetree/search.go b/filetree/search.go index 911b646139..7c3c5ae23e 100644 --- a/filetree/search.go +++ b/filetree/search.go @@ -77,7 +77,7 @@ func Search(start *Node, find string, ignoreCase, regExp bool, loc FindLocation, // fmt.Printf("dir: %v closed\n", sfn.FPath) return tree.Break // don't go down into closed directories! } - if sfn.IsDir() || sfn.IsExec() || sfn.Info.Kind == "octet-stream" || sfn.isAutoSave() { + if sfn.IsDir() || sfn.IsExec() || sfn.Info.Kind == "octet-stream" || sfn.isAutoSave() || sfn.Info.Generated { // fmt.Printf("dir: %v opened\n", sfn.Nm) return tree.Continue } @@ -163,6 +163,9 @@ func findAll(start *Node, find string, ignoreCase, regExp bool, langs []fileinfo if strings.HasSuffix(info.Name(), ".code") { // exclude self return nil } + if fileinfo.IsGeneratedFile(path) { + return nil + } if len(langs) > 0 { mtyp, _, err := fileinfo.MimeFromFile(path) if err != nil { diff --git a/filetree/tree.go b/filetree/tree.go index f0d561782c..d0bb62db56 100644 --- a/filetree/tree.go +++ b/filetree/tree.go @@ -26,6 +26,20 @@ const ( externalFilesName = "[external files]" ) +// Treer is an interface for getting the Root node as a Tree struct. +type Treer interface { + AsFileTree() *Tree +} + +// AsTree returns the given value as a [Tree] if it has +// an AsFileTree() method, or nil otherwise. +func AsTree(n tree.Node) *Tree { + if t, ok := n.(Treer); ok { + return t.AsFileTree() + } + return nil +} + // Tree is the root widget of a file tree representing files in a given directory // (and subdirectories thereof), and has some overall management state for how to // view things. @@ -80,16 +94,19 @@ type Tree struct { func (ft *Tree) Init() { ft.Node.Init() - ft.FileRoot = ft + ft.Root = ft ft.FileNodeType = types.For[Node]() ft.OpenDepth = 4 ft.DirsOnTop = true ft.FirstMaker(func(p *tree.Plan) { + if len(ft.externalFiles) == 0 { + return + } tree.AddNew(p, externalFilesName, func() Filer { return tree.NewOfType(ft.FileNodeType).(Filer) }, func(wf Filer) { w := wf.AsFileNode() - w.FileRoot = ft + w.Root = ft.Root w.Filepath = externalFilesName w.Info.Mode = os.ModeDir w.Info.VCS = vcs.Stored @@ -110,6 +127,10 @@ func (fv *Tree) Destroy() { fv.Tree.Destroy() } +func (ft *Tree) AsFileTree() *Tree { + return ft +} + // OpenPath opens the filetree at the given os file system directory path. // It reads all the files at the given path into this tree. // Only paths listed in [Tree.Dirs] will be opened. @@ -152,7 +173,7 @@ func (ft *Tree) OpenPathFS(fsys fs.FS, path string) *Tree { } // UpdatePath updates the tree at the directory level for given path -// and everything below it. It flags that it needs render update, +// and everything below it. It flags that it needs render update, // but if a deletion or insertion happened, then NeedsLayout should also // be called. func (ft *Tree) UpdatePath(path string) { @@ -334,8 +355,13 @@ func (ft *Tree) AddExternalFile(fpath string) (*Node, error) { if has, _ := ft.hasExternalFile(pth); has { return ft.externalNodeByPath(pth) } + newExt := len(ft.externalFiles) == 0 ft.externalFiles = append(ft.externalFiles, pth) - ft.Child(0).(Filer).AsFileNode().Update() + if newExt { + ft.Update() + } else { + ft.Child(0).(Filer).AsFileNode().Update() + } return ft.externalNodeByPath(pth) } diff --git a/filetree/typegen.go b/filetree/typegen.go index 376b56a57c..547cb5d1fb 100644 --- a/filetree/typegen.go +++ b/filetree/typegen.go @@ -3,14 +3,16 @@ package filetree import ( + "io/fs" + "cogentcore.org/core/base/vcs" "cogentcore.org/core/tree" "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Filer", IDName: "filer", Doc: "Filer is an interface for file tree file actions that all [Node]s satisfy.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "AsFileNode", Doc: "AsFileNode returns the [Node]", Returns: []string{"Node"}}, {Name: "RenameFiles", Doc: "RenameFiles renames any selected files."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Filer", IDName: "filer", Doc: "Filer is an interface for file tree file actions that all [Node]s satisfy.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "AsFileNode", Doc: "AsFileNode returns the [Node]", Returns: []string{"Node"}}, {Name: "RenameFiles", Doc: "RenameFiles renames any selected files."}, {Name: "GetFileInfo", Doc: "GetFileInfo updates the .Info for this file", Returns: []string{"error"}}, {Name: "OpenFile", Doc: "OpenFile opens the file for node. This is called by OpenFilesDefault", Returns: []string{"error"}}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Node", IDName: "node", Doc: "Node represents a file in the file system, as a [core.Tree] node.\nThe name of the node is the name of the file.\nFolders have children containing further nodes.", Directives: []types.Directive{{Tool: "core", Directive: "embedder"}}, Methods: []types.Method{{Name: "Cut", Doc: "Cut copies the selected files to the clipboard and then deletes them.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Paste", Doc: "Paste inserts files from the clipboard.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "OpenFilesDefault", Doc: "OpenFilesDefault opens selected files with default app for that file type (os defined).\nruns open on Mac, xdg-open on Linux, and start on Windows", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "duplicateFiles", Doc: "duplicateFiles makes a copy of selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "deleteFiles", Doc: "deletes any selected files or directories. If any directory is selected,\nall files and subdirectories in that directory are also deleted.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "RenameFiles", Doc: "renames any selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "RenameFile", Doc: "RenameFile renames file to new name", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"newpath"}, Returns: []string{"error"}}, {Name: "newFiles", Doc: "newFiles makes a new file in selected directory", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "addToVCS"}}, {Name: "newFile", Doc: "newFile makes a new file in this directory node", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "addToVCS"}}, {Name: "newFolders", Doc: "makes a new folder in the given selected directory", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"foldername"}}, {Name: "newFolder", Doc: "newFolder makes a new folder (directory) in this directory node", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"foldername"}}, {Name: "showFileInfo", Doc: "Shows file information about selected file(s)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "sortBys", Doc: "sortBys determines how to sort the selected files in the directory.\nDefault is alpha by name, optionally can be sorted by modification time.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"modTime"}}, {Name: "openAll", Doc: "openAll opens all directories under this one", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "removeFromExterns", Doc: "removeFromExterns removes file from list of external files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "addToVCSSelected", Doc: "addToVCSSelected adds selected files to version control system", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "deleteFromVCSSelected", Doc: "deleteFromVCSSelected removes selected files from version control system", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "commitToVCSSelected", Doc: "commitToVCSSelected commits to version control system based on last selected file", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "revertVCSSelected", Doc: "revertVCSSelected removes selected files from version control system", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "diffVCSSelected", Doc: "diffVCSSelected shows the diffs between two versions of selected files, given by the\nrevision specifiers -- if empty, defaults to A = current HEAD, B = current WC file.\n-1, -2 etc also work as universal ways of specifying prior revisions.\nDiffs are shown in a DiffEditorDialog.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"rev_a", "rev_b"}}, {Name: "logVCSSelected", Doc: "logVCSSelected shows the VCS log of commits for selected files.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "blameVCSSelected", Doc: "blameVCSSelected shows the VCS blame report for this file, reporting for each line\nthe revision and author of the last change.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Tree"}}, Fields: []types.Field{{Name: "Filepath", Doc: "Filepath is the full path to this file."}, {Name: "Info", Doc: "Info is the full standard file info about this file."}, {Name: "Buffer", Doc: "Buffer is the file buffer for editing this file."}, {Name: "FileRoot", Doc: "FileRoot is the root [Tree] of the tree, which has global state."}, {Name: "DirRepo", Doc: "DirRepo is the version control system repository for this directory,\nonly non-nil if this is the highest-level directory in the tree under vcs control."}, {Name: "repoFiles", Doc: "repoFiles has the version control system repository file status,\nproviding a much faster way to get file status, vs. the repo.Status\ncall which is exceptionally slow."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Node", IDName: "node", Doc: "Node represents a file in the file system, as a [core.Tree] node.\nThe name of the node is the name of the file.\nFolders have children containing further nodes.", Directives: []types.Directive{{Tool: "core", Directive: "embedder"}}, Methods: []types.Method{{Name: "Cut", Doc: "Cut copies the selected files to the clipboard and then deletes them.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Paste", Doc: "Paste inserts files from the clipboard.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "OpenFilesDefault", Doc: "OpenFilesDefault opens selected files with default app for that file type (os defined).\nruns open on Mac, xdg-open on Linux, and start on Windows", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "duplicateFiles", Doc: "duplicateFiles makes a copy of selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "deleteFiles", Doc: "deletes any selected files or directories. If any directory is selected,\nall files and subdirectories in that directory are also deleted.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "RenameFiles", Doc: "renames any selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "RenameFile", Doc: "RenameFile renames file to new name", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"newpath"}, Returns: []string{"error"}}, {Name: "newFiles", Doc: "newFiles makes a new file in selected directory", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "addToVCS"}}, {Name: "newFile", Doc: "newFile makes a new file in this directory node", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "addToVCS"}}, {Name: "newFolders", Doc: "makes a new folder in the given selected directory", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"foldername"}}, {Name: "newFolder", Doc: "newFolder makes a new folder (directory) in this directory node", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"foldername"}}, {Name: "showFileInfo", Doc: "Shows file information about selected file(s)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "sortBys", Doc: "sortBys determines how to sort the selected files in the directory.\nDefault is alpha by name, optionally can be sorted by modification time.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"modTime"}}, {Name: "openAll", Doc: "openAll opens all directories under this one", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "removeFromExterns", Doc: "removeFromExterns removes file from list of external files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "addToVCSSelected", Doc: "addToVCSSelected adds selected files to version control system", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "deleteFromVCSSelected", Doc: "deleteFromVCSSelected removes selected files from version control system", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "commitToVCSSelected", Doc: "commitToVCSSelected commits to version control system based on last selected file", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "revertVCSSelected", Doc: "revertVCSSelected removes selected files from version control system", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "diffVCSSelected", Doc: "diffVCSSelected shows the diffs between two versions of selected files, given by the\nrevision specifiers -- if empty, defaults to A = current HEAD, B = current WC file.\n-1, -2 etc also work as universal ways of specifying prior revisions.\nDiffs are shown in a DiffEditorDialog.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"rev_a", "rev_b"}}, {Name: "logVCSSelected", Doc: "logVCSSelected shows the VCS log of commits for selected files.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "blameVCSSelected", Doc: "blameVCSSelected shows the VCS blame report for this file, reporting for each line\nthe revision and author of the last change.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Tree"}}, Fields: []types.Field{{Name: "Filepath", Doc: "Filepath is the full path to this file."}, {Name: "Info", Doc: "Info is the full standard file info about this file."}, {Name: "Buffer", Doc: "Buffer is the file buffer for editing this file."}, {Name: "DirRepo", Doc: "DirRepo is the version control system repository for this directory,\nonly non-nil if this is the highest-level directory in the tree under vcs control."}, {Name: "repoFiles", Doc: "repoFiles has the version control system repository file status,\nproviding a much faster way to get file status, vs. the repo.Status\ncall which is exceptionally slow."}}}) // NewNode returns a new [Node] with the given optional parent: // Node represents a file in the file system, as a [core.Tree] node. @@ -35,7 +37,7 @@ func AsNode(n tree.Node) *Node { // AsNode satisfies the [NodeEmbedder] interface func (t *Node) AsNode() *Node { return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Tree", IDName: "tree", Doc: "Tree is the root widget of a file tree representing files in a given directory\n(and subdirectories thereof), and has some overall management state for how to\nview things.", Embeds: []types.Field{{Name: "Node"}}, Fields: []types.Field{{Name: "externalFiles", Doc: "externalFiles are external files outside the root path of the tree.\nThey are stored in terms of their absolute paths. These are shown\nin the first sub-node if present; use [Tree.AddExternalFile] to add one."}, {Name: "Dirs", Doc: "records state of directories within the tree (encoded using paths relative to root),\ne.g., open (have been opened by the user) -- can persist this to restore prior view of a tree"}, {Name: "DirsOnTop", Doc: "if true, then all directories are placed at the top of the tree.\nOtherwise everything is mixed."}, {Name: "FileNodeType", Doc: "type of node to create; defaults to [Node] but can use custom node types"}, {Name: "inOpenAll", Doc: "if true, we are in midst of an OpenAll call; nodes should open all dirs"}, {Name: "watcher", Doc: "change notify for all dirs"}, {Name: "doneWatcher", Doc: "channel to close watcher watcher"}, {Name: "watchedPaths", Doc: "map of paths that have been added to watcher; only active if bool = true"}, {Name: "lastWatchUpdate", Doc: "last path updated by watcher"}, {Name: "lastWatchTime", Doc: "timestamp of last update"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Tree", IDName: "tree", Doc: "Tree is the root widget of a file tree representing files in a given directory\n(and subdirectories thereof), and has some overall management state for how to\nview things.", Embeds: []types.Field{{Name: "Node"}}, Fields: []types.Field{{Name: "externalFiles", Doc: "externalFiles are external files outside the root path of the tree.\nThey are stored in terms of their absolute paths. These are shown\nin the first sub-node if present; use [Tree.AddExternalFile] to add one."}, {Name: "Dirs", Doc: "Dirs records state of directories within the tree (encoded using paths relative to root),\ne.g., open (have been opened by the user) -- can persist this to restore prior view of a tree"}, {Name: "DirsOnTop", Doc: "DirsOnTop indicates whether all directories are placed at the top of the tree.\nOtherwise everything is mixed. This is the default."}, {Name: "SortByModTime", Doc: "SortByModTime causes files to be sorted by modification time by default.\nOtherwise it is a per-directory option."}, {Name: "FileNodeType", Doc: "FileNodeType is the type of node to create; defaults to [Node] but can use custom node types"}, {Name: "FilterFunc", Doc: "FilterFunc, if set, determines whether to include the given node in the tree.\nreturn true to include, false to not. This applies to files and directories alike."}, {Name: "FS", Doc: "FS is the file system we are browsing, if it is an FS (nil = os filesystem)"}, {Name: "inOpenAll", Doc: "inOpenAll indicates whether we are in midst of an OpenAll call; nodes should open all dirs."}, {Name: "watcher", Doc: "watcher does change notify for all dirs"}, {Name: "doneWatcher", Doc: "doneWatcher is channel to close watcher watcher"}, {Name: "watchedPaths", Doc: "watchedPaths is map of paths that have been added to watcher; only active if bool = true"}, {Name: "lastWatchUpdate", Doc: "lastWatchUpdate is last path updated by watcher"}, {Name: "lastWatchTime", Doc: "lastWatchTime is timestamp of last update"}}}) // NewTree returns a new [Tree] with the given optional parent: // Tree is the root widget of a file tree representing files in a given directory @@ -44,14 +46,31 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.Tree", IDN func NewTree(parent ...tree.Node) *Tree { return tree.New[Tree](parent...) } // SetDirsOnTop sets the [Tree.DirsOnTop]: -// if true, then all directories are placed at the top of the tree. -// Otherwise everything is mixed. +// DirsOnTop indicates whether all directories are placed at the top of the tree. +// Otherwise everything is mixed. This is the default. func (t *Tree) SetDirsOnTop(v bool) *Tree { t.DirsOnTop = v; return t } +// SetSortByModTime sets the [Tree.SortByModTime]: +// SortByModTime causes files to be sorted by modification time by default. +// Otherwise it is a per-directory option. +func (t *Tree) SetSortByModTime(v bool) *Tree { t.SortByModTime = v; return t } + // SetFileNodeType sets the [Tree.FileNodeType]: -// type of node to create; defaults to [Node] but can use custom node types +// FileNodeType is the type of node to create; defaults to [Node] but can use custom node types func (t *Tree) SetFileNodeType(v *types.Type) *Tree { t.FileNodeType = v; return t } +// SetFilterFunc sets the [Tree.FilterFunc]: +// FilterFunc, if set, determines whether to include the given node in the tree. +// return true to include, false to not. This applies to files and directories alike. +func (t *Tree) SetFilterFunc(v func(path string, info fs.FileInfo) bool) *Tree { + t.FilterFunc = v + return t +} + +// SetFS sets the [Tree.FS]: +// FS is the file system we are browsing, if it is an FS (nil = os filesystem) +func (t *Tree) SetFS(v fs.FS) *Tree { t.FS = v; return t } + var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/filetree.VCSLog", IDName: "vcs-log", Doc: "VCSLog is a widget that represents VCS log data.", Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "Log", Doc: "current log"}, {Name: "File", Doc: "file that this is a log of -- if blank then it is entire repository"}, {Name: "Since", Doc: "date expression for how long ago to include log entries from"}, {Name: "Repo", Doc: "version control system repository"}, {Name: "revisionA", Doc: "revision A -- defaults to HEAD"}, {Name: "revisionB", Doc: "revision B -- blank means current working copy"}, {Name: "setA", Doc: "double-click will set the A revision -- else B"}, {Name: "arev"}, {Name: "brev"}, {Name: "atf"}, {Name: "btf"}}}) // NewVCSLog returns a new [VCSLog] with the given optional parent: diff --git a/filetree/vcs.go b/filetree/vcs.go index 6e9b1fc5cb..db9a72e938 100644 --- a/filetree/vcs.go +++ b/filetree/vcs.go @@ -22,7 +22,7 @@ import ( // FirstVCS returns the first VCS repository starting from this node and going down. // also returns the node having that repository func (fn *Node) FirstVCS() (vcs.Repo, *Node) { - if fn.FileRoot.FS != nil { + if fn.FileRoot().FS != nil { return nil, nil } var repo vcs.Repo @@ -73,7 +73,11 @@ func (fn *Node) detectVCSRepo(updateFiles bool) bool { // and the node for the directory where the repo is based. // Goes up the tree until a repository is found. func (fn *Node) Repo() (vcs.Repo, *Node) { - if fn.isExternal() || fn.FileRoot.FS != nil { + fr := fn.FileRoot() + if fr == nil { + return nil, nil + } + if fn.isExternal() || fr == nil || fr.FS != nil { return nil, nil } if fn.DirRepo != nil { diff --git a/goal/GPU.md b/goal/GPU.md new file mode 100644 index 0000000000..165ffa9a29 --- /dev/null +++ b/goal/GPU.md @@ -0,0 +1,124 @@ +# Goal GPU support + +The use of massively parallel _Graphical Processsing Unit_ (_GPU_) hardware has revolutionized machine learning and other fields, producing many factors of speedup relative to traditional _CPU_ (_Central Processing Unit_) computation. However, there are numerous challenges for supporting GPU-based computation, relative to the more flexible CPU coding. + +The Goal framework provides a solution to these challenges that enables the same Go-based code to work efficiently and reasonably naturally on both the GPU and CPU (i.e., standard Go execution), via the [gosl](gosl) package. Debugging code on the GPU is notoriously difficult because the usual tools are not directly available (not even print statements), so the ability to run exactly the same code on the CPU is invaluable, in addition to the benefits in portability across platforms without GPU hardware. + +See the [gosl](gosl) documentation for the details on how to write code that works on the GPU. The remainder of this document provides an overview of the overall approach in relation to other related tools. + +The two most important challenges are: + +* The GPU _has its own separate memory space_ that needs to be synchronized explicitly and bidirectionally with the standard CPU memory (this is true programmatically even if at a hardware level there is shared memory). + +* Computation must be organized into discrete chunks that can be computed efficiently in parallel, and each such chunk of computation lives in its own separate _kernel_ (_compute shader_) in the GPU, as an entirely separate, self-contained program, operating on _global variables_ that define the entire memory space of the computation. + +To be maximally efficient, both of these factors must be optimized, such that: + +* The bidirectional syncing of memory between CPU and GPU should be minimized, because such transfers incur a significant overhead. + +* The overall computation should be broken down into the _largest possible chunks_ to minimize the number of discrete kernel runs, each of which incurs significant overhead. + +Thus, it is unfortunately _highly inefficient_ to implement GPU-based computation by running each elemental vectorizable tensor operation (add, multiply, etc) as a separate GPU kernel, with its own separate bidirectional memory sync, even though that is a conceptually attractive and simple way to organize GPU computation, with minimal disruption relative to the CPU model. + +The [JAX](https://github.com/jax-ml/jax) framework in Python provides one solution to this situation, optimized for neural network machine learning uses, by imposing strict _functional programming_ constraints on the code you write (i.e., all functions must be _read-only_), and leveraging those to automatically combine elemental computations into larger parallelizable chunks, using a "just in time" (_jit_) compiler. + +We take a different approach, which is much simpler implementationally but requires a bit more work from the developer, which is to provide tools that allow _you_ to organize your computation into kernel-sized chunks according to your knowledge of the problem, and transparently turn that code into the final CPU and GPU programs. + +In many cases, a human programmer can most likely out-perform the automatic compilation process, by knowing the full scope of what needs to be computed, and figuring out how to package it most efficiently per the above constraints. In the end, you get maximum efficiency and complete transparency about exactly what is being computed, perhaps with fewer "gotcha" bugs arising from all the magic happening under the hood, but again it may take a bit more work to get there. + +The role of Goal is to allow you to express the full computation in the clear, simple, Go language, using intuitive data structures that minimize the need for additional boilerplate to run efficiently on CPU and GPU. This ability to write a single codebase that runs efficiently on CPU and GPU is similar to the [SYCL](https://en.wikipedia.org/wiki/SYCL) framework (and several others discussed on that wikipedia page), which builds on [OpenCL](https://en.wikipedia.org/wiki/OpenCL), both of which are based on the C / C++ programming language. + +In addition to the critical differences between Go and C++ as languages, Goal targets only one hardware platform: WebGPU (via our [gpu](../gpu) package), so it is more specifically optimized for this use-case. Furthermore, SYCL and other approaches require you to write GPU-like code that can also run on the CPU (with lots of explicit fine-grained memory and compute management), whereas Goal provides a more natural CPU-like programming model, while imposing some stronger constraints that encourage more efficient implementations. + +The bottom line is that the fantasy of being able to write CPU-native code and have it magically "just work" on the GPU with high levels of efficiency is just that: a fantasy. The reality is that code must be specifically structured and organized to work efficiently on the GPU. Goal just makes this process relatively clean and efficient and easy to read, with a minimum of extra boilerplate. The resulting code should be easily understood by anyone familiar with the Go language, even if that isn't the way you would have written it in the first place. The reward is that you can get highly efficient results with significant GPU-accelerated speedups that works on _any platform_, including the web and mobile phones, all with a single easy-to-read codebase. + +# Kernel functions + +First, we assume the scope is a single Go package that implements a set of computations on some number of associated data representations. The package will likely contain a lot of CPU-only Go code that manages all the surrounding infrastructure for the computations, in terms of creating and configuring the data in memory, visualization, i/o, etc. + +The GPU-specific computation is organized into some (hopefully small) number of **kernel** functions, that are conceptually called using a **parallel for loop**, e.g., something like this: +```Go +for i := range parallel(data) { + Compute(i) +} +``` + +The `i` index effectively iterates over the range of the values of the `data` variable, with the GPU version launching kernels on the GPU for each different index value. The CPU version actually runs in parallel as well, using goroutines. + +We assume that multiple kernels will in general be required, and that there is likely to be a significant amount of shared code infrastructure across these kernels. Thus, the kernel functions are typically relatively short, and call into a large body of code that is likely shared among the different kernel functions. + +Even though the GPU kernels must each be compiled separately into a single distinct WGSL _shader_ file that is run under WebGPU, they can `import` a shared codebase of files, and thus replicate the same overall shared code structure as the CPU versions. + +The GPU code can only handle a highly restricted _subset_ of Go code, with data structures having strict alignment requirements, and no `string` or other composite variable-length data structures (maps, slices etc). Thus, the [gosl](gosl) package recognizes `//gosl:start` and `//gosl:end` comment directives surrounding the GPU-safe (and relevant) portions of the overall code. Any `.go` or `.goal` file can contribute GPU relevant code, including in other packages, and the gosl system automatically builds a shadow package-based set of `.wgsl` files accordingly. + +> Each kernel function is marked with a `//gosl:kernel` directive, and the name of the function is used to create the name of the GPU shader file. + +```Go +// Compute does the main computation. +func Compute(i uint32) { //gosl:kernel + Params[0].IntegFromRaw(&Data[i]) +} +``` + +## Memory Organization + +Perhaps the strongest constraints for GPU programming stem from the need to organize and synchronize all the memory buffers holding the data that the GPU kernel operates on. Furthermore, within a GPU kernel, the variables representing this data are _global variables_, which is sensible given the standalone nature of each kernel. + +> To provide a common programming environment, all GPU-relevant variables must be Go global variables. + +Thus, names must be chosen appropriately for these variables, given their global scope within the Go package. The specific _values_ for these variables can be dynamically set in an easy way, but the variables themselves are global. + +Within the [gpu](../gpu) framework, each `ComputeSystem` defines a specific organization of such GPU buffer variables, and maximum efficiency is achieved by minimizing the number of such compute systems, and associated memory buffers. Each system also encapsulates the associated kernel shaders that operate on the associated memory data, so + +> Kernels and variables both must be defined within a specific system context. + +### tensorfs mapping + +TODO: + +The grouped global variables can be mapped directly to a corresponding [tensorfs](../tensor/tensorfs) directory, which provides direct accessibility to this data within interactive Goal usage. Further, different sets of variable values can be easily managed by saving and loading different such directories. + +```Go + gosl.ToDataFS("path/to/dir" [, system]) // set tensorfs items in given path to current global vars + + gosl.FromDataFS("path/to/dir" [,system]) // set global vars from given tensorfs path +``` + +These and all such `gosl` functions use the current system if none is explicitly specified, which is settable using the `gosl.SetSystem` call. Any given variable can use the `get` or `set` Goal math mode functions directly. + +## Memory access + +In general, all global GPU variables will be arrays (slices) or tensors, which are exposed to the GPU as an array of floats. + +The tensor-based indexing syntax in Goal math mode transparently works across CPU and GPU modes, and is thus the preferred way to access tensor data. + +It is critical to appreciate that none of the other convenient math-mode operations will work as you expect on the GPU, because: + +> There is only one outer-loop, kernel-level parallel looping operation allowed at a time. + +You cannot nest multiple such loops within each other. A kernel cannot launch another kernel. Therefore, as noted above, you must directly organize your computation to maximize the amount of parallel computation happening wthin each such kernel call. + +> Therefore, tensor indexing on the GPU only supports direct index values, not ranges. + +Furthermore: + +> Pointer-based access of global variables is not supported in GPU mode. + +You have to use _indexes_ into arrays exclusively. Thus, some of the data structures you may need to copy up to the GPU include index variables that determine how to access other variables. TODO: do we need helpers for any of this? + +# Examples + +A large and complex biologically-based neural network simulation framework called [axon](https://github.com/emer/axon) has been implemented using `gosl`, allowing 1000's of lines of equations and data structures to run through standard Go on the CPU, and accelerated significantly on the GPU. This allows efficient debugging and unit testing of the code in Go, whereas debugging on the GPU is notoriously difficult. + +# TODO + +## Optimization + +can run naga on wgsl code to get wgsl code out, but it doesn't seem to do much dead code elimination: https://github.com/gfx-rs/wgpu/tree/trunk/naga + +``` +naga --compact gpu_applyext.wgsl tmp.wgsl +``` + +https://github.com/LucentFlux/wgsl-minifier does radical minification but the result is unreadable so we don't know if it is doing dead code elimination. in theory it is just calling naga --compact for that. + diff --git a/goal/README.md b/goal/README.md new file mode 100644 index 0000000000..5136f6129b --- /dev/null +++ b/goal/README.md @@ -0,0 +1,437 @@ +# Goal: Go augmented language + +Goal is an augmented version of the Go language, which combines the best parts of Go, `bash`, and Python, to provide and integrated shell and numerical expression processing experience, which can be combined with the [yaegi](https://github.com/traefik/yaegi) interpreter to provide an interactive "REPL" (read, evaluate, print loop). + +Goal transpiles directly into Go, so it automatically leverages all the great features of Go, and remains fully compatible with it. The augmentation is designed to overcome some of the limitations of Go in specific domains: + +* Shell scripting, where you want to be able to directly call other executable programs with arguments, without having to navigate all the complexity of the standard [os.exec](https://pkg.go.dev/os/exec) package. + +* Numerical / math / data processing, where you want to be able to write simple mathematical expressions operating on vectors, matricies and other more powerful data types, without having to constantly worry about numerical type conversions, and advanced n-dimensional indexing and slicing expressions are critical. Python is the dominant language here precisely because it lets you ignore type information and write such expressions, using operator overloading. + +* GPU-based parallel computation, which can greatly speed up some types of parallelizable computations by effectively running many instances of the same code in parallel across a large array of data. The [gosl](gosl) package (automatically run in `goal build` mode) allows you to run the same Go-based code on a GPU or CPU (using parallel goroutines). See the [GPU](GPU.md) docs for an overview and comparison to other approaches to GPU computation. + +The main goal of Goal is to achieve a "best of both worlds" solution that retains all the type safety and explicitness of Go for all the surrounding control flow and large-scale application logic, while also allowing for a more relaxed syntax in specific, well-defined domains where the Go language has been a barrier. Thus, unlike Python where there are various weak attempts to try to encourage better coding habits, Goal retains in its Go foundation a fundamentally scalable, "industrial strength" language that has already proven its worth in countless real-world applications. + +For the shell scripting aspect of Goal, the simple idea is that each line of code is either Go or shell commands, determined in a fairly intuitive way mostly by the content at the start of the line (formal rules below). If a line starts off with something like `ls -la...` then it is clear that it is not valid Go code, and it is therefore processed as a shell command. + +You can intermix Go within a shell line by wrapping an expression with `{ }` braces, and a Go expression can contain shell code by using `$`. Here's an example: +```go +for i, f := range goalib.SplitLines($ls -la$) { // ls executes, returns string + echo {i} {strings.ToLower(f)} // {} surrounds Go within shell +} +``` +where `goalib.SplitLines` is a function that runs `strings.Split(arg, "\n")`, defined in the `goalib` standard library of such frequently-used helper functions. + +For cases where most of the code is standard Go with relatively infrequent use of shell expressions, or in the rare cases where the default interpretation doesn't work, you can explicitly tag a line as shell code using `$`: + +```go +$ chmod +x *.goal +``` + +For mathematical expressions, we use `#` symbols (`#` = number) to demarcate such expressions. Often you will write entire lines of such expressions: +```go +# x := 1. / (1. + exp(-wts[:, :, :n] * acts[:])) +``` +You can also intermix within Go code: +```go +for _, x := range #[1,2,3]# { + fmt.Println(#x^2#) +} +``` +Note that you cannot enter math mode directly from shell mode, which is unlikely to be useful anyway (you can wrap in go mode `{ }` if really needed). + +In general, the math mode syntax in Goal is designed to be as compatible with Python NumPy / SciPy syntax as possible, while also adding a few Go-specific additions as well: see the [Math mode](#math-mode) section for details. All elements of a Goal math expression are [tensors](../tensor), which can represent everything from a scalar to an n-dimenstional tensor. These are called "ndarray" in NumPy terms. + +The one special form of tensor processing that is available in regular Go code is _n dimensional indexing_, e.g., `tsr[1,2]`. This kind of expression with square brackets `[ ]` and a comma is illegal according to standard Go syntax, so when we detect it, we know that it is being used on a tensor object, and can transpile it into the corresponding `tensor.Value` or `tensor.Set*` expression. This is particularly convenient for [gosl](gosl) GPU code that has special support for tensor data. Note that for this GPU use-case, you actually do _not_ want to use math mode, because that engages a different, more complex form of indexing that does _not_ work on the GPU. + +The rationale and mnemonics for using `$` and `#` are as follows: + +* These are two of the three common ASCII keyboard symbols that are not part of standard Go syntax (`@` being the other). + +* `$` can be thought of as "S" in _S_hell, and is often used for a `bash` prompt, and many bash examples use it as a prefix. Furthermore, in bash, `$( )` is used to wrap shell expressions. + +* `#` is commonly used to refer to numbers. It is also often used as a comment syntax, but on balance the number semantics and uniqueness relative to Go syntax outweigh that issue. + +# Examples + +Here are a few useful examples of Goal code: + +You can easily perform handy duration and data size formatting: + +```go +22010706 * time.Nanosecond // 22.010706ms +datasize.Size(44610930) // 42.5 MB +``` + +# Shell mode + +## Environment variables + +* `set ` (space delimited as in all shell mode, no equals) + +## Output redirction + +* Standard output redirect: `>` and `>&` (and `|`, `|&` if needed) + +## Control flow + +* Any error stops the script execution, except for statements wrapped in `[ ]`, indicating an "optional" statement, e.g.: + +```sh +cd some; [mkdir sub]; cd sub +``` + +* `&` at the end of a statement runs in the background (as in bash) -- otherwise it waits until it completes before it continues. + +* `jobs`, `fg`, `bg`, and `kill` builtin commands function as in usual bash. + +## Shell functions (aliases) + +Use the `command` keyword to define new functions for Shell mode execution, which can then be used like any other command, for example: + +```sh +command list { + ls -la args... +} +``` + +```sh +cd data +list *.tsv +``` + +The `command` is transpiled into a Go function that takes `args ...string`. In the command function body, you can use the `args...` expression to pass all of the args, or `args[1]` etc to refer to specific positional indexes, as usual. + +The command function name is registered so that the standard shell execution code can run the function, passing the args. You can also call it directly from Go code using the standard parentheses expression. + +## Script Files and Makefile-like functionality + +As with most scripting languages, a file of goal code can be made directly executable by appending a "shebang" expression at the start of the file: + +```sh +#!/usr/bin/env goal +``` + +When executed this way, any additional args are available via an `args []any` variable, which can be passed to a command as follows: +```go +install {args...} +``` +or by referring to specific arg indexes etc. + +To make a script behave like a standard Makefile, you can define different `command`s for each of the make commands, and then add the following at the end of the file to use the args to run commands: + +```go +goal.RunCommands(args) +``` + +See [make](cmd/goal/testdata/make) for an example, in `cmd/goal/testdata/make`, which can be run for example using: + +```sh +./make build +``` + +Note that there is nothing special about the name `make` here, so this can be done with any file. + +The `make` package defines a number of useful utility functions that accomplish the standard dependency and file timestamp checking functionality from the standard `make` command, as in the [magefile](https://magefile.org/dependencies/) system. Note that the goal direct shell command syntax makes the resulting make files much closer to a standard bash-like Makefile, while still having all the benefits of Go control and expressions, compared to magefile. + +TODO: implement and document above. + +## SSH connections to remote hosts + +Any number of active SSH connections can be maintained and used dynamically within a script, including simple ways of copying data among the different hosts (including the local host). The Go mode execution is always on the local host in one running process, and only the shell commands are executed remotely, enabling a unique ability to easily coordinate and distribute processing and data across various hosts. + +Each host maintains its own working directory and environment variables, which can be configured and re-used by default whenever using a given host. + +* `cossh hostname.org [name]` establishes a connection, using given optional name to refer to this connection. If the name is not provided, a sequential number will be used, starting with 1, with 0 referring always to the local host. + +* `@name` then refers to the given host in all subsequent commands, with `@0` referring to the local host where the goal script is running. + +### Explicit per-command specification of host + +```sh +@name cd subdir; ls +``` + +### Default host + +```sh +@name // or: +cossh @name +``` + +uses the given host for all subsequent commands (unless explicitly specified), until the default is changed. Use `cossh @0` to return to localhost. + +### Redirect input / output among hosts + +The output of a remote host command can be sent to a file on the local host: +```sh +@name cat hostfile.tsv > @0:localfile.tsv +``` +Note the use of the `:` colon delimiter after the host name here. TODO: You cannot send output to a remote host file (e.g., `> @host:remotefile.tsv`) -- maybe with sftp? + +The output of any command can also be piped to a remote host as its standard input: +```sh +ls *.tsv | @host cat > files.txt +``` + +### scp to copy files easily + +The builtin `scp` function allows easy copying of files across hosts, using the persistent connections established with `cossh` instead of creating new connections as in the standard scp command. + +`scp` is _always_ run from the local host, with the remote host filename specified as `@name:remotefile` + +```sh +scp @name:hostfile.tsv localfile.tsv +``` + +Importantly, file wildcard globbing works as expected: +```sh +scp @name:*.tsv @0:data/ +``` + +and entire directories can be copied, as in `cp -a` or `cp -r` (this behavior is automatic and does not require a flag). + +### Close connections + +```sh +cossh close +``` + +Will close all active connections and return the default host to @0. All active connections are also automatically closed when the shell terminates. + +## Other Utilties + +** TODO: need a replacement for findnm -- very powerful but garbage.. + +## Rules for Go vs. Shell determination + +These are the rules used to determine whether a line is Go vs. Shell (word = IDENT token): + +* `$` at the start: Shell. +* Within Shell, `{}`: Go +* Within Go, `$ $`: Shell +* Line starts with `go` keyword: if no `( )` then Shell, else Go +* Line is one word: Shell +* Line starts with `path` expression (e.g., `./myexec`) : Shell +* Line starts with `"string"`: Shell +* Line starts with `word word`: Shell +* Line starts with `word {`: Shell +* Otherwise: Go + +TODO: update above + +## Multiple statements per line + +* Multiple statements can be combined on one line, separated by `;` as in regular Go and shell languages. Critically, the language determination for the first statement determines the language for the remaining statements; you cannot intermix the two on one line, when using `;` + +# Math mode + +The math mode in Goal is designed to be generally compatible with Python NumPy / SciPy syntax, so that the widespread experience with that syntax transfers well to Goal. This syntax is also largely compatible with MATLAB and other languages as well. However, we did not fully replicate the NumPy syntax, instead choosing to clean up a few things and generally increase consistency with Go. + +In general the Goal global functions are named the same as NumPy, without the `np.` prefix, which improves readability. It should be very straightforward to write a conversion utility that converts existing NumPy code into Goal code, and that is a better process than trying to make Goal itself perfectly compatible. + +All elements of a Goal math expression are [tensors](../tensor) (i.e., `tensor.Tensor`), which can represent everything from a scalar to an n-dimenstional tensor, with different _views_ that support the arbitrary slicing and flexible forms of indexing documented in the table below. These are called an `ndarray` in NumPy terms. See [array vs. tensor](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html#array-or-matrix-which-should-i-use) NumPy docs for more information. Note that Goal does not have a distinct `matrix` type; everything is a tensor, and when these are 2D, they function appropriately via the [matrix](../tensor/matrix) package. + +The _view_ versions of `Tensor` include `Sliced`, `Reshaped`, `Masked`, `Indexed`, and `Rows`, each of which wraps around another "source" `Tensor`, and provides its own way of accessing the underlying data: + +* `Sliced` has an arbitrary set of indexes for each dimension, so access to values along that dimension go through the indexes. Thus, you could reverse the order of the columns (dimension 1), or only operate on a subset of them. + +* `Masked` has a `tensor.Bool` tensor that filters access to the underlying source tensor through a mask: anywhere the bool value is `false`, the corresponding source value is not settable, and returns `NaN` (missing value) when accessed. + +* `Indexed` uses a tensor of indexes where the final, innermost dimension is the same size as the number of dimensions in the wrapped source tensor. The overall shape of this view is that of the remaining outer dimensions of the Indexes tensor, and like other views, assignment and return values are taken from the corresponding indexed value in the wrapped source tensor. + + The current NumPy version of indexed is rather complex and difficult for many people to understand, as articulated in this [NEP 21 proposal](https://numpy.org/neps/nep-0021-advanced-indexing.html). The `Indexed` view at least provides a simpler way of representing the indexes into the source tensor, instead of requiring multiple parallel 1D arrays. + +* `Rows` is an optimized version of `Sliced` with indexes only for the first, outermost, _row_ dimension. + +The following sections provide a full list of equivalents between the `tensor` Go code, Goal, NumPy, and MATLAB, based on the table in [numpy-for-matlab-users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html). +* The _same:_ in Goal means that the same NumPy syntax works in Goal, minus the `np.` prefix, and likewise for _or:_ (where Goal also has additional syntax). +* In the `tensor.Go` code, we sometimes just write a scalar number for simplicity, but these are actually `tensor.NewFloat64Scalar` etc. +* Goal also has support for `string` tensors, e.g., for labels, and operators such as addition that make sense for strings are supported. Otherwise, strings are automatically converted to numbers using the `tensor.Float` interface. If you have any doubt about whether you've got a `tensor.Float64` when you expect one, use `tensor.AsFloat64Tensor` which makes sure. + +## Tensor shape + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `a.NumDim()` | `ndim(a)` or `a.ndim` | `np.ndim(a)` or `a.ndim` | `ndims(a)` | number of dimensions of tensor `a` | +| `a.Len()` | `len(a)` or `a.len` or: | `np.size(a)` or `a.size` | `numel(a)` | number of elements of tensor `a` | +| `a.Shape().Sizes` | same: | `np.shape(a)` or `a.shape` | `size(a)` | "size" of each dimension in a; `shape` returns a 1D `int` tensor | +| `a.Shape().Sizes[1]` | same: | `a.shape[1]` | `size(a,2)` | the number of elements of the 2nd dimension of tensor `a` | +| `tensor.Reshape(a, 10, 2)` | same except no `a.shape = (10,2)`: | `a.reshape(10, 2)` or `np.reshape(a, 10, 2)` or `a.shape = (10,2)` | `reshape(a,10,2)` | set the shape of `a` to a new shape that has the same total number of values (len or size); No option to change order in Goal: always row major; Goal does _not_ support direct shape assignment version. | +| `tensor.Reshape(a, tensor.AsIntSlice(sh)...)` | same: | `a.reshape(10, sh)` or `np.reshape(a, sh)` | `reshape(a,sh)` | set shape based on list of dimension sizes in tensor `sh` | +| `tensor.Reshape(a, -1)` or `tensor.As1D(a)` | same: | `a.reshape(-1)` or `np.reshape(a, -1)` | `reshape(a,-1)` | a 1D vector view of `a`; Goal does not support `ravel`, which is nearly identical. | +| `tensor.Flatten(a)` | same: | `b = a.flatten()` | `b=a(:)` | returns a 1D copy of a | +| `b := tensor.Clone(a)` | `b := copy(a)` or: | `b = a.copy()` | `b=a` | direct assignment `b = a` in Goal or NumPy just makes variable b point to tensor a; `copy` is needed to generate new underlying values (MATLAB always makes a copy) | +| `tensor.Squeeze(a)` | same: |`a.squeeze()` | `squeeze(a)` | remove singleton dimensions of tensor `a`. | + + +## Constructing + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `tensor.NewFloat64FromValues(` `[]float64{1, 2, 3})` | `[1., 2., 3.]` | `np.array([1., 2., 3.])` | `[ 1 2 3 ]` | define a 1D tensor | +| | `[[1., 2., 3.], [4., 5., 6.]]` or: | `(np.array([[1., 2., 3.], [4., 5., 6.]])` | `[ 1 2 3; 4 5 6 ]` | define a 2x3 2D tensor | +| | | `[[a, b], [c, d]]` or `block([[a, b], [c, d]])` | `np.block([[a, b], [c, d]])` | `[ a b; c d ]` | construct a matrix from blocks `a`, `b`, `c`, and `d` | +| `tensor.NewFloat64(3,4)` | `zeros(3,4)` | `np.zeros((3, 4))` | `zeros(3,4)` | 3x4 2D tensor of float64 zeros; Goal does not use "tuple" so no double parens | +| `tensor.NewFloat64(3,4,5)` | `zeros(3, 4, 5)` | `np.zeros((3, 4, 5))` | `zeros(3,4,5)` | 3x4x5 three-dimensional tensor of float64 zeros | +| `tensor.NewFloat64Ones(3,4)` | `ones(3, 4)` | `np.ones((3, 4))` | `ones(3,4)` | 3x4 2D tensor of 64-bit floating point ones | +| `tensor.NewFloat64Full(5.5, 3,4)` | `full(5.5, 3, 4)` | `np.full((3, 4), 5.5)` | ? | 3x4 2D tensor of 5.5; Goal variadic arg structure requires value to come first | +| `tensor.NewFloat64Rand(3,4)` | `rand(3, 4)` or `slrand(c, fi, 3, 4)` | `rng.random(3, 4)` | `rand(3,4)` | 3x4 2D float64 tensor with uniform random 0..1 elements; `rand` uses current Go `rand` source, while `slrand` uses [gosl](../gpu/gosl/slrand) GPU-safe call with counter `c` and function index `fi` and key = index of element | +| TODO: | |`np.concatenate((a,b),1)` or `np.hstack((a,b))` or `np.column_stack((a,b))` or `np.c_[a,b]` | `[a b]` | concatenate columns of a and b | +| TODO: | |`np.concatenate((a,b))` or `np.vstack((a,b))` or `np.r_[a,b]` | `[a; b]` | concatenate rows of a and b | +| TODO: | |`np.tile(a, (m, n))` | `repmat(a, m, n)` | create m by n copies of a | +| TODO: | |`a[np.r_[:len(a),0]]` | `a([1:end 1],:)` | `a` with copy of the first row appended to the end | + +## Ranges and grids + +See [NumPy](https://numpy.org/doc/stable/user/how-to-partition.html) docs for details. + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `tensor.NewIntRange(1, 11)` | same: |`np.arange(1., 11.)` or `np.r_[1.:11.]` or `np.r_[1:10:10j]` | `1:10` | create an increasing vector; `arange` in goal is always ints; use `linspace` or `tensor.AsFloat64` for floats | +| | same: |`np.arange(10.)` or `np.r_[:10.]` or `np.r_[:9:10j]` | `0:9` | create an increasing vector; 1 arg is the stop value in a slice | +| | |`np.arange(1.,11.)` `[:, np.newaxis]` | `[1:10]'` | create a column vector | +| `t.NewFloat64` `SpacedLinear(` `1, 3, 4, true)` | `linspace(1,3,4,true)` |`np.linspace(1,3,4)` | `linspace(1,3,4)` | 4 equally spaced samples between 1 and 3, inclusive of end (use `false` at end for exclusive) | +| | |`np.mgrid[0:9.,0:6.]` or `np.meshgrid(r_[0:9.],` `r_[0:6.])` | `[x,y]=meshgrid(0:8,0:5)` | two 2D tensors: one of x values, the other of y values | +| | |`ogrid[0:9.,0:6.]` or `np.ix_(np.r_[0:9.],` `np.r_[0:6.]` | | the best way to eval functions on a grid | +| | |`np.meshgrid([1,2,4],` `[2,4,5])` | `[x,y]=meshgrid([1,2,4],[2,4,5])` | | +| | |`np.ix_([1,2,4],` `[2,4,5])` | | the best way to eval functions on a grid | + +## Basic indexing + +See [NumPy basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing). Tensor Go uses the `Reslice` function for all cases (repeated `tensor.` prefix replaced with `t.` to take less space). Here you can clearly see the advantage of Goal in allowing significantly more succinct expressions to be written for accomplishing critical tensor functionality. + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `t.Reslice(a, 1, 4)` | same: |`a[1, 4]` | `a(2,5)` | access element in second row, fifth column in 2D tensor `a` | +| `t.Reslice(a, -1)` | same: |`a[-1]` | `a(end)` | access last element | +| `t.Reslice(a,` `1, t.FullAxis)` | same: |`a[1]` or `a[1, :]` | `a(2,:)` | entire second row of 2D tensor `a`; unspecified dimensions are equivalent to `:` (could omit second arg in Reslice too) | +| `t.Reslice(a,` `Slice{Stop:5})` | same: |`a[0:5]` or `a[:5]` or `a[0:5, :]` | `a(1:5,:)` | 0..4 rows of `a`; uses same Go slice ranging here: (start:stop) where stop is _exclusive_ | +| `t.Reslice(a,` `Slice{Start:-5})` | same: |`a[-5:]` | `a(end-4:end,:)` | last 5 rows of 2D tensor `a` | +| `t.Reslice(a,` `t.NewAxis,` `Slice{Start:-5})` | same: |`a[newaxis, -5:]` | ? | last 5 rows of 2D tensor `a`, as a column vector | +| `t.Reslice(a,` `Slice{Stop:3},` `Slice{Start:4, Stop:9})` | same: |`a[0:3, 4:9]` | `a(1:3,5:9)` | The first through third rows and fifth through ninth columns of a 2D tensor, `a`. | +| `t.Reslice(a,` `Slice{Start:2,` `Stop:25,` `Step:2}, t.FullAxis)` | same: |`a[2:21:2,:]` | `a(3:2:21,:)` | every other row of `a`, starting with the third and going to the twenty-first | +| `t.Reslice(a,` `Slice{Step:2},` `t.FullAxis)` | same: |`a[::2, :]` | `a(1:2:end,:)` | every other row of `a`, starting with the first | +| `t.Reslice(a,`, `Slice{Step:-1},` `t.FullAxis)` | same: |`a[::-1,:]` | `a(end:-1:1,:) or flipud(a)` | `a` with rows in reverse order | +| `t.Clone(t.Reslice(a,` `1, t.FullAxis))` | `b = copy(a[1, :])` or: | `b = a[1, :].copy()` | `y=x(2,:)` | without the copy, `y` would point to a view of values in `x`; `copy` creates distinct values, in this case of _only_ the 2nd row of `x` -- i.e., it "concretizes" a given view into a literal, memory-continuous set of values for that view. | +| `tmath.Assign(` `t.Reslice(a,` `Slice{Stop:5}),` `t.NewIntScalar(2))` | same: |`a[:5] = 2` | `a(1:5,:) = 2` | assign the value 2 to 0..4 rows of `a` | +| (you get the idea) | same: |`a[:5] = b[:, :5]` | `a(1:5,:) = b(:, 1:5)` | assign the values in the first 5 columns of `b` to the first 5 rows of `a` | + +## Boolean tensors and indexing + +See [NumPy boolean indexing](https://numpy.org/doc/stable/user/basics.indexing.html#boolean-array-indexing). + +Note that Goal only supports boolean logical operators (`&&` and `||`) on boolean tensors, not the single bitwise operators `&` and `|`. + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `tmath.Greater(a, 0.5)` | same: | `(a > 0.5)` | `(a > 0.5)` | `bool` tensor of shape `a` with elements `(v > 0.5)` | +| `tmath.And(a, b)` | `a && b` | `logical_and(a,b)` | `a & b` | element-wise AND operator on `bool` tensors | +| `tmath.Or(a, b)` | `a \|\| b` | `np.logical_or(a,b)` | `a \| b` | element-wise OR operator on `bool` tensors | +| `tmath.Negate(a)` | `!a` | ? | ? | element-wise negation on `bool` tensors | +| `tmath.Assign(` `tensor.Mask(a,` `tmath.Less(a, 0.5),` `0)` | same: |`a[a < 0.5]=0` | `a(a<0.5)=0` | `a` with elements less than 0.5 zeroed out | +| `tensor.Flatten(` `tensor.Mask(a,` `tmath.Less(a, 0.5)))` | same: |`a[a < 0.5].flatten()` | ? | a 1D list of the elements of `a` < 0.5 (as a copy, not a view) | +| `tensor.Mul(a,` `tmath.Greater(a, 0.5))` | same: |`a * (a > 0.5)` | `a .* (a>0.5)` | `a` with elements less than 0.5 zeroed out | + +## Advanced index-based indexing + +See [NumPy integer indexing](https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing). Note that the current NumPy version of indexed is rather complex and difficult for many people to understand, as articulated in this [NEP 21 proposal](https://numpy.org/neps/nep-0021-advanced-indexing.html). + +**TODO:** not yet implemented: + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| | |`a[np.ix_([1, 3, 4], [0, 2])]` | `a([2,4,5],[1,3])` | rows 2,4 and 5 and columns 1 and 3. | +| | |`np.nonzero(a > 0.5)` | `find(a > 0.5)` | find the indices where (a > 0.5) | +| | |`a[:, v.T > 0.5]` | `a(:,find(v>0.5))` | extract the columns of `a` where column vector `v` > 0.5 | +| | |`a[:,np.nonzero(v > 0.5)[0]]` | `a(:,find(v > 0.5))` | extract the columns of `a` where vector `v` > 0.5 | +| | |`a[:] = 3` | `a(:) = 3` | set all values to the same scalar value | +| | |`np.sort(a)` or `a.sort(axis=0)` | `sort(a)` | sort each column of a 2D tensor, `a` | +| | |`np.sort(a, axis=1)` or `a.sort(axis=1)` | `sort(a, 2)` | sort the each row of 2D tensor, `a` | +| | |`I = np.argsort(a[:, 0]); b = a[I,:]` | `[b,I]=sortrows(a,1)` | save the tensor `a` as tensor `b` with rows sorted by the first column | +| | |`np.unique(a)` | `unique(a)` | a vector of unique values in tensor `a` | + +## Basic math operations (add, multiply, etc) + +In Goal and NumPy, the standard `+, -, *, /` operators perform _element-wise_ operations because those are well-defined for all dimensionalities and are consistent across the different operators, whereas matrix multiplication is specifically used in a 2D linear algebra context, and is not well defined for the other operators. + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `tmath.Add(a,b)` | same: |`a + b` | `a .+ b` | element-wise addition; Goal does this string-wise for string tensors | +| `tmath.Mul(a,b)` | same: |`a * b` | `a .* b` | element-wise multiply | +| `tmath.Div(a,b)` | same: |`a/b` | `a./b` | element-wise divide. _important:_ this always produces a floating point result. | +| `tmath.Mod(a,b)` | same: |`a%b` | `a./b` | element-wise modulous (works for float and int) | +| `tmath.Pow(a,3)` | same: | `a**3` | `a.^3` | element-wise exponentiation | +| `tmath.Cos(a)` | same: | `cos(a)` | `cos(a)` | element-wise function application | + +## 2D Matrix Linear Algebra + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `matrix.Mul(a,b)` | same: |`a @ b` | `a * b` | matrix multiply | +| `tensor.Transpose(a)` | <- or `a.T` |`a.transpose()` or `a.T` | `a.'` | transpose of `a` | +| TODO: | |`a.conj().transpose() or a.conj().T` | `a'` | conjugate transpose of `a` | +| `matrix.Det(a)` | `matrix.Det(a)` | `np.linalg.det(a)` | ? | determinant of `a` | +| `matrix.Identity(3)` | <- |`np.eye(3)` | `eye(3)` | 3x3 identity matrix | +| `matrix.Diagonal(a)` | <- |`np.diag(a)` | `diag(a)` | returns a vector of the diagonal elements of 2D tensor, `a`. Goal returns a read / write view. | +| | |`np.diag(v, 0)` | `diag(v,0)` | returns a square diagonal matrix whose nonzero values are the elements of vector, v | +| `matrix.Trace(a)` | <- |`np.trace(a)` | `trace(a)` | returns the sum of the elements along the diagonal of `a`. | +| `matrix.Tri()` | <- |`np.tri()` | `tri()` | returns a new 2D Float64 matrix with 1s in the lower triangular region (including the diagonal) and the remaining upper triangular elements zero | +| `matrix.TriL(a)` | <- |`np.tril(a)` | `tril(a)` | returns a copy of `a` with the lower triangular elements (including the diagonal) from `a` and the remaining upper triangular elements zeroed out | +| `matrix.TriU(a)` | <- |`np.triu(a)` | `triu(a)` | returns a copy of `a` with the upper triangular elements (including the diagonal) from `a` and the remaining lower triangular elements zeroed out | +| | |`linalg.inv(a)` | `inv(a)` | inverse of square 2D tensor a | +| | |`linalg.pinv(a)` | `pinv(a)` | pseudo-inverse of 2D tensor a | +| | |`np.linalg.matrix_rank(a)` | `rank(a)` | matrix rank of a 2D tensor a | +| | |`linalg.solve(a, b)` if `a` is square; `linalg.lstsq(a, b)` otherwise | `a\b` | solution of `a x = b` for x | +| | |Solve `a.T x.T = b.T` instead | `b/a` | solution of x a = b for x | +| | |`U, S, Vh = linalg.svd(a); V = Vh.T` | `[U,S,V]=svd(a)` | singular value decomposition of a | +| | |`linalg.cholesky(a)` | `chol(a)` | Cholesky factorization of a 2D tensor | +| | |`D,V = linalg.eig(a)` | `[V,D]=eig(a)` | eigenvalues and eigenvectors of `a`, where `[V,D]=eig(a,b)` eigenvalues and eigenvectors of `a, b` where | +| | |`D,V = eigs(a, k=3)` | `D,V = linalg.eig(a, b)` | `[V,D]=eigs(a,3)` | find the k=3 largest eigenvalues and eigenvectors of 2D tensor, a | +| | |`Q,R = linalg.qr(a)` | `[Q,R]=qr(a,0)` | QR decomposition +| | |`P,L,U = linalg.lu(a)` where `a == P@L@U` | `[L,U,P]=lu(a)` where `a==P'*L*U` | LU decomposition with partial pivoting (note: P(MATLAB) == transpose(P(NumPy))) | +| | |`x = linalg.lstsq(Z, y)` | `x = Z\y` | perform a linear regression of the form | + +## Statistics + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| `a.max()` or `max(a)` or `stats.Max(a)` | `a.max()` or `np.nanmax(a)` | `max(max(a))` | maximum element of `a`, Goal always ignores `NaN` as missing data | +| | |`a.max(0)` | `max(a)` | maximum element of each column of tensor `a` | +| | |`a.max(1)` | `max(a,[],2)` | maximum element of each row of tensor `a` | +| | |`np.maximum(a, b)` | `max(a,b)` | compares a and b element-wise, and returns the maximum value from each pair | +| `stats.L2Norm(a)` | `np.sqrt(v @ v)` or `np.linalg.norm(v)` | `norm(v)` | L2 norm of vector v | +| | |`cg` | `conjgrad` | conjugate gradients solver | + +## FFT and complex numbers + +todo: huge amount of work needed to support complex numbers throughout! + +| `tensor` Go | Goal | NumPy | MATLAB | Notes | +| ------------ | ----------- | ------ | ------ | ---------------- | +| | |`np.fft.fft(a)` | `fft(a)` | Fourier transform of `a` | +| | |`np.fft.ifft(a)` | `ifft(a)` | inverse Fourier transform of `a` | +| | |`signal.resample(x, np.ceil(len(x)/q))` | `decimate(x, q)` | downsample with low-pass filtering | + +## tensorfs + +The [tensorfs](../tensor/tensorfs) data filesystem provides a global filesystem-like workspace for storing tensor data, and Goal has special commands and functions to facilitate interacting with it. In an interactive `goal` shell, when you do `##` to switch into math mode, the prompt changes to show your current directory in the tensorfs, not the regular OS filesystem, and the final prompt character turns into a `#`. + +Use `get` and `set` (aliases for `tensorfs.Get` and `tensorfs.Set`) to retrieve and store data in the tensorfs: + +* `x := get("path/to/item")` retrieves the tensor data value at given path, which can then be used directly in an expression or saved to a new variable as in this example. + +* `set("path/to/item", x)` saves tensor data to given path, overwriting any existing value for that item if it already exists, and creating a new one if not. `x` can be any data expression. + +You can use the standard shell commands to navigate around the data filesystem: + +* `cd ` to change the current working directory. By default, new variables created in the shell are also recorded into the current working directory for later access. + +* `ls [-l,r] [dir]` list the contents of a directory; without arguments, it shows the current directory. The `-l` option shows each element on a separate line with its shape. `-r` does a recursive list through subdirectories. + +* `mkdir ` makes a new subdirectory. + +TODO: other commands, etc. + + + diff --git a/shell/builtins.go b/goal/builtins.go similarity index 55% rename from shell/builtins.go rename to goal/builtins.go index 941c24cd53..8f3abd023a 100644 --- a/shell/builtins.go +++ b/goal/builtins.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package goal import ( "context" @@ -10,34 +10,38 @@ import ( "log/slog" "os" "path/filepath" + "runtime" "strconv" "strings" + "cogentcore.org/core/base/errors" "cogentcore.org/core/base/exec" "cogentcore.org/core/base/logx" "cogentcore.org/core/base/sshclient" + "cogentcore.org/core/base/stringsx" "github.com/mitchellh/go-homedir" ) -// InstallBuiltins adds the builtin shell commands to [Shell.Builtins]. -func (sh *Shell) InstallBuiltins() { - sh.Builtins = make(map[string]func(cmdIO *exec.CmdIO, args ...string) error) - sh.Builtins["cd"] = sh.Cd - sh.Builtins["exit"] = sh.Exit - sh.Builtins["jobs"] = sh.JobsCmd - sh.Builtins["kill"] = sh.Kill - sh.Builtins["set"] = sh.Set - sh.Builtins["add-path"] = sh.AddPath - sh.Builtins["which"] = sh.Which - sh.Builtins["source"] = sh.Source - sh.Builtins["cossh"] = sh.CoSSH - sh.Builtins["scp"] = sh.Scp - sh.Builtins["debug"] = sh.Debug - sh.Builtins["history"] = sh.History +// InstallBuiltins adds the builtin goal commands to [Goal.Builtins]. +func (gl *Goal) InstallBuiltins() { + gl.Builtins = make(map[string]func(cmdIO *exec.CmdIO, args ...string) error) + gl.Builtins["cd"] = gl.Cd + gl.Builtins["exit"] = gl.Exit + gl.Builtins["jobs"] = gl.JobsCmd + gl.Builtins["kill"] = gl.Kill + gl.Builtins["set"] = gl.Set + gl.Builtins["unset"] = gl.Unset + gl.Builtins["add-path"] = gl.AddPath + gl.Builtins["which"] = gl.Which + gl.Builtins["source"] = gl.Source + gl.Builtins["cossh"] = gl.CoSSH + gl.Builtins["scp"] = gl.Scp + gl.Builtins["debug"] = gl.Debug + gl.Builtins["history"] = gl.History } // Cd changes the current directory. -func (sh *Shell) Cd(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Cd(cmdIO *exec.CmdIO, args ...string) error { if len(args) > 1 { return fmt.Errorf("no more than one argument can be passed to cd") } @@ -63,27 +67,49 @@ func (sh *Shell) Cd(cmdIO *exec.CmdIO, args ...string) error { if err != nil { return err } - sh.Config.Dir = dir + gl.Config.Dir = dir return nil } // Exit exits the shell. -func (sh *Shell) Exit(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Exit(cmdIO *exec.CmdIO, args ...string) error { os.Exit(0) return nil } // Set sets the given environment variable to the given value. -func (sh *Shell) Set(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Set(cmdIO *exec.CmdIO, args ...string) error { if len(args) != 2 { return fmt.Errorf("expected two arguments, got %d", len(args)) } - return os.Setenv(args[0], args[1]) + val := args[1] + if strings.Count(val, ":") > 1 || strings.Contains(val, "~") { + vl := stringsx.DedupeList(strings.Split(val, ":")) + vl = AddHomeExpand([]string{}, vl...) + val = strings.Join(vl, ":") + } + err := os.Setenv(args[0], val) + if runtime.GOOS == "darwin" { + gl.Config.RunIO(cmdIO, "/bin/launchctl", "setenv", args[0], val) + } + return err +} + +// Unset un-sets the given environment variable. +func (gl *Goal) Unset(cmdIO *exec.CmdIO, args ...string) error { + if len(args) != 1 { + return fmt.Errorf("expected one argument, got %d", len(args)) + } + err := os.Unsetenv(args[0]) + if runtime.GOOS == "darwin" { + gl.Config.RunIO(cmdIO, "/bin/launchctl", "unsetenv", args[0]) + } + return err } // JobsCmd is the builtin jobs command -func (sh *Shell) JobsCmd(cmdIO *exec.CmdIO, args ...string) error { - for i, jb := range sh.Jobs { +func (gl *Goal) JobsCmd(cmdIO *exec.CmdIO, args ...string) error { + for i, jb := range gl.Jobs { cmdIO.Printf("[%d] %s\n", i+1, jb.String()) } return nil @@ -91,27 +117,27 @@ func (sh *Shell) JobsCmd(cmdIO *exec.CmdIO, args ...string) error { // Kill kills a job by job number or PID. // Just expands the job id expressions %n into PIDs and calls system kill. -func (sh *Shell) Kill(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Kill(cmdIO *exec.CmdIO, args ...string) error { if len(args) == 0 { - return fmt.Errorf("cosh kill: expected at least one argument") + return fmt.Errorf("goal kill: expected at least one argument") } - sh.JobIDExpand(args) - sh.Config.RunIO(cmdIO, "kill", args...) + gl.JobIDExpand(args) + gl.Config.RunIO(cmdIO, "kill", args...) return nil } // Fg foregrounds a job by job number -func (sh *Shell) Fg(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Fg(cmdIO *exec.CmdIO, args ...string) error { if len(args) != 1 { - return fmt.Errorf("cosh fg: requires exactly one job id argument") + return fmt.Errorf("goal fg: requires exactly one job id argument") } jid := args[0] - exp := sh.JobIDExpand(args) + exp := gl.JobIDExpand(args) if exp != 1 { - return fmt.Errorf("cosh fg: argument was not a job id in the form %%n") + return fmt.Errorf("goal fg: argument was not a job id in the form %%n") } jno, _ := strconv.Atoi(jid[1:]) // guaranteed good - job := sh.Jobs[jno] + job := gl.Jobs[jno] cmdIO.Printf("foregrounding job [%d]\n", jno) _ = job // todo: the problem here is we need to change the stdio for running job @@ -122,49 +148,72 @@ func (sh *Shell) Fg(cmdIO *exec.CmdIO, args ...string) error { return nil } +// AddHomeExpand adds given strings to the given list of strings, +// expanding any ~ symbols with the home directory, +// and returns the updated list. +func AddHomeExpand(list []string, adds ...string) []string { + for _, add := range adds { + add, err := homedir.Expand(add) + errors.Log(err) + has := false + for _, s := range list { + if s == add { + has = true + } + } + if !has { + list = append(list, add) + } + } + return list +} + // AddPath adds the given path(s) to $PATH. -func (sh *Shell) AddPath(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) AddPath(cmdIO *exec.CmdIO, args ...string) error { if len(args) == 0 { - return fmt.Errorf("cosh add-path expected at least one argument") + return fmt.Errorf("goal add-path expected at least one argument") } path := os.Getenv("PATH") - for _, arg := range args { - arg, err := homedir.Expand(arg) - if err != nil { - return err - } - path = path + ":" + arg - } - return os.Setenv("PATH", path) + ps := strings.Split(path, ":") + ps = stringsx.DedupeList(ps) + ps = AddHomeExpand(ps, args...) + path = strings.Join(ps, ":") + err := os.Setenv("PATH", path) + // if runtime.GOOS == "darwin" { + // this is what would be required to work: + // sudo launchctl config user path $PATH -- the following does not work: + // gl.Config.RunIO(cmdIO, "/bin/launchctl", "setenv", "PATH", path) + // } + return err } // Which reports the executable associated with the given command. // Processes builtins and commands, and if not found, then passes on // to exec which. -func (sh *Shell) Which(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Which(cmdIO *exec.CmdIO, args ...string) error { if len(args) != 1 { - return fmt.Errorf("cosh which: requires one argument") + return fmt.Errorf("goal which: requires one argument") } cmd := args[0] - if _, hasCmd := sh.Commands[cmd]; hasCmd { + if _, hasCmd := gl.Commands[cmd]; hasCmd { cmdIO.Println(cmd, "is a user-defined command") return nil } - if _, hasBlt := sh.Builtins[cmd]; hasBlt { - cmdIO.Println(cmd, "is a cosh builtin command") + if _, hasBlt := gl.Builtins[cmd]; hasBlt { + cmdIO.Println(cmd, "is a goal builtin command") return nil } - sh.Config.RunIO(cmdIO, "which", args...) + gl.Config.RunIO(cmdIO, "which", args...) return nil } // Source loads and evaluates the given file(s) -func (sh *Shell) Source(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Source(cmdIO *exec.CmdIO, args ...string) error { if len(args) == 0 { - return fmt.Errorf("cosh source: requires at least one argument") + return fmt.Errorf("goal source: requires at least one argument") } for _, fn := range args { - sh.TranspileCodeFromFile(fn) + gl.TranspileCodeFromFile(fn) } // note that we do not execute the file -- just loads it in return nil @@ -176,18 +225,18 @@ func (sh *Shell) Source(cmdIO *exec.CmdIO, args ...string) error { // - host [name] -- connects to a server specified in first arg and switches // to using it, with optional name instead of default sequential number. // - close -- closes all open connections, or the specified one -func (sh *Shell) CoSSH(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) CoSSH(cmdIO *exec.CmdIO, args ...string) error { if len(args) < 1 { return fmt.Errorf("cossh: requires at least one argument") } cmd := args[0] var err error host := "" - name := fmt.Sprintf("%d", 1+len(sh.SSHClients)) + name := fmt.Sprintf("%d", 1+len(gl.SSHClients)) con := false switch { case cmd == "close": - sh.CloseSSH() + gl.CloseSSH() return nil case cmd == "@" && len(args) == 2: name = args[1] @@ -200,21 +249,21 @@ func (sh *Shell) CoSSH(cmdIO *exec.CmdIO, args ...string) error { host = args[0] } if con { - cl := sshclient.NewClient(sh.SSH) + cl := sshclient.NewClient(gl.SSH) err = cl.Connect(host) if err != nil { return err } - sh.SSHClients[name] = cl - sh.SSHActive = name + gl.SSHClients[name] = cl + gl.SSHActive = name } else { if name == "0" { - sh.SSHActive = "" + gl.SSHActive = "" } else { - sh.SSHActive = name - cl := sh.ActiveSSH() + gl.SSHActive = name + cl := gl.ActiveSSH() if cl == nil { - err = fmt.Errorf("cosh: ssh connection named: %q not found", name) + err = fmt.Errorf("goal: ssh connection named: %q not found", name) } } } @@ -226,7 +275,7 @@ func (sh *Shell) CoSSH(cmdIO *exec.CmdIO, args ...string) error { // The order is from -> to, as in standard cp. // The remote filename is automatically relative to the current working // directory on the remote host. -func (sh *Shell) Scp(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Scp(cmdIO *exec.CmdIO, args ...string) error { if len(args) != 2 { return fmt.Errorf("scp: requires exactly two arguments") } @@ -250,12 +299,12 @@ func (sh *Shell) Scp(cmdIO *exec.CmdIO, args ...string) error { host := hfn[1:ci] hfn = hfn[ci+1:] - cl, err := sh.SSHByHost(host) + cl, err := gl.SSHByHost(host) if err != nil { return err } - ctx := sh.Ctx + ctx := gl.Ctx if ctx == nil { ctx = context.Background() } @@ -269,7 +318,7 @@ func (sh *Shell) Scp(cmdIO *exec.CmdIO, args ...string) error { } // Debug changes log level -func (sh *Shell) Debug(cmdIO *exec.CmdIO, args ...string) error { +func (gl *Goal) Debug(cmdIO *exec.CmdIO, args ...string) error { if len(args) == 0 { if logx.UserLevel == slog.LevelDebug { logx.UserLevel = slog.LevelInfo @@ -289,8 +338,8 @@ func (sh *Shell) Debug(cmdIO *exec.CmdIO, args ...string) error { } // History shows history -func (sh *Shell) History(cmdIO *exec.CmdIO, args ...string) error { - n := len(sh.Hist) +func (gl *Goal) History(cmdIO *exec.CmdIO, args ...string) error { + n := len(gl.Hist) nh := n if len(args) == 1 { an, err := strconv.Atoi(args[0]) @@ -302,7 +351,7 @@ func (sh *Shell) History(cmdIO *exec.CmdIO, args ...string) error { return fmt.Errorf("history: uses at most one argument") } for i := n - nh; i < n; i++ { - cmdIO.Printf("%d:\t%s\n", i, sh.Hist[i]) + cmdIO.Printf("%d:\t%s\n", i, gl.Hist[i]) } return nil } diff --git a/shell/cmd/cosh/cfg.cosh b/goal/cmd/goal/cfg.cosh similarity index 100% rename from shell/cmd/cosh/cfg.cosh rename to goal/cmd/goal/cfg.cosh diff --git a/goal/cmd/goal/goal.go b/goal/cmd/goal/goal.go new file mode 100644 index 0000000000..7aeced5f8e --- /dev/null +++ b/goal/cmd/goal/goal.go @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Command goal is an interactive cli for running and compiling Goal code. +package main + +import ( + "cogentcore.org/core/cli" + "cogentcore.org/core/goal/interpreter" +) + +func main() { //types:skip + opts := cli.DefaultOptions("goal", "An interactive tool for running and compiling Goal (Go augmented language).") + cfg := &interpreter.Config{} + cfg.InteractiveFunc = interpreter.Interactive + cli.Run(opts, cfg, interpreter.Run, interpreter.Build) +} diff --git a/shell/cmd/cosh/testdata/make b/goal/cmd/goal/testdata/make similarity index 67% rename from shell/cmd/cosh/testdata/make rename to goal/cmd/goal/testdata/make index a47bd26d3f..79dc53408b 100755 --- a/shell/cmd/cosh/testdata/make +++ b/goal/cmd/goal/testdata/make @@ -1,5 +1,5 @@ -#!/usr/bin/env cosh -// test makefile for cosh. +#!/usr/bin/env goal +// test makefile for goal. // example usage: // ./make build @@ -11,5 +11,5 @@ command test { println("running the test command") } -shell.RunCommands(args) +goal.RunCommands(args) diff --git a/goal/cmd/goal/testdata/test.goal b/goal/cmd/goal/testdata/test.goal new file mode 100644 index 0000000000..ab8686e543 --- /dev/null +++ b/goal/cmd/goal/testdata/test.goal @@ -0,0 +1,38 @@ +// test file for goal cli + +// todo: doesn't work: #1152 +// echo {args} + +for i, fn := range goalib.SplitLines(`/bin/ls -1`) { + fmt.Println(i, fn) +} + +## + +x := 1 +y := 4 +a := x * 2 +b := x + y +c := x * y + a * b + +fmt.Println(c) + +l := linspace(0.1, 0.2, 5, true) +fmt.Println(l) + +m := reshape(arange(36),6,6) + +if m[1,1] == 7 { + fmt.Println(true) +} + +for i := 0; i < 3; i++ { + fmt.Println(i) +} + +for i, v := range m { + fmt.Println(i, v) +} + +## + diff --git a/shell/complete.go b/goal/complete.go similarity index 82% rename from shell/complete.go rename to goal/complete.go index bf79e0bf92..25caa5f51d 100644 --- a/shell/complete.go +++ b/goal/complete.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package goal import ( "os" @@ -16,14 +16,14 @@ import ( ) // CompleteMatch is the [complete.MatchFunc] for the shell. -func (sh *Shell) CompleteMatch(data any, text string, posLine, posChar int) (md complete.Matches) { +func (gl *Goal) CompleteMatch(data any, text string, posLine, posChar int) (md complete.Matches) { comps := complete.Completions{} text = text[:posChar] md.Seed = complete.SeedPath(text) fullPath := complete.SeedSpace(text) fullPath = errors.Log1(homedir.Expand(fullPath)) parent := strings.TrimSuffix(fullPath, md.Seed) - dir := filepath.Join(sh.Config.Dir, parent) + dir := filepath.Join(gl.Config.Dir, parent) if filepath.IsAbs(parent) { dir = parent } @@ -37,18 +37,18 @@ func (sh *Shell) CompleteMatch(data any, text string, posLine, posChar int) (md comps = append(comps, complete.Completion{ Text: name, Icon: icon, - Desc: filepath.Join(sh.Config.Dir, name), + Desc: filepath.Join(gl.Config.Dir, name), }) } if parent == "" { - for cmd := range sh.Builtins { + for cmd := range gl.Builtins { comps = append(comps, complete.Completion{ Text: cmd, Icon: icons.Terminal, Desc: "Builtin command: " + cmd, }) } - for cmd := range sh.Commands { + for cmd := range gl.Commands { comps = append(comps, complete.Completion{ Text: cmd, Icon: icons.Terminal, @@ -63,18 +63,18 @@ func (sh *Shell) CompleteMatch(data any, text string, posLine, posChar int) (md } // CompleteEdit is the [complete.EditFunc] for the shell. -func (sh *Shell) CompleteEdit(data any, text string, cursorPos int, completion complete.Completion, seed string) (ed complete.Edit) { +func (gl *Goal) CompleteEdit(data any, text string, cursorPos int, completion complete.Completion, seed string) (ed complete.Edit) { return complete.EditWord(text, cursorPos, completion.Text, seed) } // ReadlineCompleter implements [github.com/ergochat/readline.AutoCompleter]. type ReadlineCompleter struct { - Shell *Shell + Goal *Goal } func (rc *ReadlineCompleter) Do(line []rune, pos int) (newLine [][]rune, length int) { text := string(line) - md := rc.Shell.CompleteMatch(nil, text, 0, pos) + md := rc.Goal.CompleteMatch(nil, text, 0, pos) res := [][]rune{} for _, match := range md.Matches { after := strings.TrimPrefix(match.Text, md.Seed) diff --git a/shell/exec.go b/goal/exec.go similarity index 57% rename from shell/exec.go rename to goal/exec.go index e79cb151cc..57b768c334 100644 --- a/shell/exec.go +++ b/goal/exec.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package goal import ( "bytes" @@ -11,6 +11,8 @@ import ( "path/filepath" "slices" "strings" + "sync" + "time" "cogentcore.org/core/base/exec" "cogentcore.org/core/base/reflectx" @@ -21,22 +23,22 @@ import ( // Exec handles command execution for all cases, parameterized by the args. // It executes the given command string, waiting for the command to finish, // handling the given arguments appropriately. -// If there is any error, it adds it to the shell, and triggers CancelExecution. +// If there is any error, it adds it to the goal, and triggers CancelExecution. // - errOk = don't call AddError so execution will not stop on error // - start = calls Start on the command, which then runs asynchronously, with // a goroutine forked to Wait for it and close its IO // - output = return the output of the command as a string (otherwise return is "") -func (sh *Shell) Exec(errOk, start, output bool, cmd any, args ...any) string { +func (gl *Goal) Exec(errOk, start, output bool, cmd any, args ...any) string { out := "" - if !errOk && len(sh.Errors) > 0 { + if !errOk && len(gl.Errors) > 0 { return out } - cmdIO := exec.NewCmdIO(&sh.Config) + cmdIO := exec.NewCmdIO(&gl.Config) cmdIO.StackStart() if start { cmdIO.PushIn(nil) // no stdin for bg } - cl, scmd, sargs := sh.ExecArgs(cmdIO, errOk, cmd, args...) + cl, scmd, sargs := gl.ExecArgs(cmdIO, errOk, cmd, args...) if scmd == "" { return out } @@ -52,35 +54,38 @@ func (sh *Shell) Exec(errOk, start, output bool, cmd any, args ...any) string { err = cl.Run(&cmdIO.StdIOState, scmd, sargs...) } if !errOk { - sh.AddError(err) + gl.AddError(err) } } else { ran := false - ran, out = sh.RunBuiltinOrCommand(cmdIO, errOk, output, scmd, sargs...) + ran, out = gl.RunBuiltinOrCommand(cmdIO, errOk, start, output, scmd, sargs...) if !ran { - sh.isCommand.Push(false) + gl.isCommand.Push(false) switch { case start: - err = sh.Config.StartIO(cmdIO, scmd, sargs...) - sh.Jobs.Push(cmdIO) + // fmt.Fprintf(gl.debugTrace, "start exe %s in: %#v out: %#v %v\n ", scmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe()) + err = gl.Config.StartIO(cmdIO, scmd, sargs...) + job := &Job{CmdIO: cmdIO} + gl.Jobs.Push(job) go func() { if !cmdIO.OutIsPipe() { - fmt.Printf("[%d] %s\n", len(sh.Jobs), cmdIO.String()) + fmt.Printf("[%d] %s\n", len(gl.Jobs), cmdIO.String()) } cmdIO.Cmd.Wait() cmdIO.PopToStart() - sh.DeleteJob(cmdIO) + gl.DeleteJob(job) }() case output: cmdIO.PushOut(nil) - out, err = sh.Config.OutputIO(cmdIO, scmd, sargs...) + out, err = gl.Config.OutputIO(cmdIO, scmd, sargs...) default: - err = sh.Config.RunIO(cmdIO, scmd, sargs...) + // fmt.Fprintf(gl.debugTrace, "run exe %s in: %#v out: %#v %v\n ", scmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe()) + err = gl.Config.RunIO(cmdIO, scmd, sargs...) } if !errOk { - sh.AddError(err) + gl.AddError(err) } - sh.isCommand.Pop() + gl.isCommand.Pop() } } if !start { @@ -91,64 +96,105 @@ func (sh *Shell) Exec(errOk, start, output bool, cmd any, args ...any) string { // RunBuiltinOrCommand runs a builtin or a command, returning true if it ran, // and the output string if running in output mode. -func (sh *Shell) RunBuiltinOrCommand(cmdIO *exec.CmdIO, errOk, output bool, cmd string, args ...string) (bool, string) { +func (gl *Goal) RunBuiltinOrCommand(cmdIO *exec.CmdIO, errOk, start, output bool, cmd string, args ...string) (bool, string) { out := "" - cmdFun, hasCmd := sh.Commands[cmd] - bltFun, hasBlt := sh.Builtins[cmd] + cmdFun, hasCmd := gl.Commands[cmd] + bltFun, hasBlt := gl.Builtins[cmd] if !hasCmd && !hasBlt { return false, out } if hasCmd { - sh.commandArgs.Push(args) - sh.isCommand.Push(true) + gl.commandArgs.Push(args) + gl.isCommand.Push(true) } // note: we need to set both os. and wrapper versions, so it works the same // in compiled vs. interpreted mode - oldsh := sh.Config.StdIO.Set(&cmdIO.StdIO) - oldwrap := sh.StdIOWrappers.SetWrappers(&cmdIO.StdIO) - oldstd := cmdIO.SetToOS() - if output { + var oldsh, oldwrap, oldstd *exec.StdIO + save := func() { + oldsh = gl.Config.StdIO.Set(&cmdIO.StdIO) + oldwrap = gl.StdIOWrappers.SetWrappers(&cmdIO.StdIO) + oldstd = cmdIO.SetToOS() + } + + done := func() { + if hasCmd { + gl.isCommand.Pop() + gl.commandArgs.Pop() + } + // fmt.Fprintf(gl.debugTrace, "%s restore %#v\n", cmd, oldstd.In) + oldstd.SetToOS() + gl.StdIOWrappers.SetWrappers(oldwrap) + gl.Config.StdIO = *oldsh + } + + switch { + case start: + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + if !cmdIO.OutIsPipe() { + fmt.Printf("[%d] %s\n", len(gl.Jobs), cmd) + } + if hasCmd { + oldwrap = gl.StdIOWrappers.SetWrappers(&cmdIO.StdIO) + // oldstd = cmdIO.SetToOS() + // fmt.Fprintf(gl.debugTrace, "%s oldstd in: %#v out: %#v\n", cmd, oldstd.In, oldstd.Out) + cmdFun(args...) + // oldstd.SetToOS() + gl.StdIOWrappers.SetWrappers(oldwrap) + gl.isCommand.Pop() + gl.commandArgs.Pop() + } else { + gl.AddError(bltFun(cmdIO, args...)) + } + time.Sleep(time.Millisecond) + wg.Done() + }() + // fmt.Fprintf(gl.debugTrace, "%s push: %#v out: %#v %v\n", cmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe()) + job := &Job{CmdIO: cmdIO} + gl.Jobs.Push(job) + go func() { + wg.Wait() + cmdIO.PopToStart() + gl.DeleteJob(job) + }() + case output: + save() obuf := &bytes.Buffer{} // os.Stdout = obuf // needs a file - sh.Config.StdIO.Out = obuf - sh.StdIOWrappers.SetWrappedOut(obuf) + gl.Config.StdIO.Out = obuf + gl.StdIOWrappers.SetWrappedOut(obuf) cmdIO.PushOut(obuf) if hasCmd { cmdFun(args...) } else { - sh.AddError(bltFun(cmdIO, args...)) + gl.AddError(bltFun(cmdIO, args...)) } out = strings.TrimSuffix(obuf.String(), "\n") - } else { + done() + default: + save() if hasCmd { cmdFun(args...) } else { - sh.AddError(bltFun(cmdIO, args...)) + gl.AddError(bltFun(cmdIO, args...)) } + done() } - - if hasCmd { - sh.isCommand.Pop() - sh.commandArgs.Pop() - } - oldstd.SetToOS() - sh.StdIOWrappers.SetWrappers(oldwrap) - sh.Config.StdIO = *oldsh - return true, out } -func (sh *Shell) HandleArgErr(errok bool, err error) error { +func (gl *Goal) HandleArgErr(errok bool, err error) error { if err == nil { return err } if errok { - sh.Config.StdIO.ErrPrintln(err.Error()) + gl.Config.StdIO.ErrPrintln(err.Error()) } else { - sh.AddError(err) + gl.AddError(err) } return err } @@ -156,16 +202,17 @@ func (sh *Shell) HandleArgErr(errok bool, err error) error { // ExecArgs processes the args to given exec command, // handling all of the input / output redirection and // file globbing, homedir expansion, etc. -func (sh *Shell) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) (*sshclient.Client, string, []string) { - if len(sh.Jobs) > 0 { - jb := sh.Jobs.Peek() - if jb.OutIsPipe() { +func (gl *Goal) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) (*sshclient.Client, string, []string) { + if len(gl.Jobs) > 0 { + jb := gl.Jobs.Peek() + if jb.OutIsPipe() && !jb.GotPipe { + jb.GotPipe = true cmdIO.PushIn(jb.PipeIn.Peek()) } } scmd := reflectx.ToString(cmd) - cl := sh.ActiveSSH() - isCmd := sh.isCommand.Peek() + cl := gl.ActiveSSH() + // isCmd := gl.isCommand.Peek() sargs := make([]string, 0, len(args)) var err error for _, a := range args { @@ -175,7 +222,7 @@ func (sh *Shell) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) ( } if cl == nil { s, err = homedir.Expand(s) - sh.HandleArgErr(errOk, err) + gl.HandleArgErr(errOk, err) // note: handling globbing in a later pass, to not clutter.. } else { if s[0] == '~' { @@ -190,18 +237,18 @@ func (sh *Shell) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) ( cl = nil } else { hnm := scmd[1:] - if scl, ok := sh.SSHClients[hnm]; ok { + if scl, ok := gl.SSHClients[hnm]; ok { newHost = hnm cl = scl } else { - sh.HandleArgErr(errOk, fmt.Errorf("cosh: ssh connection named: %q not found", hnm)) + gl.HandleArgErr(errOk, fmt.Errorf("goal: ssh connection named: %q not found", hnm)) } } if len(sargs) > 0 { scmd = sargs[0] sargs = sargs[1:] } else { // just a ssh switch - sh.SSHActive = newHost + gl.SSHActive = newHost return nil, "", nil } } @@ -209,11 +256,11 @@ func (sh *Shell) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) ( s := sargs[i] switch { case s[0] == '>': - sargs = sh.OutToFile(cl, cmdIO, errOk, sargs, i) + sargs = gl.OutToFile(cl, cmdIO, errOk, sargs, i) case s[0] == '|': - sargs = sh.OutToPipe(cl, cmdIO, errOk, sargs, i) - case cl == nil && isCmd && strings.HasPrefix(s, "args"): - sargs = sh.CmdArgs(errOk, sargs, i) + sargs = gl.OutToPipe(cl, cmdIO, errOk, sargs, i) + case cl == nil && strings.HasPrefix(s, "args"): + sargs = gl.CmdArgs(errOk, sargs, i) i-- // back up because we consume this one } } @@ -235,7 +282,7 @@ func (sh *Shell) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) ( } // OutToFile processes the > arg that sends output to a file -func (sh *Shell) OutToFile(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string { +func (gl *Goal) OutToFile(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string { n := len(sargs) s := sargs[i] sn := len(s) @@ -260,7 +307,7 @@ func (sh *Shell) OutToFile(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, narg = 1 } if fn == "" { - sh.HandleArgErr(errOk, fmt.Errorf("cosh: no output file specified")) + gl.HandleArgErr(errOk, fmt.Errorf("goal: no output file specified")) return sargs } if cl != nil { @@ -285,35 +332,33 @@ func (sh *Shell) OutToFile(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, cmdIO.PushErr(f) } } else { - sh.HandleArgErr(errOk, err) + gl.HandleArgErr(errOk, err) } return sargs } // OutToPipe processes the | arg that sends output to a pipe -func (sh *Shell) OutToPipe(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string { +func (gl *Goal) OutToPipe(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string { s := sargs[i] sn := len(s) errf := false if sn > 1 && s[1] == '&' { errf = true } - // todo: what to do here? sargs = slices.Delete(sargs, i, i+1) cmdIO.PushOutPipe() if errf { cmdIO.PushErr(cmdIO.Out) } - // sh.HandleArgErr(errok, err) return sargs } // CmdArgs processes expressions involving "args" for commands -func (sh *Shell) CmdArgs(errOk bool, sargs []string, i int) []string { +func (gl *Goal) CmdArgs(errOk bool, sargs []string, i int) []string { // n := len(sargs) // s := sargs[i] // sn := len(s) - args := sh.commandArgs.Peek() + args := gl.commandArgs.Peek() // fmt.Println("command args:", args) @@ -327,8 +372,8 @@ func (sh *Shell) CmdArgs(errOk bool, sargs []string, i int) []string { } // CancelExecution calls the Cancel() function if set. -func (sh *Shell) CancelExecution() { - if sh.Cancel != nil { - sh.Cancel() +func (gl *Goal) CancelExecution() { + if gl.Cancel != nil { + gl.Cancel() } } diff --git a/shell/exec_test.go b/goal/exec_test.go similarity index 78% rename from shell/exec_test.go rename to goal/exec_test.go index 657bf0275d..9efb0a82c7 100644 --- a/shell/exec_test.go +++ b/goal/exec_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package goal import ( "testing" @@ -11,5 +11,5 @@ import ( ) func TestExec(t *testing.T) { - assert.Equal(t, "hi", NewShell().Output("echo", "hi")) + assert.Equal(t, "hi", NewGoal().Output("echo", "hi")) } diff --git a/goal/goal.go b/goal/goal.go new file mode 100644 index 0000000000..2e7e67a87c --- /dev/null +++ b/goal/goal.go @@ -0,0 +1,372 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package goal provides the Goal Go augmented language transpiler, +// which combines the best parts of Go, bash, and Python to provide +// an integrated shell and numerical expression processing experience. +package goal + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/exec" + "cogentcore.org/core/base/logx" + "cogentcore.org/core/base/reflectx" + "cogentcore.org/core/base/sshclient" + "cogentcore.org/core/base/stack" + "cogentcore.org/core/goal/transpile" + "github.com/mitchellh/go-homedir" +) + +// Goal represents one running Goal language context. +type Goal struct { + + // Config is the [exec.Config] used to run commands. + Config exec.Config + + // StdIOWrappers are IO wrappers sent to the interpreter, so we can + // control the IO streams used within the interpreter. + // Call SetWrappers on this with another StdIO object to update settings. + StdIOWrappers exec.StdIO + + // ssh connection, configuration + SSH *sshclient.Config + + // collection of ssh clients + SSHClients map[string]*sshclient.Client + + // SSHActive is the name of the active SSH client + SSHActive string + + // Builtins are all the builtin shell commands + Builtins map[string]func(cmdIO *exec.CmdIO, args ...string) error + + // commands that have been defined, which can be run in Exec mode. + Commands map[string]func(args ...string) + + // Jobs is a stack of commands running in the background + // (via Start instead of Run) + Jobs stack.Stack[*Job] + + // Cancel, while the interpreter is running, can be called + // to stop the code interpreting. + // It is connected to the Ctx context, by StartContext() + // Both can be nil. + Cancel func() + + // Errors is a stack of runtime errors + Errors []error + + // Ctx is the context used for cancelling current shell running + // a single chunk of code, typically from the interpreter. + // We are not able to pass the context around so it is set here, + // in the StartContext function. Clear when done with ClearContext. + Ctx context.Context + + // original standard IO setings, to restore + OrigStdIO exec.StdIO + + // Hist is the accumulated list of command-line input, + // which is displayed with the history builtin command, + // and saved / restored from ~/.goalhist file + Hist []string + + // transpiling state + TrState transpile.State + + // commandArgs is a stack of args passed to a command, used for simplified + // processing of args expressions. + commandArgs stack.Stack[[]string] + + // isCommand is a stack of bools indicating whether the _immediate_ run context + // is a command, which affects the way that args are processed. + isCommand stack.Stack[bool] + + // debugTrace is a file written to for debugging + debugTrace *os.File +} + +// NewGoal returns a new [Goal] with default options. +func NewGoal() *Goal { + gl := &Goal{ + Config: exec.Config{ + Dir: errors.Log1(os.Getwd()), + Env: map[string]string{}, + Buffer: false, + }, + } + gl.TrState.FuncToVar = true + gl.Config.StdIO.SetFromOS() + gl.SSH = sshclient.NewConfig(&gl.Config) + gl.SSHClients = make(map[string]*sshclient.Client) + gl.Commands = make(map[string]func(args ...string)) + gl.InstallBuiltins() + // gl.debugTrace, _ = os.Create("goal.debug") // debugging + return gl +} + +// StartContext starts a processing context, +// setting the Ctx and Cancel Fields. +// Call EndContext when current operation finishes. +func (gl *Goal) StartContext() context.Context { + gl.Ctx, gl.Cancel = context.WithCancel(context.Background()) + return gl.Ctx +} + +// EndContext ends a processing context, clearing the +// Ctx and Cancel fields. +func (gl *Goal) EndContext() { + gl.Ctx = nil + gl.Cancel = nil +} + +// SaveOrigStdIO saves the current Config.StdIO as the original to revert to +// after an error, and sets the StdIOWrappers to use them. +func (gl *Goal) SaveOrigStdIO() { + gl.OrigStdIO = gl.Config.StdIO + gl.StdIOWrappers.NewWrappers(&gl.OrigStdIO) +} + +// RestoreOrigStdIO reverts to using the saved OrigStdIO +func (gl *Goal) RestoreOrigStdIO() { + gl.Config.StdIO = gl.OrigStdIO + gl.OrigStdIO.SetToOS() + gl.StdIOWrappers.SetWrappers(&gl.OrigStdIO) +} + +// Close closes any resources associated with the shell, +// including terminating any commands that are not running "nohup" +// in the background. +func (gl *Goal) Close() { + gl.CloseSSH() + // todo: kill jobs etc +} + +// CloseSSH closes all open ssh client connections +func (gl *Goal) CloseSSH() { + gl.SSHActive = "" + for _, cl := range gl.SSHClients { + cl.Close() + } + gl.SSHClients = make(map[string]*sshclient.Client) +} + +// ActiveSSH returns the active ssh client +func (gl *Goal) ActiveSSH() *sshclient.Client { + if gl.SSHActive == "" { + return nil + } + return gl.SSHClients[gl.SSHActive] +} + +// Host returns the name we're running commands on, +// which is empty if localhost (default). +func (gl *Goal) Host() string { + cl := gl.ActiveSSH() + if cl == nil { + return "" + } + return "@" + gl.SSHActive + ":" + cl.Host +} + +// HostAndDir returns the name we're running commands on, +// which is empty if localhost (default), +// and the current directory on that host. +func (gl *Goal) HostAndDir() string { + host := "" + dir := gl.Config.Dir + home := errors.Log1(homedir.Dir()) + cl := gl.ActiveSSH() + if cl != nil { + host = "@" + gl.SSHActive + ":" + cl.Host + ":" + dir = cl.Dir + home = cl.HomeDir + } + rel := errors.Log1(filepath.Rel(home, dir)) + // if it has to go back, then it is not in home dir, so no ~ + if strings.Contains(rel, "..") { + return host + dir + string(filepath.Separator) + } + return host + filepath.Join("~", rel) + string(filepath.Separator) +} + +// SSHByHost returns the SSH client for given host name, with err if not found +func (gl *Goal) SSHByHost(host string) (*sshclient.Client, error) { + if scl, ok := gl.SSHClients[host]; ok { + return scl, nil + } + return nil, fmt.Errorf("ssh connection named: %q not found", host) +} + +// TranspileCode processes each line of given code, +// adding the results to the LineStack +func (gl *Goal) TranspileCode(code string) { + gl.TrState.TranspileCode(code) +} + +// TranspileCodeFromFile transpiles the code in given file +func (gl *Goal) TranspileCodeFromFile(file string) error { + b, err := os.ReadFile(file) + if err != nil { + return err + } + gl.TranspileCode(string(b)) + return nil +} + +// TranspileFile transpiles the given input goal file to the +// given output Go file. If no existing package declaration +// is found, then package main and func main declarations are +// added. This also affects how functions are interpreted. +func (gl *Goal) TranspileFile(in string, out string) error { + return gl.TrState.TranspileFile(in, out) +} + +// AddError adds the given error to the error stack if it is non-nil, +// and calls the Cancel function if set, to stop execution. +// This is the main way that goal errors are handled. +// It also prints the error. +func (gl *Goal) AddError(err error) error { + if err == nil { + return nil + } + gl.Errors = append(gl.Errors, err) + logx.PrintlnError(err) + gl.CancelExecution() + return err +} + +// TranspileConfig transpiles the .goal startup config file in the user's +// home directory if it exists. +func (gl *Goal) TranspileConfig() error { + path, err := homedir.Expand("~/.goal") + if err != nil { + return err + } + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err + } + gl.TranspileCode(string(b)) + return nil +} + +// AddHistory adds given line to the Hist record of commands +func (gl *Goal) AddHistory(line string) { + gl.Hist = append(gl.Hist, line) +} + +// SaveHistory saves up to the given number of lines of current history +// to given file, e.g., ~/.goalhist for the default goal program. +// If n is <= 0 all lines are saved. n is typically 500 by default. +func (gl *Goal) SaveHistory(n int, file string) error { + path, err := homedir.Expand(file) + if err != nil { + return err + } + hn := len(gl.Hist) + sn := hn + if n > 0 { + sn = min(n, hn) + } + lh := strings.Join(gl.Hist[hn-sn:hn], "\n") + err = os.WriteFile(path, []byte(lh), 0666) + if err != nil { + return err + } + return nil +} + +// OpenHistory opens Hist history lines from given file, +// e.g., ~/.goalhist +func (gl *Goal) OpenHistory(file string) error { + path, err := homedir.Expand(file) + if err != nil { + return err + } + b, err := os.ReadFile(path) + if err != nil { + return err + } + gl.Hist = strings.Split(string(b), "\n") + return nil +} + +// AddCommand adds given command to list of available commands +func (gl *Goal) AddCommand(name string, cmd func(args ...string)) { + gl.Commands[name] = cmd +} + +// RunCommands runs the given command(s). This is typically called +// from a Makefile-style goal script. +func (gl *Goal) RunCommands(cmds []any) error { + for _, cmd := range cmds { + if cmdFun, hasCmd := gl.Commands[reflectx.ToString(cmd)]; hasCmd { + cmdFun() + } else { + return errors.Log(fmt.Errorf("command %q not found", cmd)) + } + } + return nil +} + +// DeleteAllJobs deletes any existing jobs, closing stdio. +func (gl *Goal) DeleteAllJobs() { + n := len(gl.Jobs) + for i := n - 1; i >= 0; i-- { + jb := gl.Jobs.Pop() + jb.CmdIO.PopToStart() + } +} + +// DeleteJob deletes the given job and returns true if successful, +func (gl *Goal) DeleteJob(job *Job) bool { + idx := slices.Index(gl.Jobs, job) + if idx >= 0 { + gl.Jobs = slices.Delete(gl.Jobs, idx, idx+1) + return true + } + return false +} + +// JobIDExpand expands %n job id values in args with the full PID +// returns number of PIDs expanded +func (gl *Goal) JobIDExpand(args []string) int { + exp := 0 + for i, id := range args { + if id[0] == '%' { + idx, err := strconv.Atoi(id[1:]) + if err == nil { + if idx > 0 && idx <= len(gl.Jobs) { + jb := gl.Jobs[idx-1] + if jb.Cmd != nil && jb.Cmd.Process != nil { + args[i] = fmt.Sprintf("%d", jb.Cmd.Process.Pid) + exp++ + } + } else { + gl.AddError(fmt.Errorf("goal: job number out of range: %d", idx)) + } + } + } + } + return exp +} + +// Job represents a job that has been started and we're waiting for it to finish. +type Job struct { + *exec.CmdIO + IsExec bool + GotPipe bool +} diff --git a/shell/cosh/coshlib.go b/goal/goalib/goalib.go similarity index 94% rename from shell/cosh/coshlib.go rename to goal/goalib/goalib.go index ddd1499063..556abbc947 100644 --- a/shell/cosh/coshlib.go +++ b/goal/goalib/goalib.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package cosh defines convenient utility functions for -// use in the cosh shell, available with the cosh prefix. -package cosh +// Package goalib defines convenient utility functions for +// use in the goal shell, available with the goalib prefix. +package goalib import ( "io/fs" diff --git a/goal/gosl/README.md b/goal/gosl/README.md new file mode 100644 index 0000000000..f2c7aa34e2 --- /dev/null +++ b/goal/gosl/README.md @@ -0,0 +1,254 @@ +# gosl: Go as a shader language + +`gosl` implements _Go as a shader language_ for GPU compute shaders (using [WebGPU](https://www.w3.org/TR/webgpu/)), **enabling standard Go code to run on the GPU**. + +`gosl` converts Go code to WGSL which can then be loaded directly into a WebGPU compute shader, using the [gpu](../../gpu) GPU compute shader system. It operates within the overall [Goal](../README.md) framework of an augmented version of the Go language. See the [GPU](../GPU.md) documentation for an overview of issues in GPU computation. + +The relevant regions of Go code to be run on the GPU are tagged using the `//gosl:start` and `//gosl:end` comment directives, and this code must only use basic expressions and concrete types that will compile correctly in a GPU shader (see [Restrictions](#restrictions) below). Method functions and pass-by-reference pointer arguments to `struct` types are supported and incur no additional compute cost due to inlining (see notes below for more detail). + +See [examples/basic](examples/basic) and [rand](examples/rand) for complete working examples. + +Typically, `gosl` is called from a go generate command, e.g., by including this comment directive: + +``` +//go:generate gosl +``` + +To install the `gosl` command: +```bash +$ go install cogentcore.org/core/goal/gosl@latest +``` + +It is also strongly recommended to install the `naga` WGSL compiler from https://github.com/gfx-rs/wgpu and the `tint` compiler from https://dawn.googlesource.com/dawn/ Both of these are used if available to validate the generated GPU shader code. It is much faster to fix the issues at generation time rather than when trying to run the app later. Once code passes validation in both of these compilers, it should load fine in your app, and if the Go version runs correctly, there is a good chance of at least some reasonable behavior on the GPU. + +# Usage + +There are two key elements for GPU-enabled code: + +1. One or more [Kernel](#kernels) compute functions that take an _index_ argument and perform computations for that specific index of data, _in parallel_. **GPU computation is effectively just a parallel `for` loop**. On the GPU, each such kernel is implemented by its own separate compute shader code, and one of the main functions of `gosl` is to generate this code from the Go sources, in the automatically created `shaders/` directory. + +2. [Global variables](#global-variables) on which the kernel functions _exclusively_ operate: all relevant data must be specifically copied from the CPU to the GPU and back. As explained in the [GPU](../GPU.md) docs, each GPU compute shader is effectively a _standalone_ program operating on these global variables. To replicate this environment on the CPU, so the code works in both contexts, we need to make these variables global in the CPU (Go) environment as well. + +`gosl` generates a file named `gosl.go` in your package directory that initializes the GPU with all of the global variables, and functions for running the kernels and syncing the gobal variable data back and forth between the CPu and GPU. + +## Kernels + +Each distinct compute kernel must be tagged with a `//gosl:kernel` comment directive, as in this example (from `examples/basic`): +```Go +// Compute does the main computation. +func Compute(i uint32) { //gosl:kernel + Params[0].IntegFromRaw(int(i)) +} +``` + +The kernel functions receive a `uint32` index argument, and use this to index into the global variables containing the relevant data. Typically the kernel code itself just calls other relevant function(s) using the index, as in the above example. Critically, _all_ of the data that a kernel function ultimately depends on must be contained with the global variables, and these variables must have been sync'd up to the GPU from the CPU prior to running the kernel (more on this below). + +In the CPU mode, the kernel is effectively run in a `for` loop like this: +```Go + for i := range n { + Compute(uint32(i)) + } +``` +A parallel goroutine-based mechanism is actually used, but conceptually this is what it does, on both the CPU and the GPU. To reiterate: **GPU computation is effectively just a parallel for loop**. + +## Global variables + +The global variables on which the kernels operate are declared in the usual Go manner, as a single `var` block, which is marked at the top using the `//gosl:vars` comment directive: + +```Go +//gosl:vars +var ( + // Params are the parameters for the computation. + //gosl:read-only + Params []ParamStruct + + // Data is the data on which the computation operates. + // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. + //gosl:dims 2 + Data tensor.Float32 +) +``` + +All such variables must be either: +1. A `slice` of GPU-alignment compatible `struct` types, such as `ParamStruct` in the above example. +2. A `tensor` of a GPU-compatible elemental data type (`float32`, `uint32`, or `int32`), with the number of dimensions indicated by the `//gosl:dims ` tag as shown above. + +You can also just declare a slice of elemental GPU-compatible data values such as `float32`, but it is generally preferable to use the tensor instead. + +### Tensor data + +On the GPU, the tensor data is represented using a simple flat array of the basic data type. To index into this array, the _strides_ for each dimension are encoded in a special `TensorStrides` tensor that is managed by `gosl`, in the generated `gosl.go` file. `gosl` automatically generates the appropriate indexing code using these strides (which is why the number of dimensions is needed). + +Whenever the strides of any tensor variable change, and at least once at initialization, your code must call the function that copies the current strides up to the GPU: +```Go + ToGPUTensorStrides() +``` + +### Multiple tensor variables for large data + +The size of each memory buffer is limited by the GPU, to a maximum of at most 4GB on modern GPU hardware. Therefore, if you need to have any single tensor that holds more than this amount of data, then a bank of multiple vars are required. `gosl` provides helper functions to make this relatively straightforward. + +TODO: this could be encoded in the TensorStrides. It will always be the outer-most index that determines when it gets over threshold, which all can be pre-computed. + +### Systems and Groups + +Each kernel belongs to a `gpu.ComputeSystem`, and each such system has one specific configuration of memory variables. In general, it is best to use a single set of global variables, and perform as much of the computation as possible on this set of variables, to minimize the number of memory transfers. However, if necessary, multiple systems can be defined, using an optional additional system name argument for the `args` and `kernel` tags. + +In addition, the vars can be organized into _groups_, which generally should have similar memory syncing behavior, as documented in the [gpu](../gpu) system. + +Here's an example with multiple groups: +```Go +//gosl:vars [system name] +var ( + // Layer-level parameters + //gosl:group -uniform Params + Layers []LayerParam // note: struct with appropriate memory alignment + + // Path-level parameters + Paths []PathParam + + // Unit state values + //gosl:group Units + Units tensor.Float32 + + // Synapse weight state values + Weights tensor.Float32 +) +``` + +## Memory syncing + +Each global variable gets an automatically-generated `*Var` enum (e.g., `DataVar` for global variable named `Data`), that used for the memory syncing functions, to make it easy to specify any number of such variables to sync, which is by far the most efficient. All of this is in the generated `gosl.go` file. For example: + +```Go + ToGPU(ParamsVar, DataVar) +``` + +Specifies that the current contents of `Params` and `Data` are to be copied up to the GPU, which is guaranteed to complete by the time the next kernel run starts, within a given system. + +## Kernel running + +As with memory transfers, it is much more efficient to run multiple kernels in sequence, all operating on the current data variables, followed by a single sync of the updated global variable data that has been computed. Thus, there are separate functions for specifying the kernels to run, followed by a single "Done" function that actually submits the entire batch of kernels, along with memory sync commands to get the data back from the GPU. For example: + +```Go + RunCompute1(n) + RunCompute2(n) + ... + RunDone(Data1Var, Data2Var) // launch all kernels and get data back to given vars +``` + +For CPU mode, `RunDone` is a no-op, and it just runs each kernel during each `Run` command. + +It is absolutely essential to understand that _all data must already be on the GPU_ at the start of the first Run command, and that any CPU-based computation between these calls is completely irrelevant for the GPU. Thus, it typically makes sense to just have a sequence of Run commands grouped together into a logical unit, with the relevant `ToGPU` calls at the start and the final `RunDone` grabs everything of relevance back from the GPU. + +## GPU relevant code taggng + +In a large GPU-based application, you should organize your code as you normally would in any standard Go application, distributing it across different files and packages. The GPU-relevant parts of each of those files can be tagged with the gosl tags: +``` +//gosl:start + +< Go code to be translated > + +//gosl:end +``` +to make this code available to all of the shaders that are generated. + +Use the `//gosl:import "package/path"` directive to import GPU-relevant code from other packages, similar to the standard Go import directive. It is assumed that many other Go imports are not GPU relevant, so this separate directive is required. + +If any `enums` variables are defined, pass the `-gosl` flag to the `core generate` command to ensure that the `N` value is tagged with `//gosl:start` and `//gosl:end` tags. + +**IMPORTANT:** all `.go` and `.wgsl` files are removed from the `shaders` directory prior to processing to ensure everything there is current -- always specify a different source location for any custom `.wgsl` files that are included. + +# Command line usage + +``` +gosl [flags] +``` + +The flags are: +``` + -debug + enable debugging messages while running + -exclude string + comma-separated list of names of functions to exclude from exporting to WGSL (default "Update,Defaults") + -keep + keep temporary converted versions of the source files, for debugging + -out string + output directory for shader code, relative to where gosl is invoked -- must not be an empty string (default "shaders") +``` + +`gosl` always operates on the current directory, looking for all files with `//gosl:` tags, and accumulating all the `import` files that they include, etc. + +Any `struct` types encountered will be checked for 16-byte alignment of sub-types and overall sizes as an even multiple of 16 bytes (4 `float32` or `int32` values), which is the alignment used in WGSL and glsl shader languages, and the underlying GPU hardware presumably. Look for error messages on the output from the gosl run. This ensures that direct byte-wise copies of data between CPU and GPU will be successful. The fact that `gosl` operates directly on the original CPU-side Go code uniquely enables it to perform these alignment checks, which are otherwise a major source of difficult-to-diagnose bugs. + +# Restrictions + +In general shader code should be simple mathematical expressions and data types, with minimal control logic via `if`, `for` statements, and only using the subset of Go that is consistent with C. Here are specific restrictions: + +* Can only use `float32`, `[u]int32` for basic types (`int` is converted to `int32` automatically), and `struct` types composed of these same types -- no other Go types (i.e., `map`, slices, `string`, etc) are compatible. There are strict alignment restrictions on 16 byte (e.g., 4 `float32`'s) intervals that are enforced via the `alignsl` sub-package. + +* WGSL does _not_ support 64 bit float or int. + +* Use `slbool.Bool` instead of `bool` -- it defines a Go-friendly interface based on a `int32` basic type. + +* Alignment and padding of `struct` fields is key -- this is automatically checked by `gosl`. + +* WGSL does not support enum types, but standard go `const` declarations will be converted. Use an `int32` or `uint32` data type. It will automatically deal with the simple incrementing `iota` values, but not more complex cases. Also, for bitflags, define explicitly, not using `bitflags` package, and use `0x01`, `0x02`, `0x04` etc instead of `1<<2` -- in theory the latter should be ok but in practice it complains. + +* Cannot use multiple return values, or multiple assignment of variables in a single `=` expression. + +* *Can* use multiple variable names with the same type (e.g., `min, max float32`) -- this will be properly converted to the more redundant form with the type repeated, for WGSL. + +* `switch` `case` statements are _purely_ self-contained -- no `fallthrough` allowed! does support multiple items per `case` however. Every `switch` _must_ have a `default` case. + +* WGSL does specify that new variables are initialized to 0, like Go, but also somehow discourages that use-case. It is safer to initialize directly: +```Go + val := float32(0) // guaranteed 0 value + var val float32 // ok but generally avoid +``` + +* A local variable to a global `struct` array variable (e.g., `par := &Params[i]`) can only be created as a function argument. There are special access restrictions that make it impossible to do otherwise. + +* tensor variables can only be used in `storage` (not `uniform`) memory, due to restrictions on dynamic sizing and alignment. Aside from this constraint, it is possible to designate a group of variables to use uniform memory, with the `-uniform` argument as the first item in the `//gosl:group` comment directive. + +## Other language features + +* [tour-of-wgsl](https://google.github.io/tour-of-wgsl/types/pointers/passing_pointers/) is a good reference to explain things more directly than the spec. + +* `ptr` provides a pointer arg +* `private` scope = within the shader code "module", i.e., one thread. +* `function` = within the function, not outside it. +* `workgroup` = shared across workgroup -- coudl be powerful (but slow!) -- need to learn more. + +## Atomic access + +WGSL adopts the Metal (lowest common denominator) strong constraint of imposing a _type_ level restriction on atomic operations: you can only do atomic operations on variables that have been declared atomic, as in: + +``` +var PathGBuf: array>; +... +atomicAdd(&PathGBuf[idx], val); +``` + +This also unfortunately has the side-effect that you cannot do _non-atomic_ operations on atomic variables, as discussed extensively here: https://github.com/gpuweb/gpuweb/issues/2377 Gosl automatically detects the use of atomic functions on GPU variables, and tags them as atomic. + +## Random numbers: slrand + +See [slrand](https://github.com/emer/gosl/v2/tree/main/slrand) for a shader-optimized random number generation package, which is supported by `gosl` -- it will convert `slrand` calls into appropriate WGSL named function calls. `gosl` will also copy the `slrand.wgsl` file, which contains the full source code for the RNG, into the destination `shaders` directory, so it can be included with a simple local path: + +```Go +//gosl:wgsl mycode +// #include "slrand.wgsl" +//gosl:end mycode +``` + +# Performance + +With sufficiently large N, and ignoring the data copying setup time, around ~80x speedup is typical on a Macbook Pro with M1 processor. The `rand` example produces a 175x speedup! + +# Implementation / Design Notes + +# Links + +Key docs for WGSL as compute shaders: + diff --git a/gpu/gosl/alignsl/README.md b/goal/gosl/alignsl/README.md similarity index 100% rename from gpu/gosl/alignsl/README.md rename to goal/gosl/alignsl/README.md diff --git a/gpu/gosl/alignsl/alignsl.go b/goal/gosl/alignsl/alignsl.go similarity index 97% rename from gpu/gosl/alignsl/alignsl.go rename to goal/gosl/alignsl/alignsl.go index 21cf353bdb..27d39250a2 100644 --- a/gpu/gosl/alignsl/alignsl.go +++ b/goal/gosl/alignsl/alignsl.go @@ -97,7 +97,8 @@ func CheckStruct(cx *Context, st *types.Struct, stName string) bool { last := cx.Sizes.Sizeof(flds[nf-1].Type()) totsz := int(offs[nf-1] + last) mod := totsz % 16 - if mod != 0 { + vectyp := strings.Contains(strings.ToLower(stName), "vec") // vector types are ok + if !vectyp && mod != 0 { needs := 4 - (mod / 4) hasErr = cx.AddError(fmt.Sprintf(" total size: %d not even multiple of 16 -- needs %d extra 32bit padding fields", totsz, needs), hasErr, stName) } diff --git a/gpu/gosl/doc.go b/goal/gosl/doc.go similarity index 73% rename from gpu/gosl/doc.go rename to goal/gosl/doc.go index b10cfe9835..a644ce0cbb 100644 --- a/gpu/gosl/doc.go +++ b/goal/gosl/doc.go @@ -10,11 +10,7 @@ /* gosl translates Go source code into WGSL compatible shader code. -use //gosl:start and //gosl:end to -bracket code that should be copied into shaders/.wgsl -Use //gosl:main instead of start for shader code that is -commented out in the .go file, which will be copied into the filename -and uncommented. +Use //gosl:start and //gosl:end to bracket code to generate. pass filenames or directory names for files to process. @@ -27,7 +23,7 @@ The flags are: -debug enable debugging messages while running -exclude string - comma-separated list of names of functions to exclude from exporting to HLSL (default "Update,Defaults") + comma-separated list of names of functions to exclude from exporting to WGSL (default "Update,Defaults") -keep keep temporary converted versions of the source files, for debugging -out string diff --git a/gpu/gosl/examples/basic/README.md b/goal/gosl/examples/basic/README.md similarity index 100% rename from gpu/gosl/examples/basic/README.md rename to goal/gosl/examples/basic/README.md diff --git a/goal/gosl/examples/basic/atomic.go b/goal/gosl/examples/basic/atomic.go new file mode 100644 index 0000000000..681f1ba127 --- /dev/null +++ b/goal/gosl/examples/basic/atomic.go @@ -0,0 +1,18 @@ +// Code generated by "goal build"; DO NOT EDIT. +//line atomic.goal:1 +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "sync/atomic" + +//gosl:start + +// Atomic does an atomic computation on the data. +func Atomic(i uint32) { //gosl:kernel + atomic.AddInt32(IntData.ValuePtr(int(i), int(Integ)), 1) +} + +//gosl:end diff --git a/goal/gosl/examples/basic/atomic.goal b/goal/gosl/examples/basic/atomic.goal new file mode 100644 index 0000000000..b4e7f97858 --- /dev/null +++ b/goal/gosl/examples/basic/atomic.goal @@ -0,0 +1,16 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "sync/atomic" + +//gosl:start + +// Atomic does an atomic computation on the data. +func Atomic(i uint32) { //gosl:kernel + atomic.AddInt32(&IntData[i, Integ], 1) +} + +//gosl:end diff --git a/goal/gosl/examples/basic/compute.go b/goal/gosl/examples/basic/compute.go new file mode 100644 index 0000000000..0ce3857ac7 --- /dev/null +++ b/goal/gosl/examples/basic/compute.go @@ -0,0 +1,84 @@ +// Code generated by "goal build"; DO NOT EDIT. +//line compute.goal:1 +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "cogentcore.org/core/math32" + "cogentcore.org/core/tensor" +) + +//gosl:start +//gosl:import "cogentcore.org/core/math32" + +//gosl:vars +var ( + // Params are the parameters for the computation. + // + //gosl:group Params + //gosl:read-only + Params []ParamStruct + + // Data is the data on which the computation operates. + // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. + // + //gosl:group Data + //gosl:dims 2 + Data *tensor.Float32 + + // IntData is the int data on which the computation operates. + // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. + // + //gosl:dims 2 + IntData *tensor.Int32 +) + +const ( + Raw int = iota + Integ + Exp + NVars +) + +// ParamStruct has the test params +type ParamStruct struct { + + // rate constant in msec + Tau float32 + + // 1/Tau + Dt float32 + + pad float32 + pad1 float32 +} + +// IntegFromRaw computes integrated value from current raw value +func (ps *ParamStruct) IntegFromRaw(idx int) { + integ := Data.Value(int(idx), int(Integ)) + integ += ps.Dt * (Data.Value(int(idx), int(Raw)) - integ) + Data.Set(integ, int(idx), int(Integ)) + Data.Set(math32.FastExp(-integ), int(idx), int(Exp)) +} + +// Compute does the main computation. +func Compute(i uint32) { //gosl:kernel + params := GetParams(0) + params.IntegFromRaw(int(i)) +} + +//gosl:end + +// note: only core compute code needs to be in shader -- all init is done CPU-side + +func (ps *ParamStruct) Defaults() { + ps.Tau = 5 + ps.Update() +} + +func (ps *ParamStruct) Update() { + ps.Dt = 1.0 / ps.Tau +} diff --git a/goal/gosl/examples/basic/compute.goal b/goal/gosl/examples/basic/compute.goal new file mode 100644 index 0000000000..246461e6ff --- /dev/null +++ b/goal/gosl/examples/basic/compute.goal @@ -0,0 +1,79 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "cogentcore.org/core/math32" + "cogentcore.org/core/tensor" +) + +//gosl:start +//gosl:import "cogentcore.org/core/math32" + +//gosl:vars +var ( + // Params are the parameters for the computation. + //gosl:group Params + //gosl:read-only + Params []ParamStruct + + // Data is the data on which the computation operates. + // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. + //gosl:group Data + //gosl:dims 2 + Data *tensor.Float32 + + // IntData is the int data on which the computation operates. + // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. + //gosl:dims 2 + IntData *tensor.Int32 +) + +const ( + Raw int = iota + Integ + Exp + NVars +) + +// ParamStruct has the test params +type ParamStruct struct { + + // rate constant in msec + Tau float32 + + // 1/Tau + Dt float32 + + pad float32 + pad1 float32 +} + +// IntegFromRaw computes integrated value from current raw value +func (ps *ParamStruct) IntegFromRaw(idx int) { + integ := Data[idx, Integ] + integ += ps.Dt * (Data[idx, Raw] - integ) + Data[idx, Integ] = integ + Data[idx, Exp] = math32.FastExp(-integ) +} + +// Compute does the main computation. +func Compute(i uint32) { //gosl:kernel + params := GetParams(0) + params.IntegFromRaw(int(i)) +} + +//gosl:end + +// note: only core compute code needs to be in shader -- all init is done CPU-side + +func (ps *ParamStruct) Defaults() { + ps.Tau = 5 + ps.Update() +} + +func (ps *ParamStruct) Update() { + ps.Dt = 1.0 / ps.Tau +} diff --git a/goal/gosl/examples/basic/gosl.go b/goal/gosl/examples/basic/gosl.go new file mode 100644 index 0000000000..7d9852296e --- /dev/null +++ b/goal/gosl/examples/basic/gosl.go @@ -0,0 +1,273 @@ +// Code generated by "gosl"; DO NOT EDIT + +package main + +import ( + "embed" + "unsafe" + "cogentcore.org/core/gpu" + "cogentcore.org/core/tensor" +) + +//go:embed shaders/*.wgsl +var shaders embed.FS + +// ComputeGPU is the compute gpu device +var ComputeGPU *gpu.GPU + +// UseGPU indicates whether to use GPU vs. CPU. +var UseGPU bool + +// GPUSystem is a GPU compute System with kernels operating on the +// same set of data variables. +var GPUSystem *gpu.ComputeSystem + +// GPUVars is an enum for GPU variables, for specifying what to sync. +type GPUVars int32 //enums:enum + +const ( + ParamsVar GPUVars = 0 + DataVar GPUVars = 1 + IntDataVar GPUVars = 2 +) + +// Tensor stride variables +var TensorStrides tensor.Uint32 + +// GPUInit initializes the GPU compute system, +// configuring system(s), variables and kernels. +// It is safe to call multiple times: detects if already run. +func GPUInit() { + if ComputeGPU != nil { + return + } + gp := gpu.NewComputeGPU() + ComputeGPU = gp + { + sy := gpu.NewComputeSystem(gp, "Default") + GPUSystem = sy + gpu.NewComputePipelineShaderFS(shaders, "shaders/Atomic.wgsl", sy) + gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy) + vars := sy.Vars() + { + sgp := vars.AddGroup(gpu.Storage, "Params") + var vr *gpu.Var + _ = vr + vr = sgp.Add("TensorStrides", gpu.Uint32, 1, gpu.ComputeShader) + vr.ReadOnly = true + vr = sgp.AddStruct("Params", int(unsafe.Sizeof(ParamStruct{})), 1, gpu.ComputeShader) + vr.ReadOnly = true + sgp.SetNValues(1) + } + { + sgp := vars.AddGroup(gpu.Storage, "Data") + var vr *gpu.Var + _ = vr + vr = sgp.Add("Data", gpu.Float32, 1, gpu.ComputeShader) + vr = sgp.Add("IntData", gpu.Int32, 1, gpu.ComputeShader) + sgp.SetNValues(1) + } + sy.Config() + } +} + +// GPURelease releases the GPU compute system resources. +// Call this at program exit. +func GPURelease() { + if GPUSystem != nil { + GPUSystem.Release() + GPUSystem = nil + } + + if ComputeGPU != nil { + ComputeGPU.Release() + ComputeGPU = nil + } +} + +// RunAtomic runs the Atomic kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// Can call multiple Run* kernels in a row, which are then all launched +// in the same command submission on the GPU, which is by far the most efficient. +// MUST call RunDone (with optional vars to sync) after all Run calls. +// Alternatively, a single-shot RunOneAtomic call does Run and Done for a +// single run-and-sync case. +func RunAtomic(n int) { + if UseGPU { + RunAtomicGPU(n) + } else { + RunAtomicCPU(n) + } +} + +// RunAtomicGPU runs the Atomic kernel on the GPU. See [RunAtomic] for more info. +func RunAtomicGPU(n int) { + sy := GPUSystem + pl := sy.ComputePipelines["Atomic"] + ce, _ := sy.BeginComputePass() + pl.Dispatch1D(ce, n, 64) +} + +// RunAtomicCPU runs the Atomic kernel on the CPU. +func RunAtomicCPU(n int) { + gpu.VectorizeFunc(0, n, Atomic) +} + +// RunOneAtomic runs the Atomic kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// This version then calls RunDone with the given variables to sync +// after the Run, for a single-shot Run-and-Done call. If multiple kernels +// can be run in sequence, it is much more efficient to do multiple Run* +// calls followed by a RunDone call. +func RunOneAtomic(n int, syncVars ...GPUVars) { + if UseGPU { + RunAtomicGPU(n) + RunDone(syncVars...) + } else { + RunAtomicCPU(n) + } +} +// RunCompute runs the Compute kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// Can call multiple Run* kernels in a row, which are then all launched +// in the same command submission on the GPU, which is by far the most efficient. +// MUST call RunDone (with optional vars to sync) after all Run calls. +// Alternatively, a single-shot RunOneCompute call does Run and Done for a +// single run-and-sync case. +func RunCompute(n int) { + if UseGPU { + RunComputeGPU(n) + } else { + RunComputeCPU(n) + } +} + +// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info. +func RunComputeGPU(n int) { + sy := GPUSystem + pl := sy.ComputePipelines["Compute"] + ce, _ := sy.BeginComputePass() + pl.Dispatch1D(ce, n, 64) +} + +// RunComputeCPU runs the Compute kernel on the CPU. +func RunComputeCPU(n int) { + gpu.VectorizeFunc(0, n, Compute) +} + +// RunOneCompute runs the Compute kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// This version then calls RunDone with the given variables to sync +// after the Run, for a single-shot Run-and-Done call. If multiple kernels +// can be run in sequence, it is much more efficient to do multiple Run* +// calls followed by a RunDone call. +func RunOneCompute(n int, syncVars ...GPUVars) { + if UseGPU { + RunComputeGPU(n) + RunDone(syncVars...) + } else { + RunComputeCPU(n) + } +} +// RunDone must be called after Run* calls to start compute kernels. +// This actually submits the kernel jobs to the GPU, and adds commands +// to synchronize the given variables back from the GPU to the CPU. +// After this function completes, the GPU results will be available in +// the specified variables. +func RunDone(syncVars ...GPUVars) { + if !UseGPU { + return + } + sy := GPUSystem + sy.ComputeEncoder.End() + ReadFromGPU(syncVars...) + sy.EndComputePass() + SyncFromGPU(syncVars...) +} + +// ToGPU copies given variables to the GPU for the system. +func ToGPU(vars ...GPUVars) { + if !UseGPU { + return + } + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case ParamsVar: + v, _ := syVars.ValueByIndex(0, "Params", 0) + gpu.SetValueFrom(v, Params) + case DataVar: + v, _ := syVars.ValueByIndex(1, "Data", 0) + gpu.SetValueFrom(v, Data.Values) + case IntDataVar: + v, _ := syVars.ValueByIndex(1, "IntData", 0) + gpu.SetValueFrom(v, IntData.Values) + } + } +} + +// ToGPUTensorStrides gets tensor strides and starts copying to the GPU. +func ToGPUTensorStrides() { + if !UseGPU { + return + } + sy := GPUSystem + syVars := sy.Vars() + TensorStrides.SetShapeSizes(20) + TensorStrides.SetInt1D(Data.Shape().Strides[0], 0) + TensorStrides.SetInt1D(Data.Shape().Strides[1], 1) + TensorStrides.SetInt1D(IntData.Shape().Strides[0], 10) + TensorStrides.SetInt1D(IntData.Shape().Strides[1], 11) + v, _ := syVars.ValueByIndex(0, "TensorStrides", 0) + gpu.SetValueFrom(v, TensorStrides.Values) +} + +// ReadFromGPU starts the process of copying vars to the GPU. +func ReadFromGPU(vars ...GPUVars) { + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case ParamsVar: + v, _ := syVars.ValueByIndex(0, "Params", 0) + v.GPUToRead(sy.CommandEncoder) + case DataVar: + v, _ := syVars.ValueByIndex(1, "Data", 0) + v.GPUToRead(sy.CommandEncoder) + case IntDataVar: + v, _ := syVars.ValueByIndex(1, "IntData", 0) + v.GPUToRead(sy.CommandEncoder) + } + } +} + +// SyncFromGPU synchronizes vars from the GPU to the actual variable. +func SyncFromGPU(vars ...GPUVars) { + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case ParamsVar: + v, _ := syVars.ValueByIndex(0, "Params", 0) + v.ReadSync() + gpu.ReadToBytes(v, Params) + case DataVar: + v, _ := syVars.ValueByIndex(1, "Data", 0) + v.ReadSync() + gpu.ReadToBytes(v, Data.Values) + case IntDataVar: + v, _ := syVars.ValueByIndex(1, "IntData", 0) + v.ReadSync() + gpu.ReadToBytes(v, IntData.Values) + } + } +} + +// GetParams returns a pointer to the given global variable: +// [Params] []ParamStruct at given index. +// To ensure that values are updated on the GPU, you must call [SetParams]. +// after all changes have been made. +func GetParams(idx uint32) *ParamStruct { + return &Params[idx] +} diff --git a/goal/gosl/examples/basic/main.go b/goal/gosl/examples/basic/main.go new file mode 100644 index 0000000000..0e5b93ac4d --- /dev/null +++ b/goal/gosl/examples/basic/main.go @@ -0,0 +1,106 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This example just does some basic calculations on data structures and +// reports the time difference between the CPU and GPU. +package main + +import ( + "fmt" + "math/rand" + "runtime" + + "cogentcore.org/core/base/timer" + "cogentcore.org/core/gpu" + "cogentcore.org/core/tensor" +) + +//go:generate gosl + +func init() { + // must lock main thread for gpu! + runtime.LockOSThread() +} + +func main() { + gpu.Debug = true + GPUInit() + + rand.Seed(0) + // gpu.NumThreads = 1 // to restrict to sequential for loop + n := 16_000_000 + // n := 2_000_000 + + Params = make([]ParamStruct, 1) + Params[0].Defaults() + + Data = tensor.NewFloat32() + Data.SetShapeSizes(n, 3) + nt := Data.Len() + + IntData = tensor.NewInt32() + IntData.SetShapeSizes(n, 3) + + for i := range nt { + Data.Set1D(rand.Float32(), i) + } + + sid := tensor.NewInt32() + sid.SetShapeSizes(n, 3) + + sd := tensor.NewFloat32() + sd.SetShapeSizes(n, 3) + for i := range nt { + sd.Set1D(Data.Value1D(i), i) + } + + cpuTmr := timer.Time{} + cpuTmr.Start() + + RunOneAtomic(n) + RunOneCompute(n) + + cpuTmr.Stop() + + cd := Data + cid := IntData + Data = sd + IntData = sid + + gpuFullTmr := timer.Time{} + gpuFullTmr.Start() + + UseGPU = true + ToGPUTensorStrides() + ToGPU(ParamsVar, DataVar, IntDataVar) + + gpuTmr := timer.Time{} + gpuTmr.Start() + + RunAtomic(n) + RunCompute(n) + gpuTmr.Stop() + + RunDone(DataVar, IntDataVar) + gpuFullTmr.Stop() + + mx := min(n, 5) + for i := 0; i < mx; i++ { + fmt.Printf("%d\t CPU IntData: %d\t GPU: %d\n", i, cid.Value(1, Integ), sid.Value(i, Integ)) + } + fmt.Println() + for i := 0; i < mx; i++ { + d := cd.Value(i, Exp) - sd.Value(i, Exp) + fmt.Printf("CPU:\t%d\t Raw: %6.4g\t Integ: %6.4g\t Exp: %6.4g\tGPU: %6.4g\tDiff: %g\n", i, cd.Value(i, Raw), cd.Value(i, Integ), cd.Value(i, Exp), sd.Value(i, Exp), d) + fmt.Printf("GPU:\t%d\t Raw: %6.4g\t Integ: %6.4g\t Exp: %6.4g\tCPU: %6.4g\tDiff: %g\n\n", i, sd.Value(i, Raw), sd.Value(i, Integ), sd.Value(i, Exp), cd.Value(i, Exp), d) + } + fmt.Printf("\n") + + cpu := cpuTmr.Total + gpu := gpuTmr.Total + gpuFull := gpuFullTmr.Total + fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFull, float64(cpu)/float64(gpu)) + + GPURelease() +} diff --git a/goal/gosl/examples/basic/shaders/Atomic.wgsl b/goal/gosl/examples/basic/shaders/Atomic.wgsl new file mode 100644 index 0000000000..6b30bb330c --- /dev/null +++ b/goal/gosl/examples/basic/shaders/Atomic.wgsl @@ -0,0 +1,45 @@ +// Code generated by "gosl"; DO NOT EDIT +// kernel: Atomic + +// // Params are the parameters for the computation. // +@group(0) @binding(0) +var TensorStrides: array; +@group(0) @binding(1) +var Params: array; +// // Data is the data on which the computation operates. // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. // +@group(1) @binding(0) +var Data: array; +@group(1) @binding(1) +var IntData: array>; + +alias GPUVars = i32; + +@compute @workgroup_size(64, 1, 1) +fn main(@builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3, @builtin(local_invocation_index) loci: u32) { + let idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64; + Atomic(idx); +} + +fn Index2D(s0: u32, s1: u32, i0: u32, i1: u32) -> u32 { + return s0 * i0 + s1 * i1; +} + + +//////// import: "compute.go" +const Raw: i32 = 0; +const Integ: i32 = 1; +const Exp: i32 = 2; +const NVars: i32 = 3; +struct ParamStruct { + Tau: f32, + Dt: f32, + pad: f32, + pad1: f32, +} + +//////// import: "atomic.go" +fn Atomic(i: u32) { //gosl:kernel + atomicAdd(&IntData[Index2D(TensorStrides[10], TensorStrides[11], u32(i), u32(Integ))], 1); +} + +//////// import: "math32-fastexp.go" \ No newline at end of file diff --git a/goal/gosl/examples/basic/shaders/Compute.wgsl b/goal/gosl/examples/basic/shaders/Compute.wgsl new file mode 100644 index 0000000000..66cca304f5 --- /dev/null +++ b/goal/gosl/examples/basic/shaders/Compute.wgsl @@ -0,0 +1,60 @@ +// Code generated by "gosl"; DO NOT EDIT +// kernel: Compute + +// // Params are the parameters for the computation. // +@group(0) @binding(0) +var TensorStrides: array; +@group(0) @binding(1) +var Params: array; +// // Data is the data on which the computation operates. // 2D: outer index is data, inner index is: Raw, Integ, Exp vars. // +@group(1) @binding(0) +var Data: array; +@group(1) @binding(1) +var IntData: array; + +alias GPUVars = i32; + +@compute @workgroup_size(64, 1, 1) +fn main(@builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3, @builtin(local_invocation_index) loci: u32) { + let idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64; + Compute(idx); +} + +fn Index2D(s0: u32, s1: u32, i0: u32, i1: u32) -> u32 { + return s0 * i0 + s1 * i1; +} + + +//////// import: "compute.go" +const Raw: i32 = 0; +const Integ: i32 = 1; +const Exp: i32 = 2; +const NVars: i32 = 3; +struct ParamStruct { + Tau: f32, + Dt: f32, + pad: f32, + pad1: f32, +} +fn ParamStruct_IntegFromRaw(ps: ptr, idx: i32) { + var integ = Data[Index2D(TensorStrides[0], TensorStrides[1], u32(idx), u32(Integ))]; + integ += (*ps).Dt * (Data[Index2D(TensorStrides[0], TensorStrides[1], u32(idx), u32(Raw))] - integ); + Data[Index2D(TensorStrides[0], TensorStrides[1], u32(idx), u32(Integ))] = integ; + Data[Index2D(TensorStrides[0], TensorStrides[1], u32(idx), u32(Exp))] = FastExp(-integ); +} +fn Compute(i: u32) { //gosl:kernel + var params = Params[0]; + ParamStruct_IntegFromRaw(¶ms, i32(i)); +} + +//////// import: "atomic.go" + +//////// import: "math32-fastexp.go" +fn FastExp(x: f32) -> f32 { + if (x <= -88.02969) { // this doesn't add anything and -exp is main use-case anyway + return f32(0.0); + } + var i = i32(12102203*x) + i32(127)*(i32(1)<<23); + var m = (i >> 7) & 0xFFFF; // copy mantissa + i += (((((((((((3537 * m) >> 16) + 13668) * m) >> 18) + 15817) * m) >> 14) - 80470) * m) >> 11);return bitcast(u32(i)); +} \ No newline at end of file diff --git a/gpu/gosl/examples/rand/README.md b/goal/gosl/examples/rand/README.md similarity index 100% rename from gpu/gosl/examples/rand/README.md rename to goal/gosl/examples/rand/README.md diff --git a/goal/gosl/examples/rand/gosl.go b/goal/gosl/examples/rand/gosl.go new file mode 100644 index 0000000000..d2a989c0d4 --- /dev/null +++ b/goal/gosl/examples/rand/gosl.go @@ -0,0 +1,202 @@ +// Code generated by "gosl"; DO NOT EDIT + +package main + +import ( + "embed" + "unsafe" + "cogentcore.org/core/gpu" + "cogentcore.org/core/tensor" +) + +//go:embed shaders/*.wgsl +var shaders embed.FS + +// ComputeGPU is the compute gpu device +var ComputeGPU *gpu.GPU + +// UseGPU indicates whether to use GPU vs. CPU. +var UseGPU bool + +// GPUSystem is a GPU compute System with kernels operating on the +// same set of data variables. +var GPUSystem *gpu.ComputeSystem + +// GPUVars is an enum for GPU variables, for specifying what to sync. +type GPUVars int32 //enums:enum + +const ( + SeedVar GPUVars = 0 + DataVar GPUVars = 1 +) + +// Dummy tensor stride variable to avoid import error +var __TensorStrides tensor.Uint32 + +// GPUInit initializes the GPU compute system, +// configuring system(s), variables and kernels. +// It is safe to call multiple times: detects if already run. +func GPUInit() { + if ComputeGPU != nil { + return + } + gp := gpu.NewComputeGPU() + ComputeGPU = gp + { + sy := gpu.NewComputeSystem(gp, "Default") + GPUSystem = sy + gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy) + vars := sy.Vars() + { + sgp := vars.AddGroup(gpu.Storage, "Group_0") + var vr *gpu.Var + _ = vr + vr = sgp.AddStruct("Seed", int(unsafe.Sizeof(Seeds{})), 1, gpu.ComputeShader) + vr.ReadOnly = true + vr = sgp.AddStruct("Data", int(unsafe.Sizeof(Rnds{})), 1, gpu.ComputeShader) + sgp.SetNValues(1) + } + sy.Config() + } +} + +// GPURelease releases the GPU compute system resources. +// Call this at program exit. +func GPURelease() { + if GPUSystem != nil { + GPUSystem.Release() + GPUSystem = nil + } + + if ComputeGPU != nil { + ComputeGPU.Release() + ComputeGPU = nil + } +} + +// RunCompute runs the Compute kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// Can call multiple Run* kernels in a row, which are then all launched +// in the same command submission on the GPU, which is by far the most efficient. +// MUST call RunDone (with optional vars to sync) after all Run calls. +// Alternatively, a single-shot RunOneCompute call does Run and Done for a +// single run-and-sync case. +func RunCompute(n int) { + if UseGPU { + RunComputeGPU(n) + } else { + RunComputeCPU(n) + } +} + +// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info. +func RunComputeGPU(n int) { + sy := GPUSystem + pl := sy.ComputePipelines["Compute"] + ce, _ := sy.BeginComputePass() + pl.Dispatch1D(ce, n, 64) +} + +// RunComputeCPU runs the Compute kernel on the CPU. +func RunComputeCPU(n int) { + gpu.VectorizeFunc(0, n, Compute) +} + +// RunOneCompute runs the Compute kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// This version then calls RunDone with the given variables to sync +// after the Run, for a single-shot Run-and-Done call. If multiple kernels +// can be run in sequence, it is much more efficient to do multiple Run* +// calls followed by a RunDone call. +func RunOneCompute(n int, syncVars ...GPUVars) { + if UseGPU { + RunComputeGPU(n) + RunDone(syncVars...) + } else { + RunComputeCPU(n) + } +} +// RunDone must be called after Run* calls to start compute kernels. +// This actually submits the kernel jobs to the GPU, and adds commands +// to synchronize the given variables back from the GPU to the CPU. +// After this function completes, the GPU results will be available in +// the specified variables. +func RunDone(syncVars ...GPUVars) { + if !UseGPU { + return + } + sy := GPUSystem + sy.ComputeEncoder.End() + ReadFromGPU(syncVars...) + sy.EndComputePass() + SyncFromGPU(syncVars...) +} + +// ToGPU copies given variables to the GPU for the system. +func ToGPU(vars ...GPUVars) { + if !UseGPU { + return + } + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case SeedVar: + v, _ := syVars.ValueByIndex(0, "Seed", 0) + gpu.SetValueFrom(v, Seed) + case DataVar: + v, _ := syVars.ValueByIndex(0, "Data", 0) + gpu.SetValueFrom(v, Data) + } + } +} + +// ReadFromGPU starts the process of copying vars to the GPU. +func ReadFromGPU(vars ...GPUVars) { + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case SeedVar: + v, _ := syVars.ValueByIndex(0, "Seed", 0) + v.GPUToRead(sy.CommandEncoder) + case DataVar: + v, _ := syVars.ValueByIndex(0, "Data", 0) + v.GPUToRead(sy.CommandEncoder) + } + } +} + +// SyncFromGPU synchronizes vars from the GPU to the actual variable. +func SyncFromGPU(vars ...GPUVars) { + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case SeedVar: + v, _ := syVars.ValueByIndex(0, "Seed", 0) + v.ReadSync() + gpu.ReadToBytes(v, Seed) + case DataVar: + v, _ := syVars.ValueByIndex(0, "Data", 0) + v.ReadSync() + gpu.ReadToBytes(v, Data) + } + } +} + +// GetSeed returns a pointer to the given global variable: +// [Seed] []Seeds at given index. +// To ensure that values are updated on the GPU, you must call [SetSeed]. +// after all changes have been made. +func GetSeed(idx uint32) *Seeds { + return &Seed[idx] +} + +// GetData returns a pointer to the given global variable: +// [Data] []Rnds at given index. +// To ensure that values are updated on the GPU, you must call [SetData]. +// after all changes have been made. +func GetData(idx uint32) *Rnds { + return &Data[idx] +} diff --git a/gpu/gosl/examples/rand/main.go b/goal/gosl/examples/rand/main.go similarity index 55% rename from gpu/gosl/examples/rand/main.go rename to goal/gosl/examples/rand/main.go index f053e96766..e43e3bf4db 100644 --- a/gpu/gosl/examples/rand/main.go +++ b/goal/gosl/examples/rand/main.go @@ -5,23 +5,15 @@ package main import ( - "embed" "fmt" "runtime" - "unsafe" "log/slog" "cogentcore.org/core/base/timer" - "cogentcore.org/core/gpu" ) -// note: standard one to use is plain "gosl" which should be go install'd - -//go:generate ../../gosl rand.go rand.wgsl - -//go:embed shaders/*.wgsl -var shaders embed.FS +//go:generate gosl func init() { // must lock main thread for gpu! @@ -29,63 +21,41 @@ func init() { } func main() { - gpu.Debug = true - gp := gpu.NewComputeGPU() - fmt.Printf("Running on GPU: %s\n", gp.DeviceName) + GPUInit() // n := 10 - n := 4_000_000 // 5_000_000 is too much -- 256_000_000 -- up against buf size limit - threads := 64 + // n := 16_000_000 // max for macbook M* + n := 200_000 + + UseGPU = false + + Seed = make([]Seeds, 1) dataC := make([]Rnds, n) dataG := make([]Rnds, n) + Data = dataC + cpuTmr := timer.Time{} cpuTmr.Start() - - seed := uint64(0) - for i := range dataC { - d := &dataC[i] - d.RndGen(seed, uint32(i)) - } + RunOneCompute(n) cpuTmr.Stop() - sy := gpu.NewComputeSystem(gp, "slrand") - pl := gpu.NewComputePipelineShaderFS(shaders, "shaders/rand.wgsl", sy) - vars := sy.Vars() - sgp := vars.AddGroup(gpu.Storage) - - ctrv := sgp.AddStruct("Counter", int(unsafe.Sizeof(seed)), 1, gpu.ComputeShader) - datav := sgp.AddStruct("Data", int(unsafe.Sizeof(Rnds{})), n, gpu.ComputeShader) - - sgp.SetNValues(1) - sy.Config() - - cvl := ctrv.Values.Values[0] - dvl := datav.Values.Values[0] + UseGPU = true + Data = dataG gpuFullTmr := timer.Time{} gpuFullTmr.Start() - gpu.SetValueFrom(cvl, []uint64{seed}) - gpu.SetValueFrom(dvl, dataG) - - sgp.CreateReadBuffers() + ToGPU(SeedVar, DataVar) gpuTmr := timer.Time{} gpuTmr.Start() - ce, _ := sy.BeginComputePass() - pl.Dispatch1D(ce, n, threads) - ce.End() - dvl.GPUToRead(sy.CommandEncoder) - sy.EndComputePass(ce) - + RunCompute(n) gpuTmr.Stop() - dvl.ReadSync() - gpu.ReadToBytes(dvl, dataG) - + RunDone(DataVar) gpuFullTmr.Stop() anyDiffEx := false @@ -128,6 +98,5 @@ func main() { gpu := gpuTmr.Total fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFullTmr.Total, float64(cpu)/float64(gpu)) - sy.Release() - gp.Release() + GPURelease() } diff --git a/gpu/gosl/examples/rand/rand.go b/goal/gosl/examples/rand/rand.go similarity index 86% rename from gpu/gosl/examples/rand/rand.go rename to goal/gosl/examples/rand/rand.go index 2f28690f8c..cc7bf7dac2 100644 --- a/gpu/gosl/examples/rand/rand.go +++ b/goal/gosl/examples/rand/rand.go @@ -3,16 +3,26 @@ package main import ( "fmt" - "cogentcore.org/core/gpu/gosl/slrand" - "cogentcore.org/core/gpu/gosl/sltype" + "cogentcore.org/core/goal/gosl/slrand" + "cogentcore.org/core/goal/gosl/sltype" "cogentcore.org/core/math32" ) -//gosl:wgsl rand -// #include "slrand.wgsl" -//gosl:end rand +//gosl:start -//gosl:start rand +//gosl:vars +var ( + //gosl:read-only + Seed []Seeds + + // Data + Data []Rnds +) + +type Seeds struct { + Seed uint64 + pad, pad1 int32 +} type Rnds struct { Uints sltype.Uint32Vec2 @@ -39,7 +49,11 @@ func (r *Rnds) RndGen(counter uint64, idx uint32) { r.Gauss = slrand.Float32NormVec2(counter, uint32(3), idx) } -//gosl:end rand +func Compute(i uint32) { //gosl:kernel + Data[i].RndGen(Seed[0].Seed, i) +} + +//gosl:end const Tol = 1.0e-4 // fails at lower tol eventually -- -6 works for many diff --git a/goal/gosl/examples/rand/shaders/Compute.wgsl b/goal/gosl/examples/rand/shaders/Compute.wgsl new file mode 100644 index 0000000000..f29eaec34d --- /dev/null +++ b/goal/gosl/examples/rand/shaders/Compute.wgsl @@ -0,0 +1,215 @@ +// Code generated by "gosl"; DO NOT EDIT +// kernel: Compute + +@group(0) @binding(0) +var Seed: array; +@group(0) @binding(1) +var Data: array; + +alias GPUVars = i32; + +@compute @workgroup_size(64, 1, 1) +fn main(@builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3, @builtin(local_invocation_index) loci: u32) { + let idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64; + Compute(idx); +} + + +//////// import: "rand.go" +struct Seeds { + Seed: su64, + pad: i32, + pad1: i32, +} +struct Rnds { + Uints: vec2, + pad: i32, + pad1: i32, + Floats: vec2, + pad2: i32, + pad3: i32, + Floats11: vec2, + pad4: i32, + pad5: i32, + Gauss: vec2, + pad6: i32, + pad7: i32, +} +fn Rnds_RndGen(r: ptr, counter: su64, idx: u32) { + (*r).Uints = RandUint32Vec2(counter, u32(0), idx); + (*r).Floats = RandFloat32Vec2(counter, u32(1), idx); + (*r).Floats11 = RandFloat32Range11Vec2(counter, u32(2), idx); + (*r).Gauss = RandFloat32NormVec2(counter, u32(3), idx); +} +fn Compute(i: u32) { //gosl:kernel + var data=Data[i]; + Rnds_RndGen(&data, Seed[0].Seed, i); + Data[i]=data; +} + +//////// import: "slrand.wgsl" +fn Philox2x32round(counter: su64, key: u32) -> su64 { + let mul = Uint32Mul64(u32(0xD256D193), counter.x); + var ctr: su64; + ctr.x = mul.y ^ key ^ counter.y; + ctr.y = mul.x; + return ctr; +} +fn Philox2x32bumpkey(key: u32) -> u32 { + return key + u32(0x9E3779B9); +} +fn Philox2x32(counter: su64, key: u32) -> vec2 { + var ctr = Philox2x32round(counter, key); // 1 + var ky = Philox2x32bumpkey(key); + ctr = Philox2x32round(ctr, ky); // 2 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 3 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 4 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 5 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 6 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 7 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 8 + ky = Philox2x32bumpkey(ky); + ctr = Philox2x32round(ctr, ky); // 9 + ky = Philox2x32bumpkey(ky); + return Philox2x32round(ctr, ky); // 10 +} +fn RandUint32Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { + return Philox2x32(Uint64Add32(counter, funcIndex), key); +} +fn RandUint32(counter: su64, funcIndex: u32, key: u32) -> u32 { + return Philox2x32(Uint64Add32(counter, funcIndex), key).x; +} +fn RandFloat32Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { + return Uint32ToFloat32Vec2(RandUint32Vec2(counter, funcIndex, key)); +} +fn RandFloat32(counter: su64, funcIndex: u32, key: u32) -> f32 { + return Uint32ToFloat32(RandUint32(counter, funcIndex, key)); +} +fn RandFloat32Range11Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { + return Uint32ToFloat32Vec2(RandUint32Vec2(counter, funcIndex, key)); +} +fn RandFloat32Range11(counter: su64, funcIndex: u32, key: u32) -> f32 { + return Uint32ToFloat32Range11(RandUint32(counter, funcIndex, key)); +} +fn RandBoolP(counter: su64, funcIndex: u32, key: u32, p: f32) -> bool { + return (RandFloat32(counter, funcIndex, key) < p); +} +fn sincospi(x: f32) -> vec2 { + let PIf = 3.1415926535897932; + var r: vec2; + r.x = cos(PIf*x); + r.y = sin(PIf*x); + return r; +} +fn RandFloat32NormVec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { + let ur = RandUint32Vec2(counter, funcIndex, key); + var f = sincospi(Uint32ToFloat32Range11(ur.x)); + let r = sqrt(-2.0 * log(Uint32ToFloat32(ur.y))); // guaranteed to avoid 0. + return f * r; +} +fn RandFloat32Norm(counter: su64, funcIndex: u32, key: u32) -> f32 { + return RandFloat32Vec2(counter, funcIndex, key).x; +} +fn RandUint32N(counter: su64, funcIndex: u32, key: u32, n: u32) -> u32 { + let v = RandFloat32(counter, funcIndex, key); + return u32(v * f32(n)); +} +struct RandCounter { + Counter: su64, + HiSeed: u32, + pad: u32, +} +fn RandCounter_Reset(ct: ptr) { + (*ct).Counter.x = u32(0); + (*ct).Counter.y = (*ct).HiSeed; +} +fn RandCounter_Seed(ct: ptr, seed: u32) { + (*ct).HiSeed = seed; + RandCounter_Reset(ct); +} +fn RandCounter_Add(ct: ptr, inc: u32) { + (*ct).Counter = Uint64Add32((*ct).Counter, inc); +} + +//////// import: "sltype.wgsl" +alias su64 = vec2; +fn Uint32Mul64(a: u32, b: u32) -> su64 { + let LOMASK = (((u32(1))<<16)-1); + var r: su64; + r.x = a * b; /* full low multiply */ + let ahi = a >> 16; + let alo = a & LOMASK; + let bhi = b >> 16; + let blo = b & LOMASK; + let ahbl = ahi * blo; + let albh = alo * bhi; + let ahbl_albh = ((ahbl&LOMASK) + (albh&LOMASK)); + var hit = ahi*bhi + (ahbl>>16) + (albh>>16); + hit += ahbl_albh >> 16; /* carry from the sum of lo(ahbl) + lo(albh) ) */ + /* carry from the sum with alo*blo */ + if ((r.x >> u32(16)) < (ahbl_albh&LOMASK)) { + hit += u32(1); + } + r.y = hit; + return r; +} +/* +fn Uint32Mul64(a: u32, b: u32) -> su64 { + return su64(a) * su64(b); +} +*/ +fn Uint64Add32(a: su64, b: u32) -> su64 { + if (b == 0) { + return a; + } + var s = a; + if (s.x > u32(0xffffffff) - b) { + s.y++; + s.x = (b - 1) - (u32(0xffffffff) - s.x); + } else { + s.x += b; + } + return s; +} +fn Uint64Incr(a: su64) -> su64 { + var s = a; + if(s.x == 0xffffffff) { + s.y++; + s.x = u32(0); + } else { + s.x++; + } + return s; +} +fn Uint32ToFloat32(val: u32) -> f32 { + let factor = f32(1.0) / (f32(u32(0xffffffff)) + f32(1.0)); + let halffactor = f32(0.5) * factor; + var f = f32(val) * factor + halffactor; + if (f == 1.0) { // exclude 1 + return bitcast(0x3F7FFFFF); + } + return f; +} +fn Uint32ToFloat32Vec2(val: vec2) -> vec2 { + var r: vec2; + r.x = Uint32ToFloat32(val.x); + r.y = Uint32ToFloat32(val.y); + return r; +} +fn Uint32ToFloat32Range11(val: u32) -> f32 { + let factor = f32(1.0) / (f32(i32(0x7fffffff)) + f32(1.0)); + let halffactor = f32(0.5) * factor; + return (f32(val) * factor + halffactor); +} +fn Uint32ToFloat32Range11Vec2(val: vec2) -> vec2 { + var r: vec2; + r.x = Uint32ToFloat32Range11(val.x); + r.y = Uint32ToFloat32Range11(val.y); + return r; +} \ No newline at end of file diff --git a/goal/gosl/gosl.go b/goal/gosl/gosl.go new file mode 100644 index 0000000000..c3343364df --- /dev/null +++ b/goal/gosl/gosl.go @@ -0,0 +1,16 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "cogentcore.org/core/cli" + "cogentcore.org/core/goal/gosl/gotosl" +) + +func main() { //types:skip + opts := cli.DefaultOptions("gosl", "Go as a shader language converts Go code to WGSL WebGPU shader code, which can be run on the GPU through WebGPU.") + cfg := &gotosl.Config{} + cli.Run(opts, cfg, gotosl.Run) +} diff --git a/goal/gosl/gotosl/README.md b/goal/gosl/gotosl/README.md new file mode 100644 index 0000000000..af647c0218 --- /dev/null +++ b/goal/gosl/gotosl/README.md @@ -0,0 +1,23 @@ +# Implementational details of Go to SL translation process + +Overall, there are three main steps: + +1. Translate all the `.go` files in the current package, and all the files they `//gosl:import`, into corresponding `.wgsl` files, and put those in `shaders/imports`. All these files will be pasted into the generated primary kernel files, that go in `shaders`, and are saved to disk for reference. All the key kernel, system, variable info is extracted from the package .go file directives during this phase. + +2. Generate the `main` kernel `.wgsl` files, for each kernel function, which: a) declare the global buffer variables; b) include everything from imports; c) define the `main` function entry point. Each resulting file is pre-processed by `naga` to ensure it compiles, and to remove dead code not needed for this particular shader. + +3. Generate the `gosl.go` file in the package directory, which contains generated Go code for configuring the gpu compute systems according to the vars. + +## Go to SL translation + +1. `files.go`: Get a list of all the .go files in the current directory that have a `//gosl:` tag (`ProjectFiles`) and all the `//gosl:import` package files that those files import, recursively. + +2. `extract.go`: Extract the `//gosl:start` -> `end` regions from all the package and imported filees. + +3. Save all these files as new `.go` files in `shaders/imports`. We manually append a simple go "main" package header with basic gosl imports for each file, which allows the go compiler to process them properly. This is then removed in the next step. + +4. `translate.go:` Run `TranslateDir` on shaders/imports using the "golang.org/x/tools/go/packages" `Load` function, which gets `ast` and type information for all that code. Run the resulting `ast.File` for each file through the modified version of the Go stdlib `src/go/printer` code (`printer.go`, `nodes.go`, `gobuild.go`, `comment.go`), which prints out WGSL instead of Go code from the underlying `ast` representation of the Go files. This is what does the actual translation. + +5. `sledits.go:` Do various forms of post-processing text replacement cleanup on the generated WGSL files, in `SLEdits` function. + + diff --git a/goal/gosl/gotosl/callgraph.go b/goal/gosl/gotosl/callgraph.go new file mode 100644 index 0000000000..f686ac02e6 --- /dev/null +++ b/goal/gosl/gotosl/callgraph.go @@ -0,0 +1,91 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "fmt" + "sort" + + "golang.org/x/exp/maps" +) + +// Function represents the call graph of functions +type Function struct { + Name string + Funcs map[string]*Function + Atomics map[string]*Var // variables that have atomic operations in this function +} + +func NewFunction(name string) *Function { + return &Function{Name: name, Funcs: make(map[string]*Function)} +} + +// get or add a function of given name +func (st *State) RecycleFunc(name string) *Function { + fn, ok := st.FuncGraph[name] + if !ok { + fn = NewFunction(name) + st.FuncGraph[name] = fn + } + return fn +} + +func getAllFuncs(f *Function, all map[string]*Function) { + for fnm, fn := range f.Funcs { + _, ok := all[fnm] + if ok { + continue + } + all[fnm] = fn + getAllFuncs(fn, all) + } +} + +// AllFuncs returns aggregated list of all functions called be given function +func (st *State) AllFuncs(name string) map[string]*Function { + fn, ok := st.FuncGraph[name] + if !ok { + fmt.Printf("gosl: ERROR kernel function named: %q not found\n", name) + return nil + } + all := make(map[string]*Function) + all[name] = fn + getAllFuncs(fn, all) + // cfs := maps.Keys(all) + // sort.Strings(cfs) + // for _, cfnm := range cfs { + // fmt.Println("\t" + cfnm) + // } + return all +} + +// AtomicVars returns all the variables marked as atomic +// within the list of functions. +func (st *State) AtomicVars(funcs map[string]*Function) map[string]*Var { + avars := make(map[string]*Var) + for _, fn := range funcs { + if fn.Atomics == nil { + continue + } + for vn, v := range fn.Atomics { + avars[vn] = v + } + } + return avars +} + +func (st *State) PrintFuncGraph() { + funs := maps.Keys(st.FuncGraph) + sort.Strings(funs) + for _, fname := range funs { + fmt.Println(fname) + fn := st.FuncGraph[fname] + cfs := maps.Keys(fn.Funcs) + sort.Strings(cfs) + for _, cfnm := range cfs { + fmt.Println("\t" + cfnm) + } + } +} diff --git a/gpu/gosl/slprint/comment.go b/goal/gosl/gotosl/comment.go similarity index 93% rename from gpu/gosl/slprint/comment.go rename to goal/gosl/gotosl/comment.go index f97a9a2084..ede45eeb98 100644 --- a/gpu/gosl/slprint/comment.go +++ b/goal/gosl/gotosl/comment.go @@ -1,8 +1,15 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is largely copied from the Go source, +// src/go/printer/comment.go: + // Copyright 2022 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package slprint +package gotosl import ( "go/ast" diff --git a/goal/gosl/gotosl/config.go b/goal/gosl/gotosl/config.go new file mode 100644 index 0000000000..746edd7686 --- /dev/null +++ b/goal/gosl/gotosl/config.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +//go:generate core generate -add-types -add-funcs + +// Keep these in sync with go/format/format.go. +const ( + tabWidth = 8 + printerMode = UseSpaces | TabIndent | printerNormalizeNumbers + + // printerNormalizeNumbers means to canonicalize number literal prefixes + // and exponents while printing. See https://golang.org/doc/go1.13#gosl. + // + // This value is defined in go/printer specifically for go/format and cmd/gosl. + printerNormalizeNumbers = 1 << 30 +) + +// Config has the configuration info for the gosl system. +type Config struct { + + // Output is the output directory for shader code, + // relative to where gosl is invoked; must not be an empty string. + Output string `flag:"out" default:"shaders"` + + // Exclude is a comma-separated list of names of functions to exclude from exporting to WGSL. + Exclude string `default:"Update,Defaults"` + + // Keep keeps temporary converted versions of the source files, for debugging. + Keep bool + + // Debug enables debugging messages while running. + Debug bool +} + +//cli:cmd -root +func Run(cfg *Config) error { //types:add + st := &State{} + st.Init(cfg) + return st.Run() +} diff --git a/goal/gosl/gotosl/extract.go b/goal/gosl/gotosl/extract.go new file mode 100644 index 0000000000..3df7dd75e4 --- /dev/null +++ b/goal/gosl/gotosl/extract.go @@ -0,0 +1,228 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "bytes" + "fmt" + "path/filepath" + "strings" + + "slices" +) + +// ExtractFiles processes all the package files and saves the corresponding +// .go files with simple go header. +func (st *State) ExtractFiles() { + st.ImportPackages = make(map[string]bool) + for impath := range st.GoImports { + _, pkg := filepath.Split(impath) + if pkg != "math32" { + st.ImportPackages[pkg] = true + } + } + + for fn, fl := range st.GoFiles { + hasVars := false + fl.Lines, hasVars = st.ExtractGosl(fl.Lines) + if hasVars { + st.GoVarsFiles[fn] = fl + delete(st.GoFiles, fn) + } + WriteFileLines(filepath.Join(st.ImportsDir, fn), st.AppendGoHeader(fl.Lines)) + } +} + +// ExtractImports processes all the imported files and saves the corresponding +// .go files with simple go header. +func (st *State) ExtractImports() { + if len(st.GoImports) == 0 { + return + } + for impath, im := range st.GoImports { + _, pkg := filepath.Split(impath) + for fn, fl := range im { + fl.Lines, _ = st.ExtractGosl(fl.Lines) + WriteFileLines(filepath.Join(st.ImportsDir, pkg+"-"+fn), st.AppendGoHeader(fl.Lines)) + } + } +} + +// ExtractGosl gosl comment-directive tagged regions from given file. +func (st *State) ExtractGosl(lines [][]byte) (outLines [][]byte, hasVars bool) { + key := []byte("//gosl:") + start := []byte("start") + wgsl := []byte("wgsl") + nowgsl := []byte("nowgsl") + end := []byte("end") + vars := []byte("vars") + imp := []byte("import") + kernel := []byte("//gosl:kernel") + fnc := []byte("func") + + inReg := false + inHlsl := false + inNoHlsl := false + for li, ln := range lines { + tln := bytes.TrimSpace(ln) + isKey := bytes.HasPrefix(tln, key) + var keyStr []byte + if isKey { + keyStr = tln[len(key):] + // fmt.Printf("key: %s\n", string(keyStr)) + } + switch { + case inReg && isKey && bytes.HasPrefix(keyStr, end): + if inHlsl || inNoHlsl { + outLines = append(outLines, ln) + } + inReg = false + inHlsl = false + inNoHlsl = false + case inReg && isKey && bytes.HasPrefix(keyStr, vars): + hasVars = true + outLines = append(outLines, ln) + case inReg: + for pkg := range st.ImportPackages { // remove package prefixes + if !bytes.Contains(ln, imp) { + ln = bytes.ReplaceAll(ln, []byte(pkg+"."), []byte{}) + } + } + if bytes.HasPrefix(ln, fnc) && bytes.Contains(ln, kernel) { + sysnm := strings.TrimSpace(string(ln[bytes.LastIndex(ln, kernel)+len(kernel):])) + sy := st.System(sysnm) + fcall := string(ln[5:]) + lp := strings.Index(fcall, "(") + rp := strings.LastIndex(fcall, ")") + args := fcall[lp+1 : rp] + fnm := fcall[:lp] + funcode := "" + for ki := li + 1; ki < len(lines); ki++ { + kl := lines[ki] + if len(kl) > 0 && kl[0] == '}' { + break + } + funcode += string(kl) + "\n" + } + kn := &Kernel{Name: fnm, Args: args, FuncCode: funcode} + sy.Kernels[fnm] = kn + if st.Config.Debug { + fmt.Println("\tAdded kernel:", fnm, "args:", args, "system:", sy.Name) + } + } + outLines = append(outLines, ln) + case isKey && bytes.HasPrefix(keyStr, start): + inReg = true + case isKey && bytes.HasPrefix(keyStr, nowgsl): + inReg = true + inNoHlsl = true + outLines = append(outLines, ln) // key to include self here + case isKey && bytes.HasPrefix(keyStr, wgsl): + inReg = true + inHlsl = true + outLines = append(outLines, ln) + } + } + return +} + +// AppendGoHeader appends Go header +func (st *State) AppendGoHeader(lines [][]byte) [][]byte { + olns := make([][]byte, 0, len(lines)+10) + olns = append(olns, []byte("package imports")) + olns = append(olns, []byte(`import ( + "math" + "cogentcore.org/core/goal/gosl/slbool" + "cogentcore.org/core/goal/gosl/slrand" + "cogentcore.org/core/goal/gosl/sltype" + "cogentcore.org/core/tensor" +`)) + for impath := range st.GoImports { + if strings.Contains(impath, "core/goal/gosl") { + continue + } + olns = append(olns, []byte("\t\""+impath+"\"")) + } + olns = append(olns, []byte(")")) + olns = append(olns, lines...) + SlBoolReplace(olns) + return olns +} + +// ExtractWGSL extracts the WGSL code embedded within .Go files, +// which is commented out in the Go code -- remove comments. +func (st *State) ExtractWGSL(lines [][]byte) [][]byte { + key := []byte("//gosl:") + wgsl := []byte("wgsl") + nowgsl := []byte("nowgsl") + end := []byte("end") + stComment := []byte("/*") + edComment := []byte("*/") + comment := []byte("// ") + pack := []byte("package") + imp := []byte("import") + lparen := []byte("(") + rparen := []byte(")") + + mx := min(10, len(lines)) + stln := 0 + gotImp := false + for li := 0; li < mx; li++ { + ln := lines[li] + switch { + case bytes.HasPrefix(ln, pack): + stln = li + 1 + case bytes.HasPrefix(ln, imp): + if bytes.HasSuffix(ln, lparen) { + gotImp = true + } else { + stln = li + 1 + } + case gotImp && bytes.HasPrefix(ln, rparen): + stln = li + 1 + } + } + + lines = lines[stln:] // get rid of package, import + + inHlsl := false + inNoHlsl := false + noHlslStart := 0 + for li := 0; li < len(lines); li++ { + ln := lines[li] + isKey := bytes.HasPrefix(ln, key) + var keyStr []byte + if isKey { + keyStr = ln[len(key):] + // fmt.Printf("key: %s\n", string(keyStr)) + } + switch { + case inNoHlsl && isKey && bytes.HasPrefix(keyStr, end): + lines = slices.Delete(lines, noHlslStart, li+1) + li -= ((li + 1) - noHlslStart) + inNoHlsl = false + case inHlsl && isKey && bytes.HasPrefix(keyStr, end): + lines = slices.Delete(lines, li, li+1) + li-- + inHlsl = false + case inHlsl: + switch { + case bytes.HasPrefix(ln, stComment) || bytes.HasPrefix(ln, edComment): + lines = slices.Delete(lines, li, li+1) + li-- + case bytes.HasPrefix(ln, comment): + lines[li] = ln[3:] + } + case isKey && bytes.HasPrefix(keyStr, wgsl): + inHlsl = true + lines = slices.Delete(lines, li, li+1) + li-- + case isKey && bytes.HasPrefix(keyStr, nowgsl): + inNoHlsl = true + noHlslStart = li + } + } + return lines +} diff --git a/goal/gosl/gotosl/files.go b/goal/gosl/gotosl/files.go new file mode 100644 index 0000000000..a086cd00d6 --- /dev/null +++ b/goal/gosl/gotosl/files.go @@ -0,0 +1,214 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "bytes" + "fmt" + "io/fs" + "log" + "os" + "path/filepath" + "strings" + + "cogentcore.org/core/base/fsx" + "golang.org/x/tools/go/packages" +) + +// wgslFile returns the file with a ".wgsl" extension +func wgslFile(fn string) string { + f, _ := fsx.ExtSplit(fn) + return f + ".wgsl" +} + +// bareFile returns the file with no extention +func bareFile(fn string) string { + f, _ := fsx.ExtSplit(fn) + return f +} + +func ReadFileLines(fn string) ([][]byte, error) { + nl := []byte("\n") + buf, err := os.ReadFile(fn) + if err != nil { + fmt.Println(err) + return nil, err + } + lines := bytes.Split(buf, nl) + return lines, nil +} + +func WriteFileLines(fn string, lines [][]byte) error { + res := bytes.Join(lines, []byte("\n")) + return os.WriteFile(fn, res, 0644) +} + +// HasGoslTag returns true if given file has a //gosl: tag +func (st *State) HasGoslTag(lines [][]byte) bool { + key := []byte("//gosl:") + pkg := []byte("package ") + for _, ln := range lines { + tln := bytes.TrimSpace(ln) + if st.Package == "" { + if bytes.HasPrefix(tln, pkg) { + st.Package = string(bytes.TrimPrefix(tln, pkg)) + } + } + if bytes.HasPrefix(tln, key) { + return true + } + } + return false +} + +func IsGoFile(f fs.DirEntry) bool { + name := f.Name() + return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") && !f.IsDir() +} + +func IsWGSLFile(f fs.DirEntry) bool { + name := f.Name() + return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".wgsl") && !f.IsDir() +} + +// ProjectFiles gets the files in the current directory. +func (st *State) ProjectFiles() { + fls := fsx.Filenames(".", ".go") + st.GoFiles = make(map[string]*File) + st.GoVarsFiles = make(map[string]*File) + for _, fn := range fls { + fl := &File{Name: fn} + var err error + fl.Lines, err = ReadFileLines(fn) + if err != nil { + continue + } + if !st.HasGoslTag(fl.Lines) { + continue + } + st.GoFiles[fn] = fl + st.ImportFiles(fl.Lines) + } +} + +// ImportFiles checks the given content for //gosl:import tags +// and imports the package if so. +func (st *State) ImportFiles(lines [][]byte) { + key := []byte("//gosl:import ") + for _, ln := range lines { + tln := bytes.TrimSpace(ln) + if !bytes.HasPrefix(tln, key) { + continue + } + impath := strings.TrimSpace(string(tln[len(key):])) + if impath[0] == '"' { + impath = impath[1:] + } + if impath[len(impath)-1] == '"' { + impath = impath[:len(impath)-1] + } + _, ok := st.GoImports[impath] + if ok { + continue + } + var pkgs []*packages.Package + var err error + pkgs, err = packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, impath) + if err != nil { + fmt.Println(err) + continue + } + pfls := make(map[string]*File) + st.GoImports[impath] = pfls + pkg := pkgs[0] + gofls := pkg.GoFiles + if len(gofls) == 0 { + fmt.Printf("WARNING: no go files found in path: %s\n", impath) + } + for _, gf := range gofls { + lns, err := ReadFileLines(gf) + if err != nil { + continue + } + if !st.HasGoslTag(lns) { + continue + } + _, fo := filepath.Split(gf) + pfls[fo] = &File{Name: fo, Lines: lns} + st.ImportFiles(lns) + // fmt.Printf("added file: %s from package: %s\n", gf, impath) + } + st.GoImports[impath] = pfls + } +} + +// RemoveGenFiles removes .go, .wgsl, .spv files in shader generated dir +func RemoveGenFiles(dir string) { + err := filepath.WalkDir(dir, func(path string, f fs.DirEntry, err error) error { + if err != nil { + return err + } + if IsGoFile(f) || IsWGSLFile(f) { + os.Remove(path) + } + return nil + }) + if err != nil { + log.Println(err) + } +} + +// CopyPackageFile copies given file name from given package path +// into the current imports directory. +// e.g., "slrand.wgsl", "cogentcore.org/core/goal/gosl/slrand" +func (st *State) CopyPackageFile(fnm, packagePath string) error { + for _, f := range st.SLImportFiles { + if f.Name == fnm { // don't re-import + return nil + } + } + tofn := filepath.Join(st.ImportsDir, fnm) + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, packagePath) + if err != nil { + fmt.Println(err) + return err + } + if len(pkgs) != 1 { + err = fmt.Errorf("%s package not found", packagePath) + fmt.Println(err) + return err + } + pkg := pkgs[0] + var fn string + if len(pkg.GoFiles) > 0 { + fn = pkg.GoFiles[0] + } else if len(pkg.OtherFiles) > 0 { + fn = pkg.GoFiles[0] + } else { + err = fmt.Errorf("No files found in package: %s", packagePath) + fmt.Println(err) + return err + } + dir, _ := filepath.Split(fn) + fmfn := filepath.Join(dir, fnm) + lines, err := CopyFile(fmfn, tofn) + if err == nil { + lines = SlRemoveComments(lines) + st.SLImportFiles = append(st.SLImportFiles, &File{Name: fnm, Lines: lines}) + } + return nil +} + +func CopyFile(src, dst string) ([][]byte, error) { + lines, err := ReadFileLines(src) + if err != nil { + return lines, err + } + err = WriteFileLines(dst, lines) + if err != nil { + return lines, err + } + return lines, err +} diff --git a/goal/gosl/gotosl/gengpu.go b/goal/gosl/gotosl/gengpu.go new file mode 100644 index 0000000000..13a3b848d8 --- /dev/null +++ b/goal/gosl/gotosl/gengpu.go @@ -0,0 +1,420 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "fmt" + "os" + "slices" + "strings" + + "golang.org/x/exp/maps" +) + +// genSysName is the name to use for system in generating code. +// if only one system, the name is empty +func (st *State) genSysName(sy *System) string { + if len(st.Systems) == 1 { + return "" + } + return sy.Name +} + +// genSysVar is the name to use for system in generating code. +// if only one system, the name is empty +func (st *State) genSysVar(sy *System) string { + return fmt.Sprintf("GPU%sSystem", st.genSysName(sy)) +} + +// GenGPU generates and writes the Go GPU helper code +func (st *State) GenGPU() { + var b strings.Builder + + header := `// Code generated by "gosl"; DO NOT EDIT + +package %s + +import ( + "embed" + "unsafe" + "cogentcore.org/core/gpu" + "cogentcore.org/core/tensor" +) + +//go:embed %s/*.wgsl +var shaders embed.FS + +// ComputeGPU is the compute gpu device +var ComputeGPU *gpu.GPU + +// UseGPU indicates whether to use GPU vs. CPU. +var UseGPU bool + +` + + b.WriteString(fmt.Sprintf(header, st.Package, st.Config.Output)) + + sys := maps.Keys(st.Systems) + slices.Sort(sys) + + for _, synm := range sys { + sy := st.Systems[synm] + b.WriteString(fmt.Sprintf("// %s is a GPU compute System with kernels operating on the\n// same set of data variables.\n", st.genSysVar(sy))) + b.WriteString(fmt.Sprintf("var %s *gpu.ComputeSystem\n", st.genSysVar(sy))) + } + + venum := ` +// GPUVars is an enum for GPU variables, for specifying what to sync. +type GPUVars int32 //enums:enum + +const ( +` + + b.WriteString(venum) + + vidx := 0 + hasTensors := false + for _, synm := range sys { + sy := st.Systems[synm] + + if sy.NTensors > 0 { + hasTensors = true + } + for _, gp := range sy.Groups { + for _, vr := range gp.Vars { + b.WriteString(fmt.Sprintf("\t%sVar GPUVars = %d\n", vr.Name, vidx)) + vidx++ + } + } + } + b.WriteString(")\n") + + if hasTensors { + b.WriteString("\n// Tensor stride variables\n") + for _, synm := range sys { + sy := st.Systems[synm] + genSynm := st.genSysName(sy) + b.WriteString(fmt.Sprintf("var %sTensorStrides tensor.Uint32\n", genSynm)) + } + } else { + b.WriteString("\n// Dummy tensor stride variable to avoid import error\n") + b.WriteString("var __TensorStrides tensor.Uint32\n") + } + + initf := ` +// GPUInit initializes the GPU compute system, +// configuring system(s), variables and kernels. +// It is safe to call multiple times: detects if already run. +func GPUInit() { + if ComputeGPU != nil { + return + } + gp := gpu.NewComputeGPU() + ComputeGPU = gp +` + + b.WriteString(initf) + + for _, synm := range sys { + sy := st.Systems[synm] + b.WriteString(st.GenGPUSystemInit(sy)) + } + b.WriteString("}\n\n") + + release := `// GPURelease releases the GPU compute system resources. +// Call this at program exit. +func GPURelease() { +` + + b.WriteString(release) + + sysRelease := ` if %[1]s != nil { + %[1]s.Release() + %[1]s = nil + } +` + + for _, synm := range sys { + sy := st.Systems[synm] + b.WriteString(fmt.Sprintf(sysRelease, st.genSysVar(sy))) + } + + gpuRelease := ` + if ComputeGPU != nil { + ComputeGPU.Release() + ComputeGPU = nil + } +} + +` + b.WriteString(gpuRelease) + + for _, synm := range sys { + sy := st.Systems[synm] + b.WriteString(st.GenGPUSystemOps(sy)) + } + + gs := b.String() + fn := "gosl.go" + os.WriteFile(fn, []byte(gs), 0644) +} + +// GenGPUSystemInit generates GPU Init code for given system. +func (st *State) GenGPUSystemInit(sy *System) string { + var b strings.Builder + + syvar := st.genSysVar(sy) + + b.WriteString("\t{\n") + b.WriteString(fmt.Sprintf("\t\tsy := gpu.NewComputeSystem(gp, %q)\n", sy.Name)) + b.WriteString(fmt.Sprintf("\t\t%s = sy\n", syvar)) + + kns := maps.Keys(sy.Kernels) + slices.Sort(kns) + for _, knm := range kns { + kn := sy.Kernels[knm] + b.WriteString(fmt.Sprintf("\t\tgpu.NewComputePipelineShaderFS(shaders, %q, sy)\n", kn.Filename)) + } + b.WriteString("\t\tvars := sy.Vars()\n") + for gi, gp := range sy.Groups { + b.WriteString("\t\t{\n") + gtyp := "gpu.Storage" + if gp.Uniform { + gtyp = "gpu.Uniform" + } + b.WriteString(fmt.Sprintf("\t\t\tsgp := vars.AddGroup(%s, %q)\n", gtyp, gp.Name)) + b.WriteString("\t\t\tvar vr *gpu.Var\n\t\t\t_ = vr\n") + if sy.NTensors > 0 && gi == 0 { + b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(%q, gpu.%s, 1, gpu.ComputeShader)\n", "TensorStrides", "Uint32")) + b.WriteString("\t\t\tvr.ReadOnly = true\n") + } + for _, vr := range gp.Vars { + if vr.Tensor { + typ := strings.TrimPrefix(vr.Type, "tensor.") + b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(%q, gpu.%s, 1, gpu.ComputeShader)\n", vr.Name, typ)) + } else { + b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.AddStruct(%q, int(unsafe.Sizeof(%s{})), 1, gpu.ComputeShader)\n", vr.Name, vr.SLType())) + } + if vr.ReadOnly { + b.WriteString("\t\t\tvr.ReadOnly = true\n") + } + } + b.WriteString("\t\t\tsgp.SetNValues(1)\n") + b.WriteString("\t\t}\n") + } + b.WriteString("\t\tsy.Config()\n") + b.WriteString("\t}\n") + return b.String() +} + +// GenGPUSystemOps generates GPU helper functions for given system. +func (st *State) GenGPUSystemOps(sy *System) string { + var b strings.Builder + + syvar := st.genSysVar(sy) + synm := st.genSysName(sy) + + // 1 = kernel, 2 = system var, 3 = sysname (blank for 1 default) + run := `// Run%[1]s runs the %[1]s kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// Can call multiple Run* kernels in a row, which are then all launched +// in the same command submission on the GPU, which is by far the most efficient. +// MUST call RunDone (with optional vars to sync) after all Run calls. +// Alternatively, a single-shot RunOne%[1]s call does Run and Done for a +// single run-and-sync case. +func Run%[1]s(n int) { + if UseGPU { + Run%[1]sGPU(n) + } else { + Run%[1]sCPU(n) + } +} + +// Run%[1]sGPU runs the %[1]s kernel on the GPU. See [Run%[1]s] for more info. +func Run%[1]sGPU(n int) { + sy := %[2]s + pl := sy.ComputePipelines[%[1]q] + ce, _ := sy.BeginComputePass() + pl.Dispatch1D(ce, n, 64) +} + +// Run%[1]sCPU runs the %[1]s kernel on the CPU. +func Run%[1]sCPU(n int) { + gpu.VectorizeFunc(0, n, %[1]s) +} + +// RunOne%[1]s runs the %[1]s kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// This version then calls RunDone with the given variables to sync +// after the Run, for a single-shot Run-and-Done call. If multiple kernels +// can be run in sequence, it is much more efficient to do multiple Run* +// calls followed by a RunDone call. +func RunOne%[1]s(n int, syncVars ...GPUVars) { + if UseGPU { + Run%[1]sGPU(n) + RunDone%[3]s(syncVars...) + } else { + Run%[1]sCPU(n) + } +} +` + // 1 = sysname (blank for 1 default), 2 = system var + runDone := `// RunDone%[1]s must be called after Run* calls to start compute kernels. +// This actually submits the kernel jobs to the GPU, and adds commands +// to synchronize the given variables back from the GPU to the CPU. +// After this function completes, the GPU results will be available in +// the specified variables. +func RunDone%[1]s(syncVars ...GPUVars) { + if !UseGPU { + return + } + sy := %[2]s + sy.ComputeEncoder.End() + %[1]sReadFromGPU(syncVars...) + sy.EndComputePass() + %[1]sSyncFromGPU(syncVars...) +} + +// %[1]sToGPU copies given variables to the GPU for the system. +func %[1]sToGPU(vars ...GPUVars) { + if !UseGPU { + return + } + sy := %[2]s + syVars := sy.Vars() + for _, vr := range vars { + switch vr { +` + + kns := maps.Keys(sy.Kernels) + slices.Sort(kns) + for _, knm := range kns { + kn := sy.Kernels[knm] + b.WriteString(fmt.Sprintf(run, kn.Name, syvar, synm)) + } + b.WriteString(fmt.Sprintf(runDone, synm, syvar)) + + for gi, gp := range sy.Groups { + for _, vr := range gp.Vars { + b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name)) + b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name)) + vv := vr.Name + if vr.Tensor { + vv += ".Values" + } + b.WriteString(fmt.Sprintf("\t\t\tgpu.SetValueFrom(v, %s)\n", vv)) + } + } + b.WriteString("\t\t}\n\t}\n}\n") + + runSync := `// Run%[1]sGPUSync can be called to synchronize data between CPU and GPU. +// Any prior ToGPU* calls will execute to send data to the GPU, +// and any subsequent RunDone* calls will copy data back from the GPU. +func Run%[1]sGPUSync() { + if !UseGPU { + return + } + sy := %[2]s + sy.BeginComputePass() +} +` + b.WriteString(fmt.Sprintf(runSync, synm, syvar)) + + if sy.NTensors > 0 { + tensorStrides := ` +// %[1]sToGPUTensorStrides gets tensor strides and starts copying to the GPU. +func %[1]sToGPUTensorStrides() { + if !UseGPU { + return + } + sy := %[2]s + syVars := sy.Vars() +` + b.WriteString(fmt.Sprintf(tensorStrides, synm, syvar)) + + strvar := synm + "TensorStrides" + + b.WriteString(fmt.Sprintf("\t%s.SetShapeSizes(%d)\n", strvar, sy.NTensors*10)) + + for _, gp := range sy.Groups { + for _, vr := range gp.Vars { + if !vr.Tensor { + continue + } + for d := range vr.TensorDims { + b.WriteString(fmt.Sprintf("\t%sTensorStrides.SetInt1D(%s.Shape().Strides[%d], %d)\n", synm, vr.Name, d, vr.TensorIndex*10+d)) + } + } + } + b.WriteString(fmt.Sprintf("\tv, _ := syVars.ValueByIndex(0, %q, 0)\n", strvar)) + b.WriteString(fmt.Sprintf("\tgpu.SetValueFrom(v, %s.Values)\n", strvar)) + b.WriteString("}\n") + } + + fmGPU := ` +// %[1]sReadFromGPU starts the process of copying vars to the GPU. +func %[1]sReadFromGPU(vars ...GPUVars) { + sy := %[2]s + syVars := sy.Vars() + for _, vr := range vars { + switch vr { +` + + b.WriteString(fmt.Sprintf(fmGPU, synm, syvar)) + + for gi, gp := range sy.Groups { + for _, vr := range gp.Vars { + b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name)) + b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name)) + b.WriteString("\t\t\tv.GPUToRead(sy.CommandEncoder)\n") + } + } + b.WriteString("\t\t}\n\t}\n}\n") + + syncGPU := ` +// %[1]sSyncFromGPU synchronizes vars from the GPU to the actual variable. +func %[1]sSyncFromGPU(vars ...GPUVars) { + sy := %[2]s + syVars := sy.Vars() + for _, vr := range vars { + switch vr { +` + + b.WriteString(fmt.Sprintf(syncGPU, synm, syvar)) + + for gi, gp := range sy.Groups { + for _, vr := range gp.Vars { + b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name)) + b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name)) + b.WriteString(fmt.Sprintf("\t\t\tv.ReadSync()\n")) + vv := vr.Name + if vr.Tensor { + vv += ".Values" + } + b.WriteString(fmt.Sprintf("\t\t\tgpu.ReadToBytes(v, %s)\n", vv)) + } + } + b.WriteString("\t\t}\n\t}\n}\n") + + getFun := ` +// Get%[1]s returns a pointer to the given global variable: +// [%[1]s] []%[2]s at given index. +// To ensure that values are updated on the GPU, you must call [Set%[1]s]. +// after all changes have been made. +func Get%[1]s(idx uint32) *%[2]s { + return &%[1]s[idx] +} +` + for _, gp := range sy.Groups { + for _, vr := range gp.Vars { + if vr.Tensor { + continue + } + b.WriteString(fmt.Sprintf(getFun, vr.Name, vr.SLType())) + } + } + + return b.String() +} diff --git a/goal/gosl/gotosl/genkernel.go b/goal/gosl/gotosl/genkernel.go new file mode 100644 index 0000000000..1772fa9b53 --- /dev/null +++ b/goal/gosl/gotosl/genkernel.go @@ -0,0 +1,109 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "fmt" + "strings" +) + +// GenKernelHeader returns the novel generated WGSL kernel code +// for given kernel, which goes at the top of the resulting file. +func (st *State) GenKernelHeader(sy *System, kn *Kernel, avars map[string]*Var) string { + var b strings.Builder + b.WriteString("// Code generated by \"gosl\"; DO NOT EDIT\n") + b.WriteString("// kernel: " + kn.Name + "\n\n") + + for gi, gp := range sy.Groups { + if gp.Doc != "" { + b.WriteString("// " + gp.Doc + "\n") + } + str := "storage" + if gp.Uniform { + str = "uniform" + } + viOff := 0 + if gi == 0 && sy.NTensors > 0 { + access := ", read" + if gp.Uniform { + access = "" + } + viOff = 1 + b.WriteString("@group(0) @binding(0)\n") + b.WriteString(fmt.Sprintf("var<%s%s> TensorStrides: array;\n", str, access)) + } + for vi, vr := range gp.Vars { + access := ", read_write" + if vr.ReadOnly { + access = ", read" + } + if gp.Uniform { + access = "" + } + if vr.Doc != "" { + b.WriteString("// " + vr.Doc + "\n") + } + b.WriteString(fmt.Sprintf("@group(%d) @binding(%d)\n", gi, vi+viOff)) + b.WriteString(fmt.Sprintf("var<%s%s> %s: ", str, access, vr.Name)) + if _, ok := avars[vr.Name]; ok { + b.WriteString(fmt.Sprintf("array>;\n", vr.SLType())) + } else { + b.WriteString(fmt.Sprintf("array<%s>;\n", vr.SLType())) + } + } + } + + b.WriteString("\nalias GPUVars = i32;\n\n") // gets included when iteratively processing enumgen.go + + b.WriteString("@compute @workgroup_size(64, 1, 1)\n") + b.WriteString("fn main(@builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3, @builtin(local_invocation_index) loci: u32) {\n") + b.WriteString("\tlet idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64;\n") + b.WriteString(fmt.Sprintf("\t%s(idx);\n", kn.Name)) + b.WriteString("}\n") + b.WriteString(st.GenTensorFuncs(sy)) + return b.String() +} + +// GenTensorFuncs returns the generated WGSL code +// for indexing the tensors in given system. +func (st *State) GenTensorFuncs(sy *System) string { + var b strings.Builder + + done := make(map[string]bool) + + for _, gp := range sy.Groups { + for _, vr := range gp.Vars { + if !vr.Tensor { + continue + } + fn := vr.IndexFunc() + if _, ok := done[fn]; ok { + continue + } + done[fn] = true + typ := "u32" + b.WriteString("\nfn " + fn + "(") + nd := vr.TensorDims + for d := range nd { + b.WriteString(fmt.Sprintf("s%d: %s, ", d, typ)) + } + for d := range nd { + b.WriteString(fmt.Sprintf("i%d: u32", d)) + if d < nd-1 { + b.WriteString(", ") + } + } + b.WriteString(") -> u32 {\n\treturn ") + for d := range nd { + b.WriteString(fmt.Sprintf("s%d * i%d", d, d)) + if d < nd-1 { + b.WriteString(" + ") + } + } + b.WriteString(";\n}\n") + } + } + return b.String() +} diff --git a/gpu/gosl/slprint/gobuild.go b/goal/gosl/gotosl/gobuild.go similarity index 94% rename from gpu/gosl/slprint/gobuild.go rename to goal/gosl/gotosl/gobuild.go index fd0c4b4002..233196dde8 100644 --- a/gpu/gosl/slprint/gobuild.go +++ b/goal/gosl/gotosl/gobuild.go @@ -1,8 +1,15 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is largely copied from the Go source, +// src/go/printer/gobuild.go: + // Copyright 2020 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package slprint +package gotosl import ( "go/build/constraint" diff --git a/goal/gosl/gotosl/gosl_test.go b/goal/gosl/gotosl/gosl_test.go new file mode 100644 index 0000000000..056b3e6e4b --- /dev/null +++ b/goal/gosl/gotosl/gosl_test.go @@ -0,0 +1,47 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "os" + "testing" + + "cogentcore.org/core/cli" + "github.com/stretchr/testify/assert" +) + +// TestTranslate +func TestTranslate(t *testing.T) { + os.Chdir("testdata") + + opts := cli.DefaultOptions("gosl", "Go as a shader language converts Go code to WGSL WebGPU shader code, which can be run on the GPU through WebGPU.") + cfg := &Config{} + cli.Run(opts, cfg, Run) + + exSh, err := os.ReadFile("Compute.golden") + if err != nil { + t.Error(err) + return + } + exGosl, err := os.ReadFile("gosl.golden") + if err != nil { + t.Error(err) + return + } + + gotSh, err := os.ReadFile("shaders/Compute.wgsl") + if err != nil { + t.Error(err) + return + } + gotGosl, err := os.ReadFile("gosl.go") + if err != nil { + t.Error(err) + return + } + + assert.Equal(t, string(exSh), string(gotSh)) + assert.Equal(t, string(exGosl), string(gotGosl)) +} diff --git a/goal/gosl/gotosl/gotosl.go b/goal/gosl/gotosl/gotosl.go new file mode 100644 index 0000000000..bb606d779f --- /dev/null +++ b/goal/gosl/gotosl/gotosl.go @@ -0,0 +1,320 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "fmt" + "go/ast" + "os" + "path/filepath" + "reflect" + "strings" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/stack" +) + +// System represents a ComputeSystem, and its kernels and variables. +type System struct { + Name string + + // Kernels are the kernels using this compute system. + Kernels map[string]*Kernel + + // Groups are the variables for this compute system. + Groups []*Group + + // NTensors is the number of tensor vars. + NTensors int +} + +func NewSystem(name string) *System { + sy := &System{Name: name} + sy.Kernels = make(map[string]*Kernel) + return sy +} + +// Kernel represents a kernel function, which is the basis for +// each wgsl generated code file. +type Kernel struct { + Name string + + Args string + + // Filename is the name of the kernel shader file, e.g., shaders/Compute.wgsl + Filename string + + // function code + FuncCode string + + // Lines is full shader code + Lines [][]byte +} + +// Var represents one global system buffer variable. +type Var struct { + Name string + + // comment docs about this var. + Doc string + + // Type of variable: either []Type or F32, U32 for tensors + Type string + + // ReadOnly indicates that this variable is never read back from GPU, + // specified by the gosl:read-only property in the variable comments. + // It is important to optimize GPU memory usage to indicate this. + ReadOnly bool + + // True if a tensor type + Tensor bool + + // Number of dimensions + TensorDims int + + // data kind of the tensor + TensorKind reflect.Kind + + // index of tensor in list of tensor variables, for indexing. + TensorIndex int +} + +func (vr *Var) SetTensorKind() { + kindStr := strings.TrimPrefix(vr.Type, "tensor.") + kind := reflect.Float32 + switch kindStr { + case "Float32": + kind = reflect.Float32 + case "Uint32": + kind = reflect.Uint32 + case "Int32": + kind = reflect.Int32 + default: + errors.Log(fmt.Errorf("gosl: variable %q type is not supported: %q", vr.Name, kindStr)) + } + vr.TensorKind = kind +} + +// SLType returns the WGSL type string +func (vr *Var) SLType() string { + if vr.Tensor { + switch vr.TensorKind { + case reflect.Float32: + return "f32" + case reflect.Int32: + return "i32" + case reflect.Uint32: + return "u32" + } + } else { + return vr.Type[2:] + } + return "" +} + +// IndexFunc returns the tensor index function name +func (vr *Var) IndexFunc() string { + return fmt.Sprintf("Index%dD", vr.TensorDims) +} + +// IndexStride returns the tensor stride variable reference +func (vr *Var) IndexStride(dim int) string { + return fmt.Sprintf("TensorStrides[%d]", vr.TensorIndex*10+dim) +} + +// Group represents one variable group. +type Group struct { + Name string + + // comment docs about this group + Doc string + + // Uniform indicates a uniform group; else default is Storage. + Uniform bool + + Vars []*Var +} + +// File has contents of a file as lines of bytes. +type File struct { + Name string + Lines [][]byte +} + +// GetGlobalVar holds GetVar expression, to Set variable back when done. +type GetGlobalVar struct { + // global variable + Var *Var + // name of temporary variable + TmpVar string + // index passed to the Get function + IdxExpr ast.Expr +} + +// State holds the current Go -> WGSL processing state. +type State struct { + // Config options. + Config *Config + + // path to shaders/imports directory. + ImportsDir string + + // name of the package + Package string + + // GoFiles are all the files with gosl content in current directory. + GoFiles map[string]*File + + // GoVarsFiles are all the files with gosl:vars content in current directory. + // These must be processed first! they are moved from GoFiles to here. + GoVarsFiles map[string]*File + + // GoImports has all the imported files. + GoImports map[string]map[string]*File + + // ImportPackages has short package names, to remove from go code + // so everything lives in same main package. + ImportPackages map[string]bool + + // Systems has the kernels and variables for each system. + // There is an initial "Default" system when system is not specified. + Systems map[string]*System + + // GetFuncs is a map of GetVar, SetVar function names for global vars. + GetFuncs map[string]*Var + + // SLImportFiles are all the extracted and translated WGSL files in shaders/imports, + // which are copied into the generated shader kernel files. + SLImportFiles []*File + + // generated Go GPU gosl.go file contents + GPUFile File + + // ExcludeMap is the compiled map of functions to exclude in Go -> WGSL translation. + ExcludeMap map[string]bool + + // GetVarStack is a stack per function definition of GetVar variables + // that need to be set at the end. + GetVarStack stack.Stack[map[string]*GetGlobalVar] + + // GetFuncGraph is true if getting the function graph (first pass) + GetFuncGraph bool + + // KernelFuncs are the list of functions to include for current kernel. + KernelFuncs map[string]*Function + + // FuncGraph is the call graph of functions, for dead code elimination + FuncGraph map[string]*Function +} + +func (st *State) Init(cfg *Config) { + st.Config = cfg + st.GoImports = make(map[string]map[string]*File) + st.Systems = make(map[string]*System) + st.ExcludeMap = make(map[string]bool) + ex := strings.Split(cfg.Exclude, ",") + for _, fn := range ex { + st.ExcludeMap[fn] = true + } + st.Systems["Default"] = NewSystem("Default") +} + +func (st *State) Run() error { + if gomod := os.Getenv("GO111MODULE"); gomod == "off" { + err := errors.New("gosl only works in go modules mode, but GO111MODULE=off") + return err + } + if st.Config.Output == "" { + st.Config.Output = "shaders" + } + + st.ProjectFiles() // get list of all files, recursively gets imports etc. + if len(st.GoFiles) == 0 { + return nil + } + + st.ImportsDir = filepath.Join(st.Config.Output, "imports") + os.MkdirAll(st.Config.Output, 0755) + os.MkdirAll(st.ImportsDir, 0755) + RemoveGenFiles(st.Config.Output) + RemoveGenFiles(st.ImportsDir) + + st.ExtractFiles() // get .go from project files + st.ExtractImports() // get .go from imports + st.TranslateDir("./" + st.ImportsDir) + + st.GenGPU() + + return nil +} + +// System returns the given system by name, making if not made. +// if name is empty, "Default" is used. +func (st *State) System(sysname string) *System { + if sysname == "" { + sysname = "Default" + } + sy, ok := st.Systems[sysname] + if ok { + return sy + } + sy = NewSystem(sysname) + st.Systems[sysname] = sy + return sy +} + +// GlobalVar returns global variable of given name, if found. +func (st *State) GlobalVar(vrnm string) *Var { + if st == nil { + return nil + } + if st.Systems == nil { + return nil + } + for _, sy := range st.Systems { + for _, gp := range sy.Groups { + for _, vr := range gp.Vars { + if vr.Name == vrnm { + return vr + } + } + } + } + return nil +} + +// GetTempVar returns temp var for global variable of given name, if found. +func (st *State) GetTempVar(vrnm string) *GetGlobalVar { + if st == nil || st.GetVarStack == nil { + return nil + } + nv := len(st.GetVarStack) + for i := nv - 1; i >= 0; i-- { + gvars := st.GetVarStack[i] + if gv, ok := gvars[vrnm]; ok { + return gv + } + } + return nil +} + +// VarsAdded is called when a set of vars has been added; update relevant maps etc. +func (st *State) VarsAdded() { + st.GetFuncs = make(map[string]*Var) + for _, sy := range st.Systems { + tensorIdx := 0 + for _, gp := range sy.Groups { + for _, vr := range gp.Vars { + if vr.Tensor { + vr.TensorIndex = tensorIdx + tensorIdx++ + continue + } + st.GetFuncs["Get"+vr.Name] = vr + } + } + sy.NTensors = tensorIdx + } +} diff --git a/gpu/gosl/slprint/nodes.go b/goal/gosl/gotosl/nodes.go similarity index 80% rename from gpu/gosl/slprint/nodes.go rename to goal/gosl/gotosl/nodes.go index 13ace493a0..075c0cc33d 100644 --- a/gpu/gosl/slprint/nodes.go +++ b/goal/gosl/gotosl/nodes.go @@ -1,3 +1,10 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is largely copied from the Go source, +// src/go/printer/nodes.go: + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -6,7 +13,7 @@ // expressions, statements, declarations, and files. It uses // the print functionality implemented in printer.go. -package slprint +package gotosl import ( "fmt" @@ -64,6 +71,42 @@ func (p *printer) linebreak(line, min int, ws whiteSpace, newSection bool) (nbre return } +// gosl: find any gosl directive in given comments, returns directive(s) and remaining docs +func (p *printer) findDirective(g *ast.CommentGroup) (dirs []string, docs string) { + if g == nil { + return + } + for _, c := range g.List { + if strings.HasPrefix(c.Text, "//gosl:") { + dirs = append(dirs, c.Text[7:]) + } else { + docs += c.Text + " " + } + } + return +} + +// gosl: hasDirective returns whether directive(s) contains string +func hasDirective(dirs []string, dir string) bool { + for _, d := range dirs { + if strings.Contains(d, dir) { + return true + } + } + return false +} + +// gosl: directiveAfter returns the directive after given leading text, +// and a bool indicating if the string was found. +func directiveAfter(dirs []string, dir string) (string, bool) { + for _, d := range dirs { + if strings.HasPrefix(d, dir) { + return strings.TrimSpace(strings.TrimPrefix(d, dir)), true + } + } + return "", false +} + // setComment sets g as the next comment if g != nil and if node comments // are enabled - this mode is used when printing source code fragments such // as exports only. It assumes that there is no pending comment in p.comments @@ -416,7 +459,7 @@ func (p *printer) parameters(fields *ast.FieldList, mode paramMode) { // A type parameter list [P T] where the name P and the type expression T syntactically // combine to another valid (value) expression requires a trailing comma, as in [P *T,] // (or an enclosing interface as in [P interface(*T)]), so that the type parameter list - // is not parsed as an array length [P*T]. + // is not gotosld as an array length [P*T]. p.print(token.COMMA) } @@ -430,24 +473,72 @@ func (p *printer) parameters(fields *ast.FieldList, mode paramMode) { p.print(closeTok) } +type rwArg struct { + idx *ast.IndexExpr + tmpVar string +} + +func (p *printer) assignRwArgs(rwargs []rwArg) { + nrw := len(rwargs) + if nrw == 0 { + return + } + p.print(token.SEMICOLON, blank, formfeed) + for i, rw := range rwargs { + p.expr(rw.idx) + p.print(token.ASSIGN) + tv := rw.tmpVar + if len(tv) > 0 && tv[0] == '&' { + tv = tv[1:] + } + p.print(tv) + if i < nrw-1 { + p.print(token.SEMICOLON, blank) + } + } +} + // gosl: ensure basic literals are properly cast -func (p *printer) matchLiteralArgs(args []ast.Expr, params *types.Tuple) []ast.Expr { +func (p *printer) goslFixArgs(args []ast.Expr, params *types.Tuple) ([]ast.Expr, []rwArg) { ags := slices.Clone(args) mx := min(len(args), params.Len()) + var rwargs []rwArg for i := 0; i < mx; i++ { ag := args[i] pr := params.At(i) - lit, ok := ag.(*ast.BasicLit) - if !ok { - continue + switch x := ag.(type) { + case *ast.BasicLit: + typ := pr.Type() + tnm := getLocalTypeName(typ) + nn := normalizedNumber(x) + nn.Value = tnm + "(" + nn.Value + ")" + ags[i] = nn + case *ast.Ident: + if gvar := p.GoToSL.GetTempVar(x.Name); gvar != nil { + x.Name = "&" + x.Name + ags[i] = x + } + case *ast.IndexExpr: + isGlobal, tmpVar, _, _, isReadOnly := p.globalVar(x) + if isGlobal { + ags[i] = &ast.Ident{Name: tmpVar} + if !isReadOnly { + rwargs = append(rwargs, rwArg{idx: x, tmpVar: tmpVar}) + } + } + case *ast.UnaryExpr: + if idx, ok := x.X.(*ast.IndexExpr); ok { + isGlobal, tmpVar, _, _, isReadOnly := p.globalVar(idx) + if isGlobal { + ags[i] = &ast.Ident{Name: tmpVar} + if !isReadOnly { + rwargs = append(rwargs, rwArg{idx: idx, tmpVar: tmpVar}) + } + } + } } - typ := pr.Type() - tnm := getLocalTypeName(typ) - nn := normalizedNumber(lit) - nn.Value = tnm + "(" + nn.Value + ")" - ags[i] = nn } - return ags + return ags, rwargs } // gosl: ensure basic literals are properly cast @@ -532,7 +623,13 @@ func (p *printer) pathType(x *ast.SelectorExpr) (types.Type, error) { return nil, fmt.Errorf("gosl pathType: path not a pure selector path") } np := len(paths) - bt, err := getStructType(p.getIdType(paths[np-1])) + idt := p.getIdType(paths[np-1]) + if idt == nil { + err := fmt.Errorf("gosl pathType ERROR: cannot find type for name: %q", paths[np-1].Name) + p.userError(err) + return nil, err + } + bt, err := p.getStructType(idt) if err != nil { return nil, err } @@ -545,7 +642,7 @@ func (p *printer) pathType(x *ast.SelectorExpr) (types.Type, error) { if pi == 0 { return f.Type(), nil } else { - bt, err = getStructType(f.Type()) + bt, err = p.getStructType(f.Type()) if err != nil { return nil, err } @@ -587,16 +684,14 @@ func (p *printer) ptrType(x ast.Expr) (ast.Expr, bool) { } // gosl: printMethRecv prints the method recv prefix for function. returns true if recv is ptr -func (p *printer) printMethRecv() bool { - isPtr := false +func (p *printer) printMethRecv() (isPtr bool, typnm string) { if u, ok := p.curMethRecv.Type.(*ast.StarExpr); ok { - p.expr(u.X) + typnm = u.X.(*ast.Ident).Name isPtr = true } else { - p.expr(p.curMethRecv.Type) + typnm = p.curMethRecv.Type.(*ast.Ident).Name } - p.print("_") - return isPtr + return } // combinesWithName reports whether a name followed by the expression x @@ -1002,7 +1097,7 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int) { prec := x.Op.Precedence() if prec < prec1 { // parenthesis needed - // Note: The parser inserts an ast.ParenExpr node; thus this case + // Note: The gotoslr inserts an ast.ParenExpr node; thus this case // can only occur if the AST is created in a different way. p.print(token.LPAREN) p.expr0(x, reduceDepth(depth)) // parentheses undo one level of depth @@ -1020,7 +1115,11 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int) { xline := p.pos.Line // before the operator (it may be on the next line!) yline := p.lineFor(x.Y.Pos()) p.setPos(x.OpPos) - p.print(x.Op) + if x.Op == token.AND_NOT { + p.print(token.AND, blank, token.TILDE) + } else { + p.print(x.Op) + } if xline != yline && xline > 0 && yline > 0 { // at least one line break, but respect an extra empty line // in the source @@ -1102,7 +1201,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { } case *ast.BasicLit: - if p.Config.Mode&normalizeNumbers != 0 { + if p.PrintConfig.Mode&normalizeNumbers != 0 { x = normalizedNumber(x) } p.print(x) @@ -1209,7 +1308,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { if len(x.Args) > 1 { depth++ } - + fid, isid := x.Fun.(*ast.Ident) // Conversions to literal function types or <-chan // types require parentheses around the type. paren := false @@ -1226,21 +1325,25 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { p.methodExpr(x, depth) break // handles everything, break out of case } - p.expr1(x.Fun, token.HighestPrec, depth) - if paren { - p.print(token.RPAREN) - } - p.setPos(x.Lparen) - p.print(token.LPAREN) args := x.Args - if fid, ok := x.Fun.(*ast.Ident); ok { + var rwargs []rwArg + if isid { + if p.curFunc != nil { + p.curFunc.Funcs[fid.Name] = p.GoToSL.RecycleFunc(fid.Name) + } if obj, ok := p.pkg.TypesInfo.Uses[fid]; ok { if ft, ok := obj.(*types.Func); ok { sig := ft.Type().(*types.Signature) - args = p.matchLiteralArgs(x.Args, sig.Params()) + args, rwargs = p.goslFixArgs(x.Args, sig.Params()) } } } + p.expr1(x.Fun, token.HighestPrec, depth) + if paren { + p.print(token.RPAREN) + } + p.setPos(x.Lparen) + p.print(token.LPAREN) if x.Ellipsis.IsValid() { p.exprList(x.Lparen, args, depth, 0, x.Ellipsis, false) p.setPos(x.Ellipsis) @@ -1253,15 +1356,22 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { } p.setPos(x.Rparen) p.print(token.RPAREN) + p.assignRwArgs(rwargs) case *ast.CompositeLit: // composite literal elements that are composite literals themselves may have the type omitted + lb := token.LBRACE + rb := token.RBRACE + if _, isAry := x.Type.(*ast.ArrayType); isAry { + lb = token.LPAREN + rb = token.RPAREN + } if x.Type != nil { p.expr1(x.Type, token.HighestPrec, depth) } p.level++ p.setPos(x.Lbrace) - p.print(token.LBRACE) + p.print(lb) p.exprList(x.Lbrace, x.Elts, 1, commaTerm, x.Rbrace, x.Incomplete) // do not insert extra line break following a /*-style comment // before the closing '}' as it might break the code if there @@ -1276,7 +1386,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { // the proper level of indentation p.print(indent, unindent, mode) p.setPos(x.Rbrace) - p.print(token.RBRACE, mode) + p.print(rb, mode) p.level-- case *ast.Ellipsis: @@ -1286,12 +1396,13 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) { } case *ast.ArrayType: - p.print(token.LBRACK) - if x.Len != nil { - p.expr(x.Len) - } - p.print(token.RBRACK) - p.expr(x.Elt) + p.print("array") + // p.print(token.LBRACK) + // if x.Len != nil { + // p.expr(x.Len) + // } + // p.print(token.RBRACK) + // p.expr(x.Elt) case *ast.StructType: // p.print(token.STRUCT) @@ -1425,7 +1536,7 @@ func (p *printer) methodPath(x *ast.SelectorExpr) (recvPath, recvType string, pa break } err = fmt.Errorf("gosl methodPath ERROR: path for method call must be simple list of fields, not %#v:", cur.X) - fmt.Println(err.Error()) + p.userError(err) return } if p.isPtrArg(baseRecv) { @@ -1433,26 +1544,38 @@ func (p *printer) methodPath(x *ast.SelectorExpr) (recvPath, recvType string, pa } else { recvPath = "&" + baseRecv.Name } - bt, err := getStructType(p.getIdType(baseRecv)) + var idt types.Type + if gvar := p.GoToSL.GetTempVar(baseRecv.Name); gvar != nil { + idt = p.getTypeNameType(gvar.Var.SLType()) + } else { + idt = p.getIdType(baseRecv) + } + if idt == nil { + err = fmt.Errorf("gosl methodPath ERROR: cannot find type for name: %q", baseRecv.Name) + p.userError(err) + return + } + bt, err := p.getStructType(idt) if err != nil { + fmt.Println(baseRecv) return } curt := bt np := len(paths) for pi := np - 1; pi >= 0; pi-- { - p := paths[pi] - recvPath += "." + p - f := fieldByName(curt, p) + pth := paths[pi] + recvPath += "." + pth + f := fieldByName(curt, pth) if f == nil { - err = fmt.Errorf("gosl ERROR: field not found %q in type: %q:", p, curt.String()) - fmt.Println(err.Error()) + err = fmt.Errorf("gosl ERROR: field not found %q in type: %q:", pth, curt.String()) + p.userError(err) return } if pi == 0 { pathType = f.Type() recvType = getLocalTypeName(f.Type()) } else { - curt, err = getStructType(f.Type()) + curt, err = p.getStructType(f.Type()) if err != nil { return } @@ -1479,12 +1602,20 @@ func (p *printer) getIdType(id *ast.Ident) types.Type { return nil } +func (p *printer) getTypeNameType(typeName string) types.Type { + obj := p.pkg.Types.Scope().Lookup(typeName) + if obj != nil { + return obj.Type() + } + return nil +} + func getLocalTypeName(typ types.Type) string { _, nm := path.Split(typ.String()) return nm } -func getStructType(typ types.Type) (*types.Struct, error) { +func (p *printer) getStructType(typ types.Type) (*types.Struct, error) { typ = typ.Underlying() if st, ok := typ.(*types.Struct); ok { return st, nil @@ -1495,11 +1626,182 @@ func getStructType(typ types.Type) (*types.Struct, error) { return st, nil } } + if sl, ok := typ.(*types.Slice); ok { + typ = sl.Elem().Underlying() + if st, ok := typ.(*types.Struct); ok { + return st, nil + } + } err := fmt.Errorf("gosl ERROR: type is not a struct and it should be: %q %+t", typ.String(), typ) - fmt.Println(err.Error()) + p.userError(err) return nil, err } +func (p *printer) getNamedType(typ types.Type) (*types.Named, error) { + if nmd, ok := typ.(*types.Named); ok { + return nmd, nil + } + typ = typ.Underlying() + if ptr, ok := typ.(*types.Pointer); ok { + typ = ptr.Elem() + if nmd, ok := typ.(*types.Named); ok { + return nmd, nil + } + } + if sl, ok := typ.(*types.Slice); ok { + typ = sl.Elem() + if nmd, ok := typ.(*types.Named); ok { + return nmd, nil + } + } + err := fmt.Errorf("gosl ERROR: type is not a named type: %q %+t", typ.String(), typ) + p.userError(err) + return nil, err +} + +// gosl: globalVar looks up whether the id in an IndexExpr is a global gosl variable. +// in which case it returns a temp variable name to use, and the type info. +func (p *printer) globalVar(idx *ast.IndexExpr) (isGlobal bool, tmpVar, typName string, vtyp types.Type, isReadOnly bool) { + id, ok := idx.X.(*ast.Ident) + if !ok { + return + } + gvr := p.GoToSL.GlobalVar(id.Name) + if gvr == nil { + return + } + isGlobal = true + isReadOnly = gvr.ReadOnly + tmpVar = strings.ToLower(id.Name) + vtyp = p.getIdType(id) + if vtyp == nil { + err := fmt.Errorf("gosl globalVar ERROR: cannot find type for name: %q", id.Name) + p.userError(err) + return + } + nmd, err := p.getNamedType(vtyp) + if err == nil { + vtyp = nmd + } + typName = gvr.SLType() + p.print("var ", tmpVar, token.ASSIGN) + p.expr(idx) + p.print(token.SEMICOLON, blank) + tmpVar = "&" + tmpVar + return +} + +// gosl: replace GetVar function call with assignment of local var +func (p *printer) getGlobalVar(ae *ast.AssignStmt, gvr *Var) { + tmpVar := ae.Lhs[0].(*ast.Ident).Name + cf := ae.Rhs[0].(*ast.CallExpr) + p.print("var", blank, tmpVar, blank, token.ASSIGN, blank, gvr.Name, token.LBRACK) + p.expr(cf.Args[0]) + p.print(token.RBRACK, token.SEMICOLON) + gvars := p.GoToSL.GetVarStack.Peek() + gvars[tmpVar] = &GetGlobalVar{Var: gvr, TmpVar: tmpVar, IdxExpr: cf.Args[0]} + p.GoToSL.GetVarStack[len(p.GoToSL.GetVarStack)-1] = gvars +} + +// gosl: set non-read-only global vars back from temp var +func (p *printer) setGlobalVars(gvrs map[string]*GetGlobalVar) { + for _, gvr := range gvrs { + if gvr.Var.ReadOnly { + continue + } + p.print(formfeed, "\t") + p.print(gvr.Var.Name, token.LBRACK) + p.expr(gvr.IdxExpr) + p.print(token.RBRACK, blank, token.ASSIGN, blank) + p.print(gvr.TmpVar) + p.print(token.SEMICOLON) + } +} + +// gosl: methodIndex processes an index expression as receiver type of method call +func (p *printer) methodIndex(idx *ast.IndexExpr) (recvPath, recvType string, pathType types.Type, isReadOnly bool, err error) { + id, ok := idx.X.(*ast.Ident) + if !ok { + err = fmt.Errorf("gosl methodIndex ERROR: must have a recv variable identifier, not %#v:", idx.X) + p.userError(err) + return + } + isGlobal, tmpVar, typName, vtyp, isReadOnly := p.globalVar(idx) + if isGlobal { + recvPath = tmpVar + recvType = typName + pathType = vtyp + } else { + _ = id + // do above + } + return +} + +func (p *printer) tensorMethod(x *ast.CallExpr, vr *Var, methName string) { + args := x.Args + + stArg := 0 + if strings.HasPrefix(methName, "Set") { + stArg = 1 + } + if strings.HasSuffix(methName, "Ptr") { + p.print(token.AND) + if p.curMethIsAtomic { + gv := p.GoToSL.GlobalVar(vr.Name) + if gv != nil { + if p.curFunc != nil { + if p.curFunc.Atomics == nil { + p.curFunc.Atomics = make(map[string]*Var) + } + p.curFunc.Atomics[vr.Name] = vr + } + } + } + } + p.print(vr.Name, token.LBRACK) + p.print(vr.IndexFunc(), token.LPAREN) + nd := vr.TensorDims + for d := range nd { + p.print(vr.IndexStride(d), token.COMMA, blank) + } + n := len(args) + for i := stArg; i < n; i++ { + ag := args[i] + p.print("u32", token.LPAREN) + if ce, ok := ag.(*ast.CallExpr); ok { // get rid of int() wrapper from goal n-dim index + if fn, ok := ce.Fun.(*ast.Ident); ok { + if fn.Name == "int" { + ag = ce.Args[0] + } + } + } + p.expr(ag) + p.print(token.RPAREN) + if i < n-1 { + p.print(token.COMMA, blank) + } + } + p.print(token.RPAREN, token.RBRACK) + if strings.HasPrefix(methName, "Set") { + opnm := strings.TrimPrefix(methName, "Set") + tok := token.ASSIGN + switch opnm { + case "Add": + tok = token.ADD_ASSIGN + case "Sub": + tok = token.SUB_ASSIGN + case "Mul": + tok = token.MUL_ASSIGN + case "Div": + tok = token.QUO_ASSIGN + } + + p.print(blank, tok, blank) + p.expr(args[0]) + } +} + func (p *printer) methodExpr(x *ast.CallExpr, depth int) { path := x.Fun.(*ast.SelectorExpr) // we know fun is selector methName := path.Sel.Name @@ -1507,6 +1809,7 @@ func (p *printer) methodExpr(x *ast.CallExpr, depth int) { recvType := "" var err error pathIsPackage := false + var rwargs []rwArg var pathType types.Type if sl, ok := path.X.(*ast.SelectorExpr); ok { // path is itself a selector recvPath, recvType, pathType, err = p.methodPath(sl) @@ -1514,31 +1817,86 @@ func (p *printer) methodExpr(x *ast.CallExpr, depth int) { return } } else if id, ok := path.X.(*ast.Ident); ok { + gvr := p.GoToSL.GlobalVar(id.Name) + if gvr != nil && gvr.Tensor { + p.tensorMethod(x, gvr, methName) + return + } recvPath = id.Name typ := p.getIdType(id) if typ != nil { recvType = getLocalTypeName(typ) if strings.HasPrefix(recvType, "invalid") { - pathIsPackage = true - recvType = id.Name // is a package path + if gvar := p.GoToSL.GetTempVar(id.Name); gvar != nil { + recvType = gvar.Var.SLType() + recvPath = "&" + recvPath + pathType = p.getTypeNameType(gvar.Var.SLType()) + } else { + pathIsPackage = true + recvType = id.Name // is a package path + } } else { pathType = typ + recvPath = recvPath } } else { pathIsPackage = true recvType = id.Name // is a package path } + } else if idx, ok := path.X.(*ast.IndexExpr); ok { + isReadOnly := false + recvPath, recvType, pathType, isReadOnly, err = p.methodIndex(idx) + if err != nil { + return + } + if !isReadOnly { + rwargs = append(rwargs, rwArg{idx: idx, tmpVar: recvPath}) + } } else { err := fmt.Errorf("gosl methodExpr ERROR: path expression for method call must be simple list of fields, not %#v:", path.X) - fmt.Println(err.Error()) + p.userError(err) return } + args := x.Args + if pathType != nil { + meth, _, _ := types.LookupFieldOrMethod(pathType, true, p.pkg.Types, methName) + if meth != nil { + if ft, ok := meth.(*types.Func); ok { + sig := ft.Type().(*types.Signature) + var rwa []rwArg + args, rwa = p.goslFixArgs(x.Args, sig.Params()) + rwargs = append(rwargs, rwa...) + } + } + if len(rwargs) > 0 { + p.print(formfeed) + } + } + // fmt.Println(pathIsPackage, recvType, methName, recvPath) if pathIsPackage { - p.print(recvType + "." + methName) + if recvType == "atomic" || recvType == "atomicx" { + p.curMethIsAtomic = true + switch { + case strings.HasPrefix(methName, "Add"): + p.print("atomicAdd") + case strings.HasPrefix(methName, "Max"): + p.print("atomicMax") + } + } else { + p.print(recvType + "." + methName) + if p.curFunc != nil { + p.curFunc.Funcs[methName] = p.GoToSL.RecycleFunc(methName) + } + } p.setPos(x.Lparen) p.print(token.LPAREN) } else { - p.print(recvType + "_" + methName) + recvType = strings.TrimPrefix(recvType, "imports.") // no! + fname := recvType + "_" + methName + if p.curFunc != nil { + p.curFunc.Funcs[fname] = p.GoToSL.RecycleFunc(fname) + } + p.print(fname) p.setPos(x.Lparen) p.print(token.LPAREN) p.print(recvPath) @@ -1546,16 +1904,6 @@ func (p *printer) methodExpr(x *ast.CallExpr, depth int) { p.print(token.COMMA, blank) } } - args := x.Args - if pathType != nil { - meth, _, _ := types.LookupFieldOrMethod(pathType, true, p.pkg.Types, methName) - if meth != nil { - if ft, ok := meth.(*types.Func); ok { - sig := ft.Type().(*types.Signature) - args = p.matchLiteralArgs(x.Args, sig.Params()) - } - } - } if x.Ellipsis.IsValid() { p.exprList(x.Lparen, args, depth, 0, x.Ellipsis, false) p.setPos(x.Ellipsis) @@ -1568,7 +1916,9 @@ func (p *printer) methodExpr(x *ast.CallExpr, depth int) { } p.setPos(x.Rparen) p.print(token.RPAREN) + p.curMethIsAtomic = false + p.assignRwArgs(rwargs) // gosl: assign temp var back to global var } func (p *printer) expr0(x ast.Expr, depth int) { @@ -1625,9 +1975,28 @@ func (p *printer) stmtList(list []ast.Stmt, nindent int, nextIsRBrace bool) { // block prints an *ast.BlockStmt; it always spans at least two lines. func (p *printer) block(b *ast.BlockStmt, nindent int) { + p.GoToSL.GetVarStack.Push(make(map[string]*GetGlobalVar)) p.setPos(b.Lbrace) p.print(token.LBRACE) - p.stmtList(b.List, nindent, true) + nstmt := len(b.List) + retLast := false + if nstmt > 1 { + if _, ok := b.List[nstmt-1].(*ast.ReturnStmt); ok { + retLast = true + } + } + if retLast { + p.stmtList(b.List[:nstmt-1], nindent, true) + } else { + p.stmtList(b.List, nindent, true) + } + getVars := p.GoToSL.GetVarStack.Pop() + if len(getVars) > 0 { // gosl: set the get vars + p.setGlobalVars(getVars) + } + if retLast { + p.stmt(b.List[nstmt-1], true, false) + } p.linebreak(p.lineFor(b.Rbrace), 1, ignore, true) p.setPos(b.Rbrace) p.print(token.RBRACE) @@ -1811,6 +2180,16 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, nosemi bool) { depth++ } if s.Tok == token.DEFINE { + if ce, ok := s.Rhs[0].(*ast.CallExpr); ok { + if fid, ok := ce.Fun.(*ast.Ident); ok { + if strings.HasPrefix(fid.Name, "Get") { + if gvr, ok := p.GoToSL.GetFuncs[fid.Name]; ok { + p.getGlobalVar(s, gvr) // replace GetVar function call with assignment of local var + return + } + } + } + } p.print("var", blank) // we don't know if it is var or let.. } p.exprList(s.Pos(), s.Lhs, depth, 0, s.TokPos, false) @@ -1888,7 +2267,7 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, nosemi bool) { default: // This can only happen with an incorrectly // constructed AST. Permit it but print so - // that it can be parsed without errors. + // that it can be gotosld without errors. p.print(token.LBRACE, indent, formfeed) p.stmt(s.Else, true, false) p.print(unindent, formfeed, token.RBRACE) @@ -1954,22 +2333,31 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, nosemi bool) { p.block(s.Body, 1) case *ast.RangeStmt: + // gosl: only supporting the for i := range 10 kind of range loop p.print(token.FOR, blank) if s.Key != nil { + p.print(token.LPAREN, "var", blank) p.expr(s.Key) - if s.Value != nil { - // use position of value following the comma as - // comma position for correct comment placement - p.setPos(s.Value.Pos()) - p.print(token.COMMA, blank) - p.expr(s.Value) - } - p.print(blank) - p.setPos(s.TokPos) - p.print(s.Tok, blank) + p.print(token.ASSIGN, "0", token.SEMICOLON, blank) + p.expr(s.Key) + p.print(token.LSS) + p.expr(stripParens(s.X)) + p.print(token.SEMICOLON, blank) + p.expr(s.Key) + p.print(token.INC, token.RPAREN) + // if s.Value != nil { + // // use position of value following the comma as + // // comma position for correct comment placement + // p.setPos(s.Value.Pos()) + // p.print(token.COMMA, blank) + // p.expr(s.Value) + // } + // p.print(blank) + // p.setPos(s.TokPos) + // p.print(s.Tok, blank) } - p.print(token.RANGE, blank) - p.expr(stripParens(s.X)) + // p.print(token.RANGE, blank) + // p.expr(stripParens(s.X)) p.print(blank) p.block(s.Body, 1) @@ -2095,7 +2483,7 @@ func (p *printer) valueSpec(s *ast.ValueSpec, keepType bool, tok token.Token, fi } func sanitizeImportPath(lit *ast.BasicLit) *ast.BasicLit { - // Note: An unmodified AST generated by go/parser will already + // Note: An unmodified AST generated by go/gotoslr will already // contain a backward- or double-quoted path string that does // not contain any invalid characters, and most of the work // here is not needed. However, a modified or generated AST @@ -2226,6 +2614,107 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool, tok token.Token) { } } +// gosl: process system global vars +func (p *printer) systemVars(d *ast.GenDecl, sysname string) { + if !p.GoToSL.GetFuncGraph { + return + } + sy := p.GoToSL.System(sysname) + var gp *Group + var err error + for _, s := range d.Specs { + vs := s.(*ast.ValueSpec) + dirs, docs := p.findDirective(vs.Doc) + readOnly := false + if hasDirective(dirs, "read-only") { + readOnly = true + } + if gpnm, ok := directiveAfter(dirs, "group"); ok { + if gpnm == "" { + gp = &Group{Name: fmt.Sprintf("Group_%d", len(sy.Groups)), Doc: docs} + sy.Groups = append(sy.Groups, gp) + } else { + gps := strings.Fields(gpnm) + gp = &Group{Doc: docs} + if gps[0] == "-uniform" { + gp.Uniform = true + if len(gps) > 1 { + gp.Name = gps[1] + } + } else { + gp.Name = gps[0] + } + sy.Groups = append(sy.Groups, gp) + } + } + if gp == nil { + gp = &Group{Name: fmt.Sprintf("Group_%d", len(sy.Groups)), Doc: docs} + sy.Groups = append(sy.Groups, gp) + } + if len(vs.Names) != 1 { + err = fmt.Errorf("gosl: system %q: vars must have only 1 variable per line", sysname) + p.userError(err) + } + nm := vs.Names[0].Name + typ := "" + if sl, ok := vs.Type.(*ast.ArrayType); ok { + id, ok := sl.Elt.(*ast.Ident) + if !ok { + err = fmt.Errorf("gosl: system %q: Var type not recognized: %#v", sysname, sl.Elt) + p.userError(err) + continue + } + typ = "[]" + id.Name + } else { + sel, ok := vs.Type.(*ast.SelectorExpr) + if !ok { + st, ok := vs.Type.(*ast.StarExpr) + if !ok { + err = fmt.Errorf("gosl: system %q: Var types must be []slices or tensor.Float32, tensor.Uint32", sysname) + p.userError(err) + + continue + } + sel, ok = st.X.(*ast.SelectorExpr) + if !ok { + err = fmt.Errorf("gosl: system %q: Var types must be []slices or tensor.Float32, tensor.Uint32", sysname) + p.userError(err) + continue + } + } + sid, ok := sel.X.(*ast.Ident) + if !ok { + err = fmt.Errorf("gosl: system %q: Var type selector is not recognized: %#v", sysname, sel.X) + p.userError(err) + continue + } + typ = sid.Name + "." + sel.Sel.Name + } + vr := &Var{Name: nm, Type: typ, ReadOnly: readOnly} + if strings.HasPrefix(typ, "tensor.") { + vr.Tensor = true + dstr, ok := directiveAfter(dirs, "dims") + if !ok { + err = fmt.Errorf("gosl: system %q: variable %q tensor vars require //gosl:dims to specify number of dimensions", sysname, nm) + p.userError(err) + continue + } + dims, err := strconv.Atoi(dstr) + if !ok { + err = fmt.Errorf("gosl: system %q: variable %q tensor dims parse error: %s", sysname, nm, err.Error()) + p.userError(err) + } + vr.SetTensorKind() + vr.TensorDims = dims + } + gp.Vars = append(gp.Vars, vr) + if p.GoToSL.Config.Debug { + fmt.Println("\tAdded var:", nm, typ, "to group:", gp.Name) + } + } + p.GoToSL.VarsAdded() +} + func (p *printer) genDecl(d *ast.GenDecl) { p.setComment(d.Doc) // note: critical to print here to trigger comment generation in right place @@ -2244,6 +2733,13 @@ func (p *printer) genDecl(d *ast.GenDecl) { // p.print(indent, formfeed) if n > 1 && (d.Tok == token.CONST || d.Tok == token.VAR) { // two or more grouped const/var declarations: + if d.Tok == token.VAR { + dirs, _ := p.findDirective(d.Doc) + if sysname, ok := directiveAfter(dirs, "vars"); ok { + p.systemVars(d, sysname) + return + } + } // determine if the type column must be kept keepType := keepTypeColumn(d.Specs) firstSpec := d.Specs[0].(*ast.ValueSpec) @@ -2257,20 +2753,27 @@ func (p *printer) genDecl(d *ast.GenDecl) { } var line int for i, s := range d.Specs { + vs := s.(*ast.ValueSpec) if i > 0 { p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0) } p.recordLine(&line) - p.valueSpec(s.(*ast.ValueSpec), keepType[i], d.Tok, firstSpec, isIota, i) + p.valueSpec(vs, keepType[i], d.Tok, firstSpec, isIota, i) } } else { + tok := d.Tok + if p.curFunc == nil && tok == token.VAR { // only system vars are supported at global scope + // could add further comment-directive logic + // to specify or scope if needed + tok = token.CONST + } var line int for i, s := range d.Specs { if i > 0 { p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0) } p.recordLine(&line) - p.spec(s, n, false, d.Tok) + p.spec(s, n, false, tok) } } // p.print(unindent, formfeed) @@ -2278,8 +2781,12 @@ func (p *printer) genDecl(d *ast.GenDecl) { // p.setPos(d.Rparen) // p.print(token.RPAREN) } else if len(d.Specs) > 0 { + tok := d.Tok + if p.curFunc == nil && tok == token.VAR { // only system vars are supported at global scope + tok = token.CONST + } // single declaration - p.spec(d.Specs[0], 1, true, d.Tok) + p.spec(d.Specs[0], 1, true, tok) } } @@ -2322,7 +2829,7 @@ func (p *printer) nodeSize(n ast.Node, maxSize int) (size int) { // nodeSize computation must be independent of particular // style so that we always get the same decision; print // in RawFormat - cfg := Config{Mode: RawFormat} + cfg := PrintConfig{Mode: RawFormat} var counter sizeCounter if err := cfg.fprint(&counter, p.pkg, n, p.nodeSizes); err != nil { return @@ -2437,36 +2944,51 @@ func (p *printer) methRecvType(typ ast.Expr) string { } func (p *printer) funcDecl(d *ast.FuncDecl) { - p.setComment(d.Doc) - p.setPos(d.Pos()) - // We have to save startCol only after emitting FUNC; otherwise it can be on a - // different line (all whitespace preceding the FUNC is emitted only when the - // FUNC is emitted). - startCol := p.out.Column - len("func ") + fname := "" if d.Recv != nil { for ex := range p.ExcludeFunctions { if d.Name.Name == ex { return } } - p.print("fn", blank) if d.Recv.List[0].Names != nil { p.curMethRecv = d.Recv.List[0] - if p.printMethRecv() { + isptr, typnm := p.printMethRecv() + if isptr { p.curPtrArgs = []*ast.Ident{p.curMethRecv.Names[0]} } + fname = typnm + "_" + d.Name.Name // fmt.Printf("cur func recv: %v\n", p.curMethRecv) } // p.parameters(d.Recv, funcParam) // method: print receiver // p.print(blank) } else { - p.print("fn", blank) + fname = d.Name.Name + } + if p.GoToSL.GetFuncGraph { + p.curFunc = p.GoToSL.RecycleFunc(fname) + } else { + fn, ok := p.GoToSL.KernelFuncs[fname] + if !ok { + return + } + p.curFunc = fn } - p.expr(d.Name) + p.setComment(d.Doc) + p.setPos(d.Pos()) + // We have to save startCol only after emitting FUNC; otherwise it can be on a + // different line (all whitespace preceding the FUNC is emitted only when the + // FUNC is emitted). + startCol := p.out.Column - len("func ") + p.print("fn", blank, fname) p.signature(d.Type, d.Recv) p.funcBody(p.distanceFrom(d.Pos(), startCol), vtab, d.Body) p.curPtrArgs = nil p.curMethRecv = nil + if p.GoToSL.GetFuncGraph { + p.GoToSL.FuncGraph[fname] = p.curFunc + p.curFunc = nil + } } func (p *printer) decl(decl ast.Decl) { diff --git a/gpu/gosl/slprint/printer.go b/goal/gosl/gotosl/printer.go similarity index 95% rename from gpu/gosl/slprint/printer.go rename to goal/gosl/gotosl/printer.go index 0fa383be44..2c1cf86be8 100644 --- a/gpu/gosl/slprint/printer.go +++ b/goal/gosl/gotosl/printer.go @@ -1,8 +1,15 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is largely copied from the Go source, +// src/go/printer/printer.go: + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package slprint +package gotosl import ( "fmt" @@ -16,6 +23,7 @@ import ( "text/tabwriter" "unicode" + "cogentcore.org/core/base/fsx" "golang.org/x/tools/go/packages" ) @@ -54,7 +62,8 @@ type commentInfo struct { type printer struct { // Configuration (does not change after initialization) - Config + PrintConfig + fset *token.FileSet pkg *packages.Package // gosl: extra @@ -99,9 +108,11 @@ type printer struct { // current arguments to function that are pointers and thus need dereferencing // when accessing fields - curPtrArgs []*ast.Ident - curMethRecv *ast.Field // current method receiver, also included in curPtrArgs if ptr - curReturnType *ast.Ident + curPtrArgs []*ast.Ident + curFunc *Function + curMethRecv *ast.Field // current method receiver, also included in curPtrArgs if ptr + curReturnType *ast.Ident + curMethIsAtomic bool // current method an atomic.* function -- marks arg as atomic } func (p *printer) internalError(msg ...any) { @@ -112,6 +123,12 @@ func (p *printer) internalError(msg ...any) { } } +func (p *printer) userError(err error) { + fname := fsx.DirAndFile(p.pos.String()) + fmt.Print(fname + ": ") + fmt.Println(err.Error()) +} + // commentsHaveNewline reports whether a list of comments belonging to // an *ast.CommentGroup contains newlines. Because the position information // may only be partially correct, we also have to read the comment text. @@ -225,7 +242,7 @@ func (p *printer) writeLineDirective(pos token.Position) { func (p *printer) writeIndent() { // use "hard" htabs - indentation columns // must not be discarded by the tabwriter - n := p.Config.Indent + p.indent // include base indentation + n := p.PrintConfig.Indent + p.indent // include base indentation for i := 0; i < n; i++ { p.output = append(p.output, '\t') } @@ -287,7 +304,7 @@ func (p *printer) writeByte(ch byte, n int) { // printer benchmark by up to 10%. func (p *printer) writeString(pos token.Position, s string, isLit bool) { if p.out.Column == 1 { - if p.Config.Mode&SourcePos != 0 { + if p.PrintConfig.Mode&SourcePos != 0 { p.writeLineDirective(pos) } p.writeIndent() @@ -1321,11 +1338,12 @@ const ( normalizeNumbers Mode = 1 << 30 ) -// A Config node controls the output of Fprint. -type Config struct { - Mode Mode // default: 0 - Tabwidth int // default: 8 - Indent int // default: 0 (all code is indented at least by this much) +// A PrintConfig node controls the output of Fprint. +type PrintConfig struct { + Mode Mode // default: 0 + Tabwidth int // default: 8 + Indent int // default: 0 (all code is indented at least by this much) + GoToSL *State // gosl: ExcludeFunctions map[string]bool } @@ -1342,18 +1360,18 @@ var printerPool = sync.Pool{ }, } -func newPrinter(cfg *Config, pkg *packages.Package, nodeSizes map[ast.Node]int) *printer { +func newPrinter(cfg *PrintConfig, pkg *packages.Package, nodeSizes map[ast.Node]int) *printer { p := printerPool.Get().(*printer) *p = printer{ - Config: *cfg, - pkg: pkg, - fset: pkg.Fset, - pos: token.Position{Line: 1, Column: 1}, - out: token.Position{Line: 1, Column: 1}, - wsbuf: p.wsbuf[:0], - nodeSizes: nodeSizes, - cachedPos: -1, - output: p.output[:0], + PrintConfig: *cfg, + pkg: pkg, + fset: pkg.Fset, + pos: token.Position{Line: 1, Column: 1}, + out: token.Position{Line: 1, Column: 1}, + wsbuf: p.wsbuf[:0], + nodeSizes: nodeSizes, + cachedPos: -1, + output: p.output[:0], } return p } @@ -1368,7 +1386,7 @@ func (p *printer) free() { } // fprint implements Fprint and takes a nodesSizes map for setting up the printer state. -func (cfg *Config) fprint(output io.Writer, pkg *packages.Package, node any, nodeSizes map[ast.Node]int) (err error) { +func (cfg *PrintConfig) fprint(output io.Writer, pkg *packages.Package, node any, nodeSizes map[ast.Node]int) (err error) { // print node p := newPrinter(cfg, pkg, nodeSizes) defer p.free() @@ -1431,14 +1449,14 @@ type CommentedNode struct { // Position information is interpreted relative to the file set fset. // The node type must be *[ast.File], *[CommentedNode], [][ast.Decl], [][ast.Stmt], // or assignment-compatible to [ast.Expr], [ast.Decl], [ast.Spec], or [ast.Stmt]. -func (cfg *Config) Fprint(output io.Writer, pkg *packages.Package, node any) error { +func (cfg *PrintConfig) Fprint(output io.Writer, pkg *packages.Package, node any) error { return cfg.fprint(output, pkg, node, make(map[ast.Node]int)) } // Fprint "pretty-prints" an AST node to output. -// It calls [Config.Fprint] with default settings. +// It calls [PrintConfig.Fprint] with default settings. // Note that gofmt uses tabs for indentation but spaces for alignment; // use format.Node (package go/format) for output that matches gofmt. func Fprint(output io.Writer, pkg *packages.Package, node any) error { - return (&Config{Tabwidth: 8}).Fprint(output, pkg, node) + return (&PrintConfig{Tabwidth: 8}).Fprint(output, pkg, node) } diff --git a/gpu/gosl/sledits.go b/goal/gosl/gotosl/sledits.go similarity index 85% rename from gpu/gosl/sledits.go rename to goal/gosl/gotosl/sledits.go index b9ea5e2318..308d13682d 100644 --- a/gpu/gosl/sledits.go +++ b/goal/gosl/gotosl/sledits.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package main +package gotosl import ( "bytes" @@ -21,19 +21,15 @@ func MoveLines(lines *[][]byte, to, st, ed int) { *lines = nln } -// SlEdits performs post-generation edits for wgsl -// * moves wgsl segments around, e.g., methods -// into their proper classes -// * fixes printf, slice other common code +// SlEdits performs post-generation edits for wgsl, +// replacing type names, slbool, function calls, etc. // returns true if a slrand. or sltype. prefix was found, // driveing copying of those files. -func SlEdits(src []byte) ([]byte, bool, bool) { - // return src // uncomment to show original without edits +func SlEdits(src []byte) (lines [][]byte, hasSlrand bool, hasSltype bool) { nl := []byte("\n") - lines := bytes.Split(src, nl) - hasSlrand, hasSltype := SlEditsReplace(lines) - - return bytes.Join(lines, nl), hasSlrand, hasSltype + lines = bytes.Split(src, nl) + hasSlrand, hasSltype = SlEditsReplace(lines) + return } type Replace struct { @@ -99,6 +95,22 @@ func MathReplaceAll(mat, ln []byte) []byte { } } +func SlRemoveComments(lines [][]byte) [][]byte { + comm := []byte("//") + olns := make([][]byte, 0, len(lines)) + for _, ln := range lines { + ts := bytes.TrimSpace(ln) + if len(ts) == 0 { + continue + } + if bytes.HasPrefix(ts, comm) { + continue + } + olns = append(olns, ln) + } + return olns +} + // SlEditsReplace replaces Go with equivalent WGSL code // returns true if has slrand. or sltype. // to auto include that header file if so. @@ -114,13 +126,13 @@ func SlEditsReplace(lines [][]byte) (bool, bool) { if bytes.Contains(ln, include) { continue } + if !hasSlrand && bytes.Contains(ln, slr) { + hasSlrand = true + } + if !hasSltype && bytes.Contains(ln, styp) { + hasSltype = true + } for _, r := range Replaces { - if !hasSlrand && bytes.Contains(ln, slr) { - hasSlrand = true - } - if !hasSltype && bytes.Contains(ln, styp) { - hasSltype = true - } ln = bytes.ReplaceAll(ln, r.From, r.To) } ln = MathReplaceAll(mt32, ln) diff --git a/goal/gosl/gotosl/testdata/Compute.golden b/goal/gosl/gotosl/testdata/Compute.golden new file mode 100644 index 0000000000..fb028d31ee --- /dev/null +++ b/goal/gosl/gotosl/testdata/Compute.golden @@ -0,0 +1,102 @@ +// Code generated by "gosl"; DO NOT EDIT +// kernel: Compute + +// // Params are the parameters for the computation. +@group(0) @binding(0) +var Params: array; +@group(0) @binding(1) +var Data: array; + +alias GPUVars = i32; + +@compute @workgroup_size(64, 1, 1) +fn main(@builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3, @builtin(local_invocation_index) loci: u32) { + let idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64; + Compute(idx); +} + + +//////// import: "basic.go" +alias NeuronFlags = i32; +const NeuronOff: NeuronFlags = 0x01; +const NeuronHasExt: NeuronFlags = 0x02; // note: 1<<2 does NOT work +const NeuronHasTarg: NeuronFlags = 0x04; +const NeuronHasCmpr: NeuronFlags = 0x08; +alias Modes = i32; +const NoEvalMode: Modes = 0; +const AllModes: Modes = 1; +const Train: Modes = 2; +const Test: Modes = 3; +const testSlice = array(NoEvalMode, AllModes, Train, Test); +struct DataStruct { + Raw: f32, + Integ: f32, + Exp: f32, + Int: i32, +} +struct SubParamStruct { + A: f32, + B: f32, + C: f32, + D: f32, +} +fn SubParamStruct_Sum(sp: ptr) -> f32 { + return (*sp).A + (*sp).B + (*sp).C + (*sp).D; +} +fn SubParamStruct_SumPlus(sp: ptr, extra: f32) -> f32 { + return SubParamStruct_Sum(sp) + extra; +} +struct ParamStruct { + Tau: f32, + Dt: f32, + Option: i32, // note: standard bool doesn't work + pad: f32, // comment this out to trigger alignment warning + Subs: SubParamStruct, +} +fn ParamStruct_IntegFromRaw(ps: ptr, ds: ptr) -> f32 { + var newVal = (*ps).Dt * ((*ds).Raw - (*ds).Integ); + if (newVal < -10 || (*ps).Option == 1) { + newVal = f32(-10); + } + (*ds).Integ += newVal; + (*ds).Exp = exp(-(*ds).Integ); + var a: f32; + ParamStruct_AnotherMeth(ps, ds, &a);return (*ds).Exp; +} +fn ParamStruct_AnotherMeth(ps: ptr, ds: ptr, ptrarg: ptr) { + for (var i = 0; i < 10; i++) { + (*ds).Integ *= f32(0.99); + } + var flag: NeuronFlags; + flag &= ~NeuronHasExt; // clear flag -- op doesn't exist in C + var mode = Test; + switch (mode) { // note: no fallthrough! + case Test: { + var ab = f32(42); + (*ds).Exp /= ab; + } + case Train: { + var ab = f32(.5); + (*ds).Exp *= ab; + } + default: { + var ab = f32(1); + (*ds).Exp *= ab; + } + } + var a: f32; + var b: f32; + b = f32(42); + a = SubParamStruct_Sum(&(*ps).Subs); + (*ds).Exp = SubParamStruct_SumPlus(&(*ps).Subs, b); + (*ds).Integ = a; + for (var i=0; i<10; i++) { + (*ds).Exp *= f32(0.99); + } + *ptrarg = f32(-1); +} +fn Compute(i: u32) { //gosl:kernel + var data = Data[i]; + var params=Params[0]; ParamStruct_IntegFromRaw(¶ms, &data); + Data[i] = data; +} \ No newline at end of file diff --git a/gpu/gosl/testdata/basic.go b/goal/gosl/gotosl/testdata/basic.go similarity index 76% rename from gpu/gosl/testdata/basic.go rename to goal/gosl/gotosl/testdata/basic.go index 72a3610c95..c808b0a173 100644 --- a/gpu/gosl/testdata/basic.go +++ b/goal/gosl/gotosl/testdata/basic.go @@ -3,31 +3,21 @@ package test import ( "math" + "cogentcore.org/core/goal/gosl/slbool" "cogentcore.org/core/math32" - "cogentcore.org/core/vgpu/gosl/slbool" ) -// note: this code is included in the go pre-processing output but -// then removed from the final wgsl output. -// Use when you need different versions of the same function for CPU vs. GPU +//gosl:start -// MyTrickyFun this is the CPU version of the tricky function -func MyTrickyFun(x float32) float32 { - return 10 // ok actually not tricky here, but whatever -} - -//gosl:wgsl basic - -// // note: here is the wgsl version, only included in wgsl - -// // MyTrickyFun this is the GPU version of the tricky function -// fn MyTrickyFun(x: f32) -> f32 { -// return 16.0; // ok actually not tricky here, but whatever -// } +//gosl:vars +var ( + // Params are the parameters for the computation. + //gosl:read-only + Params []ParamStruct -//gosl:end basic - -//gosl:start basic + // Data is the data on which the computation operates. + Data []DataStruct +) // FastExp is a quartic spline approximation to the Exp function, by N.N. Schraudolph // It does not have any of the sanity checking of a standard method -- returns @@ -79,6 +69,9 @@ const ( Test ) +// testSlice is a global array: will be const = array(...); +var testSlice = [NVars]Modes{NoEvalMode, AllModes, Train, Test} + // DataStruct has the test data type DataStruct struct { @@ -91,7 +84,7 @@ type DataStruct struct { // exp of integ Exp float32 - pad float32 + Int int32 } // SubParamStruct has the test sub-params @@ -129,6 +122,7 @@ func (ps *ParamStruct) IntegFromRaw(ds *DataStruct) float32 { if newVal < -10 || ps.Option.IsTrue() { newVal = -10 } + // atomic.AddInt32(&(ds.Int), int32(newVal)) ds.Integ += newVal ds.Exp = math32.Exp(-ds.Integ) var a float32 @@ -163,10 +157,14 @@ func (ps *ParamStruct) AnotherMeth(ds *DataStruct, ptrarg *float32) { ds.Exp = ps.Subs.SumPlus(b) ds.Integ = a + for i := range 10 { + ds.Exp *= 0.99 + } + *ptrarg = -1 } -//gosl:end basic +//gosl:end // note: only core compute code needs to be in shader -- all init is done CPU-side @@ -179,21 +177,16 @@ func (ps *ParamStruct) Update() { ps.Dt = 1.0 / ps.Tau } -//gosl:wgsl basic -/* -@group(0) @binding(0) -var Params: array; - -@group(0) @binding(1) -var Data: array; - -@compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - var pars = Params[0]; - var data = Data[idx.x]; - ParamStruct_IntegFromRaw(&pars, &data); - Data[idx.x] = data; +func (ps *ParamStruct) String() string { + return "params!" +} + +//gosl:start + +// Compute does the main computation +func Compute(i uint32) { //gosl:kernel + data := GetData(i) + Params[0].IntegFromRaw(data) } -*/ -//gosl:end basic + +//gosl:end diff --git a/goal/gosl/gotosl/testdata/gosl.golden b/goal/gosl/gotosl/testdata/gosl.golden new file mode 100644 index 0000000000..5fe89c15d3 --- /dev/null +++ b/goal/gosl/gotosl/testdata/gosl.golden @@ -0,0 +1,212 @@ +// Code generated by "gosl"; DO NOT EDIT + +package test + +import ( + "embed" + "unsafe" + "cogentcore.org/core/gpu" + "cogentcore.org/core/tensor" +) + +//go:embed shaders/*.wgsl +var shaders embed.FS + +// ComputeGPU is the compute gpu device +var ComputeGPU *gpu.GPU + +// UseGPU indicates whether to use GPU vs. CPU. +var UseGPU bool + +// GPUSystem is a GPU compute System with kernels operating on the +// same set of data variables. +var GPUSystem *gpu.ComputeSystem + +// GPUVars is an enum for GPU variables, for specifying what to sync. +type GPUVars int32 //enums:enum + +const ( + ParamsVar GPUVars = 0 + DataVar GPUVars = 1 +) + +// Dummy tensor stride variable to avoid import error +var __TensorStrides tensor.Uint32 + +// GPUInit initializes the GPU compute system, +// configuring system(s), variables and kernels. +// It is safe to call multiple times: detects if already run. +func GPUInit() { + if ComputeGPU != nil { + return + } + gp := gpu.NewComputeGPU() + ComputeGPU = gp + { + sy := gpu.NewComputeSystem(gp, "Default") + GPUSystem = sy + gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy) + vars := sy.Vars() + { + sgp := vars.AddGroup(gpu.Storage, "Group_0") + var vr *gpu.Var + _ = vr + vr = sgp.AddStruct("Params", int(unsafe.Sizeof(ParamStruct{})), 1, gpu.ComputeShader) + vr.ReadOnly = true + vr = sgp.AddStruct("Data", int(unsafe.Sizeof(DataStruct{})), 1, gpu.ComputeShader) + sgp.SetNValues(1) + } + sy.Config() + } +} + +// GPURelease releases the GPU compute system resources. +// Call this at program exit. +func GPURelease() { + if GPUSystem != nil { + GPUSystem.Release() + GPUSystem = nil + } + + if ComputeGPU != nil { + ComputeGPU.Release() + ComputeGPU = nil + } +} + +// RunCompute runs the Compute kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// Can call multiple Run* kernels in a row, which are then all launched +// in the same command submission on the GPU, which is by far the most efficient. +// MUST call RunDone (with optional vars to sync) after all Run calls. +// Alternatively, a single-shot RunOneCompute call does Run and Done for a +// single run-and-sync case. +func RunCompute(n int) { + if UseGPU { + RunComputeGPU(n) + } else { + RunComputeCPU(n) + } +} + +// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info. +func RunComputeGPU(n int) { + sy := GPUSystem + pl := sy.ComputePipelines["Compute"] + ce, _ := sy.BeginComputePass() + pl.Dispatch1D(ce, n, 64) +} + +// RunComputeCPU runs the Compute kernel on the CPU. +func RunComputeCPU(n int) { + gpu.VectorizeFunc(0, n, Compute) +} + +// RunOneCompute runs the Compute kernel with given number of elements, +// on either the CPU or GPU depending on the UseGPU variable. +// This version then calls RunDone with the given variables to sync +// after the Run, for a single-shot Run-and-Done call. If multiple kernels +// can be run in sequence, it is much more efficient to do multiple Run* +// calls followed by a RunDone call. +func RunOneCompute(n int, syncVars ...GPUVars) { + if UseGPU { + RunComputeGPU(n) + RunDone(syncVars...) + } else { + RunComputeCPU(n) + } +} +// RunDone must be called after Run* calls to start compute kernels. +// This actually submits the kernel jobs to the GPU, and adds commands +// to synchronize the given variables back from the GPU to the CPU. +// After this function completes, the GPU results will be available in +// the specified variables. +func RunDone(syncVars ...GPUVars) { + if !UseGPU { + return + } + sy := GPUSystem + sy.ComputeEncoder.End() + ReadFromGPU(syncVars...) + sy.EndComputePass() + SyncFromGPU(syncVars...) +} + +// ToGPU copies given variables to the GPU for the system. +func ToGPU(vars ...GPUVars) { + if !UseGPU { + return + } + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case ParamsVar: + v, _ := syVars.ValueByIndex(0, "Params", 0) + gpu.SetValueFrom(v, Params) + case DataVar: + v, _ := syVars.ValueByIndex(0, "Data", 0) + gpu.SetValueFrom(v, Data) + } + } +} +// RunGPUSync can be called to synchronize data between CPU and GPU. +// Any prior ToGPU* calls will execute to send data to the GPU, +// and any subsequent RunDone* calls will copy data back from the GPU. +func RunGPUSync() { + if !UseGPU { + return + } + sy := GPUSystem + sy.BeginComputePass() +} + +// ReadFromGPU starts the process of copying vars to the GPU. +func ReadFromGPU(vars ...GPUVars) { + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case ParamsVar: + v, _ := syVars.ValueByIndex(0, "Params", 0) + v.GPUToRead(sy.CommandEncoder) + case DataVar: + v, _ := syVars.ValueByIndex(0, "Data", 0) + v.GPUToRead(sy.CommandEncoder) + } + } +} + +// SyncFromGPU synchronizes vars from the GPU to the actual variable. +func SyncFromGPU(vars ...GPUVars) { + sy := GPUSystem + syVars := sy.Vars() + for _, vr := range vars { + switch vr { + case ParamsVar: + v, _ := syVars.ValueByIndex(0, "Params", 0) + v.ReadSync() + gpu.ReadToBytes(v, Params) + case DataVar: + v, _ := syVars.ValueByIndex(0, "Data", 0) + v.ReadSync() + gpu.ReadToBytes(v, Data) + } + } +} + +// GetParams returns a pointer to the given global variable: +// [Params] []ParamStruct at given index. +// To ensure that values are updated on the GPU, you must call [SetParams]. +// after all changes have been made. +func GetParams(idx uint32) *ParamStruct { + return &Params[idx] +} + +// GetData returns a pointer to the given global variable: +// [Data] []DataStruct at given index. +// To ensure that values are updated on the GPU, you must call [SetData]. +// after all changes have been made. +func GetData(idx uint32) *DataStruct { + return &Data[idx] +} diff --git a/goal/gosl/gotosl/translate.go b/goal/gosl/gotosl/translate.go new file mode 100644 index 0000000000..cfb4a2fcad --- /dev/null +++ b/goal/gosl/gotosl/translate.go @@ -0,0 +1,218 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gotosl + +import ( + "bytes" + "fmt" + "go/ast" + "go/token" + "log" + "os" + "os/exec" + "path/filepath" + "sort" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/goal/gosl/alignsl" + "golang.org/x/exp/maps" + "golang.org/x/tools/go/packages" +) + +// TranslateDir translate all .Go files in given directory to WGSL. +func (st *State) TranslateDir(pf string) error { + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes | packages.NeedTypesInfo}, pf) + // pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadAllSyntax}, pf) + if err != nil { + return errors.Log(err) + } + if len(pkgs) != 1 { + err := fmt.Errorf("More than one package for path: %v", pf) + return errors.Log(err) + } + pkg := pkgs[0] + if len(pkg.GoFiles) == 0 { + err := fmt.Errorf("No Go files found in package: %+v", pkg) + return errors.Log(err) + } + + // fmt.Printf("go files: %+v", pkg.GoFiles) + // return nil, err + files := pkg.GoFiles + + serr := alignsl.CheckPackage(pkg) + + if serr != nil { + fmt.Println(serr) + } + + st.FuncGraph = make(map[string]*Function) + st.GetFuncGraph = true + + doFile := func(gofp string, buf *bytes.Buffer) { + _, gofn := filepath.Split(gofp) + if st.Config.Debug { + fmt.Printf("###################################\nTranslating Go file: %s\n", gofn) + } + var afile *ast.File + var fpos token.Position + for _, sy := range pkg.Syntax { + pos := pkg.Fset.Position(sy.Package) + _, posfn := filepath.Split(pos.Filename) + if posfn == gofn { + fpos = pos + afile = sy + break + } + } + if afile == nil { + fmt.Printf("Warning: File named: %s not found in Loaded package\n", gofn) + return + } + + pcfg := PrintConfig{GoToSL: st, Mode: printerMode, Tabwidth: tabWidth, ExcludeFunctions: st.ExcludeMap} + pcfg.Fprint(buf, pkg, afile) + if !st.GetFuncGraph && !st.Config.Keep { + os.Remove(fpos.Filename) + } + } + + // first pass is just to get the call graph: + for fn := range st.GoVarsFiles { // do varsFiles first!! + var buf bytes.Buffer + doFile(fn, &buf) + } + for _, gofp := range files { + _, gofn := filepath.Split(gofp) + if _, ok := st.GoVarsFiles[gofn]; ok { + continue + } + var buf bytes.Buffer + doFile(gofp, &buf) + } + + // st.PrintFuncGraph() + + doKernelFile := func(fname string, lines [][]byte) ([][]byte, bool, bool) { + _, gofn := filepath.Split(fname) + var buf bytes.Buffer + doFile(fname, &buf) + slfix, hasSlrand, hasSltype := SlEdits(buf.Bytes()) + slfix = SlRemoveComments(slfix) + exsl := st.ExtractWGSL(slfix) + lines = append(lines, []byte("")) + lines = append(lines, []byte(fmt.Sprintf("//////// import: %q", gofn))) + lines = append(lines, exsl...) + return lines, hasSlrand, hasSltype + } + + // next pass is per kernel + st.GetFuncGraph = false + sys := maps.Keys(st.Systems) + sort.Strings(sys) + for _, snm := range sys { + sy := st.Systems[snm] + kns := maps.Keys(sy.Kernels) + sort.Strings(kns) + for _, knm := range kns { + kn := sy.Kernels[knm] + st.KernelFuncs = st.AllFuncs(kn.Name) + if st.KernelFuncs == nil { + continue + } + var hasSlrand, hasSltype, hasR, hasT bool + avars := st.AtomicVars(st.KernelFuncs) + // if st.Config.Debug { + fmt.Printf("###################################\nTranslating Kernel file: %s\n", kn.Name) + // } + hdr := st.GenKernelHeader(sy, kn, avars) + lines := bytes.Split([]byte(hdr), []byte("\n")) + for fn := range st.GoVarsFiles { // do varsFiles first!! + lines, hasR, hasT = doKernelFile(fn, lines) + if hasR { + hasSlrand = true + } + if hasT { + hasSltype = true + } + } + for _, gofp := range files { + _, gofn := filepath.Split(gofp) + if _, ok := st.GoVarsFiles[gofn]; ok { + continue + } + lines, hasR, hasT = doKernelFile(gofp, lines) + if hasR { + hasSlrand = true + } + if hasT { + hasSltype = true + } + } + if hasSlrand { + st.CopyPackageFile("slrand.wgsl", "cogentcore.org/core/goal/gosl/slrand") + hasSltype = true + } + if hasSltype { + st.CopyPackageFile("sltype.wgsl", "cogentcore.org/core/goal/gosl/sltype") + } + for _, im := range st.SLImportFiles { + lines = append(lines, []byte("")) + lines = append(lines, []byte(fmt.Sprintf("//////// import: %q", im.Name))) + lines = append(lines, im.Lines...) + } + kn.Lines = lines + kfn := kn.Name + ".wgsl" + fn := filepath.Join(st.Config.Output, kfn) + kn.Filename = fn + WriteFileLines(fn, lines) + st.CompileFile(kfn) + } + } + + return nil +} + +var ( + nagaWarned = false + tintWarned = false +) + +func (st *State) CompileFile(fn string) error { + dir, _ := filepath.Abs(st.Config.Output) + if _, err := exec.LookPath("naga"); err == nil { + // cmd := exec.Command("naga", "--compact", fn, fn) // produces some pretty weird code actually + cmd := exec.Command("naga", fn) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + fmt.Printf("\n-----------------------------------------------------\nnaga output for: %s\n%s", fn, out) + if err != nil { + log.Println(err) + return err + } + } else { + if !nagaWarned { + fmt.Println("\nImportant: you should install the 'naga' WGSL compiler from https://github.com/gfx-rs/wgpu to get immediate validation") + nagaWarned = true + } + } + if _, err := exec.LookPath("tint"); err == nil { + cmd := exec.Command("tint", "--validate", "--format", "wgsl", "-o", "/dev/null", fn) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + fmt.Printf("\n-----------------------------------------------------\ntint output for: %s\n%s", fn, out) + if err != nil { + log.Println(err) + return err + } + } else { + if !tintWarned { + fmt.Println("\nImportant: you should install the 'tint' WGSL compiler from https://dawn.googlesource.com/dawn/ to get immediate validation") + tintWarned = true + } + } + + return nil +} diff --git a/goal/gosl/gotosl/typegen.go b/goal/gosl/gotosl/typegen.go new file mode 100644 index 0000000000..9c995c34a5 --- /dev/null +++ b/goal/gosl/gotosl/typegen.go @@ -0,0 +1,155 @@ +// Code generated by "core generate -add-types -add-funcs"; DO NOT EDIT. + +package gotosl + +import ( + "cogentcore.org/core/types" +) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Function", IDName: "function", Doc: "Function represents the call graph of functions", Fields: []types.Field{{Name: "Name"}, {Name: "Funcs"}, {Name: "Atomics"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Config", IDName: "config", Doc: "Config has the configuration info for the gosl system.", Fields: []types.Field{{Name: "Output", Doc: "Output is the output directory for shader code,\nrelative to where gosl is invoked; must not be an empty string."}, {Name: "Exclude", Doc: "Exclude is a comma-separated list of names of functions to exclude from exporting to WGSL."}, {Name: "Keep", Doc: "Keep keeps temporary converted versions of the source files, for debugging."}, {Name: "Debug", Doc: "\tDebug enables debugging messages while running."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.System", IDName: "system", Doc: "System represents a ComputeSystem, and its kernels and variables.", Fields: []types.Field{{Name: "Name"}, {Name: "Kernels", Doc: "Kernels are the kernels using this compute system."}, {Name: "Groups", Doc: "Groups are the variables for this compute system."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Kernel", IDName: "kernel", Doc: "Kernel represents a kernel function, which is the basis for\neach wgsl generated code file.", Fields: []types.Field{{Name: "Name"}, {Name: "Args"}, {Name: "Filename", Doc: "Filename is the name of the kernel shader file, e.g., shaders/Compute.wgsl"}, {Name: "FuncCode", Doc: "function code"}, {Name: "Lines", Doc: "Lines is full shader code"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Var", IDName: "var", Doc: "Var represents one global system buffer variable.", Fields: []types.Field{{Name: "Name"}, {Name: "Doc", Doc: "comment docs about this var."}, {Name: "Type", Doc: "Type of variable: either []Type or F32, U32 for tensors"}, {Name: "ReadOnly", Doc: "ReadOnly indicates that this variable is never read back from GPU,\nspecified by the gosl:read-only property in the variable comments.\nIt is important to optimize GPU memory usage to indicate this."}, {Name: "Tensor", Doc: "True if a tensor type"}, {Name: "TensorDims", Doc: "Number of dimensions"}, {Name: "TensorKind", Doc: "data kind of the tensor"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Group", IDName: "group", Doc: "Group represents one variable group.", Fields: []types.Field{{Name: "Name"}, {Name: "Doc", Doc: "comment docs about this group"}, {Name: "Uniform", Doc: "Uniform indicates a uniform group; else default is Storage"}, {Name: "Vars"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.File", IDName: "file", Doc: "File has contents of a file as lines of bytes.", Fields: []types.Field{{Name: "Name"}, {Name: "Lines"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.GetGlobalVar", IDName: "get-global-var", Doc: "GetGlobalVar holds GetVar expression, to Set variable back when done.", Fields: []types.Field{{Name: "Var", Doc: "global variable"}, {Name: "TmpVar", Doc: "name of temporary variable"}, {Name: "IdxExpr", Doc: "index passed to the Get function"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.State", IDName: "state", Doc: "State holds the current Go -> WGSL processing state.", Fields: []types.Field{{Name: "Config", Doc: "Config options."}, {Name: "ImportsDir", Doc: "path to shaders/imports directory."}, {Name: "Package", Doc: "name of the package"}, {Name: "GoFiles", Doc: "GoFiles are all the files with gosl content in current directory."}, {Name: "GoVarsFiles", Doc: "GoVarsFiles are all the files with gosl:vars content in current directory.\nThese must be processed first! they are moved from GoFiles to here."}, {Name: "GoImports", Doc: "GoImports has all the imported files."}, {Name: "ImportPackages", Doc: "ImportPackages has short package names, to remove from go code\nso everything lives in same main package."}, {Name: "Systems", Doc: "Systems has the kernels and variables for each system.\nThere is an initial \"Default\" system when system is not specified."}, {Name: "GetFuncs", Doc: "GetFuncs is a map of GetVar, SetVar function names for global vars."}, {Name: "SLImportFiles", Doc: "SLImportFiles are all the extracted and translated WGSL files in shaders/imports,\nwhich are copied into the generated shader kernel files."}, {Name: "GPUFile", Doc: "generated Go GPU gosl.go file contents"}, {Name: "ExcludeMap", Doc: "ExcludeMap is the compiled map of functions to exclude in Go -> WGSL translation."}, {Name: "GetVarStack", Doc: "GetVarStack is a stack per function definition of GetVar variables\nthat need to be set at the end."}, {Name: "GetFuncGraph", Doc: "GetFuncGraph is true if getting the function graph (first pass)"}, {Name: "KernelFuncs", Doc: "KernelFuncs are the list of functions to include for current kernel."}, {Name: "FuncGraph", Doc: "FuncGraph is the call graph of functions, for dead code elimination"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.exprListMode", IDName: "expr-list-mode"}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.paramMode", IDName: "param-mode"}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.rwArg", IDName: "rw-arg", Fields: []types.Field{{Name: "idx"}, {Name: "tmpVar"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.sizeCounter", IDName: "size-counter", Doc: "sizeCounter is an io.Writer which counts the number of bytes written,\nas well as whether a newline character was seen.", Fields: []types.Field{{Name: "hasNewline"}, {Name: "size"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.whiteSpace", IDName: "white-space"}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.pmode", IDName: "pmode", Doc: "A pmode value represents the current printer mode."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.commentInfo", IDName: "comment-info", Fields: []types.Field{{Name: "cindex"}, {Name: "comment"}, {Name: "commentOffset"}, {Name: "commentNewline"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.printer", IDName: "printer", Embeds: []types.Field{{Name: "PrintConfig", Doc: "Configuration (does not change after initialization)"}, {Name: "commentInfo", Doc: "Information about p.comments[p.cindex]; set up by nextComment."}}, Fields: []types.Field{{Name: "fset"}, {Name: "pkg"}, {Name: "output", Doc: "Current state"}, {Name: "indent"}, {Name: "level"}, {Name: "mode"}, {Name: "endAlignment"}, {Name: "impliedSemi"}, {Name: "lastTok"}, {Name: "prevOpen"}, {Name: "wsbuf"}, {Name: "goBuild"}, {Name: "plusBuild"}, {Name: "pos", Doc: "Positions\nThe out position differs from the pos position when the result\nformatting differs from the source formatting (in the amount of\nwhite space). If there's a difference and SourcePos is set in\nConfigMode, //line directives are used in the output to restore\noriginal source positions for a reader."}, {Name: "out"}, {Name: "last"}, {Name: "linePtr"}, {Name: "sourcePosErr"}, {Name: "comments", Doc: "The list of all source comments, in order of appearance."}, {Name: "useNodeComments"}, {Name: "nodeSizes", Doc: "Cache of already computed node sizes."}, {Name: "cachedPos", Doc: "Cache of most recently computed line position."}, {Name: "cachedLine"}, {Name: "curPtrArgs", Doc: "current arguments to function that are pointers and thus need dereferencing\nwhen accessing fields"}, {Name: "curFunc"}, {Name: "curMethRecv"}, {Name: "curReturnType"}, {Name: "curMethIsAtomic"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.trimmer", IDName: "trimmer", Doc: "A trimmer is an io.Writer filter for stripping tabwriter.Escape\ncharacters, trailing blanks and tabs, and for converting formfeed\nand vtab characters into newlines and htabs (in case no tabwriter\nis used). Text bracketed by tabwriter.Escape characters is passed\nthrough unchanged.", Fields: []types.Field{{Name: "output"}, {Name: "state"}, {Name: "space"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Mode", IDName: "mode", Doc: "A Mode value is a set of flags (or 0). They control printing."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.PrintConfig", IDName: "print-config", Doc: "A PrintConfig node controls the output of Fprint.", Fields: []types.Field{{Name: "Mode"}, {Name: "Tabwidth"}, {Name: "Indent"}, {Name: "GoToSL"}, {Name: "ExcludeFunctions"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.CommentedNode", IDName: "commented-node", Doc: "A CommentedNode bundles an AST node and corresponding comments.\nIt may be provided as argument to any of the [Fprint] functions.", Fields: []types.Field{{Name: "Node"}, {Name: "Comments"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/gosl/gotosl.Replace", IDName: "replace", Fields: []types.Field{{Name: "From"}, {Name: "To"}}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.NewFunction", Args: []string{"name"}, Returns: []string{"Function"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.getAllFuncs", Args: []string{"f", "all"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.formatDocComment", Doc: "formatDocComment reformats the doc comment list,\nreturning the canonical formatting.", Args: []string{"list"}, Returns: []string{"Comment"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.isDirective", Doc: "isDirective reports whether c is a comment directive.\nSee go.dev/issue/37974.\nThis code is also in go/ast.", Args: []string{"c"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.allStars", Doc: "allStars reports whether text is the interior of an\nold-style /* */ comment with a star at the start of each line.", Args: []string{"text"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.Run", Directives: []types.Directive{{Tool: "cli", Directive: "cmd", Args: []string{"-root"}}, {Tool: "types", Directive: "add"}}, Args: []string{"cfg"}, Returns: []string{"error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.wgslFile", Doc: "wgslFile returns the file with a \".wgsl\" extension", Args: []string{"fn"}, Returns: []string{"string"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.bareFile", Doc: "bareFile returns the file with no extention", Args: []string{"fn"}, Returns: []string{"string"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.ReadFileLines", Args: []string{"fn"}, Returns: []string{"[][]byte", "error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.WriteFileLines", Args: []string{"fn", "lines"}, Returns: []string{"error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.IsGoFile", Args: []string{"f"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.IsWGSLFile", Args: []string{"f"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.RemoveGenFiles", Doc: "RemoveGenFiles removes .go, .wgsl, .spv files in shader generated dir", Args: []string{"dir"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.CopyFile", Args: []string{"src", "dst"}, Returns: []string{"[][]byte", "error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.appendLines", Doc: "appendLines is like append(x, y...)\nbut it avoids creating doubled blank lines,\nwhich would not be gofmt-standard output.\nIt assumes that only whole blocks of lines are being appended,\nnot line fragments.", Args: []string{"x", "y"}, Returns: []string{"[]byte"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.isNL", Args: []string{"b"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.NewSystem", Args: []string{"name"}, Returns: []string{"System"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.hasDirective", Doc: "gosl: hasDirective returns whether directive(s) contains string", Args: []string{"dirs", "dir"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.directiveAfter", Doc: "gosl: directiveAfter returns the directive after given leading text,\nand a bool indicating if the string was found.", Args: []string{"dirs", "dir"}, Returns: []string{"string", "bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.combinesWithName", Doc: "combinesWithName reports whether a name followed by the expression x\nsyntactically combines to another valid (value) expression. For instance\nusing *T for x, \"name *T\" syntactically appears as the expression x*T.\nOn the other hand, using P|Q or *P|~Q for x, \"name P|Q\" or name *P|~Q\"\ncannot be combined into a valid (value) expression.", Args: []string{"x"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.isTypeElem", Doc: "isTypeElem reports whether x is a (possibly parenthesized) type element expression.\nThe result is false if x could be a type element OR an ordinary (value) expression.", Args: []string{"x"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.identListSize", Args: []string{"list", "maxSize"}, Returns: []string{"size"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.walkBinary", Args: []string{"e"}, Returns: []string{"has4", "has5", "maxProblem"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.cutoff", Args: []string{"e", "depth"}, Returns: []string{"int"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.diffPrec", Args: []string{"expr", "prec"}, Returns: []string{"int"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.reduceDepth", Args: []string{"depth"}, Returns: []string{"int"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.isBinary", Args: []string{"expr"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.normalizedNumber", Doc: "normalizedNumber rewrites base prefixes and exponents\nof numbers to use lower-case letters (0X123 to 0x123 and 1.2E3 to 1.2e3),\nand removes leading 0's from integer imaginary literals (0765i to 765i).\nIt leaves hexadecimal digits alone.\n\nnormalizedNumber doesn't modify the ast.BasicLit value lit points to.\nIf lit is not a number or a number in canonical format already,\nlit is returned as is. Otherwise a new ast.BasicLit is created.", Args: []string{"lit"}, Returns: []string{"BasicLit"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.fieldByName", Args: []string{"st", "name"}, Returns: []string{"Var"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.getLocalTypeName", Args: []string{"typ"}, Returns: []string{"string"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.isTypeName", Args: []string{"x"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.stripParens", Args: []string{"x"}, Returns: []string{"Expr"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.stripParensAlways", Args: []string{"x"}, Returns: []string{"Expr"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.keepTypeColumn", Doc: "The keepTypeColumn function determines if the type column of a series of\nconsecutive const or var declarations must be kept, or if initialization\nvalues (V) can be placed in the type column (T) instead. The i'th entry\nin the result slice is true if the type column in spec[i] must be kept.\n\nFor example, the declaration:\n\n\t\tconst (\n\t\t\tfoobar int = 42 // comment\n\t\t\tx = 7 // comment\n\t\t\tfoo\n\t bar = 991\n\t\t)\n\nleads to the type/values matrix below. A run of value columns (V) can\nbe moved into the type column if there is no type for any of the values\nin that column (we only move entire columns so that they align properly).\n\n\t\tmatrix formatted result\n\t matrix\n\t\tT V -> T V -> true there is a T and so the type\n\t\t- V - V true column must be kept\n\t\t- - - - false\n\t\t- V V - false V is moved into T column", Args: []string{"specs"}, Returns: []string{"[]bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.sanitizeImportPath", Args: []string{"lit"}, Returns: []string{"BasicLit"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.declToken", Args: []string{"decl"}, Returns: []string{"tok"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.isBlank", Doc: "Returns true if s contains only white space\n(only tabs and blanks can appear in the printer's context).", Args: []string{"s"}, Returns: []string{"bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.commonPrefix", Doc: "commonPrefix returns the common prefix of a and b.", Args: []string{"a", "b"}, Returns: []string{"string"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.trimRight", Doc: "trimRight returns s with trailing whitespace removed.", Args: []string{"s"}, Returns: []string{"string"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.stripCommonPrefix", Doc: "stripCommonPrefix removes a common prefix from /*-style comment lines (unless no\ncomment line is indented, all but the first line have some form of space prefix).\nThe prefix is computed using heuristics such that is likely that the comment\ncontents are nicely laid out after re-printing each line using the printer's\ncurrent indentation.", Args: []string{"lines"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.nlimit", Doc: "nlimit limits n to maxNewlines.", Args: []string{"n"}, Returns: []string{"int"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.mayCombine", Args: []string{"prev", "next"}, Returns: []string{"b"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.getDoc", Doc: "getDoc returns the ast.CommentGroup associated with n, if any.", Args: []string{"n"}, Returns: []string{"CommentGroup"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.getLastComment", Args: []string{"n"}, Returns: []string{"CommentGroup"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.newPrinter", Args: []string{"cfg", "pkg", "nodeSizes"}, Returns: []string{"printer"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.Fprint", Doc: "Fprint \"pretty-prints\" an AST node to output.\nIt calls [PrintConfig.Fprint] with default settings.\nNote that gofmt uses tabs for indentation but spaces for alignment;\nuse format.Node (package go/format) for output that matches gofmt.", Args: []string{"output", "pkg", "node"}, Returns: []string{"error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.MoveLines", Doc: "MoveLines moves the st,ed region to 'to' line", Args: []string{"lines", "to", "st", "ed"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.SlEdits", Doc: "SlEdits performs post-generation edits for wgsl,\nreplacing type names, slbool, function calls, etc.\nreturns true if a slrand. or sltype. prefix was found,\ndriveing copying of those files.", Args: []string{"src"}, Returns: []string{"lines", "hasSlrand", "hasSltype"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.MathReplaceAll", Args: []string{"mat", "ln"}, Returns: []string{"[]byte"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.SlRemoveComments", Args: []string{"lines"}, Returns: []string{"[][]byte"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.SlEditsReplace", Doc: "SlEditsReplace replaces Go with equivalent WGSL code\nreturns true if has slrand. or sltype.\nto auto include that header file if so.", Args: []string{"lines"}, Returns: []string{"bool", "bool"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/gosl/gotosl.SlBoolReplace", Doc: "SlBoolReplace replaces all the slbool methods with literal int32 expressions.", Args: []string{"lines"}}) diff --git a/gpu/gosl/slbool/README.md b/goal/gosl/slbool/README.md similarity index 100% rename from gpu/gosl/slbool/README.md rename to goal/gosl/slbool/README.md diff --git a/gpu/gosl/slbool/slbool.go b/goal/gosl/slbool/slbool.go similarity index 100% rename from gpu/gosl/slbool/slbool.go rename to goal/gosl/slbool/slbool.go diff --git a/gpu/gosl/slbool/slboolcore/slboolcore.go b/goal/gosl/slbool/slboolcore/slboolcore.go similarity index 87% rename from gpu/gosl/slbool/slboolcore/slboolcore.go rename to goal/gosl/slbool/slboolcore/slboolcore.go index a9d44c489a..d34a6e2e5e 100644 --- a/gpu/gosl/slbool/slboolcore/slboolcore.go +++ b/goal/gosl/slbool/slboolcore/slboolcore.go @@ -6,7 +6,7 @@ package slboolcore import ( "cogentcore.org/core/core" - "cogentcore.org/core/gpu/gosl/slbool" + "cogentcore.org/core/goal/gosl/slbool" ) func init() { diff --git a/gpu/gosl/slrand/README.md b/goal/gosl/slrand/README.md similarity index 81% rename from gpu/gosl/slrand/README.md rename to goal/gosl/slrand/README.md index 15e6ca02ff..0268ef99d9 100644 --- a/gpu/gosl/slrand/README.md +++ b/goal/gosl/slrand/README.md @@ -16,9 +16,7 @@ The key advantage of this algorithm is its *stateless* nature, where the result ``` where the WGSL `vec2` type is 2 `uint32` 32-bit unsigned integers. For GPU usage, the `key` is always set to the unique element being processed (e.g., the index of the data structure being updated), ensuring that different numbers are generated for each such element, and the `counter` should be configured as a shared global value that is incremented after each iteration of computation. - - - For example, if 4 RNG calls happen within a given set of GPU code, each thread starts with the same starting `counter` value, which is passed around as a local `vec2` variable and incremented locally for each RNG. Then, after all threads have been performed, the shared starting `counter` is incremented using `CounterAdd` by 4. +For example, if 4 RNG calls happen within a given set of GPU code, each thread starts with the same starting `counter` value, which is passed around as a local `vec2` variable and incremented locally for each RNG. Then, after all threads have been performed, the shared starting `counter` is incremented using `CounterAdd` by 4. The `Float` and `Uint32` etc wrapper functions around Philox2x32 will automatically increment the counter var passed to it, using the `CounterIncr()` method that manages the two 32 bit numbers as if they are a full 64 bit uint. @@ -30,11 +28,4 @@ See the [axon](https://github.com/emer/gosl/v2/tree/main/examples/axon) and [ran Critically, these examples show that the CPU and GPU code produce identical random number sequences, which is otherwise quite difficult to achieve without this specific form of RNG. -# Implementational details - -Unfortunately, vulkan `glslang` does not support 64 bit integers, even though the shader language model has somehow been updated to support them: https://github.com/KhronosGroup/glslang/issues/2965 -- https://github.com/microsoft/DirectXShaderCompiler/issues/2067. This would also greatly speed up the impl: https://github.com/microsoft/DirectXShaderCompiler/issues/2821. - -The result is that we have to use the slower version of the MulHiLo algorithm using only 32 bit uints. - - diff --git a/gpu/gosl/slrand/slrand.go b/goal/gosl/slrand/slrand.go similarity index 98% rename from gpu/gosl/slrand/slrand.go rename to goal/gosl/slrand/slrand.go index 05453cdb60..0266f165bc 100644 --- a/gpu/gosl/slrand/slrand.go +++ b/goal/gosl/slrand/slrand.go @@ -5,7 +5,7 @@ package slrand import ( - "cogentcore.org/core/gpu/gosl/sltype" + "cogentcore.org/core/goal/gosl/sltype" "cogentcore.org/core/math32" ) @@ -54,7 +54,7 @@ func Philox2x32(counter uint64, key uint32) sltype.Uint32Vec2 { return sltype.Uint64ToLoHi(Philox2x32round(counter, key)) // 10 } -//////////////////////////////////////////////////////////// +///////// // Methods below provide a standard interface with more // readable names, mapping onto the Go rand methods. // diff --git a/gpu/gosl/examples/rand/shaders/slrand.wgsl b/goal/gosl/slrand/slrand.wgsl similarity index 99% rename from gpu/gosl/examples/rand/shaders/slrand.wgsl rename to goal/gosl/slrand/slrand.wgsl index 820e7bdf62..372959fcb9 100644 --- a/gpu/gosl/examples/rand/shaders/slrand.wgsl +++ b/goal/gosl/slrand/slrand.wgsl @@ -9,7 +9,7 @@ // use on the GPU, with equivalent Go versions available in slrand.go. // This is using the Philox2x32 counter-based RNG. -#include "sltype.wgsl" +// #include "sltype.wgsl" // Philox2x32round does one round of updating of the counter. fn Philox2x32round(counter: su64, key: u32) -> su64 { diff --git a/gpu/gosl/slrand/slrand_test.go b/goal/gosl/slrand/slrand_test.go similarity index 97% rename from gpu/gosl/slrand/slrand_test.go rename to goal/gosl/slrand/slrand_test.go index 82c3d976af..38114148f7 100644 --- a/gpu/gosl/slrand/slrand_test.go +++ b/goal/gosl/slrand/slrand_test.go @@ -8,7 +8,7 @@ import ( "fmt" "testing" - "cogentcore.org/core/gpu/gosl/sltype" + "cogentcore.org/core/goal/gosl/sltype" "github.com/stretchr/testify/assert" ) diff --git a/gpu/gosl/sltype/README.md b/goal/gosl/sltype/README.md similarity index 100% rename from gpu/gosl/sltype/README.md rename to goal/gosl/sltype/README.md diff --git a/gpu/gosl/sltype/float.go b/goal/gosl/sltype/float.go similarity index 100% rename from gpu/gosl/sltype/float.go rename to goal/gosl/sltype/float.go diff --git a/gpu/gosl/sltype/int.go b/goal/gosl/sltype/int.go similarity index 94% rename from gpu/gosl/sltype/int.go rename to goal/gosl/sltype/int.go index 1e45a558d3..70468d4c7a 100644 --- a/gpu/gosl/sltype/int.go +++ b/goal/gosl/sltype/int.go @@ -20,8 +20,7 @@ type Int32Vec4 struct { W int32 } -//////////////////////////////////////// -// Unsigned +//////// Unsigned // Uint32Vec2 is a length 2 vector of uint32 type Uint32Vec2 struct { diff --git a/gpu/gosl/sltype/sltype.go b/goal/gosl/sltype/sltype.go similarity index 100% rename from gpu/gosl/sltype/sltype.go rename to goal/gosl/sltype/sltype.go diff --git a/gpu/gosl/examples/rand/shaders/sltype.wgsl b/goal/gosl/sltype/sltype.wgsl similarity index 100% rename from gpu/gosl/examples/rand/shaders/sltype.wgsl rename to goal/gosl/sltype/sltype.wgsl diff --git a/gpu/gosl/sltype/sltype_test.go b/goal/gosl/sltype/sltype_test.go similarity index 100% rename from gpu/gosl/sltype/sltype_test.go rename to goal/gosl/sltype/sltype_test.go diff --git a/gpu/gosl/threading/threading.go b/goal/gosl/threading/threading.go similarity index 100% rename from gpu/gosl/threading/threading.go rename to goal/gosl/threading/threading.go diff --git a/shell/cmd/cosh/cosh.go b/goal/interpreter/config.go similarity index 64% rename from shell/cmd/cosh/cosh.go rename to goal/interpreter/config.go index 9ef4b3cd2d..eb7dda9f7b 100644 --- a/shell/cmd/cosh/cosh.go +++ b/goal/interpreter/config.go @@ -2,26 +2,26 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Command cosh is an interactive cli for running and compiling Cogent Shell (cosh). -package main +package interpreter import ( "fmt" + "log/slog" "os" "path/filepath" "strings" "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/exec" "cogentcore.org/core/base/fsx" - "cogentcore.org/core/cli" - "cogentcore.org/core/shell" - "cogentcore.org/core/shell/interpreter" + "cogentcore.org/core/base/logx" + "cogentcore.org/core/goal" "github.com/cogentcore/yaegi/interp" ) //go:generate core generate -add-types -add-funcs -// Config is the configuration information for the cosh cli. +// Config is the configuration information for the goal cli. type Config struct { // Input is the input file to run/compile. @@ -37,7 +37,7 @@ type Config struct { Expr string `flag:"e,expr"` // Args is an optional list of arguments to pass in the run command. - // These arguments will be turned into an "args" local variable in the shell. + // These arguments will be turned into an "args" local variable in the goal. // These are automatically processed from any leftover arguments passed, so // you should not need to specify this flag manually. Args []string `cmd:"run" posarg:"leftover" required:"-"` @@ -46,24 +46,22 @@ type Config struct { // Interactive mode is the default mode for the run command unless an input file // is specified. Interactive bool `cmd:"run" flag:"i,interactive"` -} -func main() { //types:skip - opts := cli.DefaultOptions("cosh", "An interactive tool for running and compiling Cogent Shell (cosh).") - cli.Run(opts, &Config{}, Run, Build) + // InteractiveFunc is the function to run in interactive mode. + // set it to your own function as needed. + InteractiveFunc func(c *Config, in *Interpreter) error } -// Run runs the specified cosh file. If no file is specified, -// it runs an interactive shell that allows the user to input cosh. +// Run runs the specified goal file. If no file is specified, +// it runs an interactive shell that allows the user to input goal. func Run(c *Config) error { //cli:cmd -root - in := interpreter.NewInterpreter(interp.Options{}) - in.Config() + in := NewInterpreter(interp.Options{}) if len(c.Args) > 0 { - in.Eval("args := cosh.StringsToAnys(" + fmt.Sprintf("%#v)", c.Args)) + in.Eval("args := goalib.StringsToAnys(" + fmt.Sprintf("%#v)", c.Args)) } if c.Input == "" { - return Interactive(c, in) + return c.InteractiveFunc(c, in) } code := "" if errors.Log1(fsx.FileExists(c.Input)) { @@ -82,16 +80,17 @@ func Run(c *Config) error { //cli:cmd -root _, _, err := in.Eval(code) if err == nil { - err = in.Shell.DepthError() + err = in.Goal.TrState.DepthError() } if c.Interactive { - return Interactive(c, in) + return c.InteractiveFunc(c, in) } return err } -// Interactive runs an interactive shell that allows the user to input cosh. -func Interactive(c *Config, in *interpreter.Interpreter) error { +// Interactive runs an interactive shell that allows the user to input goal. +func Interactive(c *Config, in *Interpreter) error { + in.Config() if c.Expr != "" { in.Eval(c.Expr) } @@ -99,25 +98,48 @@ func Interactive(c *Config, in *interpreter.Interpreter) error { return nil } -// Build builds the specified input cosh file, or all .cosh files in the current +// Build builds the specified input goal file, or all .goal files in the current // directory if no input is specified, to corresponding .go file name(s). // If the file does not already contain a "package" specification, then // "package main; func main()..." wrappers are added, which allows the same // code to be used in interactive and Go compiled modes. +// go build is run after this. func Build(c *Config) error { var fns []string + verbose := logx.UserLevel <= slog.LevelInfo if c.Input != "" { fns = []string{c.Input} } else { - fns = fsx.Filenames(".", ".cosh") + fns = fsx.Filenames(".", ".goal") } + curpkg, _ := exec.Minor().Output("go", "list", "./") var errs []error for _, fn := range fns { + fpath := filepath.Join(curpkg, fn) + if verbose { + fmt.Println(fpath) + } ofn := strings.TrimSuffix(fn, filepath.Ext(fn)) + ".go" - err := shell.NewShell().TranspileFile(fn, ofn) + err := goal.NewGoal().TranspileFile(fn, ofn) if err != nil { errs = append(errs, err) } } + + // cfg := &gotosl.Config{} + // cfg.Debug = verbose + // err := gotosl.Run(cfg) + // if err != nil { + // errs = append(errs, err) + // } + args := []string{"build"} + if verbose { + args = append(args, "-v") + } + err := exec.Verbose().Run("go", args...) + if err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) } diff --git a/goal/interpreter/imports.go b/goal/interpreter/imports.go new file mode 100644 index 0000000000..b46f9f789e --- /dev/null +++ b/goal/interpreter/imports.go @@ -0,0 +1,28 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package interpreter + +import ( + "reflect" + + "github.com/cogentcore/yaegi/interp" +) + +var Symbols = map[string]map[string]reflect.Value{} + +// ImportGoal imports special symbols from the goal package. +func (in *Interpreter) ImportGoal() { + in.Interp.Use(interp.Exports{ + "cogentcore.org/core/goal/goal": map[string]reflect.Value{ + "Run": reflect.ValueOf(in.Goal.Run), + "RunErrOK": reflect.ValueOf(in.Goal.RunErrOK), + "Output": reflect.ValueOf(in.Goal.Output), + "OutputErrOK": reflect.ValueOf(in.Goal.OutputErrOK), + "Start": reflect.ValueOf(in.Goal.Start), + "AddCommand": reflect.ValueOf(in.Goal.AddCommand), + "RunCommands": reflect.ValueOf(in.Goal.RunCommands), + }, + }) +} diff --git a/shell/interpreter/imports_test.go b/goal/interpreter/imports_test.go similarity index 100% rename from shell/interpreter/imports_test.go rename to goal/interpreter/imports_test.go diff --git a/shell/interpreter/interpreter.go b/goal/interpreter/interpreter.go similarity index 66% rename from shell/interpreter/interpreter.go rename to goal/interpreter/interpreter.go index a70c5079f5..4e43cd338a 100644 --- a/shell/interpreter/interpreter.go +++ b/goal/interpreter/interpreter.go @@ -17,7 +17,12 @@ import ( "syscall" "cogentcore.org/core/base/errors" - "cogentcore.org/core/shell" + "cogentcore.org/core/goal" + _ "cogentcore.org/core/tensor/stats/metric" + _ "cogentcore.org/core/tensor/stats/stats" + "cogentcore.org/core/tensor/tensorfs" + _ "cogentcore.org/core/tensor/tmath" + "cogentcore.org/core/yaegicore/nogui" "github.com/cogentcore/yaegi/interp" "github.com/cogentcore/yaegi/stdlib" "github.com/ergochat/readline" @@ -25,11 +30,11 @@ import ( // Interpreter represents one running shell context type Interpreter struct { - // the cosh shell - Shell *shell.Shell + // the goal shell + Goal *goal.Goal // HistFile is the name of the history file to open / save. - // Defaults to ~/.cosh-history for the default cosh shell. + // Defaults to ~/.goal-history for the default goal shell. // Update this prior to running Config() to take effect. HistFile string @@ -46,36 +51,41 @@ func init() { // functions. End user app must call [Interp.Config] after importing any additional // symbols, prior to running the interpreter. func NewInterpreter(options interp.Options) *Interpreter { - in := &Interpreter{HistFile: "~/.cosh-history"} - in.Shell = shell.NewShell() + in := &Interpreter{HistFile: "~/.goal-history"} + in.Goal = goal.NewGoal() if options.Stdin != nil { - in.Shell.Config.StdIO.In = options.Stdin + in.Goal.Config.StdIO.In = options.Stdin } if options.Stdout != nil { - in.Shell.Config.StdIO.Out = options.Stdout + in.Goal.Config.StdIO.Out = options.Stdout } if options.Stderr != nil { - in.Shell.Config.StdIO.Err = options.Stderr + in.Goal.Config.StdIO.Err = options.Stderr } - in.Shell.SaveOrigStdIO() - options.Stdout = in.Shell.StdIOWrappers.Out - options.Stderr = in.Shell.StdIOWrappers.Err - options.Stdin = in.Shell.StdIOWrappers.In + in.Goal.SaveOrigStdIO() + options.Stdout = in.Goal.StdIOWrappers.Out + options.Stderr = in.Goal.StdIOWrappers.Err + options.Stdin = in.Goal.StdIOWrappers.In in.Interp = interp.New(options) - errors.Log(in.Interp.Use(stdlib.Symbols)) - errors.Log(in.Interp.Use(Symbols)) - in.ImportShell() + errors.Log(in.Interp.Use(nogui.Symbols)) + in.ImportGoal() go in.MonitorSignals() return in } // Prompt returns the appropriate REPL prompt to show the user. func (in *Interpreter) Prompt() string { - dp := in.Shell.TotalDepth() + dp := in.Goal.TrState.TotalDepth() + pc := ">" + dir := in.Goal.HostAndDir() + if in.Goal.TrState.MathMode { + pc = "#" + dir = tensorfs.CurDir.Path() + } if dp == 0 { - return in.Shell.HostAndDir() + " > " + return dir + " " + pc + " " } - res := "> " + res := pc + " " for range dp { res += " " // note: /t confuses readline } @@ -89,21 +99,21 @@ func (in *Interpreter) Prompt() string { // whether to print the result in interactive mode. // It automatically logs any error in addition to returning it. func (in *Interpreter) Eval(code string) (v reflect.Value, hasPrint bool, err error) { - in.Shell.TranspileCode(code) + in.Goal.TranspileCode(code) source := false - if in.Shell.SSHActive == "" { + if in.Goal.SSHActive == "" { source = strings.HasPrefix(code, "source") } - if in.Shell.TotalDepth() == 0 { - nl := len(in.Shell.Lines) + if in.Goal.TrState.TotalDepth() == 0 { + nl := len(in.Goal.TrState.Lines) if nl > 0 { - ln := in.Shell.Lines[nl-1] + ln := in.Goal.TrState.Lines[nl-1] if strings.Contains(strings.ToLower(ln), "print") { hasPrint = true } } v, err = in.RunCode() - in.Shell.Errors = nil + in.Goal.Errors = nil } if source { v, err = in.RunCode() // run accumulated code @@ -115,27 +125,28 @@ func (in *Interpreter) Eval(code string) (v reflect.Value, hasPrint bool, err er // and clears the stack of code lines. // It automatically logs any error in addition to returning it. func (in *Interpreter) RunCode() (reflect.Value, error) { - if len(in.Shell.Errors) > 0 { - return reflect.Value{}, errors.Join(in.Shell.Errors...) + if len(in.Goal.Errors) > 0 { + return reflect.Value{}, errors.Join(in.Goal.Errors...) } - in.Shell.AddChunk() - code := in.Shell.Chunks - in.Shell.ResetCode() + in.Goal.TrState.AddChunk() + code := in.Goal.TrState.Chunks + in.Goal.TrState.ResetCode() var v reflect.Value var err error for _, ch := range code { - ctx := in.Shell.StartContext() + ctx := in.Goal.StartContext() v, err = in.Interp.EvalWithContext(ctx, ch) - in.Shell.EndContext() + in.Goal.EndContext() if err != nil { cancelled := errors.Is(err, context.Canceled) // fmt.Println("cancelled:", cancelled) - in.Shell.RestoreOrigStdIO() - in.Shell.ResetDepth() + in.Goal.DeleteAllJobs() + in.Goal.RestoreOrigStdIO() + in.Goal.TrState.ResetDepth() if !cancelled { - in.Shell.AddError(err) + in.Goal.AddError(err) } else { - in.Shell.Errors = nil + in.Goal.Errors = nil } break } @@ -143,10 +154,10 @@ func (in *Interpreter) RunCode() (reflect.Value, error) { return v, err } -// RunConfig runs the .cosh startup config file in the user's +// RunConfig runs the .goal startup config file in the user's // home directory if it exists. func (in *Interpreter) RunConfig() error { - err := in.Shell.TranspileConfig() + err := in.Goal.TranspileConfig() if err != nil { errors.Log(err) } @@ -159,10 +170,12 @@ func (in *Interpreter) RunConfig() error { // It is called automatically in another goroutine in [NewInterpreter]. func (in *Interpreter) MonitorSignals() { c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) + // todo: syscall.SIGSEGV not defined on web + signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + // signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGSEGV) for { <-c - in.Shell.CancelExecution() + in.Goal.CancelExecution() } } @@ -175,9 +188,9 @@ func (in *Interpreter) Config() { // OpenHistory opens history from the current HistFile // and loads it into the readline history for given rl instance func (in *Interpreter) OpenHistory(rl *readline.Instance) error { - err := in.Shell.OpenHistory(in.HistFile) + err := in.Goal.OpenHistory(in.HistFile) if err == nil { - for _, h := range in.Shell.Hist { + for _, h := range in.Goal.Hist { rl.SaveToHistory(h) } } @@ -191,19 +204,20 @@ func (in *Interpreter) SaveHistory() error { if hfs := os.Getenv("HISTFILESIZE"); hfs != "" { en, err := strconv.Atoi(hfs) if err != nil { - in.Shell.Config.StdIO.ErrPrintf("SaveHistory: environment variable HISTFILESIZE: %q not a number: %s", hfs, err.Error()) + in.Goal.Config.StdIO.ErrPrintf("SaveHistory: environment variable HISTFILESIZE: %q not a number: %s", hfs, err.Error()) } else { n = en } } - return in.Shell.SaveHistory(n, in.HistFile) + return in.Goal.SaveHistory(n, in.HistFile) } -// Interactive runs an interactive shell that allows the user to input cosh. +// Interactive runs an interactive shell that allows the user to input goal. // Must have done in.Config() prior to calling. func (in *Interpreter) Interactive() error { + in.Goal.TrState.MathRecord = true rl, err := readline.NewFromConfig(&readline.Config{ - AutoComplete: &shell.ReadlineCompleter{Shell: in.Shell}, + AutoComplete: &goal.ReadlineCompleter{Goal: in.Goal}, Undo: true, }) if err != nil { @@ -229,21 +243,21 @@ func (in *Interpreter) Interactive() error { } if len(line) > 0 && line[0] == '!' { // history command hl, err := strconv.Atoi(line[1:]) - nh := len(in.Shell.Hist) + nh := len(in.Goal.Hist) if err != nil { - in.Shell.Config.StdIO.ErrPrintf("history number: %q not a number: %s", line[1:], err.Error()) + in.Goal.Config.StdIO.ErrPrintf("history number: %q not a number: %s", line[1:], err.Error()) line = "" } else if hl >= nh { - in.Shell.Config.StdIO.ErrPrintf("history number: %d not in range: [0:%d]", hl, nh) + in.Goal.Config.StdIO.ErrPrintf("history number: %d not in range: [0:%d]", hl, nh) line = "" } else { - line = in.Shell.Hist[hl] + line = in.Goal.Hist[hl] fmt.Printf("h:%d\t%s\n", hl, line) } } else if line != "" && !strings.HasPrefix(line, "history") && line != "h" { - in.Shell.AddHistory(line) + in.Goal.AddHistory(line) } - in.Shell.Errors = nil + in.Goal.Errors = nil v, hasPrint, err := in.Eval(line) if err == nil && !hasPrint && v.IsValid() && !v.IsZero() && v.Kind() != reflect.Func { fmt.Println(v.Interface()) diff --git a/goal/interpreter/typegen.go b/goal/interpreter/typegen.go new file mode 100644 index 0000000000..2060f498c6 --- /dev/null +++ b/goal/interpreter/typegen.go @@ -0,0 +1,21 @@ +// Code generated by "core generate -add-types -add-funcs"; DO NOT EDIT. + +package interpreter + +import ( + "cogentcore.org/core/types" +) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/interpreter.Config", IDName: "config", Doc: "Config is the configuration information for the goal cli.", Directives: []types.Directive{{Tool: "go", Directive: "generate", Args: []string{"core", "generate", "-add-types", "-add-funcs"}}}, Fields: []types.Field{{Name: "Input", Doc: "Input is the input file to run/compile.\nIf this is provided as the first argument,\nthen the program will exit after running,\nunless the Interactive mode is flagged."}, {Name: "Expr", Doc: "Expr is an optional expression to evaluate, which can be used\nin addition to the Input file to run, to execute commands\ndefined within that file for example, or as a command to run\nprior to starting interactive mode if no Input is specified."}, {Name: "Args", Doc: "Args is an optional list of arguments to pass in the run command.\nThese arguments will be turned into an \"args\" local variable in the goal.\nThese are automatically processed from any leftover arguments passed, so\nyou should not need to specify this flag manually."}, {Name: "Interactive", Doc: "Interactive runs the interactive command line after processing any input file.\nInteractive mode is the default mode for the run command unless an input file\nis specified."}, {Name: "InteractiveFunc", Doc: "InteractiveFunc is the function to run in interactive mode.\nset it to your own function as needed."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/goal/interpreter.Interpreter", IDName: "interpreter", Doc: "Interpreter represents one running shell context", Fields: []types.Field{{Name: "Goal", Doc: "the goal shell"}, {Name: "HistFile", Doc: "HistFile is the name of the history file to open / save.\nDefaults to ~/.goal-history for the default goal shell.\nUpdate this prior to running Config() to take effect."}, {Name: "Interp", Doc: "the yaegi interpreter"}}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/interpreter.Run", Doc: "Run runs the specified goal file. If no file is specified,\nit runs an interactive shell that allows the user to input goal.", Directives: []types.Directive{{Tool: "cli", Directive: "cmd", Args: []string{"-root"}}}, Args: []string{"c"}, Returns: []string{"error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/interpreter.Interactive", Doc: "Interactive runs an interactive shell that allows the user to input goal.", Args: []string{"c", "in"}, Returns: []string{"error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/interpreter.Build", Doc: "Build builds the specified input goal file, or all .goal files in the current\ndirectory if no input is specified, to corresponding .go file name(s).\nIf the file does not already contain a \"package\" specification, then\n\"package main; func main()...\" wrappers are added, which allows the same\ncode to be used in interactive and Go compiled modes.\ngo build is run after this.", Args: []string{"c"}, Returns: []string{"error"}}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/interpreter.init"}) + +var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/goal/interpreter.NewInterpreter", Doc: "NewInterpreter returns a new [Interpreter] initialized with the given options.\nIt automatically imports the standard library and configures necessary shell\nfunctions. End user app must call [Interp.Config] after importing any additional\nsymbols, prior to running the interpreter.", Args: []string{"options"}, Returns: []string{"Interpreter"}}) diff --git a/goal/math_test.go b/goal/math_test.go new file mode 100644 index 0000000000..250f196d0d --- /dev/null +++ b/goal/math_test.go @@ -0,0 +1,44 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package goal + +import ( + "testing" + + "cogentcore.org/core/goal/goalib" + "github.com/stretchr/testify/assert" +) + +var test = ` +// # x := [3, 5, 4] +# x := zeros(3, 4) +# nd := x.ndim +# sz := x.size +# sh := x.shape + +fmt.Println(x) +fmt.Println(nd) +fmt.Println(sz) +fmt.Println(sh) + +type MyStru struct { + Name string + Doc string +} + +var VarCategories = []MyStru{ + {"Act", "basic activation variables, including conductances, current, Vm, spiking"}, + {"Learn", "calcium-based learning variables and other related learning factors"}, +} +` + +func TestMath(t *testing.T) { + gl := NewGoal() + tfile := "testdata/test.goal" + ofile := "testdata/test.go" + goalib.WriteFile(tfile, test) + err := gl.TranspileFile(tfile, ofile) + assert.NoError(t, err) +} diff --git a/shell/run.go b/goal/run.go similarity index 56% rename from shell/run.go rename to goal/run.go index d6347a962f..465e784310 100644 --- a/shell/run.go +++ b/goal/run.go @@ -2,43 +2,43 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package goal // Run executes the given command string, waiting for the command to finish, // handling the given arguments appropriately. -// If there is any error, it adds it to the shell, and triggers CancelExecution. +// If there is any error, it adds it to the goal, and triggers CancelExecution. // It forwards output to [exec.Config.Stdout] and [exec.Config.Stderr] appropriately. -func (sh *Shell) Run(cmd any, args ...any) { - sh.Exec(false, false, false, cmd, args...) +func (gl *Goal) Run(cmd any, args ...any) { + gl.Exec(false, false, false, cmd, args...) } // RunErrOK executes the given command string, waiting for the command to finish, // handling the given arguments appropriately. // It does not stop execution if there is an error. -// If there is any error, it adds it to the shell. It forwards output to +// If there is any error, it adds it to the goal. It forwards output to // [exec.Config.Stdout] and [exec.Config.Stderr] appropriately. -func (sh *Shell) RunErrOK(cmd any, args ...any) { - sh.Exec(true, false, false, cmd, args...) +func (gl *Goal) RunErrOK(cmd any, args ...any) { + gl.Exec(true, false, false, cmd, args...) } // Start starts the given command string for running in the background, // handling the given arguments appropriately. -// If there is any error, it adds it to the shell. It forwards output to +// If there is any error, it adds it to the goal. It forwards output to // [exec.Config.Stdout] and [exec.Config.Stderr] appropriately. -func (sh *Shell) Start(cmd any, args ...any) { - sh.Exec(false, true, false, cmd, args...) +func (gl *Goal) Start(cmd any, args ...any) { + gl.Exec(false, true, false, cmd, args...) } // Output executes the given command string, handling the given arguments -// appropriately. If there is any error, it adds it to the shell. It returns +// appropriately. If there is any error, it adds it to the goal. It returns // the stdout as a string and forwards stderr to [exec.Config.Stderr] appropriately. -func (sh *Shell) Output(cmd any, args ...any) string { - return sh.Exec(false, false, true, cmd, args...) +func (gl *Goal) Output(cmd any, args ...any) string { + return gl.Exec(false, false, true, cmd, args...) } // OutputErrOK executes the given command string, handling the given arguments -// appropriately. If there is any error, it adds it to the shell. It returns +// appropriately. If there is any error, it adds it to the goal. It returns // the stdout as a string and forwards stderr to [exec.Config.Stderr] appropriately. -func (sh *Shell) OutputErrOK(cmd any, args ...any) string { - return sh.Exec(true, false, true, cmd, args...) +func (gl *Goal) OutputErrOK(cmd any, args ...any) string { + return gl.Exec(true, false, true, cmd, args...) } diff --git a/goal/testdata/test.goal b/goal/testdata/test.goal new file mode 100644 index 0000000000..a93be7ec5b --- /dev/null +++ b/goal/testdata/test.goal @@ -0,0 +1,21 @@ + +// # x := [3, 5, 4] +# x := zeros(3, 4) +# nd := x.ndim +# sz := x.size +# sh := x.shape + +fmt.Println(x) +fmt.Println(nd) +fmt.Println(sz) +fmt.Println(sh) + +type MyStru struct { + Name string + Doc string +} + +var VarCategories = []MyStru{ + {"Act", "basic activation variables, including conductances, current, Vm, spiking"}, + {"Learn", "calcium-based learning variables and other related learning factors"}, +} diff --git a/goal/transpile/addfuncs.go b/goal/transpile/addfuncs.go new file mode 100644 index 0000000000..a608fbba76 --- /dev/null +++ b/goal/transpile/addfuncs.go @@ -0,0 +1,38 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "path" + "reflect" + "strings" + + "cogentcore.org/core/tensor" + "cogentcore.org/core/yaegicore/nogui" +) + +func init() { + AddYaegiTensorFuncs() +} + +// AddYaegiTensorFuncs grabs all tensor* package functions registered +// in yaegicore and adds them to the `tensor.Funcs` map so we can +// properly convert symbols to either tensors or basic literals, +// depending on the arg types for the current function. +func AddYaegiTensorFuncs() { + for pth, symap := range nogui.Symbols { + if !strings.Contains(pth, "/core/tensor/") { + continue + } + _, pkg := path.Split(pth) + for name, val := range symap { + if val.Kind() != reflect.Func { + continue + } + pnm := pkg + "." + name + tensor.AddFunc(pnm, val.Interface()) + } + } +} diff --git a/goal/transpile/datafs.go b/goal/transpile/datafs.go new file mode 100644 index 0000000000..4c6be28f4a --- /dev/null +++ b/goal/transpile/datafs.go @@ -0,0 +1,50 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "errors" + "go/token" +) + +var tensorfsCommands = map[string]func(mp *mathParse) error{ + "cd": cd, + "mkdir": mkdir, + "ls": ls, +} + +func cd(mp *mathParse) error { + var dir string + if len(mp.ewords) > 1 { + dir = mp.ewords[1] + } + mp.out.Add(token.IDENT, "tensorfs.Chdir") + mp.out.Add(token.LPAREN) + mp.out.Add(token.STRING, `"`+dir+`"`) + mp.out.Add(token.RPAREN) + return nil +} + +func mkdir(mp *mathParse) error { + if len(mp.ewords) == 1 { + return errors.New("tensorfs mkdir requires a directory name") + } + dir := mp.ewords[1] + mp.out.Add(token.IDENT, "tensorfs.Mkdir") + mp.out.Add(token.LPAREN) + mp.out.Add(token.STRING, `"`+dir+`"`) + mp.out.Add(token.RPAREN) + return nil +} + +func ls(mp *mathParse) error { + mp.out.Add(token.IDENT, "tensorfs.List") + mp.out.Add(token.LPAREN) + for i := 1; i < len(mp.ewords); i++ { + mp.out.Add(token.STRING, `"`+mp.ewords[i]+`"`) + } + mp.out.Add(token.RPAREN) + return nil +} diff --git a/shell/execwords.go b/goal/transpile/execwords.go similarity index 87% rename from shell/execwords.go rename to goal/transpile/execwords.go index e5515195fe..6a66cb1ace 100644 --- a/shell/execwords.go +++ b/goal/transpile/execwords.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package transpile import ( "fmt" @@ -17,6 +17,21 @@ func ExecWords(ln string) ([]string, error) { return nil, nil } + if ln[0] == '$' { + ln = strings.TrimSpace(ln[1:]) + n = len(ln) + if n == 0 { + return nil, nil + } + if ln[n-1] == '$' { + ln = strings.TrimSpace(ln[:n-1]) + n = len(ln) + if n == 0 { + return nil, nil + } + } + } + word := "" esc := false dQuote := false @@ -139,7 +154,7 @@ func ExecWords(ln string) ([]string, error) { } addWord() if dQuote || bQuote || brack > 0 { - return words, fmt.Errorf("cosh: exec command has unterminated quotes (\": %v, `: %v) or brackets [ %v ]", dQuote, bQuote, brack > 0) + return words, fmt.Errorf("goal: exec command has unterminated quotes (\": %v, `: %v) or brackets [ %v ]", dQuote, bQuote, brack > 0) } return words, nil } @@ -147,7 +162,7 @@ func ExecWords(ln string) ([]string, error) { // ExecWordIsCommand returns true if given exec word is a command-like string // (excluding any paths) func ExecWordIsCommand(f string) bool { - if strings.Contains(f, "(") || strings.Contains(f, "=") { + if strings.Contains(f, "(") || strings.Contains(f, "[") || strings.Contains(f, "=") { return false } return true diff --git a/goal/transpile/math.go b/goal/transpile/math.go new file mode 100644 index 0000000000..0e7cc2d433 --- /dev/null +++ b/goal/transpile/math.go @@ -0,0 +1,988 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "fmt" + "go/ast" + "go/token" + "strings" + + "cogentcore.org/core/base/stack" + "cogentcore.org/core/tensor" +) + +// TranspileMath does math mode transpiling. fullLine indicates code should be +// full statement(s). +func (st *State) TranspileMath(toks Tokens, code string, fullLine bool) Tokens { + nt := len(toks) + if nt == 0 { + return nil + } + // fmt.Println(nt, toks) + + str := code[toks[0].Pos-1 : toks[nt-1].Pos] + if toks[nt-1].Str != "" { + str += toks[nt-1].Str[1:] + } + // fmt.Println(str) + mp := mathParse{state: st, toks: toks, code: code} + // mp.trace = true + + mods := AllErrors // | Trace + + if fullLine { + ewords, err := ExecWords(str) + if len(ewords) > 0 { + if cmd, ok := tensorfsCommands[ewords[0]]; ok { + mp.ewords = ewords + err := cmd(&mp) + if err != nil { + fmt.Println(ewords[0]+":", err.Error()) + return nil + } else { + return mp.out + } + } + } + + stmts, err := ParseLine(str, mods) + if err != nil { + fmt.Println("line code:", str) + fmt.Println("parse err:", err) + } + if len(stmts) == 0 { + return toks + } + mp.stmtList(stmts) + } else { + ex, err := ParseExpr(str, mods) + if err != nil { + fmt.Println("expr:", str) + fmt.Println("parse err:", err) + } + mp.expr(ex) + } + + if mp.idx != len(toks) { + fmt.Println(code) + fmt.Println(mp.out.Code()) + fmt.Printf("parsing error: index: %d != len(toks): %d\n", mp.idx, len(toks)) + } + + return mp.out +} + +// funcInfo is info about the function being processed +type funcInfo struct { + tensor.Func + + // current arg index we are processing + curArg int +} + +// mathParse has the parsing state, active only during a parsing pass +// on one specific chunk of code and tokens. +type mathParse struct { + state *State + code string // code string + toks Tokens // source tokens we are parsing + ewords []string // exec words + idx int // current index in source tokens -- critical to sync as we "use" source + out Tokens // output tokens we generate + trace bool // trace of parsing -- turn on to see alignment + + // stack of function info -- top of stack reflects the current function + funcs stack.Stack[*funcInfo] +} + +// returns the current argument for current function +func (mp *mathParse) curArg() *tensor.Arg { + cfun := mp.funcs.Peek() + if cfun == nil { + return nil + } + if cfun.curArg < len(cfun.Args) { + return cfun.Args[cfun.curArg] + } + return nil +} + +func (mp *mathParse) nextArg() { + cfun := mp.funcs.Peek() + if cfun == nil || len(cfun.Args) == 0 { + // fmt.Println("next arg no fun or no args") + return + } + n := len(cfun.Args) + if cfun.curArg == n-1 { + carg := cfun.Args[n-1] + if !carg.IsVariadic { + fmt.Println("math transpile: args exceed registered function number:", cfun) + } + return + } + cfun.curArg++ +} + +func (mp *mathParse) curArgIsTensor() bool { + carg := mp.curArg() + if carg == nil { + return false + } + return carg.IsTensor +} + +func (mp *mathParse) curArgIsInts() bool { + carg := mp.curArg() + if carg == nil { + return false + } + return carg.IsInt && carg.IsVariadic +} + +// startFunc is called when starting a new function. +// empty is "dummy" assign case using Inc. +// optional noLookup indicates to not lookup type and just +// push the name -- for internal cases to prevent arg conversions. +func (mp *mathParse) startFunc(name string, noLookup ...bool) *funcInfo { + fi := &funcInfo{} + sname := name + if name == "" || name == "tensor.Tensor" { + sname = "tmath.Inc" // one arg tensor fun + } + if len(noLookup) == 1 && noLookup[0] { + fi.Name = name + } else { + if tf, err := tensor.FuncByName(sname); err == nil { + fi.Func = *tf + } else { + fi.Name = name + } + } + mp.funcs.Push(fi) + if name != "" { + mp.out.Add(token.IDENT, name) + } + return fi +} + +func (mp *mathParse) endFunc() { + mp.funcs.Pop() +} + +// addToken adds output token and increments idx +func (mp *mathParse) addToken(tok token.Token) { + mp.out.Add(tok) + if mp.trace { + ctok := &Token{} + if mp.idx < len(mp.toks) { + ctok = mp.toks[mp.idx] + } + fmt.Printf("%d\ttok: %s \t replaces: %s\n", mp.idx, tok, ctok) + } + mp.idx++ +} + +func (mp *mathParse) addCur() { + if len(mp.toks) > mp.idx { + mp.out.AddTokens(mp.toks[mp.idx]) + mp.idx++ + return + } + fmt.Println("out of tokens!", mp.idx, mp.toks) +} + +func (mp *mathParse) stmtList(sts []ast.Stmt) { + for _, st := range sts { + mp.stmt(st) + } +} + +func (mp *mathParse) stmt(st ast.Stmt) { + if st == nil { + return + } + switch x := st.(type) { + case *ast.BadStmt: + fmt.Println("bad stmt!") + + case *ast.DeclStmt: + + case *ast.ExprStmt: + mp.expr(x.X) + + case *ast.SendStmt: + mp.expr(x.Chan) + mp.addToken(token.ARROW) + mp.expr(x.Value) + + case *ast.IncDecStmt: + fn := "Inc" + if x.Tok == token.DEC { + fn = "Dec" + } + mp.startFunc("tmath." + fn) + mp.out.Add(token.LPAREN) + mp.expr(x.X) + mp.addToken(token.RPAREN) + + case *ast.AssignStmt: + switch x.Tok { + case token.DEFINE: + mp.defineStmt(x) + default: + mp.assignStmt(x) + } + + case *ast.GoStmt: + mp.addToken(token.GO) + mp.callExpr(x.Call) + + case *ast.DeferStmt: + mp.addToken(token.DEFER) + mp.callExpr(x.Call) + + case *ast.ReturnStmt: + mp.addToken(token.RETURN) + mp.exprList(x.Results) + + case *ast.BranchStmt: + mp.addToken(x.Tok) + mp.ident(x.Label) + + case *ast.BlockStmt: + mp.addToken(token.LBRACE) + mp.stmtList(x.List) + mp.addToken(token.RBRACE) + + case *ast.IfStmt: + mp.addToken(token.IF) + mp.stmt(x.Init) + if x.Init != nil { + mp.addToken(token.SEMICOLON) + } + mp.expr(x.Cond) + mp.out.Add(token.IDENT, ".Bool1D(0)") // turn bool expr into actual bool + if x.Body != nil && len(x.Body.List) > 0 { + mp.addToken(token.LBRACE) + mp.stmtList(x.Body.List) + mp.addToken(token.RBRACE) + } else { + mp.addToken(token.LBRACE) + } + if x.Else != nil { + mp.addToken(token.ELSE) + mp.stmt(x.Else) + } + + case *ast.ForStmt: + mp.addToken(token.FOR) + mp.stmt(x.Init) + if x.Init != nil { + mp.addToken(token.SEMICOLON) + } + mp.expr(x.Cond) + if x.Cond != nil { + mp.out.Add(token.IDENT, ".Bool1D(0)") // turn bool expr into actual bool + mp.addToken(token.SEMICOLON) + } + mp.stmt(x.Post) + if x.Body != nil && len(x.Body.List) > 0 { + mp.addToken(token.LBRACE) + mp.stmtList(x.Body.List) + mp.addToken(token.RBRACE) + } else { + mp.addToken(token.LBRACE) + } + + case *ast.RangeStmt: + if x.Key == nil || x.Value == nil { + fmt.Println("for range statement requires both index and value variables") + return + } + ki, _ := x.Key.(*ast.Ident) + vi, _ := x.Value.(*ast.Ident) + ei, _ := x.X.(*ast.Ident) + if ki == nil || vi == nil || ei == nil { + fmt.Println("for range statement requires all variables (index, value, range) to be variable names, not other expressions") + return + } + knm := ki.Name + vnm := vi.Name + enm := ei.Name + + mp.addToken(token.FOR) + mp.expr(x.Key) + mp.idx += 2 + mp.addToken(token.DEFINE) + mp.out.Add(token.IDENT, "0") + mp.out.Add(token.SEMICOLON) + mp.out.Add(token.IDENT, knm) + mp.out.Add(token.IDENT, "<") + mp.out.Add(token.IDENT, enm) + mp.out.Add(token.PERIOD) + mp.out.Add(token.IDENT, "Len") + mp.idx++ + mp.out.AddMulti(token.LPAREN, token.RPAREN) + mp.idx++ + mp.out.Add(token.SEMICOLON) + mp.idx++ + mp.out.Add(token.IDENT, knm) + mp.out.AddMulti(token.INC, token.LBRACE) + + mp.out.Add(token.IDENT, vnm) + mp.out.Add(token.DEFINE) + mp.out.Add(token.IDENT, enm) + mp.out.Add(token.IDENT, ".Float1D") + mp.out.Add(token.LPAREN) + mp.out.Add(token.IDENT, knm) + mp.out.Add(token.RPAREN) + + if x.Body != nil && len(x.Body.List) > 0 { + mp.stmtList(x.Body.List) + mp.addToken(token.RBRACE) + } + + // TODO + // CaseClause: SwitchStmt:, TypeSwitchStmt:, CommClause:, SelectStmt: + } +} + +func (mp *mathParse) expr(ex ast.Expr) { + if ex == nil { + return + } + switch x := ex.(type) { + case *ast.BadExpr: + fmt.Println("bad expr!") + + case *ast.Ident: + mp.ident(x) + + case *ast.UnaryExpr: + mp.unaryExpr(x) + + case *ast.Ellipsis: + cfun := mp.funcs.Peek() + if cfun != nil && cfun.Name == "tensor.Reslice" { + mp.out.Add(token.IDENT, "tensor.Ellipsis") + mp.idx++ + } else { + mp.addToken(token.ELLIPSIS) + } + + case *ast.StarExpr: + mp.addToken(token.MUL) + mp.expr(x.X) + + case *ast.BinaryExpr: + mp.binaryExpr(x) + + case *ast.BasicLit: + mp.basicLit(x) + + case *ast.FuncLit: + + case *ast.ParenExpr: + mp.addToken(token.LPAREN) + mp.expr(x.X) + mp.addToken(token.RPAREN) + + case *ast.SelectorExpr: + mp.selectorExpr(x) + + case *ast.TypeAssertExpr: + + case *ast.IndexExpr: + mp.indexExpr(x) + + case *ast.IndexListExpr: + if x.X == nil { // array literal + mp.arrayLiteral(x) + } else { + mp.indexListExpr(x) + } + + case *ast.SliceExpr: + mp.sliceExpr(x) + + case *ast.CallExpr: + mp.callExpr(x) + + case *ast.ArrayType: + // note: shouldn't happen normally: + fmt.Println("array type:", x, x.Len) + fmt.Printf("%#v\n", x.Len) + } +} + +func (mp *mathParse) exprList(ex []ast.Expr) { + n := len(ex) + if n == 0 { + return + } + if n == 1 { + mp.expr(ex[0]) + return + } + for i := range n { + mp.expr(ex[i]) + if i < n-1 { + mp.addToken(token.COMMA) + } + } +} + +func (mp *mathParse) argsList(ex []ast.Expr) { + n := len(ex) + if n == 0 { + return + } + if n == 1 { + mp.expr(ex[0]) + return + } + for i := range n { + // cfun := mp.funcs.Peek() + // if i != cfun.curArg { + // fmt.Println(cfun, "arg should be:", i, "is:", cfun.curArg) + // } + mp.expr(ex[i]) + if i < n-1 { + mp.nextArg() + mp.addToken(token.COMMA) + } + } +} + +func (mp *mathParse) exprIsBool(ex ast.Expr) bool { + switch x := ex.(type) { + case *ast.BinaryExpr: + if (x.Op >= token.EQL && x.Op <= token.GTR) || (x.Op >= token.NEQ && x.Op <= token.GEQ) { + return true + } + case *ast.ParenExpr: + return mp.exprIsBool(x.X) + } + return false +} + +func (mp *mathParse) exprsAreBool(ex []ast.Expr) bool { + for _, x := range ex { + if mp.exprIsBool(x) { + return true + } + } + return false +} + +func (mp *mathParse) binaryExpr(ex *ast.BinaryExpr) { + if ex.Op == token.ILLEGAL { // @ = matmul + mp.startFunc("matrix.Mul") + mp.out.Add(token.LPAREN) + mp.expr(ex.X) + mp.out.Add(token.COMMA) + mp.idx++ + mp.expr(ex.Y) + mp.out.Add(token.RPAREN) + mp.endFunc() + return + } + + fn := "" + switch ex.Op { + case token.ADD: + fn = "Add" + case token.SUB: + fn = "Sub" + case token.MUL: + fn = "Mul" + if un, ok := ex.Y.(*ast.StarExpr); ok { // ** power operator + ex.Y = un.X + fn = "Pow" + } + case token.QUO: + fn = "Div" + case token.EQL: + fn = "Equal" + case token.LSS: + fn = "Less" + case token.GTR: + fn = "Greater" + case token.NEQ: + fn = "NotEqual" + case token.LEQ: + fn = "LessEqual" + case token.GEQ: + fn = "GreaterEqual" + case token.LOR: + fn = "Or" + case token.LAND: + fn = "And" + default: + fmt.Println("binary token:", ex.Op) + } + mp.startFunc("tmath." + fn) + mp.out.Add(token.LPAREN) + mp.expr(ex.X) + mp.out.Add(token.COMMA) + mp.idx++ + if fn == "Pow" { + mp.idx++ + } + mp.expr(ex.Y) + mp.out.Add(token.RPAREN) + mp.endFunc() +} + +func (mp *mathParse) unaryExpr(ex *ast.UnaryExpr) { + if _, isbl := ex.X.(*ast.BasicLit); isbl { + mp.addToken(ex.Op) + mp.expr(ex.X) + return + } + fn := "" + switch ex.Op { + case token.NOT: + fn = "Not" + case token.SUB: + fn = "Negate" + case token.ADD: + mp.expr(ex.X) + return + default: // * goes to StarExpr -- not sure what else could happen here? + mp.addToken(ex.Op) + mp.expr(ex.X) + return + } + mp.startFunc("tmath." + fn) + mp.addToken(token.LPAREN) + mp.expr(ex.X) + mp.out.Add(token.RPAREN) + mp.endFunc() +} + +func (mp *mathParse) defineStmt(as *ast.AssignStmt) { + firstStmt := mp.idx == 0 + mp.exprList(as.Lhs) + mp.addToken(as.Tok) + mp.startFunc("tensor.Tensor") + mp.out.Add(token.LPAREN) + mp.exprList(as.Rhs) + mp.out.Add(token.RPAREN) + mp.endFunc() + if firstStmt && mp.state.MathRecord { + nvar, ok := as.Lhs[0].(*ast.Ident) + if ok { + mp.out.Add(token.SEMICOLON) + mp.out.Add(token.IDENT, "tensorfs.Record("+nvar.Name+",`"+nvar.Name+"`)") + } + } +} + +func (mp *mathParse) assignStmt(as *ast.AssignStmt) { + if as.Tok == token.ASSIGN { + if _, ok := as.Lhs[0].(*ast.Ident); ok { + mp.exprList(as.Lhs) + mp.addToken(as.Tok) + mp.startFunc("") + mp.exprList(as.Rhs) + mp.endFunc() + return + } + } + fn := "" + switch as.Tok { + case token.ASSIGN: + fn = "Assign" + case token.ADD_ASSIGN: + fn = "AddAssign" + case token.SUB_ASSIGN: + fn = "SubAssign" + case token.MUL_ASSIGN: + fn = "MulAssign" + case token.QUO_ASSIGN: + fn = "DivAssign" + } + mp.startFunc("tmath." + fn) + mp.out.Add(token.LPAREN) + mp.exprList(as.Lhs) + mp.out.Add(token.COMMA) + mp.idx++ + mp.exprList(as.Rhs) + mp.out.Add(token.RPAREN) + mp.endFunc() +} + +func (mp *mathParse) basicLit(lit *ast.BasicLit) { + if mp.curArgIsTensor() { + mp.tensorLit(lit) + return + } + mp.out.Add(lit.Kind, lit.Value) + if mp.trace { + fmt.Printf("%d\ttok: %s literal\n", mp.idx, lit.Value) + } + mp.idx++ + return +} + +func (mp *mathParse) tensorLit(lit *ast.BasicLit) { + switch lit.Kind { + case token.INT: + mp.out.Add(token.IDENT, "tensor.NewIntScalar("+lit.Value+")") + mp.idx++ + case token.FLOAT: + mp.out.Add(token.IDENT, "tensor.NewFloat64Scalar("+lit.Value+")") + mp.idx++ + case token.STRING: + mp.out.Add(token.IDENT, "tensor.NewStringScalar("+lit.Value+")") + mp.idx++ + } +} + +// funWrap is a function wrapper for simple numpy property / functions +type funWrap struct { + fun string // function to call on tensor + wrap string // code for wrapping function for results of call +} + +// nis: NewIntScalar, niv: NewIntFromValues, etc +var numpyProps = map[string]funWrap{ + "ndim": {"NumDims()", "nis"}, + "len": {"Len()", "nis"}, + "size": {"Len()", "nis"}, + "shape": {"Shape().Sizes", "niv"}, + "T": {"", "tensor.Transpose"}, +} + +// tensorFunc outputs the wrapping function and whether it needs ellipsis +func (fw *funWrap) wrapFunc(mp *mathParse) bool { + ellip := false + wrapFun := fw.wrap + switch fw.wrap { + case "nis": + wrapFun = "tensor.NewIntScalar" + case "nfs": + wrapFun = "tensor.NewFloat64Scalar" + case "nss": + wrapFun = "tensor.NewStringScalar" + case "niv": + wrapFun = "tensor.NewIntFromValues" + ellip = true + case "nfv": + wrapFun = "tensor.NewFloat64FromValues" + ellip = true + case "nsv": + wrapFun = "tensor.NewStringFromValues" + ellip = true + default: + wrapFun = fw.wrap + } + mp.startFunc(wrapFun, true) // don't lookup -- don't auto-convert args + mp.out.Add(token.LPAREN) + return ellip +} + +func (mp *mathParse) selectorExpr(ex *ast.SelectorExpr) { + fw, ok := numpyProps[ex.Sel.Name] + if !ok { + mp.expr(ex.X) + mp.addToken(token.PERIOD) + mp.out.Add(token.IDENT, ex.Sel.Name) + mp.idx++ + return + } + ellip := fw.wrapFunc(mp) + mp.expr(ex.X) + if fw.fun != "" { + mp.addToken(token.PERIOD) + mp.out.Add(token.IDENT, fw.fun) + mp.idx++ + } else { + mp.idx += 2 + } + if ellip { + mp.out.Add(token.ELLIPSIS) + } + mp.out.Add(token.RPAREN) + mp.endFunc() +} + +func (mp *mathParse) indexListExpr(il *ast.IndexListExpr) { + // fmt.Println("slice expr", se) +} + +func (mp *mathParse) indexExpr(il *ast.IndexExpr) { + if _, ok := il.Index.(*ast.IndexListExpr); ok { + mp.basicSlicingExpr(il) + } +} + +func (mp *mathParse) basicSlicingExpr(il *ast.IndexExpr) { + iil := il.Index.(*ast.IndexListExpr) + fun := "tensor.Reslice" + if mp.exprsAreBool(iil.Indices) { + fun = "tensor.Mask" + } + mp.startFunc(fun) + mp.out.Add(token.LPAREN) + mp.expr(il.X) + mp.nextArg() + mp.addToken(token.COMMA) // use the [ -- can't use ( to preserve X + mp.exprList(iil.Indices) + mp.addToken(token.RPAREN) // replaces ] + mp.endFunc() +} + +func (mp *mathParse) sliceExpr(se *ast.SliceExpr) { + if se.Low == nil && se.High == nil && se.Max == nil { + mp.out.Add(token.IDENT, "tensor.FullAxis") + mp.idx++ + return + } + mp.out.Add(token.IDENT, "tensor.Slice") + mp.out.Add(token.LBRACE) + prev := false + if se.Low != nil { + mp.out.Add(token.IDENT, "Start:") + mp.expr(se.Low) + prev = true + if se.High == nil && se.Max == nil { + mp.idx++ + } + } + if se.High != nil { + if prev { + mp.out.Add(token.COMMA) + } + mp.out.Add(token.IDENT, "Stop:") + mp.idx++ + mp.expr(se.High) + prev = true + } + if se.Max != nil { + if prev { + mp.out.Add(token.COMMA) + } + mp.idx++ + if se.Low == nil && se.High == nil { + mp.idx++ + } + mp.out.Add(token.IDENT, "Step:") + mp.expr(se.Max) + } + mp.out.Add(token.RBRACE) +} + +func (mp *mathParse) arrayLiteral(il *ast.IndexListExpr) { + kind := inferKindExprList(il.Indices) + if kind == token.ILLEGAL { + kind = token.FLOAT // default + } + // todo: look for sub-arrays etc. + typ := "float64" + fun := "Float64" + switch kind { + case token.FLOAT: + case token.INT: + typ = "int" + fun = "Int" + case token.STRING: + typ = "string" + fun = "String" + } + if mp.curArgIsInts() { + mp.idx++ // opening brace we're not using + mp.exprList(il.Indices) + mp.idx++ // closing brace we're not using + return + } + var sh []int + mp.arrayShape(il.Indices, &sh) + if len(sh) > 1 { + mp.startFunc("tensor.Reshape") + mp.out.Add(token.LPAREN) + } + mp.startFunc("tensor.New" + fun + "FromValues") + mp.out.Add(token.LPAREN) + mp.out.Add(token.IDENT, "[]"+typ) + mp.addToken(token.LBRACE) + mp.exprList(il.Indices) + mp.addToken(token.RBRACE) + mp.out.AddMulti(token.ELLIPSIS, token.RPAREN) + mp.endFunc() + if len(sh) > 1 { + mp.out.Add(token.COMMA) + nsh := len(sh) + for i, s := range sh { + mp.out.Add(token.INT, fmt.Sprintf("%d", s)) + if i < nsh-1 { + mp.out.Add(token.COMMA) + } + } + mp.out.Add(token.RPAREN) + mp.endFunc() + } +} + +func (mp *mathParse) arrayShape(ex []ast.Expr, sh *[]int) { + n := len(ex) + if n == 0 { + return + } + *sh = append(*sh, n) + for i := range n { + if il, ok := ex[i].(*ast.IndexListExpr); ok { + mp.arrayShape(il.Indices, sh) + return + } + } +} + +// nofun = do not accept a function version, just a method +var numpyFuncs = map[string]funWrap{ + // "array": {"tensor.NewFloatFromValues", ""}, // todo: probably not right, maybe don't have? + "zeros": {"tensor.NewFloat64", ""}, + "full": {"tensor.NewFloat64Full", ""}, + "ones": {"tensor.NewFloat64Ones", ""}, + "rand": {"tensor.NewFloat64Rand", ""}, + "arange": {"tensor.NewIntRange", ""}, + "linspace": {"tensor.NewFloat64SpacedLinear", ""}, + "reshape": {"tensor.Reshape", ""}, + "copy": {"tensor.Clone", ""}, + "get": {"tensorfs.Get", ""}, + "set": {"tensorfs.Set", ""}, + "flatten": {"tensor.Flatten", "nofun"}, + "squeeze": {"tensor.Squeeze", "nofun"}, +} + +func (mp *mathParse) callExpr(ex *ast.CallExpr) { + switch x := ex.Fun.(type) { + case *ast.Ident: + if fw, ok := numpyProps[x.Name]; ok && fw.wrap != "nofun" { + mp.callPropFun(ex, fw) + return + } + mp.callName(ex, x.Name, "") + case *ast.SelectorExpr: + fun := x.Sel.Name + if pkg, ok := x.X.(*ast.Ident); ok { + if fw, ok := numpyFuncs[fun]; ok { + mp.callPropSelFun(ex, x.X, fw) + return + } else { + // fmt.Println("call name:", fun, pkg.Name) + mp.callName(ex, fun, pkg.Name) + } + } else { + if fw, ok := numpyFuncs[fun]; ok { + mp.callPropSelFun(ex, x.X, fw) + return + } + // todo: dot fun? + mp.expr(ex) + } + default: + mp.expr(ex.Fun) + } + mp.argsList(ex.Args) + // todo: ellipsis + mp.addToken(token.RPAREN) + mp.endFunc() +} + +// this calls a "prop" function like ndim(a) on the object. +func (mp *mathParse) callPropFun(cf *ast.CallExpr, fw funWrap) { + ellip := fw.wrapFunc(mp) + mp.idx += 2 + mp.exprList(cf.Args) // this is the tensor + mp.addToken(token.PERIOD) + mp.out.Add(token.IDENT, fw.fun) + if ellip { + mp.out.Add(token.ELLIPSIS) + } + mp.out.Add(token.RPAREN) + mp.endFunc() +} + +// this calls global function through selector like: a.reshape() +func (mp *mathParse) callPropSelFun(cf *ast.CallExpr, ex ast.Expr, fw funWrap) { + mp.startFunc(fw.fun) + mp.out.Add(token.LPAREN) // use the ( + mp.expr(ex) + mp.idx += 2 + if len(cf.Args) > 0 { + mp.nextArg() // did first + mp.addToken(token.COMMA) + mp.argsList(cf.Args) + } else { + mp.idx++ + } + mp.addToken(token.RPAREN) + mp.endFunc() +} + +func (mp *mathParse) callName(cf *ast.CallExpr, funName, pkgName string) { + if fw, ok := numpyFuncs[funName]; ok { + mp.startFunc(fw.fun) + mp.addToken(token.LPAREN) // use the ( + mp.idx++ // paren too + return + } + var err error // validate name + if pkgName != "" { + funName = pkgName + "." + funName + _, err = tensor.FuncByName(funName) + } else { // non-package qualified names are _only_ in tmath! can be lowercase + _, err = tensor.FuncByName("tmath." + funName) + if err != nil { + funName = strings.ToUpper(funName[:1]) + funName[1:] // first letter uppercased + _, err = tensor.FuncByName("tmath." + funName) + } + if err == nil { // registered, must be in tmath + funName = "tmath." + funName + } + } + if err != nil { // not a registered tensor function + // fmt.Println("regular fun", funName) + mp.startFunc(funName) + mp.addToken(token.LPAREN) // use the ( + mp.idx += 3 + return + } + mp.startFunc(funName) + mp.idx += 1 + if pkgName != "" { + mp.idx += 2 // . and selector + } + mp.addToken(token.LPAREN) +} + +// basic ident replacements +var consts = map[string]string{ + "newaxis": "tensor.NewAxis", + "pi": "tensor.NewFloat64Scalar(math.Pi)", +} + +func (mp *mathParse) ident(id *ast.Ident) { + if id == nil { + return + } + if cn, ok := consts[id.Name]; ok { + mp.out.Add(token.IDENT, cn) + mp.idx++ + return + } + if mp.curArgIsInts() { + mp.out.Add(token.IDENT, "tensor.AsIntSlice") + mp.out.Add(token.LPAREN) + mp.addCur() + mp.out.AddMulti(token.RPAREN, token.ELLIPSIS) + } else { + mp.addCur() + } +} diff --git a/goal/transpile/parser.go b/goal/transpile/parser.go new file mode 100644 index 0000000000..8a09e65ea3 --- /dev/null +++ b/goal/transpile/parser.go @@ -0,0 +1,2993 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// mparse is a hacked version of go/parser: +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "fmt" + "go/ast" + "go/build/constraint" + "go/scanner" + "go/token" + "strings" +) + +// ParseLine parses a line of code that could contain one or more statements +func ParseLine(code string, mode Mode) (stmts []ast.Stmt, err error) { + fset := token.NewFileSet() + var p parser + defer func() { + if e := recover(); e != nil { + // resume same panic if it's not a bailout + bail, ok := e.(bailout) + if !ok { + panic(e) + } else if bail.msg != "" { + p.errors.Add(p.file.Position(bail.pos), bail.msg) + } + } + p.errors.Sort() + err = p.errors.Err() + }() + p.init(fset, "", []byte(code), mode) + + stmts = p.parseStmtList() + + // If a semicolon was inserted, consume it; + // report an error if there's more tokens. + if p.tok == token.SEMICOLON && p.lit == "\n" { + p.next() + } + if p.tok == token.RBRACE { + return + } + p.expect(token.EOF) + + return +} + +// ParseExpr parses an expression +func ParseExpr(code string, mode Mode) (expr ast.Expr, err error) { + fset := token.NewFileSet() + var p parser + defer func() { + if e := recover(); e != nil { + // resume same panic if it's not a bailout + bail, ok := e.(bailout) + if !ok { + panic(e) + } else if bail.msg != "" { + p.errors.Add(p.file.Position(bail.pos), bail.msg) + } + } + p.errors.Sort() + err = p.errors.Err() + }() + p.init(fset, "", []byte(code), mode) + + expr = p.parseRhs() + + // If a semicolon was inserted, consume it; + // report an error if there's more tokens. + if p.tok == token.SEMICOLON && p.lit == "\n" { + p.next() + } + p.expect(token.EOF) + + return +} + +// A Mode value is a set of flags (or 0). +// They control the amount of source code parsed and other optional +// parser functionality. +type Mode uint + +const ( + ParseComments Mode = 1 << iota // parse comments and add them to AST + Trace // print a trace of parsed productions + DeclarationErrors // report declaration errors + SpuriousErrors // same as AllErrors, for backward-compatibility + AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines) +) + +// The parser structure holds the parser's internal state. +type parser struct { + file *token.File + errors scanner.ErrorList + scanner scanner.Scanner + + // Tracing/debugging + mode Mode // parsing mode + trace bool // == (mode&Trace != 0) + indent int // indentation used for tracing output + + // Comments + comments []*ast.CommentGroup + leadComment *ast.CommentGroup // last lead comment + lineComment *ast.CommentGroup // last line comment + top bool // in top of file (before package clause) + goVersion string // minimum Go version found in //go:build comment + + // Next token + pos token.Pos // token position + tok token.Token // one token look-ahead + lit string // token literal + + // Error recovery + // (used to limit the number of calls to parser.advance + // w/o making scanning progress - avoids potential endless + // loops across multiple parser functions during error recovery) + syncPos token.Pos // last synchronization position + syncCnt int // number of parser.advance calls without progress + + // Non-syntactic parser control + exprLev int // < 0: in control clause, >= 0: in expression + inRhs bool // if set, the parser is parsing a rhs expression + + imports []*ast.ImportSpec // list of imports + + // nestLev is used to track and limit the recursion depth + // during parsing. + nestLev int +} + +func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode Mode) { + p.file = fset.AddFile(filename, -1, len(src)) + eh := func(pos token.Position, msg string) { + if !strings.Contains(msg, "@") { + p.errors.Add(pos, msg) + } + } + p.scanner.Init(p.file, src, eh, scanner.ScanComments) + + p.top = true + p.mode = mode + p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently) + p.next() +} + +// ---------------------------------------------------------------------------- +// Parsing support + +func (p *parser) printTrace(a ...any) { + const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " + const n = len(dots) + pos := p.file.Position(p.pos) + fmt.Printf("%5d:%3d: ", pos.Line, pos.Column) + i := 2 * p.indent + for i > n { + fmt.Print(dots) + i -= n + } + // i <= n + fmt.Print(dots[0:i]) + fmt.Println(a...) +} + +func trace(p *parser, msg string) *parser { + p.printTrace(msg, "(") + p.indent++ + return p +} + +// Usage pattern: defer un(trace(p, "...")) +func un(p *parser) { + p.indent-- + p.printTrace(")") +} + +// maxNestLev is the deepest we're willing to recurse during parsing +const maxNestLev int = 1e5 + +func incNestLev(p *parser) *parser { + p.nestLev++ + if p.nestLev > maxNestLev { + p.error(p.pos, "exceeded max nesting depth") + panic(bailout{}) + } + return p +} + +// decNestLev is used to track nesting depth during parsing to prevent stack exhaustion. +// It is used along with incNestLev in a similar fashion to how un and trace are used. +func decNestLev(p *parser) { + p.nestLev-- +} + +// Advance to the next token. +func (p *parser) next0() { + // Because of one-token look-ahead, print the previous token + // when tracing as it provides a more readable output. The + // very first token (!p.pos.IsValid()) is not initialized + // (it is token.ILLEGAL), so don't print it. + if p.trace && p.pos.IsValid() { + s := p.tok.String() + switch { + case p.tok.IsLiteral(): + p.printTrace(s, p.lit) + case p.tok.IsOperator(), p.tok.IsKeyword(): + p.printTrace("\"" + s + "\"") + default: + p.printTrace(s) + } + } + + for { + p.pos, p.tok, p.lit = p.scanner.Scan() + if p.tok == token.COMMENT { + if p.top && strings.HasPrefix(p.lit, "//go:build") { + if x, err := constraint.Parse(p.lit); err == nil { + p.goVersion = constraint.GoVersion(x) + } + } + if p.mode&ParseComments == 0 { + continue + } + } else { + // Found a non-comment; top of file is over. + p.top = false + } + break + } +} + +// Consume a comment and return it and the line on which it ends. +func (p *parser) consumeComment() (comment *ast.Comment, endline int) { + // /*-style comments may end on a different line than where they start. + // Scan the comment for '\n' chars and adjust endline accordingly. + endline = p.file.Line(p.pos) + if p.lit[1] == '*' { + // don't use range here - no need to decode Unicode code points + for i := 0; i < len(p.lit); i++ { + if p.lit[i] == '\n' { + endline++ + } + } + } + + comment = &ast.Comment{Slash: p.pos, Text: p.lit} + p.next0() + + return +} + +// Consume a group of adjacent comments, add it to the parser's +// comments list, and return it together with the line at which +// the last comment in the group ends. A non-comment token or n +// empty lines terminate a comment group. +func (p *parser) consumeCommentGroup(n int) (comments *ast.CommentGroup, endline int) { + var list []*ast.Comment + endline = p.file.Line(p.pos) + for p.tok == token.COMMENT && p.file.Line(p.pos) <= endline+n { + var comment *ast.Comment + comment, endline = p.consumeComment() + list = append(list, comment) + } + + // add comment group to the comments list + comments = &ast.CommentGroup{List: list} + p.comments = append(p.comments, comments) + + return +} + +// Advance to the next non-comment token. In the process, collect +// any comment groups encountered, and remember the last lead and +// line comments. +// +// A lead comment is a comment group that starts and ends in a +// line without any other tokens and that is followed by a non-comment +// token on the line immediately after the comment group. +// +// A line comment is a comment group that follows a non-comment +// token on the same line, and that has no tokens after it on the line +// where it ends. +// +// Lead and line comments may be considered documentation that is +// stored in the AST. +func (p *parser) next() { + p.leadComment = nil + p.lineComment = nil + prev := p.pos + p.next0() + + if p.tok == token.COMMENT { + var comment *ast.CommentGroup + var endline int + + if p.file.Line(p.pos) == p.file.Line(prev) { + // The comment is on same line as the previous token; it + // cannot be a lead comment but may be a line comment. + comment, endline = p.consumeCommentGroup(0) + if p.file.Line(p.pos) != endline || p.tok == token.SEMICOLON || p.tok == token.EOF { + // The next token is on a different line, thus + // the last comment group is a line comment. + p.lineComment = comment + } + } + + // consume successor comments, if any + endline = -1 + for p.tok == token.COMMENT { + comment, endline = p.consumeCommentGroup(1) + } + + if endline+1 == p.file.Line(p.pos) { + // The next token is following on the line immediately after the + // comment group, thus the last comment group is a lead comment. + p.leadComment = comment + } + } +} + +// A bailout panic is raised to indicate early termination. pos and msg are +// only populated when bailing out of object resolution. +type bailout struct { + pos token.Pos + msg string +} + +func (p *parser) error(pos token.Pos, msg string) { + if p.trace { + defer un(trace(p, "error: "+msg)) + } + + epos := p.file.Position(pos) + + // If AllErrors is not set, discard errors reported on the same line + // as the last recorded error and stop parsing if there are more than + // 10 errors. + if p.mode&AllErrors == 0 { + n := len(p.errors) + if n > 0 && p.errors[n-1].Pos.Line == epos.Line { + return // discard - likely a spurious error + } + if n > 10 { + panic(bailout{}) + } + } + + p.errors.Add(epos, msg) +} + +func (p *parser) errorExpected(pos token.Pos, msg string) { + msg = "expected " + msg + if pos == p.pos { + // the error happened at the current position; + // make the error message more specific + switch { + case p.tok == token.SEMICOLON && p.lit == "\n": + msg += ", found newline" + case p.tok.IsLiteral(): + // print 123 rather than 'INT', etc. + msg += ", found " + p.lit + default: + msg += ", found '" + p.tok.String() + "'" + } + } + p.error(pos, msg) +} + +func (p *parser) expect(tok token.Token) token.Pos { + pos := p.pos + if p.tok != tok { + p.errorExpected(pos, "'"+tok.String()+"'") + } + p.next() // make progress + return pos +} + +// expect2 is like expect, but it returns an invalid position +// if the expected token is not found. +func (p *parser) expect2(tok token.Token) (pos token.Pos) { + if p.tok == tok { + pos = p.pos + } else { + p.errorExpected(p.pos, "'"+tok.String()+"'") + } + p.next() // make progress + return +} + +// expectClosing is like expect but provides a better error message +// for the common case of a missing comma before a newline. +func (p *parser) expectClosing(tok token.Token, context string) token.Pos { + if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" { + p.error(p.pos, "missing ',' before newline in "+context) + p.next() + } + return p.expect(tok) +} + +// expectSemi consumes a semicolon and returns the applicable line comment. +func (p *parser) expectSemi() (comment *ast.CommentGroup) { + // semicolon is optional before a closing ')' or '}' + if p.tok != token.RPAREN && p.tok != token.RBRACE { + switch p.tok { + case token.COMMA: + // permit a ',' instead of a ';' but complain + p.errorExpected(p.pos, "';'") + fallthrough + case token.SEMICOLON: + if p.lit == ";" { + // explicit semicolon + p.next() + comment = p.lineComment // use following comments + } else { + // artificial semicolon + comment = p.lineComment // use preceding comments + p.next() + } + return comment + default: + // math: allow unexpected endings.. + // p.errorExpected(p.pos, "';'") + // p.advance(stmtStart) + } + } + return nil +} + +func (p *parser) atComma(context string, follow token.Token) bool { + if p.tok == token.COMMA { + return true + } + if p.tok != follow { + msg := "missing ','" + if p.tok == token.SEMICOLON && p.lit == "\n" { + msg += " before newline" + } + p.error(p.pos, msg+" in "+context) + return true // "insert" comma and continue + } + return false +} + +func passert(cond bool, msg string) { + if !cond { + panic("go/parser internal error: " + msg) + } +} + +// advance consumes tokens until the current token p.tok +// is in the 'to' set, or token.EOF. For error recovery. +func (p *parser) advance(to map[token.Token]bool) { + for ; p.tok != token.EOF; p.next() { + if to[p.tok] { + // Return only if parser made some progress since last + // sync or if it has not reached 10 advance calls without + // progress. Otherwise consume at least one token to + // avoid an endless parser loop (it is possible that + // both parseOperand and parseStmt call advance and + // correctly do not advance, thus the need for the + // invocation limit p.syncCnt). + if p.pos == p.syncPos && p.syncCnt < 10 { + p.syncCnt++ + return + } + if p.pos > p.syncPos { + p.syncPos = p.pos + p.syncCnt = 0 + return + } + // Reaching here indicates a parser bug, likely an + // incorrect token list in this function, but it only + // leads to skipping of possibly correct code if a + // previous error is present, and thus is preferred + // over a non-terminating parse. + } + } +} + +var stmtStart = map[token.Token]bool{ + token.BREAK: true, + token.CONST: true, + token.CONTINUE: true, + token.DEFER: true, + token.FALLTHROUGH: true, + token.FOR: true, + token.GO: true, + token.GOTO: true, + token.IF: true, + token.RETURN: true, + token.SELECT: true, + token.SWITCH: true, + token.TYPE: true, + token.VAR: true, +} + +var declStart = map[token.Token]bool{ + token.IMPORT: true, + token.CONST: true, + token.TYPE: true, + token.VAR: true, +} + +var exprEnd = map[token.Token]bool{ + token.COMMA: true, + token.COLON: true, + token.SEMICOLON: true, + token.RPAREN: true, + token.RBRACK: true, + token.RBRACE: true, +} + +// safePos returns a valid file position for a given position: If pos +// is valid to begin with, safePos returns pos. If pos is out-of-range, +// safePos returns the EOF position. +// +// This is hack to work around "artificial" end positions in the AST which +// are computed by adding 1 to (presumably valid) token positions. If the +// token positions are invalid due to parse errors, the resulting end position +// may be past the file's EOF position, which would lead to panics if used +// later on. +func (p *parser) safePos(pos token.Pos) (res token.Pos) { + defer func() { + if recover() != nil { + res = token.Pos(p.file.Base() + p.file.Size()) // EOF position + } + }() + _ = p.file.Offset(pos) // trigger a panic if position is out-of-range + return pos +} + +// ---------------------------------------------------------------------------- +// Identifiers + +func (p *parser) parseIdent() *ast.Ident { + pos := p.pos + name := "_" + if p.tok == token.IDENT { + name = p.lit + p.next() + } else { + p.expect(token.IDENT) // use expect() error handling + } + return &ast.Ident{NamePos: pos, Name: name} +} + +func (p *parser) parseIdentList() (list []*ast.Ident) { + if p.trace { + defer un(trace(p, "IdentList")) + } + + list = append(list, p.parseIdent()) + for p.tok == token.COMMA { + p.next() + list = append(list, p.parseIdent()) + } + + return +} + +// ---------------------------------------------------------------------------- +// Common productions + +// If lhs is set, result list elements which are identifiers are not resolved. +func (p *parser) parseExprList() (list []ast.Expr) { + if p.trace { + defer un(trace(p, "ExpressionList")) + } + + list = append(list, p.parseExpr()) + for p.tok == token.COMMA { + p.next() + list = append(list, p.parseExpr()) + } + + return +} + +func (p *parser) parseList(inRhs bool) []ast.Expr { + old := p.inRhs + p.inRhs = inRhs + list := p.parseExprList() + p.inRhs = old + return list +} + +// math: allow full array list expressions +func (p *parser) parseArrayList(lbrack token.Pos) *ast.IndexListExpr { + if p.trace { + defer un(trace(p, "ArrayList")) + } + p.exprLev++ + // x := p.parseRhs() + x := p.parseExprList() + p.exprLev-- + rbrack := p.expect(token.RBRACK) + return &ast.IndexListExpr{Lbrack: lbrack, Indices: x, Rbrack: rbrack} +} + +// ---------------------------------------------------------------------------- +// Types + +func (p *parser) parseType() ast.Expr { + if p.trace { + defer un(trace(p, "Type")) + } + + typ := p.tryIdentOrType() + + if typ == nil { + pos := p.pos + p.errorExpected(pos, "type") + p.advance(exprEnd) + return &ast.BadExpr{From: pos, To: p.pos} + } + + return typ +} + +func (p *parser) parseQualifiedIdent(ident *ast.Ident) ast.Expr { + if p.trace { + defer un(trace(p, "QualifiedIdent")) + } + + typ := p.parseTypeName(ident) + if p.tok == token.LBRACK { + typ = p.parseTypeInstance(typ) + } + + return typ +} + +// If the result is an identifier, it is not resolved. +func (p *parser) parseTypeName(ident *ast.Ident) ast.Expr { + if p.trace { + defer un(trace(p, "TypeName")) + } + + if ident == nil { + ident = p.parseIdent() + } + + if p.tok == token.PERIOD { + // ident is a package name + p.next() + sel := p.parseIdent() + return &ast.SelectorExpr{X: ident, Sel: sel} + } + + return ident +} + +// "[" has already been consumed, and lbrack is its position. +// If len != nil it is the already consumed array length. +func (p *parser) parseArrayType(lbrack token.Pos, len ast.Expr) *ast.ArrayType { + if p.trace { + defer un(trace(p, "ArrayType")) + } + + if len == nil { + p.exprLev++ + // always permit ellipsis for more fault-tolerant parsing + if p.tok == token.ELLIPSIS { + len = &ast.Ellipsis{Ellipsis: p.pos} + p.next() + } else if p.tok != token.RBRACK { + len = p.parseRhs() + } + p.exprLev-- + } + if p.tok == token.COMMA { + // Trailing commas are accepted in type parameter + // lists but not in array type declarations. + // Accept for better error handling but complain. + p.error(p.pos, "unexpected comma; expecting ]") + p.next() + } + p.expect(token.RBRACK) + elt := p.parseType() + return &ast.ArrayType{Lbrack: lbrack, Len: len, Elt: elt} +} + +func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Expr) { + if p.trace { + defer un(trace(p, "ArrayFieldOrTypeInstance")) + } + + lbrack := p.expect(token.LBRACK) + trailingComma := token.NoPos // if valid, the position of a trailing comma preceding the ']' + var args []ast.Expr + if p.tok != token.RBRACK { + p.exprLev++ + args = append(args, p.parseRhs()) + for p.tok == token.COMMA { + comma := p.pos + p.next() + if p.tok == token.RBRACK { + trailingComma = comma + break + } + args = append(args, p.parseRhs()) + } + p.exprLev-- + } + rbrack := p.expect(token.RBRACK) + _ = rbrack + + if len(args) == 0 { + // x []E + elt := p.parseType() + return x, &ast.ArrayType{Lbrack: lbrack, Elt: elt} + } + + // x [P]E or x[P] + if len(args) == 1 { + elt := p.tryIdentOrType() + if elt != nil { + // x [P]E + if trailingComma.IsValid() { + // Trailing commas are invalid in array type fields. + p.error(trailingComma, "unexpected comma; expecting ]") + } + return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt} + } + } + + // x[P], x[P1, P2], ... + return nil, nil // typeparams.PackIndexExpr(x, lbrack, args, rbrack) +} + +func (p *parser) parseFieldDecl() *ast.Field { + if p.trace { + defer un(trace(p, "FieldDecl")) + } + + doc := p.leadComment + + var names []*ast.Ident + var typ ast.Expr + switch p.tok { + case token.IDENT: + name := p.parseIdent() + if p.tok == token.PERIOD || p.tok == token.STRING || p.tok == token.SEMICOLON || p.tok == token.RBRACE { + // embedded type + typ = name + if p.tok == token.PERIOD { + typ = p.parseQualifiedIdent(name) + } + } else { + // name1, name2, ... T + names = []*ast.Ident{name} + for p.tok == token.COMMA { + p.next() + names = append(names, p.parseIdent()) + } + // Careful dance: We don't know if we have an embedded instantiated + // type T[P1, P2, ...] or a field T of array type []E or [P]E. + if len(names) == 1 && p.tok == token.LBRACK { + name, typ = p.parseArrayFieldOrTypeInstance(name) + if name == nil { + names = nil + } + } else { + // T P + typ = p.parseType() + } + } + case token.MUL: + star := p.pos + p.next() + if p.tok == token.LPAREN { + // *(T) + p.error(p.pos, "cannot parenthesize embedded type") + p.next() + typ = p.parseQualifiedIdent(nil) + // expect closing ')' but no need to complain if missing + if p.tok == token.RPAREN { + p.next() + } + } else { + // *T + typ = p.parseQualifiedIdent(nil) + } + typ = &ast.StarExpr{Star: star, X: typ} + + case token.LPAREN: + p.error(p.pos, "cannot parenthesize embedded type") + p.next() + if p.tok == token.MUL { + // (*T) + star := p.pos + p.next() + typ = &ast.StarExpr{Star: star, X: p.parseQualifiedIdent(nil)} + } else { + // (T) + typ = p.parseQualifiedIdent(nil) + } + // expect closing ')' but no need to complain if missing + if p.tok == token.RPAREN { + p.next() + } + + default: + pos := p.pos + p.errorExpected(pos, "field name or embedded type") + p.advance(exprEnd) + typ = &ast.BadExpr{From: pos, To: p.pos} + } + + var tag *ast.BasicLit + if p.tok == token.STRING { + tag = &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit} + p.next() + } + + comment := p.expectSemi() + + field := &ast.Field{Doc: doc, Names: names, Type: typ, Tag: tag, Comment: comment} + return field +} + +func (p *parser) parseStructType() *ast.StructType { + if p.trace { + defer un(trace(p, "StructType")) + } + + pos := p.expect(token.STRUCT) + lbrace := p.expect(token.LBRACE) + var list []*ast.Field + for p.tok == token.IDENT || p.tok == token.MUL || p.tok == token.LPAREN { + // a field declaration cannot start with a '(' but we accept + // it here for more robust parsing and better error messages + // (parseFieldDecl will check and complain if necessary) + list = append(list, p.parseFieldDecl()) + } + rbrace := p.expect(token.RBRACE) + + return &ast.StructType{ + Struct: pos, + Fields: &ast.FieldList{ + Opening: lbrace, + List: list, + Closing: rbrace, + }, + } +} + +func (p *parser) parsePointerType() *ast.StarExpr { + if p.trace { + defer un(trace(p, "PointerType")) + } + + star := p.expect(token.MUL) + base := p.parseType() + + return &ast.StarExpr{Star: star, X: base} +} + +func (p *parser) parseDotsType() *ast.Ellipsis { + if p.trace { + defer un(trace(p, "DotsType")) + } + + pos := p.expect(token.ELLIPSIS) + elt := p.parseType() + + return &ast.Ellipsis{Ellipsis: pos, Elt: elt} +} + +type field struct { + name *ast.Ident + typ ast.Expr +} + +func (p *parser) parseParamDecl(name *ast.Ident, typeSetsOK bool) (f field) { + // TODO(rFindley) refactor to be more similar to paramDeclOrNil in the syntax + // package + if p.trace { + defer un(trace(p, "ParamDeclOrNil")) + } + + ptok := p.tok + if name != nil { + p.tok = token.IDENT // force token.IDENT case in switch below + } else if typeSetsOK && p.tok == token.TILDE { + // "~" ... + return field{nil, p.embeddedElem(nil)} + } + + switch p.tok { + case token.IDENT: + // name + if name != nil { + f.name = name + p.tok = ptok + } else { + f.name = p.parseIdent() + } + switch p.tok { + case token.IDENT, token.MUL, token.ARROW, token.FUNC, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN: + // name type + f.typ = p.parseType() + + case token.LBRACK: + // name "[" type1, ..., typeN "]" or name "[" n "]" type + f.name, f.typ = p.parseArrayFieldOrTypeInstance(f.name) + + case token.ELLIPSIS: + // name "..." type + f.typ = p.parseDotsType() + return // don't allow ...type "|" ... + + case token.PERIOD: + // name "." ... + f.typ = p.parseQualifiedIdent(f.name) + f.name = nil + + case token.TILDE: + if typeSetsOK { + f.typ = p.embeddedElem(nil) + return + } + + case token.OR: + if typeSetsOK { + // name "|" typeset + f.typ = p.embeddedElem(f.name) + f.name = nil + return + } + } + + case token.MUL, token.ARROW, token.FUNC, token.LBRACK, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN: + // type + f.typ = p.parseType() + + case token.ELLIPSIS: + // "..." type + // (always accepted) + f.typ = p.parseDotsType() + return // don't allow ...type "|" ... + + default: + // TODO(rfindley): this is incorrect in the case of type parameter lists + // (should be "']'" in that case) + p.errorExpected(p.pos, "')'") + p.advance(exprEnd) + } + + // [name] type "|" + if typeSetsOK && p.tok == token.OR && f.typ != nil { + f.typ = p.embeddedElem(f.typ) + } + + return +} + +func (p *parser) parseParameterList(name0 *ast.Ident, typ0 ast.Expr, closing token.Token) (params []*ast.Field) { + if p.trace { + defer un(trace(p, "ParameterList")) + } + + // Type parameters are the only parameter list closed by ']'. + tparams := closing == token.RBRACK + + pos0 := p.pos + if name0 != nil { + pos0 = name0.Pos() + } else if typ0 != nil { + pos0 = typ0.Pos() + } + + // Note: The code below matches the corresponding code in the syntax + // parser closely. Changes must be reflected in either parser. + // For the code to match, we use the local []field list that + // corresponds to []syntax.Field. At the end, the list must be + // converted into an []*ast.Field. + + var list []field + var named int // number of parameters that have an explicit name and type + var typed int // number of parameters that have an explicit type + + for name0 != nil || p.tok != closing && p.tok != token.EOF { + var par field + if typ0 != nil { + if tparams { + typ0 = p.embeddedElem(typ0) + } + par = field{name0, typ0} + } else { + par = p.parseParamDecl(name0, tparams) + } + name0 = nil // 1st name was consumed if present + typ0 = nil // 1st typ was consumed if present + if par.name != nil || par.typ != nil { + list = append(list, par) + if par.name != nil && par.typ != nil { + named++ + } + if par.typ != nil { + typed++ + } + } + if !p.atComma("parameter list", closing) { + break + } + p.next() + } + + if len(list) == 0 { + return // not uncommon + } + + // distribute parameter types (len(list) > 0) + if named == 0 { + // all unnamed => found names are type names + for i := 0; i < len(list); i++ { + par := &list[i] + if typ := par.name; typ != nil { + par.typ = typ + par.name = nil + } + } + if tparams { + // This is the same error handling as below, adjusted for type parameters only. + // See comment below for details. (go.dev/issue/64534) + var errPos token.Pos + var msg string + if named == typed /* same as typed == 0 */ { + errPos = p.pos // position error at closing ] + msg = "missing type constraint" + } else { + errPos = pos0 // position at opening [ or first name + msg = "missing type parameter name" + if len(list) == 1 { + msg += " or invalid array length" + } + } + p.error(errPos, msg) + } + } else if named != len(list) { + // some named or we're in a type parameter list => all must be named + var errPos token.Pos // left-most error position (or invalid) + var typ ast.Expr // current type (from right to left) + for i := len(list) - 1; i >= 0; i-- { + if par := &list[i]; par.typ != nil { + typ = par.typ + if par.name == nil { + errPos = typ.Pos() + n := ast.NewIdent("_") + n.NamePos = errPos // correct position + par.name = n + } + } else if typ != nil { + par.typ = typ + } else { + // par.typ == nil && typ == nil => we only have a par.name + errPos = par.name.Pos() + par.typ = &ast.BadExpr{From: errPos, To: p.pos} + } + } + if errPos.IsValid() { + var msg string + if tparams { + // Not all parameters are named because named != len(list). + // If named == typed we must have parameters that have no types, + // and they must be at the end of the parameter list, otherwise + // the types would have been filled in by the right-to-left sweep + // above and we wouldn't have an error. Since we are in a type + // parameter list, the missing types are constraints. + if named == typed { + errPos = p.pos // position error at closing ] + msg = "missing type constraint" + } else { + msg = "missing type parameter name" + // go.dev/issue/60812 + if len(list) == 1 { + msg += " or invalid array length" + } + } + } else { + msg = "mixed named and unnamed parameters" + } + p.error(errPos, msg) + } + } + + // Convert list to []*ast.Field. + // If list contains types only, each type gets its own ast.Field. + if named == 0 { + // parameter list consists of types only + for _, par := range list { + passert(par.typ != nil, "nil type in unnamed parameter list") + params = append(params, &ast.Field{Type: par.typ}) + } + return + } + + // If the parameter list consists of named parameters with types, + // collect all names with the same types into a single ast.Field. + var names []*ast.Ident + var typ ast.Expr + addParams := func() { + passert(typ != nil, "nil type in named parameter list") + field := &ast.Field{Names: names, Type: typ} + params = append(params, field) + names = nil + } + for _, par := range list { + if par.typ != typ { + if len(names) > 0 { + addParams() + } + typ = par.typ + } + names = append(names, par.name) + } + if len(names) > 0 { + addParams() + } + return +} + +func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.FieldList) { + if p.trace { + defer un(trace(p, "Parameters")) + } + + if acceptTParams && p.tok == token.LBRACK { + opening := p.pos + p.next() + // [T any](params) syntax + list := p.parseParameterList(nil, nil, token.RBRACK) + rbrack := p.expect(token.RBRACK) + tparams = &ast.FieldList{Opening: opening, List: list, Closing: rbrack} + // Type parameter lists must not be empty. + if tparams.NumFields() == 0 { + p.error(tparams.Closing, "empty type parameter list") + tparams = nil // avoid follow-on errors + } + } + + opening := p.expect(token.LPAREN) + + var fields []*ast.Field + if p.tok != token.RPAREN { + fields = p.parseParameterList(nil, nil, token.RPAREN) + } + + rparen := p.expect(token.RPAREN) + params = &ast.FieldList{Opening: opening, List: fields, Closing: rparen} + + return +} + +func (p *parser) parseResult() *ast.FieldList { + if p.trace { + defer un(trace(p, "Result")) + } + + if p.tok == token.LPAREN { + _, results := p.parseParameters(false) + return results + } + + typ := p.tryIdentOrType() + if typ != nil { + list := make([]*ast.Field, 1) + list[0] = &ast.Field{Type: typ} + return &ast.FieldList{List: list} + } + + return nil +} + +func (p *parser) parseFuncType() *ast.FuncType { + if p.trace { + defer un(trace(p, "FuncType")) + } + + pos := p.expect(token.FUNC) + tparams, params := p.parseParameters(true) + if tparams != nil { + p.error(tparams.Pos(), "function type must have no type parameters") + } + results := p.parseResult() + + return &ast.FuncType{Func: pos, Params: params, Results: results} +} + +func (p *parser) parseMethodSpec() *ast.Field { + if p.trace { + defer un(trace(p, "MethodSpec")) + } + + doc := p.leadComment + var idents []*ast.Ident + var typ ast.Expr + x := p.parseTypeName(nil) + if ident, _ := x.(*ast.Ident); ident != nil { + switch { + case p.tok == token.LBRACK: + // generic method or embedded instantiated type + lbrack := p.pos + p.next() + p.exprLev++ + x := p.parseExpr() + p.exprLev-- + if name0, _ := x.(*ast.Ident); name0 != nil && p.tok != token.COMMA && p.tok != token.RBRACK { + // generic method m[T any] + // + // Interface methods do not have type parameters. We parse them for a + // better error message and improved error recovery. + _ = p.parseParameterList(name0, nil, token.RBRACK) + _ = p.expect(token.RBRACK) + p.error(lbrack, "interface method must have no type parameters") + + // TODO(rfindley) refactor to share code with parseFuncType. + _, params := p.parseParameters(false) + results := p.parseResult() + idents = []*ast.Ident{ident} + typ = &ast.FuncType{ + Func: token.NoPos, + Params: params, + Results: results, + } + } else { + // embedded instantiated type + // TODO(rfindley) should resolve all identifiers in x. + list := []ast.Expr{x} + if p.atComma("type argument list", token.RBRACK) { + p.exprLev++ + p.next() + for p.tok != token.RBRACK && p.tok != token.EOF { + list = append(list, p.parseType()) + if !p.atComma("type argument list", token.RBRACK) { + break + } + p.next() + } + p.exprLev-- + } + // rbrack := p.expectClosing(token.RBRACK, "type argument list") + // typ = typeparams.PackIndexExpr(ident, lbrack, list, rbrack) + } + case p.tok == token.LPAREN: + // ordinary method + // TODO(rfindley) refactor to share code with parseFuncType. + _, params := p.parseParameters(false) + results := p.parseResult() + idents = []*ast.Ident{ident} + typ = &ast.FuncType{Func: token.NoPos, Params: params, Results: results} + default: + // embedded type + typ = x + } + } else { + // embedded, possibly instantiated type + typ = x + if p.tok == token.LBRACK { + // embedded instantiated interface + typ = p.parseTypeInstance(typ) + } + } + + // Comment is added at the callsite: the field below may joined with + // additional type specs using '|'. + // TODO(rfindley) this should be refactored. + // TODO(rfindley) add more tests for comment handling. + return &ast.Field{Doc: doc, Names: idents, Type: typ} +} + +func (p *parser) embeddedElem(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "EmbeddedElem")) + } + if x == nil { + x = p.embeddedTerm() + } + for p.tok == token.OR { + t := new(ast.BinaryExpr) + t.OpPos = p.pos + t.Op = token.OR + p.next() + t.X = x + t.Y = p.embeddedTerm() + x = t + } + return x +} + +func (p *parser) embeddedTerm() ast.Expr { + if p.trace { + defer un(trace(p, "EmbeddedTerm")) + } + if p.tok == token.TILDE { + t := new(ast.UnaryExpr) + t.OpPos = p.pos + t.Op = token.TILDE + p.next() + t.X = p.parseType() + return t + } + + t := p.tryIdentOrType() + if t == nil { + pos := p.pos + p.errorExpected(pos, "~ term or type") + p.advance(exprEnd) + return &ast.BadExpr{From: pos, To: p.pos} + } + + return t +} + +func (p *parser) parseInterfaceType() *ast.InterfaceType { + if p.trace { + defer un(trace(p, "InterfaceType")) + } + + pos := p.expect(token.INTERFACE) + lbrace := p.expect(token.LBRACE) + + var list []*ast.Field + +parseElements: + for { + switch { + case p.tok == token.IDENT: + f := p.parseMethodSpec() + if f.Names == nil { + f.Type = p.embeddedElem(f.Type) + } + f.Comment = p.expectSemi() + list = append(list, f) + case p.tok == token.TILDE: + typ := p.embeddedElem(nil) + comment := p.expectSemi() + list = append(list, &ast.Field{Type: typ, Comment: comment}) + default: + if t := p.tryIdentOrType(); t != nil { + typ := p.embeddedElem(t) + comment := p.expectSemi() + list = append(list, &ast.Field{Type: typ, Comment: comment}) + } else { + break parseElements + } + } + } + + // TODO(rfindley): the error produced here could be improved, since we could + // accept an identifier, 'type', or a '}' at this point. + rbrace := p.expect(token.RBRACE) + + return &ast.InterfaceType{ + Interface: pos, + Methods: &ast.FieldList{ + Opening: lbrace, + List: list, + Closing: rbrace, + }, + } +} + +func (p *parser) parseMapType() *ast.MapType { + if p.trace { + defer un(trace(p, "MapType")) + } + + pos := p.expect(token.MAP) + p.expect(token.LBRACK) + key := p.parseType() + p.expect(token.RBRACK) + value := p.parseType() + + return &ast.MapType{Map: pos, Key: key, Value: value} +} + +func (p *parser) parseChanType() *ast.ChanType { + if p.trace { + defer un(trace(p, "ChanType")) + } + + pos := p.pos + dir := ast.SEND | ast.RECV + var arrow token.Pos + if p.tok == token.CHAN { + p.next() + if p.tok == token.ARROW { + arrow = p.pos + p.next() + dir = ast.SEND + } + } else { + arrow = p.expect(token.ARROW) + p.expect(token.CHAN) + dir = ast.RECV + } + value := p.parseType() + + return &ast.ChanType{Begin: pos, Arrow: arrow, Dir: dir, Value: value} +} + +func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "TypeInstance")) + } + + opening := p.expect(token.LBRACK) + p.exprLev++ + var list []ast.Expr + for p.tok != token.RBRACK && p.tok != token.EOF { + list = append(list, p.parseType()) + if !p.atComma("type argument list", token.RBRACK) { + break + } + p.next() + } + p.exprLev-- + + closing := p.expectClosing(token.RBRACK, "type argument list") + + if len(list) == 0 { + p.errorExpected(closing, "type argument list") + return &ast.IndexExpr{ + X: typ, + Lbrack: opening, + Index: &ast.BadExpr{From: opening + 1, To: closing}, + Rbrack: closing, + } + } + + return nil // typeparams.PackIndexExpr(typ, opening, list, closing) +} + +func (p *parser) tryIdentOrType() ast.Expr { + defer decNestLev(incNestLev(p)) + + switch p.tok { + case token.IDENT: + typ := p.parseTypeName(nil) + if p.tok == token.LBRACK { + typ = p.parseTypeInstance(typ) + } + return typ + case token.LBRACK: + lbrack := p.expect(token.LBRACK) + return p.parseArrayList(lbrack) // math: full array exprs + // return p.parseArrayType(lbrack, nil) + case token.STRUCT: + return p.parseStructType() + case token.MUL: + return p.parsePointerType() + case token.FUNC: + return p.parseFuncType() + case token.INTERFACE: + return p.parseInterfaceType() + case token.MAP: + return p.parseMapType() + case token.CHAN, token.ARROW: + return p.parseChanType() + case token.LPAREN: + lparen := p.pos + p.next() + typ := p.parseType() + rparen := p.expect(token.RPAREN) + return &ast.ParenExpr{Lparen: lparen, X: typ, Rparen: rparen} + } + + // no type found + return nil +} + +// ---------------------------------------------------------------------------- +// Blocks + +func (p *parser) parseStmtList() (list []ast.Stmt) { + if p.trace { + defer un(trace(p, "StatementList")) + } + + for p.tok != token.CASE && p.tok != token.DEFAULT && p.tok != token.RBRACE && p.tok != token.EOF { + list = append(list, p.parseStmt()) + } + + return +} + +func (p *parser) parseBody() *ast.BlockStmt { + if p.trace { + defer un(trace(p, "Body")) + } + + lbrace := p.expect(token.LBRACE) + list := p.parseStmtList() + rbrace := p.expect2(token.RBRACE) + + return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace} +} + +func (p *parser) parseBlockStmt() *ast.BlockStmt { + if p.trace { + defer un(trace(p, "BlockStmt")) + } + + lbrace := p.expect(token.LBRACE) + if p.tok == token.EOF { // math: allow start only + return &ast.BlockStmt{Lbrace: lbrace} + } + list := p.parseStmtList() + rbrace := p.expect2(token.RBRACE) + + return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace} +} + +// ---------------------------------------------------------------------------- +// Expressions + +func (p *parser) parseFuncTypeOrLit() ast.Expr { + if p.trace { + defer un(trace(p, "FuncTypeOrLit")) + } + + typ := p.parseFuncType() + if p.tok != token.LBRACE { + // function type only + return typ + } + + p.exprLev++ + body := p.parseBody() + p.exprLev-- + + return &ast.FuncLit{Type: typ, Body: body} +} + +// parseOperand may return an expression or a raw type (incl. array +// types of the form [...]T). Callers must verify the result. +func (p *parser) parseOperand() ast.Expr { + if p.trace { + defer un(trace(p, "Operand")) + } + + switch p.tok { + case token.IDENT: + x := p.parseIdent() + return x + + case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING: + x := &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit} + // fmt.Println("operand lit:", p.lit) + p.next() + if p.tok == token.COLON { + return p.parseSliceExpr(x) + } else { + return x + } + + case token.LPAREN: + lparen := p.pos + p.next() + p.exprLev++ + x := p.parseRhs() // types may be parenthesized: (some type) + p.exprLev-- + rparen := p.expect(token.RPAREN) + return &ast.ParenExpr{Lparen: lparen, X: x, Rparen: rparen} + + case token.FUNC: + return p.parseFuncTypeOrLit() + + case token.COLON: + p.expect(token.COLON) + return p.parseSliceExpr(nil) + } + + if typ := p.tryIdentOrType(); typ != nil { // do not consume trailing type parameters + // could be type for composite literal or conversion + _, isIdent := typ.(*ast.Ident) + passert(!isIdent, "type cannot be identifier") + return typ + } + + // we have an error + pos := p.pos + p.errorExpected(pos, "operand") + p.advance(stmtStart) + return &ast.BadExpr{From: pos, To: p.pos} +} + +func (p *parser) parseSelector(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "Selector")) + } + + sel := p.parseIdent() + + return &ast.SelectorExpr{X: x, Sel: sel} +} + +func (p *parser) parseTypeAssertion(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "TypeAssertion")) + } + + lparen := p.expect(token.LPAREN) + var typ ast.Expr + if p.tok == token.TYPE { + // type switch: typ == nil + p.next() + } else { + typ = p.parseType() + } + rparen := p.expect(token.RPAREN) + + return &ast.TypeAssertExpr{X: x, Type: typ, Lparen: lparen, Rparen: rparen} +} + +func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "parseIndexOrSliceOrInstance")) + } + + lbrack := p.expect(token.LBRACK) + if p.tok == token.RBRACK { + // empty index, slice or index expressions are not permitted; + // accept them for parsing tolerance, but complain + p.errorExpected(p.pos, "operand") + rbrack := p.pos + p.next() + return &ast.IndexExpr{ + X: x, + Lbrack: lbrack, + Index: &ast.BadExpr{From: rbrack, To: rbrack}, + Rbrack: rbrack, + } + } + + ix := p.parseArrayList(lbrack) + return &ast.IndexExpr{ + X: x, + Lbrack: lbrack, + Index: ix, + Rbrack: ix.Rbrack, + } +} + +func (p *parser) parseSliceExpr(ex ast.Expr) *ast.SliceExpr { + if p.trace { + defer un(trace(p, "parseSliceExpr")) + } + lbrack := p.pos + p.exprLev++ + + const N = 3 // change the 3 to 2 to disable 3-index slices + var index [N]ast.Expr + index[0] = ex + var colons [N - 1]token.Pos + ncolons := 0 + if ex == nil { + ncolons++ + } + var rpos token.Pos + // fmt.Println(ncolons, p.tok) + switch p.tok { + case token.COLON: + // slice expression + for p.tok == token.COLON && ncolons < len(colons) { + colons[ncolons] = p.pos + ncolons++ + p.next() + if p.tok != token.COMMA && p.tok != token.COLON && p.tok != token.RBRACK && p.tok != token.EOF { + ix := p.parseRhs() + if se, ok := ix.(*ast.SliceExpr); ok { + index[ncolons] = se.Low + if ncolons == 1 && se.High != nil { + ncolons++ + index[ncolons] = se.High + } + // fmt.Printf("nc: %d low: %#v hi: %#v max: %#v\n", ncolons, se.Low, se.High, se.Max) + } else { + // fmt.Printf("nc: %d low: %#v\n", ncolons, ix) + if _, ok := ix.(*ast.BadExpr); !ok { + index[ncolons] = ix + } + } + // } else { + // fmt.Println(ncolons, "else") + } + } + case token.COMMA: + rpos = p.pos // expect(token.COMMA) + case token.RBRACK: + rpos = p.pos // expect(token.RBRACK) + // instance expression + // args = append(args, index[0]) + // for p.tok == token.COMMA { + // p.next() + // if p.tok != token.RBRACK && p.tok != token.EOF { + // args = append(args, p.parseType()) + // } + // } + default: + ix := p.parseRhs() + // fmt.Printf("nc: %d ix: %#v\n", ncolons, ix) + index[ncolons] = ix + } + + p.exprLev-- + // rbrack := p.expect(token.RBRACK) + + // slice expression + slice3 := false + if ncolons == 2 { + slice3 = true + // Check presence of middle and final index here rather than during type-checking + // to prevent erroneous programs from passing through gofmt (was go.dev/issue/7305). + // if index[1] == nil { + // p.error(colons[0], "middle index required in 3-index slice") + // index[1] = &ast.BadExpr{From: colons[0] + 1, To: colons[1]} + // } + // if index[2] == nil { + // p.error(colons[1], "final index required in 3-index slice") + // index[2] = &ast.BadExpr{From: colons[1] + 1} // , To: rbrack + // } + } + se := &ast.SliceExpr{Lbrack: lbrack, Low: index[0], High: index[1], Max: index[2], Slice3: slice3, Rbrack: rpos} + // fmt.Printf("final: %#v\n", se) + return se + // + // if len(args) == 0 { + // // index expression + // return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack} + // } + + // instance expression + return nil // typeparams.PackIndexExpr(x, lbrack, args, rbrack) +} + +func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr { + if p.trace { + defer un(trace(p, "CallOrConversion")) + } + + lparen := p.expect(token.LPAREN) + p.exprLev++ + var list []ast.Expr + var ellipsis token.Pos + for p.tok != token.RPAREN && p.tok != token.EOF && !ellipsis.IsValid() { + list = append(list, p.parseRhs()) // builtins may expect a type: make(some type, ...) + if p.tok == token.ELLIPSIS { + ellipsis = p.pos + p.next() + } + if !p.atComma("argument list", token.RPAREN) { + break + } + p.next() + } + p.exprLev-- + rparen := p.expectClosing(token.RPAREN, "argument list") + + return &ast.CallExpr{Fun: fun, Lparen: lparen, Args: list, Ellipsis: ellipsis, Rparen: rparen} +} + +func (p *parser) parseValue() ast.Expr { + if p.trace { + defer un(trace(p, "Element")) + } + + if p.tok == token.LBRACE { + return p.parseLiteralValue(nil) + } + + x := p.parseExpr() + + return x +} + +func (p *parser) parseElement() ast.Expr { + if p.trace { + defer un(trace(p, "Element")) + } + + x := p.parseValue() + if p.tok == token.COLON { + colon := p.pos + p.next() + x = &ast.KeyValueExpr{Key: x, Colon: colon, Value: p.parseValue()} + } + + return x +} + +func (p *parser) parseElementList() (list []ast.Expr) { + if p.trace { + defer un(trace(p, "ElementList")) + } + + for p.tok != token.RBRACE && p.tok != token.EOF { + list = append(list, p.parseElement()) + if !p.atComma("composite literal", token.RBRACE) { + break + } + p.next() + } + + return +} + +func (p *parser) parseLiteralValue(typ ast.Expr) ast.Expr { + defer decNestLev(incNestLev(p)) + + if p.trace { + defer un(trace(p, "LiteralValue")) + } + + lbrace := p.expect(token.LBRACE) + var elts []ast.Expr + p.exprLev++ + if p.tok != token.RBRACE { + elts = p.parseElementList() + } + p.exprLev-- + rbrace := p.expectClosing(token.RBRACE, "composite literal") + return &ast.CompositeLit{Type: typ, Lbrace: lbrace, Elts: elts, Rbrace: rbrace} +} + +func (p *parser) parsePrimaryExpr(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "PrimaryExpr")) + } + + // math: ellipses can show up in index expression. + if p.tok == token.ELLIPSIS { + p.next() + return &ast.Ellipsis{Ellipsis: p.pos} + } + + if x == nil { + x = p.parseOperand() + } + // We track the nesting here rather than at the entry for the function, + // since it can iteratively produce a nested output, and we want to + // limit how deep a structure we generate. + var n int + defer func() { p.nestLev -= n }() + for n = 1; ; n++ { + incNestLev(p) + switch p.tok { + case token.PERIOD: + p.next() + switch p.tok { + case token.IDENT: + x = p.parseSelector(x) + case token.LPAREN: + x = p.parseTypeAssertion(x) + default: + pos := p.pos + p.errorExpected(pos, "selector or type assertion") + // TODO(rFindley) The check for token.RBRACE below is a targeted fix + // to error recovery sufficient to make the x/tools tests to + // pass with the new parsing logic introduced for type + // parameters. Remove this once error recovery has been + // more generally reconsidered. + if p.tok != token.RBRACE { + p.next() // make progress + } + sel := &ast.Ident{NamePos: pos, Name: "_"} + x = &ast.SelectorExpr{X: x, Sel: sel} + } + case token.LBRACK: + x = p.parseIndexOrSliceOrInstance(x) + case token.LPAREN: + x = p.parseCallOrConversion(x) + case token.LBRACE: + // operand may have returned a parenthesized complit + // type; accept it but complain if we have a complit + t := ast.Unparen(x) + // determine if '{' belongs to a composite literal or a block statement + switch t.(type) { + case *ast.BadExpr, *ast.Ident, *ast.SelectorExpr: + if p.exprLev < 0 { + return x + } + // x is possibly a composite literal type + case *ast.IndexExpr, *ast.IndexListExpr: + if p.exprLev < 0 { + return x + } + // x is possibly a composite literal type + case *ast.ArrayType, *ast.StructType, *ast.MapType: + // x is a composite literal type + default: + return x + } + if t != x { + p.error(t.Pos(), "cannot parenthesize type in composite literal") + // already progressed, no need to advance + } + x = p.parseLiteralValue(x) + default: + return x + } + } +} + +func (p *parser) parseUnaryExpr() ast.Expr { + defer decNestLev(incNestLev(p)) + + if p.trace { + defer un(trace(p, "UnaryExpr")) + } + + switch p.tok { + case token.ADD, token.SUB, token.NOT, token.XOR, token.AND, token.TILDE: + pos, op := p.pos, p.tok + p.next() + x := p.parseUnaryExpr() + return &ast.UnaryExpr{OpPos: pos, Op: op, X: x} + + case token.ARROW: + // channel type or receive expression + arrow := p.pos + p.next() + + // If the next token is token.CHAN we still don't know if it + // is a channel type or a receive operation - we only know + // once we have found the end of the unary expression. There + // are two cases: + // + // <- type => (<-type) must be channel type + // <- expr => <-(expr) is a receive from an expression + // + // In the first case, the arrow must be re-associated with + // the channel type parsed already: + // + // <- (chan type) => (<-chan type) + // <- (chan<- type) => (<-chan (<-type)) + + x := p.parseUnaryExpr() + + // determine which case we have + if typ, ok := x.(*ast.ChanType); ok { + // (<-type) + + // re-associate position info and <- + dir := ast.SEND + for ok && dir == ast.SEND { + if typ.Dir == ast.RECV { + // error: (<-type) is (<-(<-chan T)) + p.errorExpected(typ.Arrow, "'chan'") + } + arrow, typ.Begin, typ.Arrow = typ.Arrow, arrow, arrow + dir, typ.Dir = typ.Dir, ast.RECV + typ, ok = typ.Value.(*ast.ChanType) + } + if dir == ast.SEND { + p.errorExpected(arrow, "channel type") + } + + return x + } + + // <-(expr) + return &ast.UnaryExpr{OpPos: arrow, Op: token.ARROW, X: x} + + case token.MUL: + // pointer type or unary "*" expression + pos := p.pos + p.next() + x := p.parseUnaryExpr() + return &ast.StarExpr{Star: pos, X: x} + } + + return p.parsePrimaryExpr(nil) +} + +func (p *parser) tokPrec() (token.Token, int) { + tok := p.tok + if p.inRhs && tok == token.ASSIGN { + tok = token.EQL + } + if p.tok == token.ILLEGAL && p.lit == "@" { + // fmt.Println("@ token") + return token.ILLEGAL, 5 + } + return tok, tok.Precedence() +} + +// parseBinaryExpr parses a (possibly) binary expression. +// If x is non-nil, it is used as the left operand. +// +// TODO(rfindley): parseBinaryExpr has become overloaded. Consider refactoring. +func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr { + if p.trace { + defer un(trace(p, "BinaryExpr")) + } + + if x == nil { + x = p.parseUnaryExpr() + } + // We track the nesting here rather than at the entry for the function, + // since it can iteratively produce a nested output, and we want to + // limit how deep a structure we generate. + var n int + defer func() { p.nestLev -= n }() + for n = 1; ; n++ { + incNestLev(p) + op, oprec := p.tokPrec() + if oprec < prec1 { + return x + } + pos := p.pos + if op == token.ILLEGAL { + p.next() + } else { + pos = p.expect(op) + } + y := p.parseBinaryExpr(nil, oprec+1) + x = &ast.BinaryExpr{X: x, OpPos: pos, Op: op, Y: y} + } +} + +// The result may be a type or even a raw type ([...]int). +func (p *parser) parseExpr() ast.Expr { + if p.trace { + defer un(trace(p, "Expression")) + } + + return p.parseBinaryExpr(nil, token.LowestPrec+1) +} + +func (p *parser) parseRhs() ast.Expr { + old := p.inRhs + p.inRhs = true + x := p.parseExpr() + p.inRhs = old + return x +} + +// ---------------------------------------------------------------------------- +// Statements + +// Parsing modes for parseSimpleStmt. +const ( + basic = iota + labelOk + rangeOk +) + +// parseSimpleStmt returns true as 2nd result if it parsed the assignment +// of a range clause (with mode == rangeOk). The returned statement is an +// assignment with a right-hand side that is a single unary expression of +// the form "range x". No guarantees are given for the left-hand side. +func (p *parser) parseSimpleStmt(mode int) (ast.Stmt, bool) { + if p.trace { + defer un(trace(p, "SimpleStmt")) + } + + x := p.parseList(false) + + switch p.tok { + case + token.DEFINE, token.ASSIGN, token.ADD_ASSIGN, + token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, + token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, + token.XOR_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN: + // assignment statement, possibly part of a range clause + pos, tok := p.pos, p.tok + p.next() + var y []ast.Expr + isRange := false + if mode == rangeOk && p.tok == token.RANGE && (tok == token.DEFINE || tok == token.ASSIGN) { + pos := p.pos + p.next() + y = []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}} + isRange = true + } else { + y = p.parseList(true) + } + return &ast.AssignStmt{Lhs: x, TokPos: pos, Tok: tok, Rhs: y}, isRange + } + + if len(x) > 1 { + p.errorExpected(x[0].Pos(), "1 expression") + // continue with first expression + } + + switch p.tok { + case token.COLON: + // labeled statement + colon := p.pos + p.next() + if label, isIdent := x[0].(*ast.Ident); mode == labelOk && isIdent { + // Go spec: The scope of a label is the body of the function + // in which it is declared and excludes the body of any nested + // function. + stmt := &ast.LabeledStmt{Label: label, Colon: colon, Stmt: p.parseStmt()} + return stmt, false + } + // The label declaration typically starts at x[0].Pos(), but the label + // declaration may be erroneous due to a token after that position (and + // before the ':'). If SpuriousErrors is not set, the (only) error + // reported for the line is the illegal label error instead of the token + // before the ':' that caused the problem. Thus, use the (latest) colon + // position for error reporting. + p.error(colon, "illegal label declaration") + return &ast.BadStmt{From: x[0].Pos(), To: colon + 1}, false + + case token.ARROW: + // send statement + arrow := p.pos + p.next() + y := p.parseRhs() + return &ast.SendStmt{Chan: x[0], Arrow: arrow, Value: y}, false + + case token.INC, token.DEC: + // increment or decrement + s := &ast.IncDecStmt{X: x[0], TokPos: p.pos, Tok: p.tok} + p.next() + return s, false + } + + // expression + return &ast.ExprStmt{X: x[0]}, false +} + +func (p *parser) parseCallExpr(callType string) *ast.CallExpr { + x := p.parseRhs() // could be a conversion: (some type)(x) + if t := ast.Unparen(x); t != x { + p.error(x.Pos(), fmt.Sprintf("expression in %s must not be parenthesized", callType)) + x = t + } + if call, isCall := x.(*ast.CallExpr); isCall { + return call + } + if _, isBad := x.(*ast.BadExpr); !isBad { + // only report error if it's a new one + p.error(p.safePos(x.End()), fmt.Sprintf("expression in %s must be function call", callType)) + } + return nil +} + +func (p *parser) parseGoStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "GoStmt")) + } + + pos := p.expect(token.GO) + call := p.parseCallExpr("go") + p.expectSemi() + if call == nil { + return &ast.BadStmt{From: pos, To: pos + 2} // len("go") + } + + return &ast.GoStmt{Go: pos, Call: call} +} + +func (p *parser) parseDeferStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "DeferStmt")) + } + + pos := p.expect(token.DEFER) + call := p.parseCallExpr("defer") + p.expectSemi() + if call == nil { + return &ast.BadStmt{From: pos, To: pos + 5} // len("defer") + } + + return &ast.DeferStmt{Defer: pos, Call: call} +} + +func (p *parser) parseReturnStmt() *ast.ReturnStmt { + if p.trace { + defer un(trace(p, "ReturnStmt")) + } + + pos := p.pos + p.expect(token.RETURN) + var x []ast.Expr + if p.tok != token.SEMICOLON && p.tok != token.RBRACE { + x = p.parseList(true) + } + p.expectSemi() + + return &ast.ReturnStmt{Return: pos, Results: x} +} + +func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt { + if p.trace { + defer un(trace(p, "BranchStmt")) + } + + pos := p.expect(tok) + var label *ast.Ident + if tok != token.FALLTHROUGH && p.tok == token.IDENT { + label = p.parseIdent() + } + p.expectSemi() + + return &ast.BranchStmt{TokPos: pos, Tok: tok, Label: label} +} + +func (p *parser) makeExpr(s ast.Stmt, want string) ast.Expr { + if s == nil { + return nil + } + if es, isExpr := s.(*ast.ExprStmt); isExpr { + return es.X + } + found := "simple statement" + if _, isAss := s.(*ast.AssignStmt); isAss { + found = "assignment" + } + p.error(s.Pos(), fmt.Sprintf("expected %s, found %s (missing parentheses around composite literal?)", want, found)) + return &ast.BadExpr{From: s.Pos(), To: p.safePos(s.End())} +} + +// parseIfHeader is an adjusted version of parser.header +// in cmd/compile/internal/syntax/parser.go, which has +// been tuned for better error handling. +func (p *parser) parseIfHeader() (init ast.Stmt, cond ast.Expr) { + if p.tok == token.LBRACE { + p.error(p.pos, "missing condition in if statement") + cond = &ast.BadExpr{From: p.pos, To: p.pos} + return + } + // p.tok != token.LBRACE + + prevLev := p.exprLev + p.exprLev = -1 + + if p.tok != token.SEMICOLON { + // accept potential variable declaration but complain + if p.tok == token.VAR { + p.next() + p.error(p.pos, "var declaration not allowed in if initializer") + } + init, _ = p.parseSimpleStmt(basic) + } + + var condStmt ast.Stmt + var semi struct { + pos token.Pos + lit string // ";" or "\n"; valid if pos.IsValid() + } + if p.tok != token.LBRACE { + if p.tok == token.SEMICOLON { + semi.pos = p.pos + semi.lit = p.lit + p.next() + } else { + p.expect(token.SEMICOLON) + } + if p.tok != token.LBRACE { + condStmt, _ = p.parseSimpleStmt(basic) + } + } else { + condStmt = init + init = nil + } + + if condStmt != nil { + cond = p.makeExpr(condStmt, "boolean expression") + } else if semi.pos.IsValid() { + if semi.lit == "\n" { + p.error(semi.pos, "unexpected newline, expecting { after if clause") + } else { + p.error(semi.pos, "missing condition in if statement") + } + } + + // make sure we have a valid AST + if cond == nil { + cond = &ast.BadExpr{From: p.pos, To: p.pos} + } + + p.exprLev = prevLev + return +} + +func (p *parser) parseIfStmt() *ast.IfStmt { + defer decNestLev(incNestLev(p)) + + if p.trace { + defer un(trace(p, "IfStmt")) + } + + pos := p.expect(token.IF) + + init, cond := p.parseIfHeader() + body := p.parseBlockStmt() + + var else_ ast.Stmt + if p.tok == token.ELSE { + p.next() + switch p.tok { + case token.IF: + else_ = p.parseIfStmt() + case token.LBRACE: + else_ = p.parseBlockStmt() + p.expectSemi() + default: + p.errorExpected(p.pos, "if statement or block") + else_ = &ast.BadStmt{From: p.pos, To: p.pos} + } + } else { + p.expectSemi() + } + + return &ast.IfStmt{If: pos, Init: init, Cond: cond, Body: body, Else: else_} +} + +func (p *parser) parseCaseClause() *ast.CaseClause { + if p.trace { + defer un(trace(p, "CaseClause")) + } + + pos := p.pos + var list []ast.Expr + if p.tok == token.CASE { + p.next() + list = p.parseList(true) + } else { + p.expect(token.DEFAULT) + } + + colon := p.expect(token.COLON) + body := p.parseStmtList() + + return &ast.CaseClause{Case: pos, List: list, Colon: colon, Body: body} +} + +func isTypeSwitchAssert(x ast.Expr) bool { + a, ok := x.(*ast.TypeAssertExpr) + return ok && a.Type == nil +} + +func (p *parser) isTypeSwitchGuard(s ast.Stmt) bool { + switch t := s.(type) { + case *ast.ExprStmt: + // x.(type) + return isTypeSwitchAssert(t.X) + case *ast.AssignStmt: + // v := x.(type) + if len(t.Lhs) == 1 && len(t.Rhs) == 1 && isTypeSwitchAssert(t.Rhs[0]) { + switch t.Tok { + case token.ASSIGN: + // permit v = x.(type) but complain + p.error(t.TokPos, "expected ':=', found '='") + fallthrough + case token.DEFINE: + return true + } + } + } + return false +} + +func (p *parser) parseSwitchStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "SwitchStmt")) + } + + pos := p.expect(token.SWITCH) + + var s1, s2 ast.Stmt + if p.tok != token.LBRACE { + prevLev := p.exprLev + p.exprLev = -1 + if p.tok != token.SEMICOLON { + s2, _ = p.parseSimpleStmt(basic) + } + if p.tok == token.SEMICOLON { + p.next() + s1 = s2 + s2 = nil + if p.tok != token.LBRACE { + // A TypeSwitchGuard may declare a variable in addition + // to the variable declared in the initial SimpleStmt. + // Introduce extra scope to avoid redeclaration errors: + // + // switch t := 0; t := x.(T) { ... } + // + // (this code is not valid Go because the first t + // cannot be accessed and thus is never used, the extra + // scope is needed for the correct error message). + // + // If we don't have a type switch, s2 must be an expression. + // Having the extra nested but empty scope won't affect it. + s2, _ = p.parseSimpleStmt(basic) + } + } + p.exprLev = prevLev + } + + typeSwitch := p.isTypeSwitchGuard(s2) + lbrace := p.expect(token.LBRACE) + var list []ast.Stmt + for p.tok == token.CASE || p.tok == token.DEFAULT { + list = append(list, p.parseCaseClause()) + } + rbrace := p.expect(token.RBRACE) + p.expectSemi() + body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace} + + if typeSwitch { + return &ast.TypeSwitchStmt{Switch: pos, Init: s1, Assign: s2, Body: body} + } + + return &ast.SwitchStmt{Switch: pos, Init: s1, Tag: p.makeExpr(s2, "switch expression"), Body: body} +} + +func (p *parser) parseCommClause() *ast.CommClause { + if p.trace { + defer un(trace(p, "CommClause")) + } + + pos := p.pos + var comm ast.Stmt + if p.tok == token.CASE { + p.next() + lhs := p.parseList(false) + if p.tok == token.ARROW { + // SendStmt + if len(lhs) > 1 { + p.errorExpected(lhs[0].Pos(), "1 expression") + // continue with first expression + } + arrow := p.pos + p.next() + rhs := p.parseRhs() + comm = &ast.SendStmt{Chan: lhs[0], Arrow: arrow, Value: rhs} + } else { + // RecvStmt + if tok := p.tok; tok == token.ASSIGN || tok == token.DEFINE { + // RecvStmt with assignment + if len(lhs) > 2 { + p.errorExpected(lhs[0].Pos(), "1 or 2 expressions") + // continue with first two expressions + lhs = lhs[0:2] + } + pos := p.pos + p.next() + rhs := p.parseRhs() + comm = &ast.AssignStmt{Lhs: lhs, TokPos: pos, Tok: tok, Rhs: []ast.Expr{rhs}} + } else { + // lhs must be single receive operation + if len(lhs) > 1 { + p.errorExpected(lhs[0].Pos(), "1 expression") + // continue with first expression + } + comm = &ast.ExprStmt{X: lhs[0]} + } + } + } else { + p.expect(token.DEFAULT) + } + + colon := p.expect(token.COLON) + body := p.parseStmtList() + + return &ast.CommClause{Case: pos, Comm: comm, Colon: colon, Body: body} +} + +func (p *parser) parseSelectStmt() *ast.SelectStmt { + if p.trace { + defer un(trace(p, "SelectStmt")) + } + + pos := p.expect(token.SELECT) + lbrace := p.expect(token.LBRACE) + var list []ast.Stmt + for p.tok == token.CASE || p.tok == token.DEFAULT { + list = append(list, p.parseCommClause()) + } + rbrace := p.expect(token.RBRACE) + p.expectSemi() + body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace} + + return &ast.SelectStmt{Select: pos, Body: body} +} + +func (p *parser) parseForStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "ForStmt")) + } + + pos := p.expect(token.FOR) + + var s1, s2, s3 ast.Stmt + var isRange bool + if p.tok != token.LBRACE { + prevLev := p.exprLev + p.exprLev = -1 + if p.tok != token.SEMICOLON { + if p.tok == token.RANGE { + // "for range x" (nil lhs in assignment) + pos := p.pos + p.next() + y := []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}} + s2 = &ast.AssignStmt{Rhs: y} + isRange = true + } else { + s2, isRange = p.parseSimpleStmt(rangeOk) + } + } + if !isRange && p.tok == token.SEMICOLON { + p.next() + s1 = s2 + s2 = nil + if p.tok != token.SEMICOLON { + s2, _ = p.parseSimpleStmt(basic) + } + p.expectSemi() + if p.tok != token.LBRACE { + s3, _ = p.parseSimpleStmt(basic) + } + } + p.exprLev = prevLev + } + + body := p.parseBlockStmt() + p.expectSemi() + + if isRange { + as := s2.(*ast.AssignStmt) + // check lhs + var key, value ast.Expr + switch len(as.Lhs) { + case 0: + // nothing to do + case 1: + key = as.Lhs[0] + case 2: + key, value = as.Lhs[0], as.Lhs[1] + default: + p.errorExpected(as.Lhs[len(as.Lhs)-1].Pos(), "at most 2 expressions") + return &ast.BadStmt{From: pos, To: p.safePos(body.End())} + } + // parseSimpleStmt returned a right-hand side that + // is a single unary expression of the form "range x" + x := as.Rhs[0].(*ast.UnaryExpr).X + return &ast.RangeStmt{ + For: pos, + Key: key, + Value: value, + TokPos: as.TokPos, + Tok: as.Tok, + Range: as.Rhs[0].Pos(), + X: x, + Body: body, + } + } + + // regular for statement + return &ast.ForStmt{ + For: pos, + Init: s1, + Cond: p.makeExpr(s2, "boolean or range expression"), + Post: s3, + Body: body, + } +} + +func (p *parser) parseStmt() (s ast.Stmt) { + defer decNestLev(incNestLev(p)) + + if p.trace { + defer un(trace(p, "Statement")) + } + + switch p.tok { + case token.CONST, token.TYPE, token.VAR: + s = &ast.DeclStmt{Decl: p.parseDecl(stmtStart)} + case + // tokens that may start an expression + token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands + token.LBRACK, token.STRUCT, token.MAP, token.CHAN, token.INTERFACE, // composite types + token.ADD, token.SUB, token.MUL, token.AND, token.XOR, token.ARROW, token.NOT: // unary operators + s, _ = p.parseSimpleStmt(labelOk) + // because of the required look-ahead, labeled statements are + // parsed by parseSimpleStmt - don't expect a semicolon after + // them + if _, isLabeledStmt := s.(*ast.LabeledStmt); !isLabeledStmt { + p.expectSemi() + } + case token.GO: + s = p.parseGoStmt() + case token.DEFER: + s = p.parseDeferStmt() + case token.RETURN: + s = p.parseReturnStmt() + case token.BREAK, token.CONTINUE, token.GOTO, token.FALLTHROUGH: + s = p.parseBranchStmt(p.tok) + case token.LBRACE: + s = p.parseBlockStmt() + p.expectSemi() + case token.IF: + s = p.parseIfStmt() + case token.SWITCH: + s = p.parseSwitchStmt() + case token.SELECT: + s = p.parseSelectStmt() + case token.FOR: + s = p.parseForStmt() + case token.SEMICOLON: + // Is it ever possible to have an implicit semicolon + // producing an empty statement in a valid program? + // (handle correctly anyway) + s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: p.lit == "\n"} + p.next() + case token.RBRACE: + // a semicolon may be omitted before a closing "}" + s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: true} + default: + // no statement found + pos := p.pos + p.errorExpected(pos, "statement") + p.advance(stmtStart) + s = &ast.BadStmt{From: pos, To: p.pos} + } + + return +} + +// ---------------------------------------------------------------------------- +// Declarations + +type parseSpecFunction func(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec + +func (p *parser) parseImportSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec { + if p.trace { + defer un(trace(p, "ImportSpec")) + } + + var ident *ast.Ident + switch p.tok { + case token.IDENT: + ident = p.parseIdent() + case token.PERIOD: + ident = &ast.Ident{NamePos: p.pos, Name: "."} + p.next() + } + + pos := p.pos + var path string + if p.tok == token.STRING { + path = p.lit + p.next() + } else if p.tok.IsLiteral() { + p.error(pos, "import path must be a string") + p.next() + } else { + p.error(pos, "missing import path") + p.advance(exprEnd) + } + comment := p.expectSemi() + + // collect imports + spec := &ast.ImportSpec{ + Doc: doc, + Name: ident, + Path: &ast.BasicLit{ValuePos: pos, Kind: token.STRING, Value: path}, + Comment: comment, + } + p.imports = append(p.imports, spec) + + return spec +} + +func (p *parser) parseValueSpec(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec { + if p.trace { + defer un(trace(p, keyword.String()+"Spec")) + } + + idents := p.parseIdentList() + var typ ast.Expr + var values []ast.Expr + switch keyword { + case token.CONST: + // always permit optional type and initialization for more tolerant parsing + if p.tok != token.EOF && p.tok != token.SEMICOLON && p.tok != token.RPAREN { + typ = p.tryIdentOrType() + if p.tok == token.ASSIGN { + p.next() + values = p.parseList(true) + } + } + case token.VAR: + if p.tok != token.ASSIGN { + typ = p.parseType() + } + if p.tok == token.ASSIGN { + p.next() + values = p.parseList(true) + } + default: + panic("unreachable") + } + comment := p.expectSemi() + + spec := &ast.ValueSpec{ + Doc: doc, + Names: idents, + Type: typ, + Values: values, + Comment: comment, + } + return spec +} + +func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, typ0 ast.Expr) { + if p.trace { + defer un(trace(p, "parseGenericType")) + } + + list := p.parseParameterList(name0, typ0, token.RBRACK) + closePos := p.expect(token.RBRACK) + spec.TypeParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos} + // Let the type checker decide whether to accept type parameters on aliases: + // see go.dev/issue/46477. + if p.tok == token.ASSIGN { + // type alias + spec.Assign = p.pos + p.next() + } + spec.Type = p.parseType() +} + +func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec { + if p.trace { + defer un(trace(p, "TypeSpec")) + } + + name := p.parseIdent() + spec := &ast.TypeSpec{Doc: doc, Name: name} + + if p.tok == token.LBRACK { + // spec.Name "[" ... + // array/slice type or type parameter list + lbrack := p.pos + p.next() + if p.tok == token.IDENT { + // We may have an array type or a type parameter list. + // In either case we expect an expression x (which may + // just be a name, or a more complex expression) which + // we can analyze further. + // + // A type parameter list may have a type bound starting + // with a "[" as in: P []E. In that case, simply parsing + // an expression would lead to an error: P[] is invalid. + // But since index or slice expressions are never constant + // and thus invalid array length expressions, if the name + // is followed by "[" it must be the start of an array or + // slice constraint. Only if we don't see a "[" do we + // need to parse a full expression. Notably, name <- x + // is not a concern because name <- x is a statement and + // not an expression. + var x ast.Expr = p.parseIdent() + if p.tok != token.LBRACK { + // To parse the expression starting with name, expand + // the call sequence we would get by passing in name + // to parser.expr, and pass in name to parsePrimaryExpr. + p.exprLev++ + lhs := p.parsePrimaryExpr(x) + x = p.parseBinaryExpr(lhs, token.LowestPrec+1) + p.exprLev-- + } + // Analyze expression x. If we can split x into a type parameter + // name, possibly followed by a type parameter type, we consider + // this the start of a type parameter list, with some caveats: + // a single name followed by "]" tilts the decision towards an + // array declaration; a type parameter type that could also be + // an ordinary expression but which is followed by a comma tilts + // the decision towards a type parameter list. + if pname, ptype := extractName(x, p.tok == token.COMMA); pname != nil && (ptype != nil || p.tok != token.RBRACK) { + // spec.Name "[" pname ... + // spec.Name "[" pname ptype ... + // spec.Name "[" pname ptype "," ... + p.parseGenericType(spec, lbrack, pname, ptype) // ptype may be nil + } else { + // spec.Name "[" pname "]" ... + // spec.Name "[" x ... + spec.Type = p.parseArrayType(lbrack, x) + } + } else { + // array type + spec.Type = p.parseArrayType(lbrack, nil) + } + } else { + // no type parameters + if p.tok == token.ASSIGN { + // type alias + spec.Assign = p.pos + p.next() + } + spec.Type = p.parseType() + } + + spec.Comment = p.expectSemi() + + return spec +} + +// extractName splits the expression x into (name, expr) if syntactically +// x can be written as name expr. The split only happens if expr is a type +// element (per the isTypeElem predicate) or if force is set. +// If x is just a name, the result is (name, nil). If the split succeeds, +// the result is (name, expr). Otherwise the result is (nil, x). +// Examples: +// +// x force name expr +// ------------------------------------ +// P*[]int T/F P *[]int +// P*E T P *E +// P*E F nil P*E +// P([]int) T/F P []int +// P(E) T P E +// P(E) F nil P(E) +// P*E|F|~G T/F P *E|F|~G +// P*E|F|G T P *E|F|G +// P*E|F|G F nil P*E|F|G +func extractName(x ast.Expr, force bool) (*ast.Ident, ast.Expr) { + switch x := x.(type) { + case *ast.Ident: + return x, nil + case *ast.BinaryExpr: + switch x.Op { + case token.MUL: + if name, _ := x.X.(*ast.Ident); name != nil && (force || isTypeElem(x.Y)) { + // x = name *x.Y + return name, &ast.StarExpr{Star: x.OpPos, X: x.Y} + } + case token.OR: + if name, lhs := extractName(x.X, force || isTypeElem(x.Y)); name != nil && lhs != nil { + // x = name lhs|x.Y + op := *x + op.X = lhs + return name, &op + } + } + case *ast.CallExpr: + if name, _ := x.Fun.(*ast.Ident); name != nil { + if len(x.Args) == 1 && x.Ellipsis == token.NoPos && (force || isTypeElem(x.Args[0])) { + // x = name "(" x.ArgList[0] ")" + return name, x.Args[0] + } + } + } + return nil, x +} + +// isTypeElem reports whether x is a (possibly parenthesized) type element expression. +// The result is false if x could be a type element OR an ordinary (value) expression. +func isTypeElem(x ast.Expr) bool { + switch x := x.(type) { + case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType: + return true + case *ast.BinaryExpr: + return isTypeElem(x.X) || isTypeElem(x.Y) + case *ast.UnaryExpr: + return x.Op == token.TILDE + case *ast.ParenExpr: + return isTypeElem(x.X) + } + return false +} + +func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl { + if p.trace { + defer un(trace(p, "GenDecl("+keyword.String()+")")) + } + + doc := p.leadComment + pos := p.expect(keyword) + var lparen, rparen token.Pos + var list []ast.Spec + if p.tok == token.LPAREN { + lparen = p.pos + p.next() + for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ { + list = append(list, f(p.leadComment, keyword, iota)) + } + rparen = p.expect(token.RPAREN) + p.expectSemi() + } else { + list = append(list, f(nil, keyword, 0)) + } + + return &ast.GenDecl{ + Doc: doc, + TokPos: pos, + Tok: keyword, + Lparen: lparen, + Specs: list, + Rparen: rparen, + } +} + +func (p *parser) parseFuncDecl() *ast.FuncDecl { + if p.trace { + defer un(trace(p, "FunctionDecl")) + } + + doc := p.leadComment + pos := p.expect(token.FUNC) + + var recv *ast.FieldList + if p.tok == token.LPAREN { + _, recv = p.parseParameters(false) + } + + ident := p.parseIdent() + + tparams, params := p.parseParameters(true) + if recv != nil && tparams != nil { + // Method declarations do not have type parameters. We parse them for a + // better error message and improved error recovery. + p.error(tparams.Opening, "method must have no type parameters") + tparams = nil + } + results := p.parseResult() + + var body *ast.BlockStmt + switch p.tok { + case token.LBRACE: + body = p.parseBody() + p.expectSemi() + case token.SEMICOLON: + p.next() + if p.tok == token.LBRACE { + // opening { of function declaration on next line + p.error(p.pos, "unexpected semicolon or newline before {") + body = p.parseBody() + p.expectSemi() + } + default: + p.expectSemi() + } + + decl := &ast.FuncDecl{ + Doc: doc, + Recv: recv, + Name: ident, + Type: &ast.FuncType{ + Func: pos, + TypeParams: tparams, + Params: params, + Results: results, + }, + Body: body, + } + return decl +} + +func (p *parser) parseDecl(sync map[token.Token]bool) ast.Decl { + if p.trace { + defer un(trace(p, "Declaration")) + } + + var f parseSpecFunction + switch p.tok { + case token.IMPORT: + f = p.parseImportSpec + + case token.CONST, token.VAR: + f = p.parseValueSpec + + case token.TYPE: + f = p.parseTypeSpec + + case token.FUNC: + return p.parseFuncDecl() + + default: + pos := p.pos + p.errorExpected(pos, "declaration") + p.advance(sync) + return &ast.BadDecl{From: pos, To: p.pos} + } + + return p.parseGenDecl(p.tok, f) +} diff --git a/shell/paths.go b/goal/transpile/paths.go similarity index 94% rename from shell/paths.go rename to goal/transpile/paths.go index 3dc2029484..669da539d4 100644 --- a/shell/paths.go +++ b/goal/transpile/paths.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package transpile import ( "go/token" @@ -57,6 +57,11 @@ func (tk Tokens) Path(idx0 bool) (string, int) { ci += tin + 1 str = tid + tk[tin].String() lastEnd += len(str) + if ci >= n || int(tk[ci].Pos) > lastEnd { // just 2 or 2 and a space + if tk[tin].Tok == token.COLON { // go Ident: static initializer + return "", 0 + } + } } prevWasDelim := true for { diff --git a/goal/transpile/state.go b/goal/transpile/state.go new file mode 100644 index 0000000000..e953eea3dd --- /dev/null +++ b/goal/transpile/state.go @@ -0,0 +1,225 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "errors" + "fmt" + "go/format" + "log/slog" + "os" + "strings" + + "cogentcore.org/core/base/logx" + "cogentcore.org/core/base/num" + "cogentcore.org/core/base/stringsx" + "golang.org/x/tools/imports" +) + +// State holds the transpiling state +type State struct { + // FuncToVar translates function definitions into variable definitions, + // which is the default for interactive use of random code fragments + // without the complete go formatting. + // For pure transpiling of a complete codebase with full proper Go formatting + // this should be turned off. + FuncToVar bool + + // MathMode is on when math mode is turned on. + MathMode bool + + // MathRecord is state of auto-recording of data into current directory + // in math mode. + MathRecord bool + + // depth of delim at the end of the current line. if 0, was complete. + ParenDepth, BraceDepth, BrackDepth, TypeDepth, DeclDepth int + + // Chunks of code lines that are accumulated during Transpile, + // each of which should be evaluated separately, to avoid + // issues with contextual effects from import, package etc. + Chunks []string + + // current stack of transpiled lines, that are accumulated into + // code Chunks. + Lines []string + + // stack of runtime errors. + Errors []error + + // if this is non-empty, it is the name of the last command defined. + // triggers insertion of the AddCommand call to add to list of defined commands. + lastCommand string +} + +// NewState returns a new transpiling state; mostly for testing +func NewState() *State { + st := &State{FuncToVar: true} + return st +} + +// TranspileCode processes each line of given code, +// adding the results to the LineStack +func (st *State) TranspileCode(code string) { + lns := strings.Split(code, "\n") + n := len(lns) + if n == 0 { + return + } + for _, ln := range lns { + hasDecl := st.DeclDepth > 0 + tl := st.TranspileLine(ln) + st.AddLine(tl) + if st.BraceDepth == 0 && st.BrackDepth == 0 && st.ParenDepth == 1 && st.lastCommand != "" { + st.lastCommand = "" + nl := len(st.Lines) + st.Lines[nl-1] = st.Lines[nl-1] + ")" + st.ParenDepth-- + } + if hasDecl && st.DeclDepth == 0 { // break at decl + st.AddChunk() + } + } +} + +// TranspileFile transpiles the given input goal file to the +// given output Go file. If no existing package declaration +// is found, then package main and func main declarations are +// added. This also affects how functions are interpreted. +func (st *State) TranspileFile(in string, out string) error { + b, err := os.ReadFile(in) + if err != nil { + return err + } + code := string(b) + lns := stringsx.SplitLines(code) + hasPackage := false + for _, ln := range lns { + if strings.HasPrefix(ln, "package ") { + hasPackage = true + break + } + } + if hasPackage { + st.FuncToVar = false // use raw functions + } + st.TranspileCode(code) + st.FuncToVar = true + if err != nil { + return err + } + + hdr := `package main +import ( + "cogentcore.org/core/goal" + "cogentcore.org/core/goal/goalib" + "cogentcore.org/core/tensor" + _ "cogentcore.org/core/tensor/tmath" + _ "cogentcore.org/core/tensor/stats/stats" + _ "cogentcore.org/core/tensor/stats/metric" +) + +func main() { + goal := goal.NewGoal() + _ = goal +` + + src := st.Code() + res := []byte(src) + bsrc := res + gen := fmt.Sprintf("// Code generated by \"goal build\"; DO NOT EDIT.\n//line %s:1\n", in) + if hasPackage { + bsrc = []byte(gen + src) + res, err = format.Source(bsrc) + } else { + bsrc = []byte(gen + hdr + src + "\n}") + res, err = imports.Process(out, bsrc, nil) + } + if err != nil { + res = bsrc + fmt.Println(err.Error()) + } else { + err = st.DepthError() + } + werr := os.WriteFile(out, res, 0666) + return errors.Join(err, werr) +} + +// TotalDepth returns the sum of any unresolved paren, brace, or bracket depths. +func (st *State) TotalDepth() int { + return num.Abs(st.ParenDepth) + num.Abs(st.BraceDepth) + num.Abs(st.BrackDepth) +} + +// ResetCode resets the stack of transpiled code +func (st *State) ResetCode() { + st.Chunks = nil + st.Lines = nil +} + +// ResetDepth resets the current depths to 0 +func (st *State) ResetDepth() { + st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth = 0, 0, 0, 0, 0 +} + +// DepthError reports an error if any of the parsing depths are not zero, +// to be called at the end of transpiling a complete block of code. +func (st *State) DepthError() error { + if st.TotalDepth() == 0 { + return nil + } + str := "" + if st.ParenDepth != 0 { + str += fmt.Sprintf("Incomplete parentheses (), remaining depth: %d\n", st.ParenDepth) + } + if st.BraceDepth != 0 { + str += fmt.Sprintf("Incomplete braces [], remaining depth: %d\n", st.BraceDepth) + } + if st.BrackDepth != 0 { + str += fmt.Sprintf("Incomplete brackets {}, remaining depth: %d\n", st.BrackDepth) + } + if str != "" { + slog.Error(str) + return errors.New(str) + } + return nil +} + +// AddLine adds line on the stack +func (st *State) AddLine(ln string) { + st.Lines = append(st.Lines, ln) +} + +// Code returns the current transpiled lines, +// split into chunks that should be compiled separately. +func (st *State) Code() string { + st.AddChunk() + if len(st.Chunks) == 0 { + return "" + } + return strings.Join(st.Chunks, "\n") +} + +// AddChunk adds current lines into a chunk of code +// that should be compiled separately. +func (st *State) AddChunk() { + if len(st.Lines) == 0 { + return + } + st.Chunks = append(st.Chunks, strings.Join(st.Lines, "\n")) + st.Lines = nil +} + +// AddError adds the given error to the error stack if it is non-nil, +// and calls the Cancel function if set, to stop execution. +// This is the main way that goal errors are handled. +// It also prints the error. +func (st *State) AddError(err error) error { + if err == nil { + return nil + } + st.Errors = append(st.Errors, err) + logx.PrintlnError(err) + return err +} diff --git a/shell/token.go b/goal/transpile/token.go similarity index 83% rename from shell/token.go rename to goal/transpile/token.go index a84c01a790..335c63279d 100644 --- a/shell/token.go +++ b/goal/transpile/token.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package shell +package transpile import ( "go/scanner" @@ -28,6 +28,30 @@ type Token struct { Pos token.Pos } +// Tokens converts the string into tokens +func TokensFromString(ln string) Tokens { + fset := token.NewFileSet() + f := fset.AddFile("", fset.Base(), len(ln)) + var sc scanner.Scanner + sc.Init(f, []byte(ln), errHandler, scanner.ScanComments|2) // 2 is non-exported dontInsertSemis + // note to Go team: just export this stuff. seriously. + + var toks Tokens + for { + pos, tok, lit := sc.Scan() + if tok == token.EOF { + break + } + // logx.PrintfDebug(" token: %s\t%s\t%q\n", fset.Position(pos), tok, lit) + toks = append(toks, &Token{Tok: tok, Pos: pos, Str: lit}) + } + return toks +} + +func errHandler(pos token.Position, msg string) { + logx.PrintlnDebug("Scan Error:", pos, msg) +} + // Tokens is a slice of Token type Tokens []*Token @@ -47,8 +71,15 @@ func (tk *Tokens) Add(tok token.Token, str ...string) *Token { return nt } +// AddMulti adds new basic tokens (not IDENT). +func (tk *Tokens) AddMulti(tok ...token.Token) { + for _, t := range tok { + tk.Add(t) + } +} + // AddTokens adds given tokens to our list -func (tk *Tokens) AddTokens(toks Tokens) *Tokens { +func (tk *Tokens) AddTokens(toks ...*Token) *Tokens { *tk = append(*tk, toks...) return tk } @@ -174,6 +205,11 @@ func (tk Tokens) Code() string { } str += tok.String() prvIdent = true + case tok.Tok == token.COMMENT: + if str != "" { + str += " " + } + str += tok.String() case tok.IsGo(): if prvIdent { str += " " @@ -308,26 +344,44 @@ func (tk Tokens) BracketDepths() (paren, brace, brack int) { return } -// Tokens converts the string into tokens -func (sh *Shell) Tokens(ln string) Tokens { - fset := token.NewFileSet() - f := fset.AddFile("", fset.Base(), len(ln)) - var sc scanner.Scanner - sc.Init(f, []byte(ln), sh.errHandler, scanner.ScanComments|2) // 2 is non-exported dontInsertSemis - // note to Go team: just export this stuff. seriously. - - var toks Tokens - for { - pos, tok, lit := sc.Scan() - if tok == token.EOF { - break +// ModeEnd returns the position (or -1 if not found) for the +// next ILLEGAL mode token ($ or #) given the starting one that +// is at the 0 position of the current set of tokens. +func (tk Tokens) ModeEnd() int { + n := len(tk) + if n == 0 { + return -1 + } + st := tk[0].Str + for i := 1; i < n; i++ { + if tk[i].Tok != token.ILLEGAL { + continue + } + if tk[i].Str == st { + return i } - // logx.PrintfDebug(" token: %s\t%s\t%q\n", fset.Position(pos), tok, lit) - toks = append(toks, &Token{Tok: tok, Pos: pos, Str: lit}) } - return toks + return -1 } -func (sh *Shell) errHandler(pos token.Position, msg string) { - logx.PrintlnDebug("Scan Error:", pos, msg) +// IsAssignExpr checks if there are any Go assignment or define tokens +// outside of { } Go code. +func (tk Tokens) IsAssignExpr() bool { + n := len(tk) + if n == 0 { + return false + } + for i := 1; i < n; i++ { + tok := tk[i].Tok + if tok == token.ASSIGN || tok == token.DEFINE || (tok >= token.ADD_ASSIGN && tok <= token.AND_NOT_ASSIGN) { + return true + } + if tok == token.LBRACE { // skip Go mode + rp := tk[i:n].RightMatching() + if rp > 0 { + i += rp + } + } + } + return false } diff --git a/goal/transpile/transpile.go b/goal/transpile/transpile.go new file mode 100644 index 0000000000..2e86475639 --- /dev/null +++ b/goal/transpile/transpile.go @@ -0,0 +1,431 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "fmt" + "go/token" + "slices" + "strings" + + "cogentcore.org/core/base/logx" +) + +// TranspileLine is the main function for parsing a single line of goal input, +// returning a new transpiled line of code that converts Exec code into corresponding +// Go function calls. +func (st *State) TranspileLine(code string) string { + if len(code) == 0 { + return code + } + if strings.HasPrefix(code, "#!") { + return "" + } + toks := st.TranspileLineTokens(code) + paren, brace, brack := toks.BracketDepths() + st.ParenDepth += paren + st.BraceDepth += brace + st.BrackDepth += brack + // logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth) + if st.TypeDepth > 0 && st.BraceDepth == 0 { + st.TypeDepth = 0 + } + if st.DeclDepth > 0 && (st.ParenDepth == 0 && st.BraceDepth == 0) { + st.DeclDepth = 0 + } + // logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth) + return toks.Code() +} + +// TranspileLineTokens returns the tokens for the full line +func (st *State) TranspileLineTokens(code string) Tokens { + if code == "" { + return nil + } + toks := TokensFromString(code) + n := len(toks) + if n == 0 { + return toks + } + if st.MathMode { + if len(toks) >= 2 { + if toks[0].Tok == token.ILLEGAL && toks[0].Str == "#" && toks[1].Tok == token.ILLEGAL && toks[1].Str == "#" { + st.MathMode = false + return nil + } + } + return st.TranspileMath(toks, code, true) + } + ewords, err := ExecWords(code) + if err != nil { + st.AddError(err) + return nil + } + logx.PrintlnDebug("\n########## line:\n", code, "\nTokens:", len(toks), "\n", toks.String(), "\nWords:", len(ewords), "\n", ewords) + + if toks[0].Tok == token.TYPE { + st.TypeDepth++ + } + if toks[0].Tok == token.IMPORT || toks[0].Tok == token.VAR || toks[0].Tok == token.CONST { + st.DeclDepth++ + } + // logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth) + if st.TypeDepth > 0 || st.DeclDepth > 0 { + logx.PrintlnDebug("go: type / decl defn") + return st.TranspileGo(toks, code) + } + + t0 := toks[0] + _, t0pn := toks.Path(true) // true = first position + en := len(ewords) + + f0exec := (t0.Tok == token.IDENT && ExecWordIsCommand(ewords[0])) + if f0exec && n > 1 && toks[1].Tok == token.COLON { // go Ident: static initializer + f0exec = false + } + + switch { + case t0.Tok == token.ILLEGAL: + if t0.Str == "#" { + logx.PrintlnDebug("math #") + if toks[1].Tok == token.ILLEGAL && toks[1].Str == "#" { + st.MathMode = true + return nil + } + return st.TranspileMath(toks[1:], code, true) + } + return st.TranspileExec(ewords, false) + case t0.Tok == token.LBRACE: + logx.PrintlnDebug("go: { } line") + return st.TranspileGoRange(toks, code, 1, n-1) + case t0.Tok == token.LBRACK: + logx.PrintlnDebug("exec: [ ] line") + return st.TranspileExec(ewords, false) // it processes the [ ] + case t0.Tok == token.IDENT && t0.Str == "command": + st.lastCommand = toks[1].Str // 1 is the name -- triggers AddCommand + toks = toks[2:] // get rid of first + toks.Insert(0, token.IDENT, "goal.AddCommand") + toks.Insert(1, token.LPAREN) + toks.Insert(2, token.STRING, `"`+st.lastCommand+`"`) + toks.Insert(3, token.COMMA) + toks.Insert(4, token.FUNC) + toks.Insert(5, token.LPAREN) + toks.Insert(6, token.IDENT, "args") + toks.Insert(7, token.ELLIPSIS) + toks.Insert(8, token.IDENT, "string") + toks.Insert(9, token.RPAREN) + toks.AddTokens(st.TranspileGo(toks[11:], code)...) + case t0.IsGo(): + if t0.Tok == token.GO { + if !toks.Contains(token.LPAREN) { + logx.PrintlnDebug("exec: go command") + return st.TranspileExec(ewords, false) + } + } + logx.PrintlnDebug("go keyword") + return st.TranspileGo(toks, code) + case toks[n-1].Tok == token.INC || toks[n-1].Tok == token.DEC: + logx.PrintlnDebug("go ++ / --") + return st.TranspileGo(toks, code) + case t0pn > 0: // path expr + logx.PrintlnDebug("exec: path...") + return st.TranspileExec(ewords, false) + case t0.Tok == token.STRING: + logx.PrintlnDebug("exec: string...") + return st.TranspileExec(ewords, false) + case f0exec && en == 1: + logx.PrintlnDebug("exec: 1 word") + return st.TranspileExec(ewords, false) + case !f0exec: // exec must be IDENT + logx.PrintlnDebug("go: not ident") + return st.TranspileGo(toks, code) + case f0exec && en > 1 && ewords[0] != "set" && toks.IsAssignExpr(): + logx.PrintlnDebug("go: assignment or defn") + return st.TranspileGo(toks, code) + case f0exec && en > 1 && ewords[0] != "set" && toks.IsAssignExpr(): + logx.PrintlnDebug("go: assignment or defn") + return st.TranspileGo(toks, code) + case f0exec: // now any ident + logx.PrintlnDebug("exec: ident..") + return st.TranspileExec(ewords, false) + default: + logx.PrintlnDebug("go: default") + return st.TranspileGo(toks, code) + } + return toks +} + +// TranspileGoRange returns transpiled tokens assuming Go code, +// for given start, end (exclusive) range of given tokens and code. +// In general the positions in the tokens applies to the _original_ code +// so you should just keep the original code string. However, this is +// needed for a specific case. +func (st *State) TranspileGoRange(toks Tokens, code string, start, end int) Tokens { + codeSt := toks[start].Pos - 1 + codeEd := token.Pos(len(code)) + if end <= len(toks)-1 { + codeEd = toks[end].Pos - 1 + } + return st.TranspileGo(toks[start:end], code[codeSt:codeEd]) +} + +// TranspileGo returns transpiled tokens assuming Go code. +// Unpacks any encapsulated shell or math expressions. +func (st *State) TranspileGo(toks Tokens, code string) Tokens { + n := len(toks) + if n == 0 { + return toks + } + if st.FuncToVar && toks[0].Tok == token.FUNC { // reorder as an assignment + if len(toks) > 1 && toks[1].Tok == token.IDENT { + toks[0] = toks[1] + toks.Insert(1, token.DEFINE) + toks[2] = &Token{Tok: token.FUNC} + n = len(toks) + } + } + gtoks := make(Tokens, 0, len(toks)) // return tokens + for i := 0; i < n; i++ { + tok := toks[i] + switch { + case tok.Tok == token.ILLEGAL: + et := toks[i:].ModeEnd() + if et > 0 { + if tok.Str == "#" { + gtoks.AddTokens(st.TranspileMath(toks[i+1:i+et], code, false)...) + } else { + gtoks.AddTokens(st.TranspileExecTokens(toks[i+1:i+et+1], code, true)...) + } + i += et + continue + } else { + gtoks = append(gtoks, tok) + } + case tok.Tok == token.LBRACK && i > 0 && toks[i-1].Tok == token.IDENT: // index expr + ixtoks := toks[i:] + rm := ixtoks.RightMatching() + if rm < 3 { + gtoks = append(gtoks, tok) + continue + } + idx := st.TranspileGoNDimIndex(toks, code, >oks, i-1, rm+i) + if idx > 0 { + i = idx + } else { + gtoks = append(gtoks, tok) + } + default: + gtoks = append(gtoks, tok) + } + } + return gtoks +} + +// TranspileExecString returns transpiled tokens assuming Exec code, +// from a string, with the given bool indicating whether [Output] is needed. +func (st *State) TranspileExecString(str string, output bool) Tokens { + if len(str) <= 1 { + return nil + } + ewords, err := ExecWords(str) + if err != nil { + st.AddError(err) + } + return st.TranspileExec(ewords, output) +} + +// TranspileExecTokens returns transpiled tokens assuming Exec code, +// from given tokens, with the given bool indicating +// whether [Output] is needed. +func (st *State) TranspileExecTokens(toks Tokens, code string, output bool) Tokens { + nt := len(toks) + if nt == 0 { + return nil + } + str := code[toks[0].Pos-1 : toks[nt-1].Pos-1] + return st.TranspileExecString(str, output) +} + +// TranspileExec returns transpiled tokens assuming Exec code, +// with the given bools indicating the type of run to execute. +func (st *State) TranspileExec(ewords []string, output bool) Tokens { + n := len(ewords) + if n == 0 { + return nil + } + etoks := make(Tokens, 0, n+5) // return tokens + var execTok *Token + bgJob := false + noStop := false + if ewords[0] == "[" { + ewords = ewords[1:] + n-- + noStop = true + } + startExec := func() { + bgJob = false + etoks.Add(token.IDENT, "goal") + etoks.Add(token.PERIOD) + switch { + case output && noStop: + execTok = etoks.Add(token.IDENT, "OutputErrOK") + case output && !noStop: + execTok = etoks.Add(token.IDENT, "Output") + case !output && noStop: + execTok = etoks.Add(token.IDENT, "RunErrOK") + case !output && !noStop: + execTok = etoks.Add(token.IDENT, "Run") + } + etoks.Add(token.LPAREN) + } + endExec := func() { + if bgJob { + execTok.Str = "Start" + } + etoks.DeleteLastComma() + etoks.Add(token.RPAREN) + } + + startExec() + + for i := 0; i < n; i++ { + f := ewords[i] + switch { + case f == "{": // embedded go + if n < i+3 { + st.AddError(fmt.Errorf("goal: no matching right brace } found in exec command line")) + } else { + gstr := ewords[i+1] + etoks.AddTokens(st.TranspileGo(TokensFromString(gstr), gstr)...) + etoks.Add(token.COMMA) + i += 2 + } + case f == "[": + noStop = true + case f == "]": // solo is def end + // just skip + noStop = false + case f == "&": + bgJob = true + case f[0] == '|': + execTok.Str = "Start" + etoks.Add(token.IDENT, AddQuotes(f)) + etoks.Add(token.COMMA) + endExec() + etoks.Add(token.SEMICOLON) + etoks.AddTokens(st.TranspileExec(ewords[i+1:], output)...) + return etoks + case f == ";": + endExec() + etoks.Add(token.SEMICOLON) + etoks.AddTokens(st.TranspileExec(ewords[i+1:], output)...) + return etoks + default: + if f[0] == '"' || f[0] == '`' { + etoks.Add(token.STRING, f) + } else { + etoks.Add(token.IDENT, AddQuotes(f)) // mark as an IDENT but add quotes! + } + etoks.Add(token.COMMA) + } + } + endExec() + return etoks +} + +// TranspileGoNDimIndex processes an ident[*] sequence of tokens, +// translating it into a corresponding tensor Value or Set expression, +// if it is a multi-dimensional indexing expression which is not valid in Go, +// to support simple n-dimensional tensor indexing in Go (not math) mode. +// Gets the current sequence of toks tokens, where the ident starts at idIdx +// and the ] is at rbIdx. It puts the results in gtoks generated tokens. +// Returns a positive index to resume processing at, if it is actually an +// n-dimensional expr, and -1 if not, in which case the normal process resumes. +func (st *State) TranspileGoNDimIndex(toks Tokens, code string, gtoks *Tokens, idIdx, rbIdx int) int { + var commas []int + for i := idIdx + 2; i < rbIdx; i++ { + tk := toks[i] + if tk.Tok == token.COMMA { + commas = append(commas, i) + } + if tk.Tok == token.LPAREN || tk.Tok == token.LBRACE || tk.Tok == token.LBRACK { + rp := toks[i:rbIdx].RightMatching() + if rp > 0 { + i += rp + } + } + } + if len(commas) == 0 { // not multidim + return -1 + } + isPtr := false + if idIdx > 0 && toks[idIdx-1].Tok == token.AND { + isPtr = true + lgt := len(*gtoks) + *gtoks = slices.Delete(*gtoks, lgt-2, lgt-1) // get rid of & + } + // now we need to determine if it is a Set based on what happens after rb + isSet := false + stok := token.ILLEGAL + n := len(toks) + hasComment := false + if toks[n-1].Tok == token.COMMENT { + hasComment = true + n-- + } + if n-rbIdx > 1 { + ntk := toks[rbIdx+1].Tok + if ntk == token.ASSIGN || (ntk >= token.ADD_ASSIGN && ntk <= token.QUO_ASSIGN) { + isSet = true + stok = ntk + } + } + fun := "Value" + if isPtr { + fun = "ValuePtr" + isSet = false + } else if isSet { + fun = "Set" + switch stok { + case token.ADD_ASSIGN: + fun += "Add" + case token.SUB_ASSIGN: + fun += "Sub" + case token.MUL_ASSIGN: + fun += "Mul" + case token.QUO_ASSIGN: + fun += "Div" + } + } + gtoks.Add(token.PERIOD) + gtoks.Add(token.IDENT, fun) + gtoks.Add(token.LPAREN) + if isSet { + gtoks.AddTokens(st.TranspileGo(toks[rbIdx+2:n], code)...) + gtoks.Add(token.COMMA) + } + sti := idIdx + 2 + for _, cp := range commas { + gtoks.Add(token.IDENT, "int") + gtoks.Add(token.LPAREN) + gtoks.AddTokens(st.TranspileGo(toks[sti:cp], code)...) + gtoks.Add(token.RPAREN) + gtoks.Add(token.COMMA) + sti = cp + 1 + } + gtoks.Add(token.IDENT, "int") + gtoks.Add(token.LPAREN) + gtoks.AddTokens(st.TranspileGo(toks[sti:rbIdx], code)...) + gtoks.Add(token.RPAREN) + gtoks.Add(token.RPAREN) + if isSet { + if hasComment { + gtoks.AddTokens(toks[len(toks)-1]) + } + return len(toks) + } else { + return rbIdx + } +} diff --git a/goal/transpile/transpile_test.go b/goal/transpile/transpile_test.go new file mode 100644 index 0000000000..b9d8841ed5 --- /dev/null +++ b/goal/transpile/transpile_test.go @@ -0,0 +1,310 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "testing" + + _ "cogentcore.org/core/tensor/stats/metric" + _ "cogentcore.org/core/tensor/stats/stats" + _ "cogentcore.org/core/tensor/tmath" + "github.com/stretchr/testify/assert" +) + +type exIn struct { + i string + e string +} + +type wexIn struct { + i string + isErr bool + e []string +} + +// these are more general tests of full-line statements of various forms +func TestExecWords(t *testing.T) { + tests := []wexIn{ + {`ls`, false, []string{`ls`}}, + {`cat "be"`, false, []string{`cat`, `"be"`}}, + {`cat "be`, true, []string{`cat`, `"be`}}, + {`cat "be a thing"`, false, []string{`cat`, `"be a thing"`}}, + {`cat "{be \"a\" thing}"`, false, []string{`cat`, `"{be \"a\" thing}"`}}, + {`cat {vals[1:10]}`, false, []string{`cat`, `{`, `vals[1:10]`, `}`}}, + {`cat {myfunc(vals[1:10], "test", false)}`, false, []string{`cat`, `{`, `myfunc(vals[1:10],"test",false)`, `}`}}, + {`cat vals[1:10]`, false, []string{`cat`, `vals[1:10]`}}, + {`cat vals...`, false, []string{`cat`, `vals...`}}, + {`[cat vals...]`, false, []string{`[`, `cat`, `vals...`, `]`}}, + {`[cat vals...]; ls *.tsv`, false, []string{`[`, `cat`, `vals...`, `]`, `;`, `ls`, `*.tsv`}}, + {`cat vals... | grep -v "b"`, false, []string{`cat`, `vals...`, `|`, `grep`, `-v`, `"b"`}}, + {`cat vals...>&file.out`, false, []string{`cat`, `vals...`, `>&`, `file.out`}}, + {`cat vals...>&@0:file.out`, false, []string{`cat`, `vals...`, `>&`, `@0:file.out`}}, + {`./"Cogent Code"`, false, []string{`./"Cogent Code"`}}, + {`Cogent\ Code`, false, []string{`Cogent Code`}}, + {`./Cogent\ Code`, false, []string{`./Cogent Code`}}, + } + for _, test := range tests { + o, err := ExecWords(test.i) + assert.Equal(t, test.e, o) + if err != nil { + if !test.isErr { + t.Error("should not have been an error:", test.i) + } + } else if test.isErr { + t.Error("was supposed to be an error:", test.i) + } + } +} + +// Paths tests the Path() code +func TestPaths(t *testing.T) { + // logx.UserLevel = slog.LevelDebug + tests := []exIn{ + {`fmt.Println("hi")`, `fmt.Println`}, + {"./goal -i", `./goal`}, + {"main.go", `main.go`}, + {"cogent/", `cogent/`}, + {`./"Cogent Code"`, `./\"Cogent Code\"`}, + {`Cogent\ Code`, ``}, + {`./Cogent\ Code`, `./Cogent Code`}, + {"./ios-deploy", `./ios-deploy`}, + {"ios_deploy/sub", `ios_deploy/sub`}, + {"C:/ios_deploy/sub", `C:/ios_deploy/sub`}, + {"..", `..`}, + {"../another/dir/to/go_to", `../another/dir/to/go_to`}, + {"../an-other/dir/", `../an-other/dir/`}, + {"https://google.com/search?q=hello%20world#body", `https://google.com/search?q=hello%20world#body`}, + } + for _, test := range tests { + toks := TokensFromString(test.i) + p, _ := toks.Path(false) + assert.Equal(t, test.e, p) + } +} + +// these are more general tests of full-line statements of various forms +func TestTranspile(t *testing.T) { + // logx.UserLevel = slog.LevelDebug + tests := []exIn{ + {"ls", `goal.Run("ls")`}, + {"$ls -la$", `goal.Run("ls", "-la")`}, + {"ls -la", `goal.Run("ls", "-la")`}, + {"chmod +x file", `goal.Run("chmod", "+x", "file")`}, + {"ls --help", `goal.Run("ls", "--help")`}, + {"ls go", `goal.Run("ls", "go")`}, + {"cd go", `goal.Run("cd", "go")`}, + {`var name string`, `var name string`}, + {`name = "test"`, `name = "test"`}, + {`echo {name}`, `goal.Run("echo", name)`}, + {`echo "testing"`, `goal.Run("echo", "testing")`}, + {`number := 1.23`, `number := 1.23`}, + {`res1, res2 := FunTwoRet()`, `res1, res2 := FunTwoRet()`}, + {`res1, res2, res3 := FunThreeRet()`, `res1, res2, res3 := FunThreeRet()`}, + {`println("hi")`, `println("hi")`}, + {`fmt.Println("hi")`, `fmt.Println("hi")`}, + {`for i := 0; i < 3; i++ { fmt.Println(i, "\n")`, `for i := 0; i < 3; i++ { fmt.Println(i, "\n")`}, + {"for i, v := range $ls -la$ {", `for i, v := range goal.Output("ls", "-la") {`}, + {`// todo: fixit`, `// todo: fixit`}, + {"$go build$", `goal.Run("go", "build")`}, + {"{go build()}", `go build()`}, + {"go build", `goal.Run("go", "build")`}, + {"go build()", `go build()`}, + {"go build &", `goal.Start("go", "build")`}, + {"[mkdir subdir]", `goal.RunErrOK("mkdir", "subdir")`}, + {"set something hello-1", `goal.Run("set", "something", "hello-1")`}, + {"set something = hello", `goal.Run("set", "something", "=", "hello")`}, + {`set something = "hello"`, `goal.Run("set", "something", "=", "hello")`}, + {`set something=hello`, `goal.Run("set", "something=hello")`}, + {`set "something=hello"`, `goal.Run("set", "something=hello")`}, + {`set something="hello"`, `goal.Run("set", "something=\"hello\"")`}, + {`add-path /opt/sbin /opt/homebrew/bin`, `goal.Run("add-path", "/opt/sbin", "/opt/homebrew/bin")`}, + {`cat file > test.out`, `goal.Run("cat", "file", ">", "test.out")`}, + {`cat file | grep -v exe > test.out`, `goal.Start("cat", "file", "|"); goal.Run("grep", "-v", "exe", ">", "test.out")`}, + {`cd sub; pwd; ls -la`, `goal.Run("cd", "sub"); goal.Run("pwd"); goal.Run("ls", "-la")`}, + {`cd sub; [mkdir sub]; ls -la`, `goal.Run("cd", "sub"); goal.RunErrOK("mkdir", "sub"); goal.Run("ls", "-la")`}, + {`cd sub; mkdir names[4]`, `goal.Run("cd", "sub"); goal.Run("mkdir", "names[4]")`}, + {"ls -la > test.out", `goal.Run("ls", "-la", ">", "test.out")`}, + {"ls > test.out", `goal.Run("ls", ">", "test.out")`}, + {"ls -la >test.out", `goal.Run("ls", "-la", ">", "test.out")`}, + {"ls -la >> test.out", `goal.Run("ls", "-la", ">>", "test.out")`}, + {"ls -la >& test.out", `goal.Run("ls", "-la", ">&", "test.out")`}, + {"ls -la >>& test.out", `goal.Run("ls", "-la", ">>&", "test.out")`}, + {"ls | grep ev", `goal.Start("ls", "|"); goal.Run("grep", "ev")`}, + {"@1 ls -la", `goal.Run("@1", "ls", "-la")`}, + {"git switch main", `goal.Run("git", "switch", "main")`}, + {"git checkout 123abc", `goal.Run("git", "checkout", "123abc")`}, + {"go get cogentcore.org/core@main", `goal.Run("go", "get", "cogentcore.org/core@main")`}, + {"ls *.go", `goal.Run("ls", "*.go")`}, + {"ls ??.go", `goal.Run("ls", "??.go")`}, + {`fmt.Println("hi")`, `fmt.Println("hi")`}, + {"goal -i", `goal.Run("goal", "-i")`}, + {"./goal -i", `goal.Run("./goal", "-i")`}, + {"cat main.go", `goal.Run("cat", "main.go")`}, + {"cd cogent", `goal.Run("cd", "cogent")`}, + {"cd cogent/", `goal.Run("cd", "cogent/")`}, + {"echo $PATH", `goal.Run("echo", "$PATH")`}, + {`"./Cogent Code"`, `goal.Run("./Cogent Code")`}, + {`./"Cogent Code"`, `goal.Run("./\"Cogent Code\"")`}, + {`Cogent\ Code`, `goal.Run("Cogent Code")`}, + {`./Cogent\ Code`, `goal.Run("./Cogent Code")`}, + {`ios\ deploy -i`, `goal.Run("ios deploy", "-i")`}, + {"./ios-deploy -i", `goal.Run("./ios-deploy", "-i")`}, + {"ios_deploy -i tree_file", `goal.Run("ios_deploy", "-i", "tree_file")`}, + {"ios_deploy/sub -i tree_file", `goal.Run("ios_deploy/sub", "-i", "tree_file")`}, + {"C:/ios_deploy/sub -i tree_file", `goal.Run("C:/ios_deploy/sub", "-i", "tree_file")`}, + {"ios_deploy -i tree_file/path", `goal.Run("ios_deploy", "-i", "tree_file/path")`}, + {"ios-deploy -i", `goal.Run("ios-deploy", "-i")`}, + {"ios-deploy -i tree-file", `goal.Run("ios-deploy", "-i", "tree-file")`}, + {"ios-deploy -i tree-file/path/here", `goal.Run("ios-deploy", "-i", "tree-file/path/here")`}, + {"cd ..", `goal.Run("cd", "..")`}, + {"cd ../another/dir/to/go_to", `goal.Run("cd", "../another/dir/to/go_to")`}, + {"cd ../an-other/dir/", `goal.Run("cd", "../an-other/dir/")`}, + {"curl https://google.com/search?q=hello%20world#body", `goal.Run("curl", "https://google.com/search?q=hello%20world#body")`}, + {"func splitLines(str string) []string {", `splitLines := func(str string)[]string {`}, + {"type Result struct {", `type Result struct {`}, + {"var Jobs *table.Table", `var Jobs *table.Table`}, + {"type Result struct { JobID string", `type Result struct { JobID string`}, + {"type Result struct { JobID string `width:\"60\"`", "type Result struct { JobID string `width:\"60\"`"}, + {"func RunInExamples(fun func()) {", "RunInExamples := func(fun func()) {"}, + {"ctr++", "ctr++"}, + {"ctr--", "ctr--"}, + {"stru.ctr++", "stru.ctr++"}, + {"stru.ctr--", "stru.ctr--"}, + {"meta += ln", "meta += ln"}, + {"var data map[string]any", "var data map[string]any"}, + // non-math-mode tensor indexing: + {"x = a[1,f(2,3)]", `x = a.Value(int(1), int(f(2, 3)))`}, + {"x = a[1]", `x = a[1]`}, + {"x = a[f(2,3)]", `x = a[f(2, 3)]`}, + {"a[1,2] = 55", `a.Set(55, int(1), int(2))`}, + {"a[1,2] = 55 // and that is good", `a.Set(55, int(1), int(2)) // and that is good`}, + {"a[1,2] += f(2,55)", `a.SetAdd(f(2, 55), int(1), int(2))`}, + {"a[1,2] *= f(2,55)", `a.SetMul(f(2, 55), int(1), int(2))`}, + {"Data[idx, Integ] = integ", `Data.Set(integ, int(idx), int(Integ))`}, + {"Data[Idxs[idx, 25], Integ] = integ", `Data.Set(integ, int(Idxs.Value(int(idx), int(25))), int(Integ))`}, + {"Layers[NeuronIxs[NrnLayIndex, ni]].GatherSpikes(&Ctx[0], ni, di)", `Layers[NeuronIxs.Value(int(NrnLayIndex), int(ni))].GatherSpikes( & Ctx[0], ni, di)`}, + } + + st := NewState() + for _, test := range tests { + o := st.TranspileLine(test.i) + assert.Equal(t, test.e, o) + } +} + +// tests command generation +func TestCommand(t *testing.T) { + // logx.UserLevel = slog.LevelDebug + tests := []exIn{ + { + `command list { + ls -la args... + }`, + `goal.AddCommand("list", func(args ...string) { +goal.Run("ls", "-la", "args...") +})`}, + { + ` ss.GUI.AddToolbarItem(p, egui.ToolbarItem{ + Label: "Reset RunLog", + }) +`, + `ss.GUI.AddToolbarItem(p, egui.ToolbarItem { +Label: "Reset RunLog", +} ) +`}, + } + + for _, test := range tests { + st := NewState() + st.TranspileCode(test.i) + o := st.Code() + assert.Equal(t, test.e, o) + } +} + +// Use this for testing the current thing working on. +func TestCur(t *testing.T) { + // logx.UserLevel = slog.LevelDebug + tests := []exIn{ + {`Label: "Reset RunLog",`, `Label: "Reset RunLog",`}, + } + st := NewState() + st.MathRecord = false + for _, test := range tests { + o := st.TranspileLine(test.i) + assert.Equal(t, test.e, o) + } +} + +func TestMath(t *testing.T) { + // logx.UserLevel = slog.LevelDebug + tests := []exIn{ + {"# x := 1", `x := tensor.Tensor(tensor.NewIntScalar(1))`}, + {"# x := a + 1", `x := tensor.Tensor(tmath.Add(a, tensor.NewIntScalar(1)))`}, + {"# x = x * 4", `x = tmath.Mul(x, tensor.NewIntScalar(4))`}, + {"# a = x + y", `a = tmath.Add(x, y)`}, + {"# a := x ** 2", `a := tensor.Tensor(tmath.Pow(x, tensor.NewIntScalar(2)))`}, + {"# a = -x", `a = tmath.Negate(x)`}, + {"# x @ a", `matrix.Mul(x, a)`}, + {"# x += 1", `tmath.AddAssign(x, tensor.NewIntScalar(1))`}, + {"# a := [1,2,3,4]", `a := tensor.Tensor(tensor.NewIntFromValues([]int { 1, 2, 3, 4 } ...))`}, + {"# a := [1.,2,3,4]", `a := tensor.Tensor(tensor.NewFloat64FromValues([]float64 { 1., 2, 3, 4 } ...))`}, + {"# a := [[1,2], [3,4]]", `a := tensor.Tensor(tensor.Reshape(tensor.NewIntFromValues([]int { 1, 2, 3, 4 } ...), 2, 2))`}, + {"# a.ndim", `tensor.NewIntScalar(a.NumDims())`}, + {"# ndim(a)", `tensor.NewIntScalar(a.NumDims())`}, + {"# x.shape", `tensor.NewIntFromValues(x.Shape().Sizes ...)`}, + {"# x.T", `tensor.Transpose(x)`}, + {"# zeros(3, 4)", `tensor.NewFloat64(3, 4)`}, + {"# full(5.5, 3, 4)", `tensor.NewFloat64Full(5.5, 3, 4)`}, + {"# zeros(sh)", `tensor.NewFloat64(tensor.AsIntSlice(sh) ...)`}, + {"# arange(36)", `tensor.NewIntRange(36)`}, + {"# arange(36, 0, -1)", `tensor.NewIntRange(36, 0, - 1)`}, + {"# linspace(0, 5, 6, true)", `tensor.NewFloat64SpacedLinear(tensor.NewIntScalar(0), tensor.NewIntScalar(5), 6, true)`}, + {"# reshape(x, 6, 6)", `tensor.Reshape(x, 6, 6)`}, + {"# reshape(x, [6, 6])", `tensor.Reshape(x, 6, 6)`}, + {"# reshape(x, sh)", `tensor.Reshape(x, tensor.AsIntSlice(sh) ...)`}, + {"# reshape(arange(36), 6, 6)", `tensor.Reshape(tensor.NewIntRange(36), 6, 6)`}, + {"# a.reshape(6, 6)", `tensor.Reshape(a, 6, 6)`}, + {"# a[1, 2]", `tensor.Reslice(a, 1, 2)`}, + {"# a[:, 2]", `tensor.Reslice(a, tensor.FullAxis, 2)`}, + {"# a[1:3:1, 2]", `tensor.Reslice(a, tensor.Slice { Start:1, Stop:3, Step:1 } , 2)`}, + {"# a[::-1, 2]", `tensor.Reslice(a, tensor.Slice { Step: - 1 } , 2)`}, + {"# a[:3, 2]", `tensor.Reslice(a, tensor.Slice { Stop:3 } , 2)`}, + {"# a[2:, 2]", `tensor.Reslice(a, tensor.Slice { Start:2 } , 2)`}, + {"# a[2:, 2, newaxis]", `tensor.Reslice(a, tensor.Slice { Start:2 } , 2, tensor.NewAxis)`}, + {"# a[..., 2:]", `tensor.Reslice(a, tensor.Ellipsis, tensor.Slice { Start:2 } )`}, + {"# a[:, 2] = b", `tmath.Assign(tensor.Reslice(a, tensor.FullAxis, 2), b)`}, + {"# a[:, 2] += b", `tmath.AddAssign(tensor.Reslice(a, tensor.FullAxis, 2), b)`}, + {"# cos(a)", `tmath.Cos(a)`}, + {"# stats.Mean(a)", `stats.Mean(a)`}, + {"# (stats.Mean(a))", `(stats.Mean(a))`}, + {"# stats.Mean(reshape(a,36))", `stats.Mean(tensor.Reshape(a, 36))`}, + {"# z = a[1:5,1:5] - stats.Mean(ra)", `z = tmath.Sub(tensor.Reslice(a, tensor.Slice { Start:1, Stop:5 } , tensor.Slice { Start:1, Stop:5 } ), stats.Mean(ra))`}, + {"# metric.Matrix(metric.Cosine, a)", `metric.Matrix(metric.Cosine, a)`}, + {"# a > 5", `tmath.Greater(a, tensor.NewIntScalar(5))`}, + {"# !a", `tmath.Not(a)`}, + {"# a[a > 5]", `tensor.Mask(a, tmath.Greater(a, tensor.NewIntScalar(5)))`}, + {"# a[a > 5].flatten()", `tensor.Flatten(tensor.Mask(a, tmath.Greater(a, tensor.NewIntScalar(5))))`}, + {"# a[:3, 2].copy()", `tensor.Clone(tensor.Reslice(a, tensor.Slice { Stop:3 } , 2))`}, + {"# a[:3, 2].reshape(4,2)", `tensor.Reshape(tensor.Reslice(a, tensor.Slice { Stop:3 } , 2), 4, 2)`}, + {"# a > 5 || a < 1", `tmath.Or(tmath.Greater(a, tensor.NewIntScalar(5)), tmath.Less(a, tensor.NewIntScalar(1)))`}, + {"# fmt.Println(a)", `fmt.Println(a)`}, + {"# }", `}`}, + {"# if a[1,2] == 2 {", `if tmath.Equal(tensor.Reslice(a, 1, 2), tensor.NewIntScalar(2)).Bool1D(0) {`}, + {"# for i := 0; i < 3; i++ {", `for i := tensor.Tensor(tensor.NewIntScalar(0)); tmath.Less(i, tensor.NewIntScalar(3)).Bool1D(0); tmath.Inc(i) {`}, + {"# for i, v := range a {", `for i := 0; i < a.Len(); i++ { v := a .Float1D(i)`}, + {`# x := get("item")`, `x := tensor.Tensor(tensorfs.Get("item"))`}, + {`# set("item", x)`, `tensorfs.Set("item", x)`}, + {`# set("item", 5)`, `tensorfs.Set("item", tensor.NewIntScalar(5))`}, + {`fmt.Println(#zeros(3,4)#)`, `fmt.Println(tensor.NewFloat64(3, 4))`}, + } + + st := NewState() + st.MathRecord = false + for _, test := range tests { + o := st.TranspileLine(test.i) + assert.Equal(t, test.e, o) + } +} diff --git a/goal/transpile/types.go b/goal/transpile/types.go new file mode 100644 index 0000000000..471e961044 --- /dev/null +++ b/goal/transpile/types.go @@ -0,0 +1,72 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package transpile + +import ( + "go/ast" + "go/token" +) + +// inferKindExpr infers the basic Kind level type from given expression +func inferKindExpr(ex ast.Expr) token.Token { + if ex == nil { + return token.ILLEGAL + } + switch x := ex.(type) { + case *ast.BadExpr: + return token.ILLEGAL + + case *ast.Ident: + // todo: get type of object is not possible! + + case *ast.BinaryExpr: + ta := inferKindExpr(x.X) + tb := inferKindExpr(x.Y) + if ta == tb { + return ta + } + if ta != token.ILLEGAL { + return ta + } else { + return tb + } + + case *ast.BasicLit: + return x.Kind // key grounding + + case *ast.FuncLit: + + case *ast.ParenExpr: + return inferKindExpr(x.X) + + case *ast.SelectorExpr: + + case *ast.TypeAssertExpr: + + case *ast.IndexListExpr: + if x.X == nil { // array literal + return inferKindExprList(x.Indices) + } else { + return inferKindExpr(x.X) + } + + case *ast.SliceExpr: + + case *ast.CallExpr: + + } + return token.ILLEGAL +} + +func inferKindExprList(ex []ast.Expr) token.Token { + n := len(ex) + for i := range n { + t := inferKindExpr(ex[i]) + if t != token.ILLEGAL { + return t + } + } + return token.ILLEGAL +} diff --git a/gpu/README.md b/gpu/README.md index 53000fbb83..fae96005a9 100644 --- a/gpu/README.md +++ b/gpu/README.md @@ -152,13 +152,15 @@ Here's how it works: * Each WebGPU `Pipeline` holds **1** compute `shader` program, which is equivalent to a `kernel` in CUDA. This is the basic unit of computation, accomplishing one parallel sweep of processing across some number of identical data structures. -* You must organize at the outset your `Vars` and `Values` in the `System` to hold the data structures your shaders operate on. In general, you want to have a single static set of Vars that cover everything you'll need, and different shaders can operate on different subsets of these. You want to minimize the amount of memory transfer. +* The `Vars` and `Values` in the `System` hold all the data structures your shaders operate on, and must be configured and data uploaded before running. In general, it is best to have a single static set of Vars that cover everything you'll need, and different shaders can operate on different subsets of these, minimizing the amount of memory transfer. -* Because the `Queue.Submit` call is by far the most expensive call in WebGPU, you want to minimize those. This means you want to combine as much of your computation into one big Command sequence, with calls to various different `Pipeline` shaders (which can all be put in one command buffer) that gets submitted *once*, rather than submitting separate commands for each shader. Ideally this also involves combining memory transfers to / from the GPU in the same command buffer as well. +* Because the `Queue.Submit` call is by far the most expensive call in WebGPU, it should be minimized. This means combining as much of your computation into one big Command sequence, with calls to various different `Pipeline` shaders (which can all be put in one command buffer) that gets submitted *once*, rather than submitting separate commands for each shader. Ideally this also involves combining memory transfers to / from the GPU in the same command buffer as well. -* There are no explicit sync mechanisms in WebGPU, but it is designed so that shader compute is automatically properly synced with prior and subsequent memory transfer commands, so it automatically does the right thing for most use cases. +* There are no explicit sync mechanisms on the command, CPU side WebGPU (they only exist in the WGSL shaders), but it is designed so that shader compute is automatically properly synced with prior and subsequent memory transfer commands, so it automatically does the right thing for most use cases. -* Compute is particularly taxing on memory transfer in general, and as far as I can tell, the best strategy is to rely on the optimized `WriteBuffer` command to transfer from CPU to GPU, and then use a staging buffer to read data back from the GPU. E.g., see [this reddit post](https://www.reddit.com/r/wgpu/comments/13zqe1u/can_someone_please_explain_to_me_the_whole_buffer/). Critically, the write commands are queued and any staging buffers are managed internally, so it shouldn't be much slower than manually doing all the staging. For reading, we have to implement everything ourselves, and here it is critical to batch the `ReadSync` calls for all relevant values, so they all happen at once. Use ad-hoc `ValueGroup`s to organize these batched read operations efficiently for the different groups of values that need to be read back in the different compute stages. +* Compute is particularly taxing on memory transfer in general, and overall the best strategy is to rely on the optimized `WriteBuffer` command to transfer from CPU to GPU, and then use a staging buffer to read data back from the GPU. E.g., see [this reddit post](https://www.reddit.com/r/wgpu/comments/13zqe1u/can_someone_please_explain_to_me_the_whole_buffer/). Critically, the write commands are queued and any staging buffers are managed internally, so it shouldn't be much slower than manually doing all the staging. For reading, we have to implement everything ourselves, and here it is critical to batch the `ReadSync` calls for all relevant values, so they all happen at once. Use ad-hoc `ValueGroup`s to organize these batched read operations efficiently for the different groups of values that need to be read back in the different compute stages. + +* For large numbers of items to compute, there is a strong constraint that only 65_536 (2^16) workgroups can be submitted, _per dimension_ at a time. For unstructured 1D indexing, we typically use `[64,1,1]` for the workgroup size (which must be hard-coded into the shader and coordinated with the Go side code), which gives 64 * 65_536 = 4_194_304 max items. For more than that number, more than 1 needs to be used for the second dimension. The NumWorkgroups* functions return appropriate sizes with a minimum remainder. See [examples/compute](examples/compute) for the logic needed to get the overall global index from the workgroup sizes. # Gamma Correction (sRGB vs Linear) and Headless / Offscreen Rendering diff --git a/gpu/compute.go b/gpu/compute.go index da7ee8213b..5bc4e18dc0 100644 --- a/gpu/compute.go +++ b/gpu/compute.go @@ -7,6 +7,8 @@ package gpu import ( "fmt" "math" + "runtime" + "sync" "cogentcore.org/core/base/errors" "github.com/cogentcore/webgpu/wgpu" @@ -25,9 +27,13 @@ type ComputeSystem struct { // Access through the System.Vars() method. vars Vars - // ComputePipelines by name + // ComputePipelines by name. ComputePipelines map[string]*ComputePipeline + // ComputeEncoder is the compute specific command encoder for the + // current [BeginComputePass], and released in [EndComputePass]. + ComputeEncoder *wgpu.ComputePassEncoder + // CommandEncoder is the command encoder created in // [BeginComputePass], and released in [EndComputePass]. CommandEncoder *wgpu.CommandEncoder @@ -116,22 +122,32 @@ func (sy *ComputeSystem) NewCommandEncoder() (*wgpu.CommandEncoder, error) { // to start the compute pass, returning the encoder object // to which further compute commands should be added. // Call [EndComputePass] when done. +// If an existing [ComputeSystem.ComputeEncoder] is already set from +// a prior BeginComputePass call, then that is returned, so this +// is safe and efficient to call for every compute shader dispatch, +// where the first call will create and the rest add to the ongoing job. func (sy *ComputeSystem) BeginComputePass() (*wgpu.ComputePassEncoder, error) { + if sy.ComputeEncoder != nil { + return sy.ComputeEncoder, nil + } cmd, err := sy.NewCommandEncoder() if errors.Log(err) != nil { return nil, err } sy.CommandEncoder = cmd - return cmd.BeginComputePass(nil), nil // note: optional name in the descriptor + sy.ComputeEncoder = cmd.BeginComputePass(nil) // optional name in the encoder + return sy.ComputeEncoder, nil } // EndComputePass submits the current compute commands to the device -// Queue and releases the [CommandEncoder] and the given -// ComputePassEncoder. You must call ce.End prior to calling this. +// Queue and releases the [ComputeSystem.CommandEncoder] and +// [ComputeSystem.ComputeEncoder]. You must call ce.End prior to calling this. // Can insert other commands after ce.End, e.g., to copy data back // from the GPU, prior to calling EndComputePass. -func (sy *ComputeSystem) EndComputePass(ce *wgpu.ComputePassEncoder) error { +func (sy *ComputeSystem) EndComputePass() error { + ce := sy.ComputeEncoder cmd := sy.CommandEncoder + sy.ComputeEncoder = nil sy.CommandEncoder = nil ce.Release() // must happen before Finish cmdBuffer, err := cmd.Finish(nil) @@ -144,11 +160,47 @@ func (sy *ComputeSystem) EndComputePass(ce *wgpu.ComputePassEncoder) error { return nil } -// Warps returns the number of warps (work goups of compute threads) -// that is sufficient to compute n elements, given specified number -// of threads per this dimension. -// It just rounds up to nearest even multiple of n divided by threads: -// Ceil(n / threads) -func Warps(n, threads int) int { - return int(math.Ceil(float64(n) / float64(threads))) +// NumThreads is the number of threads to use for parallel threading, +// in the [VectorizeFunc] that is used for CPU versions of GPU functions. +// The default of 0 causes the [runtime.GOMAXPROCS] to be used. +var NumThreads = 0 + +// DefaultNumThreads returns the default number of threads to use: +// NumThreads if non-zero, otherwise [runtime.GOMAXPROCS]. +func DefaultNumThreads() int { + if NumThreads > 0 { + return NumThreads + } + return runtime.GOMAXPROCS(0) +} + +// VectorizeFunc runs given GPU kernel function taking a uint32 index +// on the CPU, using given number of threads with goroutines, for n iterations. +// If threads is 0, then GOMAXPROCS is used. +func VectorizeFunc(threads, n int, fun func(idx uint32)) { + if threads == 0 { + threads = DefaultNumThreads() + } + if threads <= 1 { + for idx := range n { + fun(uint32(idx)) + } + return + } + nper := int(math.Ceil(float64(n) / float64(threads))) + wait := sync.WaitGroup{} + for start := 0; start < n; start += nper { + end := start + nper + if end > n { + end = n + } + wait.Add(1) + go func() { + for idx := start; idx < end; idx++ { + fun(uint32(idx)) + } + wait.Done() + }() + } + wait.Wait() } diff --git a/gpu/compute_test.go b/gpu/compute_test.go new file mode 100644 index 0000000000..d4d4cec960 --- /dev/null +++ b/gpu/compute_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gpu + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNumWorkgroups(t *testing.T) { + nx, ny := NumWorkgroups1D(4_194_304, 64) + assert.Equal(t, 65536, nx) + assert.Equal(t, 1, ny) + assert.Equal(t, 4_194_304, nx*ny*64) + + nx, ny = NumWorkgroups1D(4_194_304+64, 64) + assert.Equal(t, 32769, nx) + assert.Equal(t, 2, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304+64) + + nx, ny = NumWorkgroups1D(4_194_304+90, 64) + assert.Equal(t, 32769, nx) + assert.Equal(t, 2, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304+90) + + nx, ny = NumWorkgroups1D(4_194_304+129, 64) + assert.Equal(t, 32770, nx) + assert.Equal(t, 2, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304+129) + + nx, ny = NumWorkgroups1D(4_194_304-64, 64) + assert.Equal(t, 65535, nx) + assert.Equal(t, 1, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304-64) + + nx, ny = NumWorkgroups1D(4_194_304-90, 64) + assert.Equal(t, 65535, nx) + assert.Equal(t, 1, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304-90) + + nx, ny = NumWorkgroups1D(4_194_304*64, 64) + assert.Equal(t, 65536, nx) + assert.Equal(t, 64, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304*64) + + nx, ny = NumWorkgroups1D(4_194_304*64, 64) + assert.Equal(t, 65536, nx) + assert.Equal(t, 64, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304*64) + + nx, ny = NumWorkgroups2D(4_194_304*64, 4, 16) + assert.Equal(t, 65536, nx) + assert.Equal(t, 64, ny) + assert.GreaterOrEqual(t, nx*ny*64, 4_194_304*64) +} diff --git a/gpu/cpipeline.go b/gpu/cpipeline.go index 290919c9f6..2c3ef36eb2 100644 --- a/gpu/cpipeline.go +++ b/gpu/cpipeline.go @@ -6,6 +6,7 @@ package gpu import ( "io/fs" + "math" "path" "cogentcore.org/core/base/errors" @@ -43,6 +44,7 @@ func NewComputePipelineShaderFS(fsys fs.FS, fname string, sy *ComputeSystem) *Co sh := pl.AddShader(name) errors.Log(sh.OpenFileFS(fsys, fname)) pl.AddEntry(sh, ComputeShader, "main") + sy.ComputePipelines[pl.Name] = pl return pl } @@ -73,9 +75,45 @@ func (pl *ComputePipeline) Dispatch(ce *wgpu.ComputePassEncoder, nx, ny, nz int) // (X) dimension, for given number *elements* (threads) per warp (typically 64). // See [Dispatch] for full info. // This is just a convenience method for common 1D case that calls -// the Warps method for you. +// the NumWorkgroups1D function with threads for you. func (pl *ComputePipeline) Dispatch1D(ce *wgpu.ComputePassEncoder, n, threads int) error { - return pl.Dispatch(ce, Warps(n, threads), 1, 1) + nx, ny := NumWorkgroups1D(n, threads) + return pl.Dispatch(ce, nx, ny, 1) +} + +// NumWorkgroups1D() returns the number of work groups of compute threads +// that is sufficient to compute n total elements, given specified number +// of threads in the x dimension, subject to constraint that no more than +// 65536 work groups can be deployed per dimension. +func NumWorkgroups1D(n, threads int) (nx, ny int) { + mxn := 65536 + ny = 1 + nx = int(math.Ceil(float64(n) / float64(threads))) + if nx <= 65536 { + return + } + xsz := mxn * threads + ny = int(math.Ceil(float64(n) / float64(xsz))) + nx = int(math.Ceil(float64(n) / float64(ny*threads))) + return +} + +// NumWorkgroups2D() returns the number of work groups of compute threads +// that is sufficient to compute n total elements, given specified number +// of threads per x, y dimension, subject to constraint that no more than +// 65536 work groups can be deployed per dimension. +func NumWorkgroups2D(n, x, y int) (nx, ny int) { + mxn := 65536 + sz := x * y + ny = 1 + nx = int(math.Ceil(float64(n) / float64(sz))) + if nx <= 65536 { + return + } + xsz := mxn * x // size of full x chunk + ny = int(math.Ceil(float64(n) / float64(xsz*y))) + nx = int(math.Ceil(float64(n) / float64(x*ny*y))) + return } // BindAllGroups binds the Current Value for all variables across all diff --git a/gpu/device.go b/gpu/device.go index 2d8a94dfb7..ecc6901e02 100644 --- a/gpu/device.go +++ b/gpu/device.go @@ -5,6 +5,8 @@ package gpu import ( + "fmt" + "cogentcore.org/core/base/errors" "github.com/cogentcore/webgpu/wgpu" ) @@ -38,14 +40,25 @@ func NewDevice(gpu *GPU) (*Device, error) { func NewComputeDevice(gpu *GPU) (*Device, error) { // we only request max buffer sizes so compute can go as big as it needs to limits := wgpu.DefaultLimits() - const maxv = 0xFFFFFFFF + // Per https://github.com/cogentcore/core/issues/1362 -- this may cause issues on "downlevel" + // hardware, so we may need to detect that. OTOH it probably won't be useful for compute anyway, + // but we can just sort that out later + // note: on web / chromium / dawn, limited to 10: https://issues.chromium.org/issues/366151398?pli=1 + limits.MaxStorageBuffersPerShaderStage = gpu.Limits.Limits.MaxStorageBuffersPerShaderStage + // fmt.Println("MaxStorageBuffersPerShaderStage:", gpu.Limits.Limits.MaxStorageBuffersPerShaderStage) // note: these limits are being processed and allow the MaxBufferSize to be the // controlling factor -- if we don't set these, then the slrand example doesn't // work above a smaller limit. - limits.MaxUniformBufferBindingSize = min(gpu.Limits.Limits.MaxUniformBufferBindingSize, maxv) - limits.MaxStorageBufferBindingSize = min(gpu.Limits.Limits.MaxStorageBufferBindingSize, maxv) + limits.MaxUniformBufferBindingSize = uint64(MemSizeAlignDown(int(gpu.Limits.Limits.MaxUniformBufferBindingSize), int(gpu.Limits.Limits.MinUniformBufferOffsetAlignment))) + + limits.MaxStorageBufferBindingSize = uint64(MemSizeAlignDown(int(gpu.Limits.Limits.MaxStorageBufferBindingSize), int(gpu.Limits.Limits.MinStorageBufferOffsetAlignment))) // note: this limit is not working properly: - limits.MaxBufferSize = min(gpu.Limits.Limits.MaxBufferSize, maxv) + limits.MaxBufferSize = uint64(MemSizeAlignDown(int(gpu.Limits.Limits.MaxBufferSize), int(gpu.Limits.Limits.MinStorageBufferOffsetAlignment))) + // limits.MaxBindGroups = gpu.Limits.Limits.MaxBindGroups // note: no point in changing -- web constraint + + if Debug { + fmt.Printf("Requesting sizes: MaxStorageBufferBindingSize: %X MaxBufferSize: %X\n", limits.MaxStorageBufferBindingSize, limits.MaxBufferSize) + } desc := wgpu.DeviceDescriptor{ RequiredLimits: &wgpu.RequiredLimits{ Limits: limits, diff --git a/gpu/examples/compute/compute.go b/gpu/examples/compute/compute.go index 5737215b71..500e578a60 100644 --- a/gpu/examples/compute/compute.go +++ b/gpu/examples/compute/compute.go @@ -11,7 +11,9 @@ import ( "runtime" "unsafe" + "cogentcore.org/core/base/timer" "cogentcore.org/core/gpu" + // "cogentcore.org/core/system/driver/web/jsfs" ) //go:embed squares.wgsl @@ -30,7 +32,19 @@ type Data struct { } func main() { - gpu.Debug = true + // errors.Log1(jsfs.Config(js.Global().Get("fs"))) // needed for printing etc to work + // time.Sleep(1 * time.Second) + // b := core.NewBody() + // bt := core.NewButton(b).SetText("Run Compute") + // bt.OnClick(func(e events.Event) { + compute() + // }) + // b.RunMainWindow() + // select {} +} + +func compute() { + // gpu.SetDebug(true) gp := gpu.NewComputeGPU() fmt.Printf("Running on GPU: %s\n", gp.DeviceName) @@ -42,8 +56,11 @@ func main() { vars := sy.Vars() sgp := vars.AddGroup(gpu.Storage) - n := 20 // note: not necc to spec up-front, but easier if so + // n := 16_000_000 // near max capacity on Mac M* + n := 200_000 // should fit in any webgpu threads := 64 + nx, ny := gpu.NumWorkgroups1D(n, threads) + fmt.Printf("workgroup sizes: %d, %d storage mem bytes: %X\n", nx, ny, n*int(unsafe.Sizeof(Data{}))) dv := sgp.AddStruct("Data", int(unsafe.Sizeof(Data{})), n, gpu.ComputeShader) @@ -59,18 +76,28 @@ func main() { } gpu.SetValueFrom(dvl, sd) - sgp.CreateReadBuffers() - - ce, _ := sy.BeginComputePass() - pl.Dispatch1D(ce, n, threads) - ce.End() - dvl.GPUToRead(sy.CommandEncoder) - sy.EndComputePass(ce) + gpuTmr := timer.Time{} + cpyTmr := timer.Time{} + gpuTmr.Start() + nItr := 1 + + for range nItr { + ce, _ := sy.BeginComputePass() + pl.Dispatch1D(ce, n, threads) + ce.End() + dvl.GPUToRead(sy.CommandEncoder) + sy.EndComputePass() + + cpyTmr.Start() + dvl.ReadSync() + cpyTmr.Stop() + gpu.ReadToBytes(dvl, sd) + } - dvl.ReadSync() - gpu.ReadToBytes(dvl, sd) + gpuTmr.Stop() - for i := 0; i < n; i++ { + mx := min(n, 10) + for i := 0; i < mx; i++ { tc := sd[i].A + sd[i].B td := tc * tc dc := sd[i].C - tc @@ -78,6 +105,7 @@ func main() { fmt.Printf("%d\t A: %g\t B: %g\t C: %g\t trg: %g\t D: %g \t trg: %g \t difC: %g \t difD: %g\n", i, sd[i].A, sd[i].B, sd[i].C, tc, sd[i].D, td, dc, dd) } fmt.Printf("\n") + fmt.Println("total:", gpuTmr.Total, "copy:", cpyTmr.Total) sy.Release() gp.Release() diff --git a/gpu/examples/compute/squares.wgsl b/gpu/examples/compute/squares.wgsl index aa390b0494..69ba23ee96 100644 --- a/gpu/examples/compute/squares.wgsl +++ b/gpu/examples/compute/squares.wgsl @@ -16,11 +16,19 @@ fn compute(d: ptr) { } @compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - // compute(&In[idx.x]); - var d = In[idx.x]; +@workgroup_size(64,1,1) +fn main(@builtin(workgroup_id) wgid: vec3, @builtin(num_workgroups) nwg: vec3, @builtin(local_invocation_index) loci: u32) { + // note: wgid.x is the inner loop, then y, then z + let idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64; + // note: array access is clamped so it doesn't exceed bounds, but best to check here + // and skip anything beyond the max size of buffer. + var d = In[idx]; compute(&d); - In[idx.x] = d; + In[idx] = d; + // the following is for testing indexing: uncomment to see. + // In[idx].A = f32(loci); + // In[idx].B = f32(wgid.x); + // In[idx].C = f32(wgid.y); + // In[idx].D = f32(idx); } diff --git a/gpu/gosl/README.md b/gpu/gosl/README.md deleted file mode 100644 index 754697f848..0000000000 --- a/gpu/gosl/README.md +++ /dev/null @@ -1,126 +0,0 @@ -# gosl - -`gosl` implements _Go as a shader language_ for GPU compute shaders (using [WebGPU](https://www.w3.org/TR/webgpu/)), **enabling standard Go code to run on the GPU**. - -The relevant subsets of Go code are specifically marked using `//gosl:` comment directives, and this code must only use basic expressions and concrete types that will compile correctly in a shader (see [Restrictions](#restrictions) below). Method functions and pass-by-reference pointer arguments to `struct` types are supported and incur no additional compute cost due to inlining (see notes below for more detail). - -A large and complex biologically-based neural network simulation framework called [axon](https://github.com/emer/axon) has been implemented using `gosl`, allowing 1000's of lines of equations and data structures to run through standard Go on the CPU, and accelerated significantly on the GPU. This allows efficient debugging and unit testing of the code in Go, whereas debugging on the GPU is notoriously difficult. - -`gosl` converts Go code to WGSL which can then be loaded directly into a WebGPU compute shader. - -See [examples/basic](examples/basic) and [rand](examples/rand) for examples, using the [gpu](../../gpu) GPU compute shader system. It is also possible in principle to use gosl to generate shader files for any other WebGPU application, but this has not been tested. - -You must also install `goimports` which is used on the extracted subset of Go code, to get the imports right: -```bash -$ go install golang.org/x/tools/cmd/goimports@latest -``` - -To install the `gosl` command, do: -```bash -$ go install cogentcore.org/core/vgpu/gosl@latest -``` - -In your Go code, use these comment directives: -``` -//gosl:start - -< Go code to be translated > - -//gosl:end -``` - -to bracket code to be processed. The resulting converted code is copied into a `shaders` subdirectory created under the current directory where the `gosl` command is run, using the filenames specified in the comment directives. Each such filename should correspond to a complete shader program (i.e., a "kernel"), or a file that can be included into other shader programs. Code is appended to the target file names in the order of the source .go files on the command line, so multiple .go files can be combined into one resulting WGSL file. - -WGSL specific code, e.g., for the `main` compute function or to specify `#include` files, can be included either by specifying files with a `.wgsl` extension as arguments to the `gosl` command, or by using a `//gosl:wgsl` comment directive as follows: -``` -//gosl:wgsl - -// - -//gosl:end -``` -where the WGSL shader code is commented out in the .go file -- it will be copied into the target filename and uncommented. The WGSL code can be surrounded by `/*` `*/` comment blocks (each on a separate line) for multi-line code (though using a separate `.wgsl` file is generally preferable in this case). - -For `.wgsl` files, their filename is used to determine the `shaders` destination file name, and they are automatically appended to the end of the corresponding `.wgsl` file generated from the `Go` files -- this is where the `main` function and associated global variables should be specified. - -**IMPORTANT:** all `.go` and `.wgsl` files are removed from the `shaders` directory prior to processing to ensure everything there is current -- always specify a different source location for any custom `.wgsl` files that are included. - -# Usage - -``` -gosl [flags] [path ...] -``` - -The flags are: -``` - -debug - enable debugging messages while running - -exclude string - comma-separated list of names of functions to exclude from exporting to HLSL (default "Update,Defaults") - -keep - keep temporary converted versions of the source files, for debugging - -out string - output directory for shader code, relative to where gosl is invoked -- must not be an empty string (default "shaders") -``` - -`gosl` path args can include filenames, directory names, or Go package paths (e.g., `cogentcore.org/core/math32/fastexp.go` loads just that file from the given package) -- files without any `//gosl:` comment directives will be skipped up front before any expensive processing, so it is not a problem to specify entire directories where only some files are relevant. Also, you can specify a particular file from a directory, then the entire directory, to ensure that a particular file from that directory appears first -- otherwise alphabetical order is used. `gosl` ensures that only one copy of each file is included. - -Any `struct` types encountered will be checked for 16-byte alignment of sub-types and overall sizes as an even multiple of 16 bytes (4 `float32` or `int32` values), which is the alignment used in WGSL and glsl shader languages, and the underlying GPU hardware presumably. Look for error messages on the output from the gosl run. This ensures that direct byte-wise copies of data between CPU and GPU will be successful. The fact that `gosl` operates directly on the original CPU-side Go code uniquely enables it to perform these alignment checks, which are otherwise a major source of difficult-to-diagnose bugs. - -# Restrictions - -In general shader code should be simple mathematical expressions and data types, with minimal control logic via `if`, `for` statements, and only using the subset of Go that is consistent with C. Here are specific restrictions: - -* Can only use `float32`, `[u]int32` for basic types (`int` is converted to `int32` automatically), and `struct` types composed of these same types -- no other Go types (i.e., `map`, slices, `string`, etc) are compatible. There are strict alignment restrictions on 16 byte (e.g., 4 `float32`'s) intervals that are enforced via the `alignsl` sub-package. - -* WGSL does _not_ support 64 bit float or int. - -* Use `slbool.Bool` instead of `bool` -- it defines a Go-friendly interface based on a `int32` basic type. - -* Alignment and padding of `struct` fields is key -- this is automatically checked by `gosl`. - -* WGSL does not support enum types, but standard go `const` declarations will be converted. Use an `int32` or `uint32` data type. It will automatically deal with the simple incrementing `iota` values, but not more complex cases. Also, for bitflags, define explicitly, not using `bitflags` package, and use `0x01`, `0x02`, `0x04` etc instead of `1<<2` -- in theory the latter should be ok but in practice it complains. - -* Cannot use multiple return values, or multiple assignment of variables in a single `=` expression. - -* *Can* use multiple variable names with the same type (e.g., `min, max float32`) -- this will be properly converted to the more redundant C form with the type repeated. - -* `switch` `case` statements are _purely_ self-contained -- no `fallthrough` allowed! does support multiple items per `case` however. - -* TODO: WGSL does not do multi-pass compiling, so all dependent types must be specified *before* being used in other ones, and this also precludes referencing the *current* type within itself. todo: can you just use a forward declaration? - -* WGSL does specify that new variables are initialized to 0, like Go, but also somehow discourages that use-case. It is safer to initialize directly: -```Go - val := float32(0) // guaranteed 0 value - var val float32 // ok but generally avoid -``` - -## Other language features - -* [tour-of-wgsl](https://google.github.io/tour-of-wgsl/types/pointers/passing_pointers/) is a good reference to explain things more directly than the spec. - -* `ptr` provides a pointer arg -* `private` scope = within the shader code "module", i.e., one thread. -* `function` = within the function, not outside it. -* `workgroup` = shared across workgroup -- coudl be powerful (but slow!) -- need to learn more. - -## Random numbers: slrand - -See [slrand](https://github.com/emer/gosl/v2/tree/main/slrand) for a shader-optimized random number generation package, which is supported by `gosl` -- it will convert `slrand` calls into appropriate WGSL named function calls. `gosl` will also copy the `slrand.wgsl` file, which contains the full source code for the RNG, into the destination `shaders` directory, so it can be included with a simple local path: - -```Go -//gosl:wgsl mycode -// #include "slrand.wgsl" -//gosl:end mycode -``` - -# Performance - -With sufficiently large N, and ignoring the data copying setup time, around ~80x speedup is typical on a Macbook Pro with M1 processor. The `rand` example produces a 175x speedup! - -# Implementation / Design Notes - -# Links - -Key docs for WGSL as compute shaders: - diff --git a/gpu/gosl/examples/basic/compute.go b/gpu/gosl/examples/basic/compute.go deleted file mode 100644 index 27b43a7c27..0000000000 --- a/gpu/gosl/examples/basic/compute.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2022, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import "cogentcore.org/core/math32" - -//gosl:wgsl basic -// #include "fastexp.wgsl" -//gosl:end basic - -//gosl:start basic - -// DataStruct has the test data -type DataStruct struct { - - // raw value - Raw float32 - - // integrated value - Integ float32 - - // exp of integ - Exp float32 - - // must pad to multiple of 4 floats for arrays - pad float32 -} - -// ParamStruct has the test params -type ParamStruct struct { - - // rate constant in msec - Tau float32 - - // 1/Tau - Dt float32 - - pad float32 - pad1 float32 -} - -// IntegFromRaw computes integrated value from current raw value -func (ps *ParamStruct) IntegFromRaw(ds *DataStruct) { - ds.Integ += ps.Dt * (ds.Raw - ds.Integ) - ds.Exp = math32.FastExp(-ds.Integ) -} - -//gosl:end basic - -// note: only core compute code needs to be in shader -- all init is done CPU-side - -func (ps *ParamStruct) Defaults() { - ps.Tau = 5 - ps.Update() -} - -func (ps *ParamStruct) Update() { - ps.Dt = 1.0 / ps.Tau -} - -//gosl:wgsl basic -/* -@group(0) @binding(0) -var Params: array; - -@group(0) @binding(1) -var Data: array; - -@compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - var pars = Params[0]; - var data = Data[idx.x]; - ParamStruct_IntegFromRaw(&pars, &data); - Data[idx.x] = data; -} -*/ -//gosl:end basic diff --git a/gpu/gosl/examples/basic/main.go b/gpu/gosl/examples/basic/main.go deleted file mode 100644 index 87a6be2469..0000000000 --- a/gpu/gosl/examples/basic/main.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) 2022, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// This example just does some basic calculations on data structures and -// reports the time difference between the CPU and GPU. -package main - -import ( - "embed" - "fmt" - "math/rand" - "runtime" - "unsafe" - - "cogentcore.org/core/base/timer" - "cogentcore.org/core/gpu" -) - -//go:generate ../../gosl cogentcore.org/core/math32/fastexp.go compute.go - -//go:embed shaders/basic.wgsl shaders/fastexp.wgsl -var shaders embed.FS - -func init() { - // must lock main thread for gpu! - runtime.LockOSThread() -} - -func main() { - gpu.Debug = true - gp := gpu.NewComputeGPU() - fmt.Printf("Running on GPU: %s\n", gp.DeviceName) - - // gp.PropertiesString(true) // print - - sy := gpu.NewComputeSystem(gp, "compute") - pl := gpu.NewComputePipelineShaderFS(shaders, "shaders/basic.wgsl", sy) - vars := sy.Vars() - sgp := vars.AddGroup(gpu.Storage) - - n := 2000000 // note: not necc to spec up-front, but easier if so - threads := 64 - - pv := sgp.AddStruct("Params", int(unsafe.Sizeof(ParamStruct{})), 1, gpu.ComputeShader) - dv := sgp.AddStruct("Data", int(unsafe.Sizeof(DataStruct{})), n, gpu.ComputeShader) - - sgp.SetNValues(1) - sy.Config() - - pvl := pv.Values.Values[0] - dvl := dv.Values.Values[0] - - pars := make([]ParamStruct, 1) - pars[0].Defaults() - - cd := make([]DataStruct, n) - for i := range cd { - cd[i].Raw = rand.Float32() - } - - sd := make([]DataStruct, n) - for i := range sd { - sd[i].Raw = cd[i].Raw - } - - cpuTmr := timer.Time{} - cpuTmr.Start() - for i := range cd { - pars[0].IntegFromRaw(&cd[i]) - } - cpuTmr.Stop() - - gpuFullTmr := timer.Time{} - gpuFullTmr.Start() - - gpu.SetValueFrom(pvl, pars) - gpu.SetValueFrom(dvl, sd) - - sgp.CreateReadBuffers() - - gpuTmr := timer.Time{} - gpuTmr.Start() - - ce, _ := sy.BeginComputePass() - pl.Dispatch1D(ce, n, threads) - ce.End() - dvl.GPUToRead(sy.CommandEncoder) - sy.EndComputePass(ce) - - gpuTmr.Stop() - - dvl.ReadSync() - gpu.ReadToBytes(dvl, sd) - - gpuFullTmr.Stop() - - mx := min(n, 5) - for i := 0; i < mx; i++ { - d := cd[i].Exp - sd[i].Exp - fmt.Printf("%d\t Raw: %g\t Integ: %g\t Exp: %6.4g\tTrg: %6.4g\tDiff: %g\n", i, sd[i].Raw, sd[i].Integ, sd[i].Exp, cd[i].Exp, d) - } - fmt.Printf("\n") - - cpu := cpuTmr.Total - gpu := gpuTmr.Total - gpuFull := gpuFullTmr.Total - fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFull, float64(cpu)/float64(gpu)) - - sy.Release() - gp.Release() -} diff --git a/gpu/gosl/examples/basic/shaders/basic.wgsl b/gpu/gosl/examples/basic/shaders/basic.wgsl deleted file mode 100644 index 929012d598..0000000000 --- a/gpu/gosl/examples/basic/shaders/basic.wgsl +++ /dev/null @@ -1,52 +0,0 @@ - -#include "fastexp.wgsl" - -// DataStruct has the test data -struct DataStruct { - - // raw value - Raw: f32, - - // integrated value - Integ: f32, - - // exp of integ - Exp: f32, - - // must pad to multiple of 4 floats for arrays - pad: f32, -} - -// ParamStruct has the test params -struct ParamStruct { - - // rate constant in msec - Tau: f32, - - // 1/Tau - Dt: f32, - - pad: f32, - pad1: f32, -} - -// IntegFromRaw computes integrated value from current raw value -fn ParamStruct_IntegFromRaw(ps: ptr, ds: ptr) { - (*ds).Integ += (*ps).Dt * ((*ds).Raw - (*ds).Integ); - (*ds).Exp = FastExp(-(*ds).Integ); -} - -@group(0) @binding(0) -var Params: array; - -@group(0) @binding(1) -var Data: array; - -@compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - var pars = Params[0]; - var data = Data[idx.x]; - ParamStruct_IntegFromRaw(&pars, &data); - Data[idx.x] = data; -} diff --git a/gpu/gosl/examples/basic/shaders/fastexp.wgsl b/gpu/gosl/examples/basic/shaders/fastexp.wgsl deleted file mode 100644 index c5f278921b..0000000000 --- a/gpu/gosl/examples/basic/shaders/fastexp.wgsl +++ /dev/null @@ -1,14 +0,0 @@ - -// FastExp is a quartic spline approximation to the Exp function, by N.N. Schraudolph -// It does not have any of the sanity checking of a standard method -- returns -// nonsense when arg is out of range. Runs in 2.23ns vs. 6.3ns for 64bit which is faster -// than exp actually. -fn FastExp(x: f32) -> f32 { - if (x <= -88.02969) { // this doesn't add anything and -exp is main use-case anyway - return f32(0.0); - } - var i = i32(12102203*x) + i32(127)*(i32(1)<<23); - var m = i >> 7 & 0xFFFF; // copy mantissa - i += (((((((((((3537 * m) >> 16) + 13668) * m) >> 18) + 15817) * m) >> 14) - 80470) * m) >> 11); - return bitcast(u32(i)); -} diff --git a/gpu/gosl/examples/rand/rand.wgsl b/gpu/gosl/examples/rand/rand.wgsl deleted file mode 100644 index 780ae9ef26..0000000000 --- a/gpu/gosl/examples/rand/rand.wgsl +++ /dev/null @@ -1,16 +0,0 @@ - -@group(0) @binding(0) -var Counter: array; - -@group(0) @binding(1) -var Data: array; - -@compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - var ctr = Counter[0]; - var data = Data[idx.x]; - Rnds_RndGen(&data, ctr, idx.x); - Data[idx.x] = data; -} - diff --git a/gpu/gosl/examples/rand/shaders/rand.wgsl b/gpu/gosl/examples/rand/shaders/rand.wgsl deleted file mode 100644 index baf5d96be3..0000000000 --- a/gpu/gosl/examples/rand/shaders/rand.wgsl +++ /dev/null @@ -1,371 +0,0 @@ - -// #include "slrand.wgsl" -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Original file is in Go package: github.com/cogentcore/core/gpu/gosl/slrand -// See README.md there for documentation. - -// These random number generation (RNG) functions are optimized for -// use on the GPU, with equivalent Go versions available in slrand.go. -// This is using the Philox2x32 counter-based RNG. - -// #include "sltype.wgsl" -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Original file is in Go package: github.com/cogentcore/core/gpu/gosl/sltype -// See README.md there for documentation. - -// This file emulates uint64 (u64) using 2 uint32 integers. -// and defines conversions between uint and float. - -// define a u64 type as an alias. -// if / when u64 actually happens, will make it easier to update. -alias su64 = vec2; - -// Uint32Mul64 multiplies two uint32 numbers into a uint64 (using vec2). -fn Uint32Mul64(a: u32, b: u32) -> su64 { - let LOMASK = (((u32(1))<<16)-1); - var r: su64; - r.x = a * b; /* full low multiply */ - let ahi = a >> 16; - let alo = a & LOMASK; - let bhi = b >> 16; - let blo = b & LOMASK; - - let ahbl = ahi * blo; - let albh = alo * bhi; - - let ahbl_albh = ((ahbl&LOMASK) + (albh&LOMASK)); - var hit = ahi*bhi + (ahbl>>16) + (albh>>16); - hit += ahbl_albh >> 16; /* carry from the sum of lo(ahbl) + lo(albh) ) */ - /* carry from the sum with alo*blo */ - if ((r.x >> u32(16)) < (ahbl_albh&LOMASK)) { - hit += u32(1); - } - r.y = hit; - return r; -} - -/* -// Uint32Mul64 multiplies two uint32 numbers into a uint64 (using su64). -fn Uint32Mul64(a: u32, b: u32) -> su64 { - return su64(a) * su64(b); -} -*/ - - -// Uint64Add32 adds given uint32 number to given uint64 (using vec2). -fn Uint64Add32(a: su64, b: u32) -> su64 { - if (b == 0) { - return a; - } - var s = a; - if (s.x > u32(0xffffffff) - b) { - s.y++; - s.x = (b - 1) - (u32(0xffffffff) - s.x); - } else { - s.x += b; - } - return s; -} - -// Uint64Incr returns increment of the given uint64 (using vec2). -fn Uint64Incr(a: su64) -> su64 { - var s = a; - if(s.x == 0xffffffff) { - s.y++; - s.x = u32(0); - } else { - s.x++; - } - return s; -} - -// Uint32ToFloat32 converts a uint32 integer into a float32 -// in the (0,1) interval (i.e., exclusive of 1). -// This differs from the Go standard by excluding 0, which is handy for passing -// directly to Log function, and from the reference Philox code by excluding 1 -// which is in the Go standard and most other standard RNGs. -fn Uint32ToFloat32(val: u32) -> f32 { - let factor = f32(1.0) / (f32(u32(0xffffffff)) + f32(1.0)); - let halffactor = f32(0.5) * factor; - var f = f32(val) * factor + halffactor; - if (f == 1.0) { // exclude 1 - return bitcast(0x3F7FFFFF); - } - return f; -} - -// note: there is no overloading of user-defined functions -// https://github.com/gpuweb/gpuweb/issues/876 - -// Uint32ToFloat32Vec2 converts two uint 32 bit integers -// into two corresponding 32 bit f32 values -// in the (0,1) interval (i.e., exclusive of 1). -fn Uint32ToFloat32Vec2(val: vec2) -> vec2 { - var r: vec2; - r.x = Uint32ToFloat32(val.x); - r.y = Uint32ToFloat32(val.y); - return r; -} - -// Uint32ToFloat32Range11 converts a uint32 integer into a float32 -// in the [-1..1] interval (inclusive of -1 and 1, never identically == 0). -fn Uint32ToFloat32Range11(val: u32) -> f32 { - let factor = f32(1.0) / (f32(i32(0x7fffffff)) + f32(1.0)); - let halffactor = f32(0.5) * factor; - return (f32(val) * factor + halffactor); -} - -// Uint32ToFloat32Range11Vec2 converts two uint32 integers into two float32 -// in the [-1,1] interval (inclusive of -1 and 1, never identically == 0). -fn Uint32ToFloat32Range11Vec2(val: vec2) -> vec2 { - var r: vec2; - r.x = Uint32ToFloat32Range11(val.x); - r.y = Uint32ToFloat32Range11(val.y); - return r; -} - - - - -// Philox2x32round does one round of updating of the counter. -fn Philox2x32round(counter: su64, key: u32) -> su64 { - let mul = Uint32Mul64(u32(0xD256D193), counter.x); - var ctr: su64; - ctr.x = mul.y ^ key ^ counter.y; - ctr.y = mul.x; - return ctr; -} - -// Philox2x32bumpkey does one round of updating of the key -fn Philox2x32bumpkey(key: u32) -> u32 { - return key + u32(0x9E3779B9); -} - -// Philox2x32 implements the stateless counter-based RNG algorithm -// returning a random number as two uint32 values, given a -// counter and key input that determine the result. -// The input counter is not modified. -fn Philox2x32(counter: su64, key: u32) -> vec2 { - // this is an unrolled loop of 10 updates based on initial counter and key, - // which produces the random deviation deterministically based on these inputs. - var ctr = Philox2x32round(counter, key); // 1 - var ky = Philox2x32bumpkey(key); - ctr = Philox2x32round(ctr, ky); // 2 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 3 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 4 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 5 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 6 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 7 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 8 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 9 - ky = Philox2x32bumpkey(ky); - - return Philox2x32round(ctr, ky); // 10 -} - -//////////////////////////////////////////////////////////// -// Methods below provide a standard interface with more -// readable names, mapping onto the Go rand methods. -// -// They assume a global shared counter, which is then -// incremented by a function index, defined for each function -// consuming random numbers that _could_ be called within a parallel -// processing loop. At the end of the loop, the global counter should -// be incremented by the total possible number of such functions. -// This results in fully resproducible results, invariant to -// specific processing order, and invariant to whether any one function -// actually calls the random number generator. - -// RandUint32Vec2 returns two uniformly distributed 32 unsigned integers, -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandUint32Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - return Philox2x32(Uint64Add32(counter, funcIndex), key); -} - -// RandUint32 returns a uniformly distributed 32 unsigned integer, -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandUint32(counter: su64, funcIndex: u32, key: u32) -> u32 { - return Philox2x32(Uint64Add32(counter, funcIndex), key).x; -} - -// RandFloat32Vec2 returns two uniformly distributed float32 values in range (0,1), -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - return Uint32ToFloat32Vec2(RandUint32Vec2(counter, funcIndex, key)); -} - -// RandFloat32 returns a uniformly distributed float32 value in range (0,1), -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32(counter: su64, funcIndex: u32, key: u32) -> f32 { - return Uint32ToFloat32(RandUint32(counter, funcIndex, key)); -} - -// RandFloat32Range11Vec2 returns two uniformly distributed float32 values in range [-1,1], -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Range11Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - return Uint32ToFloat32Vec2(RandUint32Vec2(counter, funcIndex, key)); -} - -// RandFloat32Range11 returns a uniformly distributed float32 value in range [-1,1], -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Range11(counter: su64, funcIndex: u32, key: u32) -> f32 { - return Uint32ToFloat32Range11(RandUint32(counter, funcIndex, key)); -} - -// RandBoolP returns a bool true value with probability p -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandBoolP(counter: su64, funcIndex: u32, key: u32, p: f32) -> bool { - return (RandFloat32(counter, funcIndex, key) < p); -} - -fn sincospi(x: f32) -> vec2 { - let PIf = 3.1415926535897932; - var r: vec2; - r.x = cos(PIf*x); - r.y = sin(PIf*x); - return r; -} - -// RandFloat32NormVec2 returns two random float32 numbers -// distributed according to the normal, Gaussian distribution -// with zero mean and unit variance. -// This is done very efficiently using the Box-Muller algorithm -// that consumes two random 32 bit uint values. -// Uses given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32NormVec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - let ur = RandUint32Vec2(counter, funcIndex, key); - var f = sincospi(Uint32ToFloat32Range11(ur.x)); - let r = sqrt(-2.0 * log(Uint32ToFloat32(ur.y))); // guaranteed to avoid 0. - return f * r; -} - -// RandFloat32Norm returns a random float32 number -// distributed according to the normal, Gaussian distribution -// with zero mean and unit variance. -// Uses given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Norm(counter: su64, funcIndex: u32, key: u32) -> f32 { - return RandFloat32Vec2(counter, funcIndex, key).x; -} - -// RandUint32N returns a uint32 in the range [0,N). -// Uses given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandUint32N(counter: su64, funcIndex: u32, key: u32, n: u32) -> u32 { - let v = RandFloat32(counter, funcIndex, key); - return u32(v * f32(n)); -} - -// Counter is used for storing the random counter using aligned 16 byte -// storage, with convenience functions for typical use cases. -// It retains a copy of the last Seed value, which is applied to -// the Hi uint32 value. -struct RandCounter { - Counter: su64, - HiSeed: u32, - pad: u32, -} - -// Reset resets counter to last set Seed state. -fn RandCounter_Reset(ct: ptr) { - (*ct).Counter.x = u32(0); - (*ct).Counter.y = (*ct).HiSeed; -} - -// Seed sets the Hi uint32 value from given seed, saving it in Seed field. -// Each increment in seed generates a unique sequence of over 4 billion numbers, -// so it is reasonable to just use incremental values there, but more widely -// spaced numbers will result in longer unique sequences. -// Resets Lo to 0. -// This same seed will be restored during Reset -fn RandCounter_Seed(ct: ptr, seed: u32) { - (*ct).HiSeed = seed; - RandCounter_Reset(ct); -} - -// Add increments the counter by given amount. -// Call this after completing a pass of computation -// where the value passed here is the max of funcIndex+1 -// used for any possible random calls during that pass. -fn RandCounter_Add(ct: ptr, inc: u32) { - (*ct).Counter = Uint64Add32((*ct).Counter, inc); -} - - -struct Rnds { - Uints: vec2, - pad: i32, - pad1: i32, - Floats: vec2, - pad2: i32, - pad3: i32, - Floats11: vec2, - pad4: i32, - pad5: i32, - Gauss: vec2, - pad6: i32, - pad7: i32, -} - -// RndGen calls random function calls to test generator. -// Note that the counter to the outer-most computation function -// is passed by *value*, so the same counter goes to each element -// as it is computed, but within this scope, counter is passed by -// reference (as a pointer) so subsequent calls get a new counter value. -// The counter should be incremented by the number of random calls -// outside of the overall update function. -fn Rnds_RndGen(r: ptr, counter: su64, idx: u32) { - (*r).Uints = RandUint32Vec2(counter, u32(0), idx); - (*r).Floats = RandFloat32Vec2(counter, u32(1), idx); - (*r).Floats11 = RandFloat32Range11Vec2(counter, u32(2), idx); - (*r).Gauss = RandFloat32NormVec2(counter, u32(3), idx); -} - -// from file: rand.wgsl - -@group(0) @binding(0) -var Counter: array; - -@group(0) @binding(1) -var Data: array; - -@compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - var ctr = Counter[0]; - var data = Data[idx.x]; - Rnds_RndGen(&data, ctr, idx.x); - Data[idx.x] = data; -} - diff --git a/gpu/gosl/extract.go b/gpu/gosl/extract.go deleted file mode 100644 index 5a41aa1649..0000000000 --- a/gpu/gosl/extract.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) 2022, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "bytes" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" - - "slices" -) - -func ReadFileLines(fn string) ([][]byte, error) { - nl := []byte("\n") - buf, err := os.ReadFile(fn) - if err != nil { - fmt.Println(err) - return nil, err - } - lines := bytes.Split(buf, nl) - return lines, nil -} - -// Extracts comment-directive tagged regions from .go files -func ExtractGoFiles(files []string) map[string][]byte { - sls := map[string][][]byte{} - key := []byte("//gosl:") - start := []byte("start") - wgsl := []byte("wgsl") - nowgsl := []byte("nowgsl") - end := []byte("end") - nl := []byte("\n") - include := []byte("#include") - - for _, fn := range files { - if !strings.HasSuffix(fn, ".go") { - continue - } - lines, err := ReadFileLines(fn) - if err != nil { - continue - } - - inReg := false - inHlsl := false - inNoHlsl := false - var outLns [][]byte - slFn := "" - for _, ln := range lines { - tln := bytes.TrimSpace(ln) - isKey := bytes.HasPrefix(tln, key) - var keyStr []byte - if isKey { - keyStr = tln[len(key):] - // fmt.Printf("key: %s\n", string(keyStr)) - } - switch { - case inReg && isKey && bytes.HasPrefix(keyStr, end): - if inHlsl || inNoHlsl { - outLns = append(outLns, ln) - } - sls[slFn] = outLns - inReg = false - inHlsl = false - inNoHlsl = false - case inReg: - for pkg := range LoadedPackageNames { // remove package prefixes - if !bytes.Contains(ln, include) { - ln = bytes.ReplaceAll(ln, []byte(pkg+"."), []byte{}) - } - } - outLns = append(outLns, ln) - case isKey && bytes.HasPrefix(keyStr, start): - inReg = true - slFn = string(keyStr[len(start)+1:]) - outLns = sls[slFn] - case isKey && bytes.HasPrefix(keyStr, nowgsl): - inReg = true - inNoHlsl = true - slFn = string(keyStr[len(nowgsl)+1:]) - outLns = sls[slFn] - outLns = append(outLns, ln) // key to include self here - case isKey && bytes.HasPrefix(keyStr, wgsl): - inReg = true - inHlsl = true - slFn = string(keyStr[len(wgsl)+1:]) - outLns = sls[slFn] - outLns = append(outLns, ln) - } - } - } - - rsls := make(map[string][]byte) - for fn, lns := range sls { - outfn := filepath.Join(*outDir, fn+".go") - olns := [][]byte{} - olns = append(olns, []byte("package main")) - olns = append(olns, []byte(`import ( - "math" - "cogentcore.org/core/gpu/gosl/slbool" - "cogentcore.org/core/gpu/gosl/slrand" - "cogentcore.org/core/gpu/gosl/sltype" -) -`)) - olns = append(olns, lns...) - SlBoolReplace(olns) - res := bytes.Join(olns, nl) - ioutil.WriteFile(outfn, res, 0644) - // not necessary and super slow: - // cmd := exec.Command("goimports", "-w", fn+".go") // get imports - // cmd.Dir, _ = filepath.Abs(*outDir) - // out, err := cmd.CombinedOutput() - // _ = out - // // fmt.Printf("\n################\ngoimports output for: %s\n%s\n", outfn, out) - // if err != nil { - // log.Println(err) - // } - rsls[fn] = bytes.Join(lns, nl) - } - - return rsls -} - -// ExtractWGSL extracts the WGSL code embedded within .Go files. -// Returns true if WGSL contains a void main( function. -func ExtractWGSL(buf []byte) ([]byte, bool) { - key := []byte("//gosl:") - wgsl := []byte("wgsl") - nowgsl := []byte("nowgsl") - end := []byte("end") - nl := []byte("\n") - stComment := []byte("/*") - edComment := []byte("*/") - comment := []byte("// ") - pack := []byte("package") - imp := []byte("import") - main := []byte("void main(") - lparen := []byte("(") - rparen := []byte(")") - - lines := bytes.Split(buf, nl) - - mx := min(10, len(lines)) - stln := 0 - gotImp := false - for li := 0; li < mx; li++ { - ln := lines[li] - switch { - case bytes.HasPrefix(ln, pack): - stln = li + 1 - case bytes.HasPrefix(ln, imp): - if bytes.HasSuffix(ln, lparen) { - gotImp = true - } else { - stln = li + 1 - } - case gotImp && bytes.HasPrefix(ln, rparen): - stln = li + 1 - } - } - - lines = lines[stln:] // get rid of package, import - - hasMain := false - inHlsl := false - inNoHlsl := false - noHlslStart := 0 - for li := 0; li < len(lines); li++ { - ln := lines[li] - isKey := bytes.HasPrefix(ln, key) - var keyStr []byte - if isKey { - keyStr = ln[len(key):] - // fmt.Printf("key: %s\n", string(keyStr)) - } - switch { - case inNoHlsl && isKey && bytes.HasPrefix(keyStr, end): - lines = slices.Delete(lines, noHlslStart, li+1) - li -= ((li + 1) - noHlslStart) - inNoHlsl = false - case inHlsl && isKey && bytes.HasPrefix(keyStr, end): - lines = slices.Delete(lines, li, li+1) - li-- - inHlsl = false - case inHlsl: - del := false - switch { - case bytes.HasPrefix(ln, stComment) || bytes.HasPrefix(ln, edComment): - lines = slices.Delete(lines, li, li+1) - li-- - del = true - case bytes.HasPrefix(ln, comment): - lines[li] = ln[3:] - } - if !del { - if bytes.HasPrefix(lines[li], main) { - hasMain = true - } - } - case isKey && bytes.HasPrefix(keyStr, wgsl): - inHlsl = true - lines = slices.Delete(lines, li, li+1) - li-- - case isKey && bytes.HasPrefix(keyStr, nowgsl): - inNoHlsl = true - noHlslStart = li - } - } - return bytes.Join(lines, nl), hasMain -} diff --git a/gpu/gosl/files.go b/gpu/gosl/files.go deleted file mode 100644 index c37feaf0e6..0000000000 --- a/gpu/gosl/files.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) 2022, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "fmt" - "io" - "io/fs" - "log" - "os" - "path/filepath" - "strings" - - "golang.org/x/tools/go/packages" -) - -// LoadedPackageNames are single prefix names of packages that were -// loaded in the list of files to process -var LoadedPackageNames = map[string]bool{} - -func IsGoFile(f fs.DirEntry) bool { - name := f.Name() - return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") && !f.IsDir() -} - -func IsWGSLFile(f fs.DirEntry) bool { - name := f.Name() - return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".wgsl") && !f.IsDir() -} - -func IsSPVFile(f fs.DirEntry) bool { - name := f.Name() - return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".spv") && !f.IsDir() -} - -func AddFile(fn string, fls []string, procd map[string]bool) []string { - if _, has := procd[fn]; has { - return fls - } - fls = append(fls, fn) - procd[fn] = true - dir, _ := filepath.Split(fn) - if dir != "" { - dir = dir[:len(dir)-1] - pd, sd := filepath.Split(dir) - if pd != "" { - dir = sd - } - if !(dir == "math32") { - if _, has := LoadedPackageNames[dir]; !has { - LoadedPackageNames[dir] = true - // fmt.Printf("package: %s\n", dir) - } - } - } - return fls -} - -// FilesFromPaths processes all paths and returns a full unique list of files -// for subsequent processing. -func FilesFromPaths(paths []string) []string { - fls := make([]string, 0, len(paths)) - procd := make(map[string]bool) - for _, path := range paths { - switch info, err := os.Stat(path); { - case err != nil: - var pkgs []*packages.Package - dir, fl := filepath.Split(path) - if dir != "" && fl != "" && strings.HasSuffix(fl, ".go") { - pkgs, err = packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, dir) - } else { - fl = "" - pkgs, err = packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, path) - } - if err != nil { - fmt.Println(err) - continue - } - pkg := pkgs[0] - gofls := pkg.GoFiles - if len(gofls) == 0 { - fmt.Printf("WARNING: no go files found in path: %s\n", path) - } - if fl != "" { - for _, gf := range gofls { - if strings.HasSuffix(gf, fl) { - fls = AddFile(gf, fls, procd) - // fmt.Printf("added file: %s from package: %s\n", gf, path) - break - } - } - } else { - for _, gf := range gofls { - fls = AddFile(gf, fls, procd) - // fmt.Printf("added file: %s from package: %s\n", gf, path) - } - } - case !info.IsDir(): - path := path - fls = AddFile(path, fls, procd) - default: - // Directories are walked, ignoring non-Go, non-WGSL files. - err := filepath.WalkDir(path, func(path string, f fs.DirEntry, err error) error { - if err != nil || !(IsGoFile(f) || IsWGSLFile(f)) { - return err - } - _, err = f.Info() - if err != nil { - return nil - } - fls = AddFile(path, fls, procd) - return nil - }) - if err != nil { - log.Println(err) - } - } - } - return fls -} - -func CopyFile(src, dst string) error { - in, err := os.Open(src) - if err != nil { - return err - } - defer in.Close() - out, err := os.Create(dst) - if err != nil { - return err - } - defer out.Close() - _, err = io.Copy(out, in) - return err -} - -// CopyPackageFile copies given file name from given package path -// into the current output directory. -// e.g., "slrand.wgsl", "cogentcore.org/core/gpu/gosl/slrand" -func CopyPackageFile(fnm, packagePath string) error { - tofn := filepath.Join(*outDir, fnm) - pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, packagePath) - if err != nil { - fmt.Println(err) - return err - } - if len(pkgs) != 1 { - err = fmt.Errorf("%s package not found", packagePath) - fmt.Println(err) - return err - } - pkg := pkgs[0] - var fn string - if len(pkg.GoFiles) > 0 { - fn = pkg.GoFiles[0] - } else if len(pkg.OtherFiles) > 0 { - fn = pkg.GoFiles[0] - } else { - err = fmt.Errorf("No files found in package: %s", packagePath) - fmt.Println(err) - return err - } - dir, _ := filepath.Split(fn) - fmfn := filepath.Join(dir, fnm) - CopyFile(fmfn, tofn) - return nil -} - -// RemoveGenFiles removes .go, .wgsl, .spv files in shader generated dir -func RemoveGenFiles(dir string) { - err := filepath.WalkDir(dir, func(path string, f fs.DirEntry, err error) error { - if err != nil { - return err - } - if IsGoFile(f) || IsWGSLFile(f) || IsSPVFile(f) { - os.Remove(path) - } - return nil - }) - if err != nil { - log.Println(err) - } -} diff --git a/gpu/gosl/gosl.go b/gpu/gosl/gosl.go deleted file mode 100644 index e7638857d2..0000000000 --- a/gpu/gosl/gosl.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2022, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// copied and heavily edited from go src/cmd/gofmt/gofmt.go: - -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "flag" - "fmt" - "os" - "strings" - - "cogentcore.org/core/gpu/gosl/slprint" -) - -// flags -var ( - outDir = flag.String("out", "shaders", "output directory for shader code, relative to where gosl is invoked; must not be an empty string") - excludeFunctions = flag.String("exclude", "Update,Defaults", "comma-separated list of names of functions to exclude from exporting to WGSL") - keepTmp = flag.Bool("keep", false, "keep temporary converted versions of the source files, for debugging") - debug = flag.Bool("debug", false, "enable debugging messages while running") - excludeFunctionMap = map[string]bool{} -) - -// Keep these in sync with go/format/format.go. -const ( - tabWidth = 8 - printerMode = slprint.UseSpaces | slprint.TabIndent | printerNormalizeNumbers - - // printerNormalizeNumbers means to canonicalize number literal prefixes - // and exponents while printing. See https://golang.org/doc/go1.13#gosl. - // - // This value is defined in go/printer specifically for go/format and cmd/gosl. - printerNormalizeNumbers = 1 << 30 -) - -func usage() { - fmt.Fprintf(os.Stderr, "usage: gosl [flags] [path ...]\n") - flag.PrintDefaults() -} - -func main() { - flag.Usage = usage - flag.Parse() - goslMain() -} - -func GoslArgs() { - exs := *excludeFunctions - ex := strings.Split(exs, ",") - for _, fn := range ex { - excludeFunctionMap[fn] = true - } -} - -func goslMain() { - if *outDir == "" { - fmt.Println("Must have an output directory (default shaders), specified in -out arg") - os.Exit(1) - return - } - - if gomod := os.Getenv("GO111MODULE"); gomod == "off" { - fmt.Println("gosl only works in go modules mode, but GO111MODULE=off") - os.Exit(1) - return - } - - os.MkdirAll(*outDir, 0755) - RemoveGenFiles(*outDir) - - args := flag.Args() - if len(args) == 0 { - fmt.Printf("at least one file name must be passed\n") - return - } - - GoslArgs() - ProcessFiles(args) -} diff --git a/gpu/gosl/gosl_test.go b/gpu/gosl/gosl_test.go deleted file mode 100644 index 8a8ec76b5a..0000000000 --- a/gpu/gosl/gosl_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "bytes" - "flag" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -var update = flag.Bool("update", false, "update .golden files") - -func runTest(t *testing.T, in, out string) { - // process flags - _, err := os.Lstat(in) - if err != nil { - t.Error(err) - return - } - - sls, err := ProcessFiles([]string{in}) - if err != nil { - t.Error(err) - return - } - - expected, err := os.ReadFile(out) - if err != nil { - t.Error(err) - return - } - - var got []byte - for _, b := range sls { - got = b - break - } - - if !bytes.Equal(got, expected) { - if *update { - if in != out { - if err := os.WriteFile(out, got, 0666); err != nil { - t.Error(err) - } - return - } - // in == out: don't accidentally destroy input - t.Errorf("WARNING: -update did not rewrite input file %s", in) - } - - assert.Equal(t, string(expected), string(got)) - if err := os.WriteFile(in+".gosl", got, 0666); err != nil { - t.Error(err) - } - } -} - -// TestRewrite processes testdata/*.input files and compares them to the -// corresponding testdata/*.golden files. The gosl flags used to process -// a file must be provided via a comment of the form -// -// //gosl flags -// -// in the processed file within the first 20 lines, if any. -func TestRewrite(t *testing.T) { - if gomod := os.Getenv("GO111MODULE"); gomod == "off" { - t.Error("gosl only works in go modules mode, but GO111MODULE=off") - return - } - - // determine input files - match, err := filepath.Glob("testdata/*.go") - if err != nil { - t.Fatal(err) - } - - if *outDir != "" { - os.MkdirAll(*outDir, 0755) - } - - for _, in := range match { - name := filepath.Base(in) - t.Run(name, func(t *testing.T) { - out := in // for files where input and output are identical - if strings.HasSuffix(in, ".go") { - out = in[:len(in)-len(".go")] + ".golden" - } - runTest(t, in, out) - }) - } -} diff --git a/gpu/gosl/process.go b/gpu/gosl/process.go deleted file mode 100644 index de818b4903..0000000000 --- a/gpu/gosl/process.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "bytes" - "fmt" - "go/ast" - "go/token" - "io/fs" - "io/ioutil" - "log" - "os" - "os/exec" - "path/filepath" - "strings" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/gpu" - "cogentcore.org/core/gpu/gosl/alignsl" - "cogentcore.org/core/gpu/gosl/slprint" - "golang.org/x/tools/go/packages" -) - -// does all the file processing -func ProcessFiles(paths []string) (map[string][]byte, error) { - fls := FilesFromPaths(paths) - gosls := ExtractGoFiles(fls) // extract Go files to shader/*.go - - wgslFiles := []string{} - for _, fn := range fls { - if strings.HasSuffix(fn, ".wgsl") { - wgslFiles = append(wgslFiles, fn) - } - } - - pf := "./" + *outDir - pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes | packages.NeedTypesInfo}, pf) - if err != nil { - log.Println(err) - return nil, err - } - if len(pkgs) != 1 { - err := fmt.Errorf("More than one package for path: %v", pf) - log.Println(err) - return nil, err - } - pkg := pkgs[0] - - if len(pkg.GoFiles) == 0 { - err := fmt.Errorf("No Go files found in package: %+v", pkg) - log.Println(err) - return nil, err - } - // fmt.Printf("go files: %+v", pkg.GoFiles) - // return nil, err - - // map of files with a main function that needs to be compiled - needsCompile := map[string]bool{} - - serr := alignsl.CheckPackage(pkg) - if serr != nil { - fmt.Println(serr) - } - - slrandCopied := false - sltypeCopied := false - for fn := range gosls { - gofn := fn + ".go" - if *debug { - fmt.Printf("###################################\nProcessing Go file: %s\n", gofn) - } - - var afile *ast.File - var fpos token.Position - for _, sy := range pkg.Syntax { - pos := pkg.Fset.Position(sy.Package) - _, posfn := filepath.Split(pos.Filename) - if posfn == gofn { - fpos = pos - afile = sy - break - } - } - if afile == nil { - fmt.Printf("Warning: File named: %s not found in processed package\n", gofn) - continue - } - - var buf bytes.Buffer - cfg := slprint.Config{Mode: printerMode, Tabwidth: tabWidth, ExcludeFunctions: excludeFunctionMap} - cfg.Fprint(&buf, pkg, afile) - // ioutil.WriteFile(filepath.Join(*outDir, fn+".tmp"), buf.Bytes(), 0644) - slfix, hasSltype, hasSlrand := SlEdits(buf.Bytes()) - if hasSlrand && !slrandCopied { - hasSltype = true - if *debug { - fmt.Printf("\tcopying slrand.wgsl to shaders\n") - } - CopyPackageFile("slrand.wgsl", "cogentcore.org/core/gpu/gosl/slrand") - slrandCopied = true - } - if hasSltype && !sltypeCopied { - if *debug { - fmt.Printf("\tcopying sltype.wgsl to shaders\n") - } - CopyPackageFile("sltype.wgsl", "cogentcore.org/core/gpu/gosl/sltype") - sltypeCopied = true - } - exsl, hasMain := ExtractWGSL(slfix) - gosls[fn] = exsl - - if hasMain { - needsCompile[fn] = true - } - if !*keepTmp { - os.Remove(fpos.Filename) - } - - // add wgsl code - for _, slfn := range wgslFiles { - if fn+".wgsl" != slfn { - continue - } - buf, err := os.ReadFile(slfn) - if err != nil { - fmt.Println(err) - continue - } - exsl = append(exsl, []byte(fmt.Sprintf("\n// from file: %s\n", slfn))...) - exsl = append(exsl, buf...) - gosls[fn] = exsl - needsCompile[fn] = true // assume any standalone has main - break - } - - slfn := filepath.Join(*outDir, fn+".wgsl") - ioutil.WriteFile(slfn, exsl, 0644) - } - - // check for wgsl files that had no go equivalent - for _, slfn := range wgslFiles { - hasGo := false - for fn := range gosls { - if fn+".wgsl" == slfn { - hasGo = true - break - } - } - if hasGo { - continue - } - _, slfno := filepath.Split(slfn) // could be in a subdir - tofn := filepath.Join(*outDir, slfno) - CopyFile(slfn, tofn) - fn := strings.TrimSuffix(slfno, ".wgsl") - needsCompile[fn] = true // assume any standalone wgsl is a main - } - - for fn := range needsCompile { - CompileFile(fn + ".wgsl") - } - return gosls, nil -} - -func CompileFile(fn string) error { - dir, _ := filepath.Abs(*outDir) - fsys := os.DirFS(dir) - b, err := fs.ReadFile(fsys, fn) - if errors.Log(err) != nil { - return err - } - is := gpu.IncludeFS(fsys, "", string(b)) - ofn := filepath.Join(dir, fn) - err = os.WriteFile(ofn, []byte(is), 0666) - if errors.Log(err) != nil { - return err - } - cmd := exec.Command("naga", fn) - cmd.Dir = dir - out, err := cmd.CombinedOutput() - fmt.Printf("\n-----------------------------------------------------\nnaga output for: %s\n%s", fn, out) - if err != nil { - log.Println(err) - return err - } - return nil -} diff --git a/gpu/gosl/slrand/slrand.wgsl b/gpu/gosl/slrand/slrand.wgsl deleted file mode 100644 index 820e7bdf62..0000000000 --- a/gpu/gosl/slrand/slrand.wgsl +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Original file is in Go package: github.com/cogentcore/core/gpu/gosl/slrand -// See README.md there for documentation. - -// These random number generation (RNG) functions are optimized for -// use on the GPU, with equivalent Go versions available in slrand.go. -// This is using the Philox2x32 counter-based RNG. - -#include "sltype.wgsl" - -// Philox2x32round does one round of updating of the counter. -fn Philox2x32round(counter: su64, key: u32) -> su64 { - let mul = Uint32Mul64(u32(0xD256D193), counter.x); - var ctr: su64; - ctr.x = mul.y ^ key ^ counter.y; - ctr.y = mul.x; - return ctr; -} - -// Philox2x32bumpkey does one round of updating of the key -fn Philox2x32bumpkey(key: u32) -> u32 { - return key + u32(0x9E3779B9); -} - -// Philox2x32 implements the stateless counter-based RNG algorithm -// returning a random number as two uint32 values, given a -// counter and key input that determine the result. -// The input counter is not modified. -fn Philox2x32(counter: su64, key: u32) -> vec2 { - // this is an unrolled loop of 10 updates based on initial counter and key, - // which produces the random deviation deterministically based on these inputs. - var ctr = Philox2x32round(counter, key); // 1 - var ky = Philox2x32bumpkey(key); - ctr = Philox2x32round(ctr, ky); // 2 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 3 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 4 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 5 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 6 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 7 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 8 - ky = Philox2x32bumpkey(ky); - ctr = Philox2x32round(ctr, ky); // 9 - ky = Philox2x32bumpkey(ky); - - return Philox2x32round(ctr, ky); // 10 -} - -//////////////////////////////////////////////////////////// -// Methods below provide a standard interface with more -// readable names, mapping onto the Go rand methods. -// -// They assume a global shared counter, which is then -// incremented by a function index, defined for each function -// consuming random numbers that _could_ be called within a parallel -// processing loop. At the end of the loop, the global counter should -// be incremented by the total possible number of such functions. -// This results in fully resproducible results, invariant to -// specific processing order, and invariant to whether any one function -// actually calls the random number generator. - -// RandUint32Vec2 returns two uniformly distributed 32 unsigned integers, -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandUint32Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - return Philox2x32(Uint64Add32(counter, funcIndex), key); -} - -// RandUint32 returns a uniformly distributed 32 unsigned integer, -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandUint32(counter: su64, funcIndex: u32, key: u32) -> u32 { - return Philox2x32(Uint64Add32(counter, funcIndex), key).x; -} - -// RandFloat32Vec2 returns two uniformly distributed float32 values in range (0,1), -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - return Uint32ToFloat32Vec2(RandUint32Vec2(counter, funcIndex, key)); -} - -// RandFloat32 returns a uniformly distributed float32 value in range (0,1), -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32(counter: su64, funcIndex: u32, key: u32) -> f32 { - return Uint32ToFloat32(RandUint32(counter, funcIndex, key)); -} - -// RandFloat32Range11Vec2 returns two uniformly distributed float32 values in range [-1,1], -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Range11Vec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - return Uint32ToFloat32Vec2(RandUint32Vec2(counter, funcIndex, key)); -} - -// RandFloat32Range11 returns a uniformly distributed float32 value in range [-1,1], -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Range11(counter: su64, funcIndex: u32, key: u32) -> f32 { - return Uint32ToFloat32Range11(RandUint32(counter, funcIndex, key)); -} - -// RandBoolP returns a bool true value with probability p -// based on given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandBoolP(counter: su64, funcIndex: u32, key: u32, p: f32) -> bool { - return (RandFloat32(counter, funcIndex, key) < p); -} - -fn sincospi(x: f32) -> vec2 { - let PIf = 3.1415926535897932; - var r: vec2; - r.x = cos(PIf*x); - r.y = sin(PIf*x); - return r; -} - -// RandFloat32NormVec2 returns two random float32 numbers -// distributed according to the normal, Gaussian distribution -// with zero mean and unit variance. -// This is done very efficiently using the Box-Muller algorithm -// that consumes two random 32 bit uint values. -// Uses given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32NormVec2(counter: su64, funcIndex: u32, key: u32) -> vec2 { - let ur = RandUint32Vec2(counter, funcIndex, key); - var f = sincospi(Uint32ToFloat32Range11(ur.x)); - let r = sqrt(-2.0 * log(Uint32ToFloat32(ur.y))); // guaranteed to avoid 0. - return f * r; -} - -// RandFloat32Norm returns a random float32 number -// distributed according to the normal, Gaussian distribution -// with zero mean and unit variance. -// Uses given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandFloat32Norm(counter: su64, funcIndex: u32, key: u32) -> f32 { - return RandFloat32Vec2(counter, funcIndex, key).x; -} - -// RandUint32N returns a uint32 in the range [0,N). -// Uses given global shared counter, function index offset from that -// counter for this specific random number call, and key as unique -// index of the item being processed. -fn RandUint32N(counter: su64, funcIndex: u32, key: u32, n: u32) -> u32 { - let v = RandFloat32(counter, funcIndex, key); - return u32(v * f32(n)); -} - -// Counter is used for storing the random counter using aligned 16 byte -// storage, with convenience functions for typical use cases. -// It retains a copy of the last Seed value, which is applied to -// the Hi uint32 value. -struct RandCounter { - Counter: su64, - HiSeed: u32, - pad: u32, -} - -// Reset resets counter to last set Seed state. -fn RandCounter_Reset(ct: ptr) { - (*ct).Counter.x = u32(0); - (*ct).Counter.y = (*ct).HiSeed; -} - -// Seed sets the Hi uint32 value from given seed, saving it in Seed field. -// Each increment in seed generates a unique sequence of over 4 billion numbers, -// so it is reasonable to just use incremental values there, but more widely -// spaced numbers will result in longer unique sequences. -// Resets Lo to 0. -// This same seed will be restored during Reset -fn RandCounter_Seed(ct: ptr, seed: u32) { - (*ct).HiSeed = seed; - RandCounter_Reset(ct); -} - -// Add increments the counter by given amount. -// Call this after completing a pass of computation -// where the value passed here is the max of funcIndex+1 -// used for any possible random calls during that pass. -fn RandCounter_Add(ct: ptr, inc: u32) { - (*ct).Counter = Uint64Add32((*ct).Counter, inc); -} diff --git a/gpu/gosl/sltype/sltype.wgsl b/gpu/gosl/sltype/sltype.wgsl deleted file mode 100644 index e3ffe9e8e6..0000000000 --- a/gpu/gosl/sltype/sltype.wgsl +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Original file is in Go package: github.com/cogentcore/core/gpu/gosl/sltype -// See README.md there for documentation. - -// This file emulates uint64 (u64) using 2 uint32 integers. -// and defines conversions between uint and float. - -// define a u64 type as an alias. -// if / when u64 actually happens, will make it easier to update. -alias su64 = vec2; - -// Uint32Mul64 multiplies two uint32 numbers into a uint64 (using vec2). -fn Uint32Mul64(a: u32, b: u32) -> su64 { - let LOMASK = (((u32(1))<<16)-1); - var r: su64; - r.x = a * b; /* full low multiply */ - let ahi = a >> 16; - let alo = a & LOMASK; - let bhi = b >> 16; - let blo = b & LOMASK; - - let ahbl = ahi * blo; - let albh = alo * bhi; - - let ahbl_albh = ((ahbl&LOMASK) + (albh&LOMASK)); - var hit = ahi*bhi + (ahbl>>16) + (albh>>16); - hit += ahbl_albh >> 16; /* carry from the sum of lo(ahbl) + lo(albh) ) */ - /* carry from the sum with alo*blo */ - if ((r.x >> u32(16)) < (ahbl_albh&LOMASK)) { - hit += u32(1); - } - r.y = hit; - return r; -} - -/* -// Uint32Mul64 multiplies two uint32 numbers into a uint64 (using su64). -fn Uint32Mul64(a: u32, b: u32) -> su64 { - return su64(a) * su64(b); -} -*/ - - -// Uint64Add32 adds given uint32 number to given uint64 (using vec2). -fn Uint64Add32(a: su64, b: u32) -> su64 { - if (b == 0) { - return a; - } - var s = a; - if (s.x > u32(0xffffffff) - b) { - s.y++; - s.x = (b - 1) - (u32(0xffffffff) - s.x); - } else { - s.x += b; - } - return s; -} - -// Uint64Incr returns increment of the given uint64 (using vec2). -fn Uint64Incr(a: su64) -> su64 { - var s = a; - if(s.x == 0xffffffff) { - s.y++; - s.x = u32(0); - } else { - s.x++; - } - return s; -} - -// Uint32ToFloat32 converts a uint32 integer into a float32 -// in the (0,1) interval (i.e., exclusive of 1). -// This differs from the Go standard by excluding 0, which is handy for passing -// directly to Log function, and from the reference Philox code by excluding 1 -// which is in the Go standard and most other standard RNGs. -fn Uint32ToFloat32(val: u32) -> f32 { - let factor = f32(1.0) / (f32(u32(0xffffffff)) + f32(1.0)); - let halffactor = f32(0.5) * factor; - var f = f32(val) * factor + halffactor; - if (f == 1.0) { // exclude 1 - return bitcast(0x3F7FFFFF); - } - return f; -} - -// note: there is no overloading of user-defined functions -// https://github.com/gpuweb/gpuweb/issues/876 - -// Uint32ToFloat32Vec2 converts two uint 32 bit integers -// into two corresponding 32 bit f32 values -// in the (0,1) interval (i.e., exclusive of 1). -fn Uint32ToFloat32Vec2(val: vec2) -> vec2 { - var r: vec2; - r.x = Uint32ToFloat32(val.x); - r.y = Uint32ToFloat32(val.y); - return r; -} - -// Uint32ToFloat32Range11 converts a uint32 integer into a float32 -// in the [-1..1] interval (inclusive of -1 and 1, never identically == 0). -fn Uint32ToFloat32Range11(val: u32) -> f32 { - let factor = f32(1.0) / (f32(i32(0x7fffffff)) + f32(1.0)); - let halffactor = f32(0.5) * factor; - return (f32(val) * factor + halffactor); -} - -// Uint32ToFloat32Range11Vec2 converts two uint32 integers into two float32 -// in the [-1,1] interval (inclusive of -1 and 1, never identically == 0). -fn Uint32ToFloat32Range11Vec2(val: vec2) -> vec2 { - var r: vec2; - r.x = Uint32ToFloat32Range11(val.x); - r.y = Uint32ToFloat32Range11(val.y); - return r; -} - - diff --git a/gpu/gosl/testdata/basic.golden b/gpu/gosl/testdata/basic.golden deleted file mode 100644 index feabd4e16a..0000000000 --- a/gpu/gosl/testdata/basic.golden +++ /dev/null @@ -1,166 +0,0 @@ - - -// note: here is the wgsl version, only included in wgsl - -// MyTrickyFun this is the GPU version of the tricky function -fn MyTrickyFun(x: f32) -> f32 { - return 16.0; // ok actually not tricky here, but whatever -} - - -// FastExp is a quartic spline approximation to the Exp function, by N.N. Schraudolph -// It does not have any of the sanity checking of a standard method -- returns -// nonsense when arg is out of range. Runs in 2.23ns vs. 6.3ns for 64bit which is faster -// than exp actually. -fn FastExp(x: f32) -> f32 { - if (x <= -88.76731) { // this doesn't add anything and -exp is main use-case anyway - return f32(0); - } - var i = i32(12102203*x) + i32(127)*(i32(1)<<23); - var m = i >> 7 & 0xFFFF; // copy mantissa - i += (((((((((((3537 * m) >> 16) + 13668) * m) >> 18) + 15817) * m) >> 14) - 80470) * m) >> 11); - return bitcast(u32(i)); -} - -// NeuronFlags are bit-flags encoding relevant binary state for neurons -alias NeuronFlags = i32; - -// The neuron flags - -// NeuronOff flag indicates that this neuron has been turned off (i.e., lesioned) -const NeuronOff: NeuronFlags = 0x01; - -// NeuronHasExt means the neuron has external input in its Ext field -const NeuronHasExt: NeuronFlags = 0x02; // note: 1<<2 does NOT work - -// NeuronHasTarg means the neuron has external target input in its Target field -const NeuronHasTarg: NeuronFlags = 0x04; - -// NeuronHasCmpr means the neuron has external comparison input in its Target field -- used for computing -// comparison statistics but does not drive neural activity ever -const NeuronHasCmpr: NeuronFlags = 0x08; - -// Modes are evaluation modes (Training, Testing, etc) -alias Modes = i32; - -// The evaluation modes - -const NoEvalMode: Modes = 0; - -// AllModes indicates that the log should occur over all modes present in other items. -const AllModes: Modes = 1; - -// Train is this a training mode for the env -const Train: Modes = 2; - -// Test is this a test mode for the env -const Test: Modes = 3; - -// DataStruct has the test data -struct DataStruct { - - // raw value - Raw: f32, - - // integrated value - Integ: f32, - - // exp of integ - Exp: f32, - - pad: f32, -} - -// SubParamStruct has the test sub-params -struct SubParamStruct { - A: f32, - B: f32, - C: f32, - D: f32, -} - -fn SubParamStruct_Sum(sp: ptr) -> f32 { - return (*sp).A + (*sp).B + (*sp).C + (*sp).D; -} - -fn SubParamStruct_SumPlus(sp: ptr, extra: f32) -> f32 { - return SubParamStruct_Sum(sp) + extra; -} - -// ParamStruct has the test params -struct ParamStruct { - - // rate constant in msec - Tau: f32, - - // 1/Tau - Dt: f32, - Option: i32, // note: standard bool doesn't work - - pad: f32, // comment this out to trigger alignment warning - - // extra parameters - Subs: SubParamStruct, -} - -fn ParamStruct_IntegFromRaw(ps: ptr, ds: ptr) -> f32 { - // note: the following are just to test basic control structures - var newVal = (*ps).Dt * ((*ds).Raw - (*ds).Integ); - if (newVal < -10 || (*ps).Option == 1) { - newVal = f32(-10); - } - (*ds).Integ += newVal; - (*ds).Exp = exp(-(*ds).Integ); - var a: f32; - ParamStruct_AnotherMeth(ps, ds, &a); - return (*ds).Exp; -} - -// AnotherMeth does more computation -fn ParamStruct_AnotherMeth(ps: ptr, ds: ptr, ptrarg: ptr) { - for (var i = 0; i < 10; i++) { - (*ds).Integ *= f32(0.99); - } - var flag: NeuronFlags; - flag &= ~NeuronHasExt; // clear flag -- op doesn't exist in C - - var mode = Test; - switch (mode) { // note: no fallthrough! - case Test: { - var ab = f32(42); - (*ds).Exp /= ab; - } - case Train: { - var ab = f32(.5); - (*ds).Exp *= ab; - } - default: { - var ab = f32(1); - (*ds).Exp *= ab; - } - } - - var a: f32; - var b: f32; - b = f32(42); - a = SubParamStruct_Sum(&(*ps).Subs); - (*ds).Exp = SubParamStruct_SumPlus(&(*ps).Subs, b); - (*ds).Integ = a; - - *ptrarg = f32(-1); -} - -@group(0) @binding(0) -var Params: array; - -@group(0) @binding(1) -var Data: array; - -@compute -@workgroup_size(64) -fn main(@builtin(global_invocation_id) idx: vec3) { - var pars = Params[0]; - var data = Data[idx.x]; - ParamStruct_IntegFromRaw(&pars, &data); - Data[idx.x] = data; -} diff --git a/gpu/value.go b/gpu/value.go index 922f8cc261..fe01ba7db5 100644 --- a/gpu/value.go +++ b/gpu/value.go @@ -30,10 +30,17 @@ type Value struct { // index of this value within the Var list of values Index int - // VarSize is the size of each Var element, which includes any fixed ArrayN + // VarSize is the size of each Var element, which includes any fixed Var.ArrayN // array size specified on the Var. + // The actual buffer size is VarSize * Value.ArrayN (or DynamicN for dynamic). VarSize int + // ArrayN is the actual number of array elements, for Uniform or Storage + // variables without a fixed array size (i.e., the Var ArrayN = 1). + // This is set when the buffer is actually created, based on the data, + // or can be set directly prior to buffer creation. + ArrayN int + // DynamicIndex is the current index into a DynamicOffset variable // to use for the SetBindGroup call. Note that this is an index, // not an offset, so it indexes the DynamicN Vars in the Value, @@ -103,6 +110,16 @@ func MemSizeAlign(size, align int) int { return (nb + 1) * align } +// MemSizeAlignDown returns the size aligned according to align byte increments, +// rounding down, not up. +func MemSizeAlignDown(size, align int) int { + if size%align == 0 { + return size + } + nb := size / align + return nb * align +} + // init initializes value based on variable and index // within list of vals for this var. func (vl *Value) init(vr *Var, dev *Device, idx int) { @@ -112,6 +129,7 @@ func (vl *Value) init(vr *Var, dev *Device, idx int) { vl.Index = idx vl.Name = fmt.Sprintf("%s_%d", vr.Name, vl.Index) vl.VarSize = vr.MemSize() + vl.ArrayN = 1 vl.alignBytes = vr.alignBytes vl.AlignVarSize = MemSizeAlign(vl.VarSize, vl.alignBytes) vl.isDynamic = vl.role == Vertex || vl.role == Index || vr.DynamicOffset @@ -121,6 +139,10 @@ func (vl *Value) init(vr *Var, dev *Device, idx int) { } } +func (vl *Value) String() string { + return fmt.Sprintf("Bytes: 0x%X", vl.MemSize()) +} + // MemSize returns the memory allocation size for this value, in bytes. func (vl *Value) MemSize() int { if vl.Texture != nil { @@ -129,11 +151,12 @@ func (vl *Value) MemSize() int { if vl.isDynamic { return vl.AlignVarSize * vl.dynamicN } - return vl.VarSize + return vl.ArrayN * vl.VarSize } // CreateBuffer creates the GPU buffer for this value if it does not // yet exist or is not the right size. +// For !ReadOnly [Storage] buffers, calls [Value.CreateReadBuffer]. func (vl *Value) CreateBuffer() error { if vl.role == SampledTexture { return nil @@ -151,7 +174,7 @@ func (vl *Value) CreateBuffer() error { buf, err := vl.device.Device.CreateBuffer(&wgpu.BufferDescriptor{ Size: uint64(sz), Label: vl.Name, - Usage: vl.role.BufferUsages(), + Usage: vl.vvar.bufferUsages(), MappedAtCreation: false, }) if errors.Log(err) != nil { @@ -159,6 +182,9 @@ func (vl *Value) CreateBuffer() error { } vl.AllocSize = sz vl.buffer = buf + if vl.role == Storage && !vl.vvar.ReadOnly { + vl.CreateReadBuffer() + } return nil } @@ -214,6 +240,9 @@ func (vl *Value) SetDynamicN(n int) { // SetValueFrom copies given values into value buffer memory, // making the buffer if it has not yet been constructed. +// The actual ArrayN size of Storage or Uniform variables will +// be computed based on the size of the from bytes, relative to +// the variable size. // IMPORTANT: do not use this for dynamic offset Uniform or // Storage variables, as the alignment will not be correct; // See [SetDynamicFromBytes]. @@ -223,6 +252,7 @@ func SetValueFrom[E any](vl *Value, from []E) error { // SetFromBytes copies given bytes into value buffer memory, // making the buffer if it has not yet been constructed. +// For !ReadOnly [Storage] buffers, calls [Value.CreateReadBuffer]. // IMPORTANT: do not use this for dynamic offset Uniform or // Storage variables, as the alignment will not be correct; // See [SetDynamicFromBytes]. @@ -232,12 +262,19 @@ func (vl *Value) SetFromBytes(from []byte) error { return errors.Log(err) } nb := len(from) + an := nb / vl.VarSize + aover := nb % vl.VarSize + if aover != 0 { + err := fmt.Errorf("gpu.Value SetFromBytes %s, Size passed: %d is not an even multiple of the variable size: %d", vl.Name, nb, vl.VarSize) + return errors.Log(err) + } if vl.isDynamic { // Vertex, Index at this point - dn := nb / vl.VarSize - vl.SetDynamicN(dn) + vl.SetDynamicN(an) + } else { + vl.ArrayN = an } tb := vl.MemSize() - if nb != tb { + if nb != tb { // this should never happen, but justin case err := fmt.Errorf("gpu.Value SetFromBytes %s, Size passed: %d != Size expected %d", vl.Name, nb, tb) return errors.Log(err) } @@ -247,13 +284,16 @@ func (vl *Value) SetFromBytes(from []byte) error { buf, err := vl.device.Device.CreateBufferInit(&wgpu.BufferInitDescriptor{ Label: vl.Name, Contents: from, - Usage: vl.role.BufferUsages(), + Usage: vl.vvar.bufferUsages(), }) if errors.Log(err) != nil { return err } vl.buffer = buf vl.AllocSize = nb + if vl.role == Storage && !vl.vvar.ReadOnly { + vl.CreateReadBuffer() + } } else { err := vl.device.Queue.WriteBuffer(vl.buffer, 0, from) if errors.Log(err) != nil { @@ -320,7 +360,7 @@ func (vl *Value) WriteDynamicBuffer() error { buf, err := vl.device.Device.CreateBufferInit(&wgpu.BufferInitDescriptor{ Label: vl.Name, Contents: vl.dynamicBuffer, - Usage: vl.role.BufferUsages(), + Usage: vl.vvar.bufferUsages(), }) if errors.Log(err) != nil { return err @@ -406,11 +446,10 @@ func (vl *Value) SetFromTexture(tx *Texture) *Texture { } // CreateReadBuffer creates a read buffer for this value, -// if it does not yet exist or is not the right size. +// for [Storage] values only. Automatically called for !ReadOnly. // Read buffer is needed for reading values back from the GPU. -// Only for Storage role variables. func (vl *Value) CreateReadBuffer() error { - if !(vl.role == Storage || vl.role == StorageTexture) { + if !(vl.role == Storage || vl.role == StorageTexture) || vl.vvar.ReadOnly { return nil } sz := vl.MemSize() diff --git a/gpu/values.go b/gpu/values.go index 695db559ed..193c6da697 100644 --- a/gpu/values.go +++ b/gpu/values.go @@ -147,18 +147,6 @@ func (vs *Values) MemSize() int { return tsz } -// CreateReadBuffers creates read buffers for all values. -func (vs *Values) CreateReadBuffers() error { - var errs []error - for _, vl := range vs.Values { - err := vl.CreateReadBuffer() - if err != nil { - errs = append(errs, err) - } - } - return errors.Join(errs...) -} - // bindGroupEntry returns the BindGroupEntry for Current // value for this variable. func (vs *Values) bindGroupEntry(vr *Var) []wgpu.BindGroupEntry { diff --git a/gpu/var.go b/gpu/var.go index 0845aa28ab..b7a4ef7f6e 100644 --- a/gpu/var.go +++ b/gpu/var.go @@ -35,13 +35,12 @@ type Var struct { // automatically be sent as 4 interleaved Float32Vector4 chuncks. Type Types - // number of elements, which is 1 for a single element, or a constant - // number for a fixed array of elements. For Vertex variables, the - // number is dynamic and does not need to be specified in advance, - // so you can leave it at 1. There can be alignment issues with arrays + // ArrayN is the number of elements in an array, only if there is a + // fixed array size. Otherwise, for single elements or dynamic arrays + // use a value of 1. There can be alignment issues with arrays // so make sure your elemental types are compatible. // Note that DynamicOffset variables can have Value buffers with multiple - // instances of the variable (with proper alignment stride), which is + // instances of the variable (with proper alignment stride), // which goes on top of any array value for the variable itself. ArrayN int @@ -87,6 +86,11 @@ type Var struct { // Only for Uniform and Storage variables. DynamicOffset bool + // ReadOnly applies only to [Storage] variables, and indicates that + // they are never read back from the GPU, so the additional staging + // buffers needed to do so are not created for these variables. + ReadOnly bool + // Values is the the array of Values allocated for this variable. // Each value has its own corresponding Buffer or Texture. // The currently-active Value is specified by the Current index, @@ -137,6 +141,9 @@ func (vr *Var) String() string { } } s := fmt.Sprintf("%d:\t%s\t%s\t(size: %d)\tValues: %d", vr.Binding, vr.Name, typ, vr.SizeOf, len(vr.Values.Values)) + if len(vr.Values.Values) == 1 { + s += "\t" + vr.Values.Values[0].String() + } return s } @@ -145,7 +152,6 @@ func (vr *Var) MemSize() int { if vr.ArrayN < 1 { vr.ArrayN = 1 } - // todo: may need to diagnose alignments here.. switch { case vr.Role >= SampledTexture: return 0 @@ -157,7 +163,6 @@ func (vr *Var) MemSize() int { // Release resets the MemPtr for values, resets any self-owned resources (Textures) func (vr *Var) Release() { vr.Values.Release() - // todo: free anything in var } // SetNValues sets specified number of Values for this var. @@ -177,3 +182,17 @@ func (vr *Var) SetCurrentValue(i int) { func (vr *Var) bindGroupEntry() []wgpu.BindGroupEntry { return vr.Values.bindGroupEntry(vr) } + +func (vr *Var) bindingType() wgpu.BufferBindingType { + if vr.Role == Storage && vr.ReadOnly { + return wgpu.BufferBindingTypeReadOnlyStorage + } + return vr.Role.BindingType() +} + +func (vr *Var) bufferUsages() wgpu.BufferUsage { + if vr.Role == Storage && vr.ReadOnly { + return wgpu.BufferUsageStorage | wgpu.BufferUsageCopyDst + } + return vr.Role.BufferUsages() +} diff --git a/gpu/vargroup.go b/gpu/vargroup.go index a676ad6610..6ab44a2c05 100644 --- a/gpu/vargroup.go +++ b/gpu/vargroup.go @@ -166,18 +166,6 @@ func (vg *VarGroup) SetAllCurrentValue(i int) { } } -// CreateReadBuffers creates read buffers for all values. -func (vg *VarGroup) CreateReadBuffers() error { - var errs []error - for _, vr := range vg.Vars { - err := vr.Values.CreateReadBuffers() - if err != nil { - errs = append(errs, err) - } - } - return errors.Join(errs...) -} - // Config must be called after all variables have been added. // Configures binding / location for all vars based on sequential order. // also does validation and returns error message. @@ -277,7 +265,7 @@ func (vg *VarGroup) bindLayout(vs *Vars) (*wgpu.BindGroupLayout, error) { } default: bd.Buffer = wgpu.BufferBindingLayout{ - Type: vr.Role.BindingType(), + Type: vr.bindingType(), HasDynamicOffset: false, MinBindingSize: 0, // 0 is fine } diff --git a/gpu/vars.go b/gpu/vars.go index 9580a95ce3..2e1720b94d 100644 --- a/gpu/vars.go +++ b/gpu/vars.go @@ -169,27 +169,6 @@ func (vs *Vars) SetDynamicIndex(group int, name string, dynamicIndex int) *Var { return vr } -// CreateReadBuffers creates read buffers for all Storage variables. -// This is needed to be able to read values back from GPU (e.g., for Compute). -func (vs *Vars) CreateReadBuffers() error { - var errs []error - ns := vs.NGroups() - for gi := vs.StartGroup(); gi < ns; gi++ { - vg := vs.Groups[gi] - if vg == nil { - continue - } - if vg.Role != Storage { - continue - } - err := vg.CreateReadBuffers() - if err != nil { - errs = append(errs, err) - } - } - return errors.Join(errs...) -} - // Config must be called after all variables have been added. // Configures all Groups and also does validation, returning error // does DescLayout too, so all ready for Pipeline config. diff --git a/math32/fastexp.go b/math32/fastexp.go index 8bd1f3046b..2af370458d 100644 --- a/math32/fastexp.go +++ b/math32/fastexp.go @@ -65,7 +65,7 @@ func FastExp3(x float32) float32 { } */ -//gosl:start fastexp +//gosl:start // FastExp is a quartic spline approximation to the Exp function, by N.N. Schraudolph // It does not have any of the sanity checking of a standard method -- returns @@ -76,9 +76,9 @@ func FastExp(x float32) float32 { return 0.0 } i := int32(12102203*x) + int32(127)*(int32(1)<<23) - m := i >> 7 & 0xFFFF // copy mantissa + m := (i >> 7) & 0xFFFF // copy mantissa i += (((((((((((3537 * m) >> 16) + 13668) * m) >> 18) + 15817) * m) >> 14) - 80470) * m) >> 11) return math.Float32frombits(uint32(i)) } -//gosl:end fastexp +//gosl:end diff --git a/math32/math.go b/math32/math.go index 0f4413e3f8..b67e02150e 100644 --- a/math32/math.go +++ b/math32/math.go @@ -15,6 +15,7 @@ package math32 //go:generate core generate import ( + "cmp" "math" "strconv" @@ -783,18 +784,7 @@ func Yn(n int, x float32) float32 { // Special additions to math. functions // Clamp clamps x to the provided closed interval [a, b] -func Clamp(x, a, b float32) float32 { - if x < a { - return a - } - if x > b { - return b - } - return x -} - -// ClampInt clamps x to the provided closed interval [a, b] -func ClampInt(x, a, b int) int { +func Clamp[T cmp.Ordered](x, a, b T) T { if x < a { return a } diff --git a/math32/minmax/avgmax.go b/math32/minmax/avgmax.go index 7877ccd4d2..72fb1dbe7a 100644 --- a/math32/minmax/avgmax.go +++ b/math32/minmax/avgmax.go @@ -4,9 +4,12 @@ package minmax -import "fmt" +import ( + "fmt" + "math" +) -//gosl:start minmax +//gosl:start const ( MaxFloat32 float32 = 3.402823466e+38 @@ -69,7 +72,7 @@ func (am *AvgMax32) CalcAvg() { } } -//gosl:end minmax +//gosl:end func (am *AvgMax32) String() string { return fmt.Sprintf("{Avg: %g, Max: %g, Sum: %g, MaxIndex: %d, N: %d}", am.Avg, am.Max, am.Sum, am.MaxIndex, am.N) @@ -114,7 +117,7 @@ func (am *AvgMax64) Init() { am.Avg = 0 am.Sum = 0 am.N = 0 - am.Max = -MaxFloat64 + am.Max = math.Inf(-1) am.MaxIndex = -1 } diff --git a/math32/minmax/minmax32.go b/math32/minmax/minmax32.go index 2f5d9fff36..8e8c6de7eb 100644 --- a/math32/minmax/minmax32.go +++ b/math32/minmax/minmax32.go @@ -4,9 +4,13 @@ package minmax -import "fmt" +import ( + "fmt" -//gosl:start minmax + "cogentcore.org/core/math32" +) + +//gosl:start // F32 represents a min / max range for float32 values. // Supports clipping, renormalizing, etc @@ -24,10 +28,10 @@ func (mr *F32) Set(mn, mx float32) { } // SetInfinity sets the Min to +MaxFloat, Max to -MaxFloat -- suitable for -// iteratively calling Fit*InRange +// iteratively calling Fit*InRange. See also Sanitize when done. func (mr *F32) SetInfinity() { - mr.Min = MaxFloat32 - mr.Max = -MaxFloat32 + mr.Min = math32.Inf(1) + mr.Max = math32.Inf(-1) } // IsValid returns true if Min <= Max @@ -87,7 +91,7 @@ func (mr *F32) FitValInRange(val float32) bool { // NormVal normalizes value to 0-1 unit range relative to current Min / Max range // Clips the value within Min-Max range first. func (mr *F32) NormValue(val float32) float32 { - return (mr.ClipValue(val) - mr.Min) * mr.Scale() + return (mr.ClampValue(val) - mr.Min) * mr.Scale() } // ProjVal projects a 0-1 normalized unit value into current Min / Max range (inverse of NormVal) @@ -95,9 +99,9 @@ func (mr *F32) ProjValue(val float32) float32 { return mr.Min + (val * mr.Range()) } -// ClipVal clips given value within Min / Max range -// Note: a NaN will remain as a NaN -func (mr *F32) ClipValue(val float32) float32 { +// ClampValue clamps given value within Min / Max range +// Note: a NaN will remain as a NaN. +func (mr *F32) ClampValue(val float32) float32 { if val < mr.Min { return mr.Min } @@ -119,7 +123,7 @@ func (mr *F32) ClipNormValue(val float32) float32 { return mr.NormValue(val) } -//gosl:end minmax +//gosl:end func (mr *F32) String() string { return fmt.Sprintf("{%g %g}", mr.Min, mr.Max) @@ -139,3 +143,20 @@ func (mr *F32) FitInRange(oth F32) bool { } return adj } + +// Sanitize ensures that the Min / Max range is not infinite or contradictory. +func (mr *F32) Sanitize() { + if math32.IsInf(mr.Min, 0) { + mr.Min = 0 + } + if math32.IsInf(mr.Max, 0) { + mr.Max = 0 + } + if mr.Min > mr.Max { + mr.Min, mr.Max = mr.Max, mr.Min + } + if mr.Min == mr.Max { + mr.Min-- + mr.Max++ + } +} diff --git a/math32/minmax/minmax64.go b/math32/minmax/minmax64.go index 046c7c3958..9386a9e72c 100644 --- a/math32/minmax/minmax64.go +++ b/math32/minmax/minmax64.go @@ -5,12 +5,9 @@ // Package minmax provides a struct that holds Min and Max values. package minmax -//go:generate core generate +import "math" -const ( - MaxFloat64 float64 = 1.7976931348623158e+308 - MinFloat64 float64 = 2.2250738585072014e-308 -) +//go:generate core generate // F64 represents a min / max range for float64 values. // Supports clipping, renormalizing, etc @@ -25,39 +22,39 @@ func (mr *F64) Set(mn, mx float64) { mr.Max = mx } -// SetInfinity sets the Min to +MaxFloat, Max to -MaxFloat -- suitable for -// iteratively calling Fit*InRange +// SetInfinity sets the Min to +Inf, Max to -Inf, suitable for +// iteratively calling Fit*InRange. See also Sanitize when done. func (mr *F64) SetInfinity() { - mr.Min = MaxFloat64 - mr.Max = -MaxFloat64 + mr.Min = math.Inf(1) + mr.Max = math.Inf(-1) } -// IsValid returns true if Min <= Max +// IsValid returns true if Min <= Max. func (mr *F64) IsValid() bool { return mr.Min <= mr.Max } -// InRange tests whether value is within the range (>= Min and <= Max) +// InRange tests whether value is within the range (>= Min and <= Max). func (mr *F64) InRange(val float64) bool { return ((val >= mr.Min) && (val <= mr.Max)) } -// IsLow tests whether value is lower than the minimum +// IsLow tests whether value is lower than the minimum. func (mr *F64) IsLow(val float64) bool { return (val < mr.Min) } -// IsHigh tests whether value is higher than the maximum +// IsHigh tests whether value is higher than the maximum. func (mr *F64) IsHigh(val float64) bool { return (val > mr.Min) } -// Range returns Max - Min +// Range returns Max - Min. func (mr *F64) Range() float64 { return mr.Max - mr.Min } -// Scale returns 1 / Range -- if Range = 0 then returns 0 +// Scale returns 1 / Range -- if Range = 0 then returns 0. func (mr *F64) Scale() float64 { r := mr.Range() if r != 0 { @@ -89,7 +86,7 @@ func (mr *F64) FitValInRange(val float64) bool { // NormVal normalizes value to 0-1 unit range relative to current Min / Max range // Clips the value within Min-Max range first. func (mr *F64) NormValue(val float64) float64 { - return (mr.ClipValue(val) - mr.Min) * mr.Scale() + return (mr.ClampValue(val) - mr.Min) * mr.Scale() } // ProjVal projects a 0-1 normalized unit value into current Min / Max range (inverse of NormVal) @@ -97,9 +94,9 @@ func (mr *F64) ProjValue(val float64) float64 { return mr.Min + (val * mr.Range()) } -// ClipVal clips given value within Min / Max range +// ClampValue clips given value within Min / Max range // Note: a NaN will remain as a NaN -func (mr *F64) ClipValue(val float64) float64 { +func (mr *F64) ClampValue(val float64) float64 { if val < mr.Min { return mr.Min } @@ -135,3 +132,20 @@ func (mr *F64) FitInRange(oth F64) bool { } return adj } + +// Sanitize ensures that the Min / Max range is not infinite or contradictory. +func (mr *F64) Sanitize() { + if math.IsInf(mr.Min, 0) { + mr.Min = 0 + } + if math.IsInf(mr.Max, 0) { + mr.Max = 0 + } + if mr.Min > mr.Max { + mr.Min, mr.Max = mr.Max, mr.Min + } + if mr.Min == mr.Max { + mr.Min-- + mr.Max++ + } +} diff --git a/math32/minmax/minmax_int.go b/math32/minmax/minmax_int.go index 704e5b9c79..fef6453e09 100644 --- a/math32/minmax/minmax_int.go +++ b/math32/minmax/minmax_int.go @@ -96,7 +96,7 @@ func (mr *Int) FitValInRange(val int) bool { // NormVal normalizes value to 0-1 unit range relative to current Min / Max range // Clips the value within Min-Max range first. func (mr *Int) NormValue(val int) float32 { - return float32(mr.ClipValue(val)-mr.Min) * mr.Scale() + return float32(mr.Clamp(val)-mr.Min) * mr.Scale() } // ProjVal projects a 0-1 normalized unit value into current Min / Max range (inverse of NormVal) @@ -105,7 +105,7 @@ func (mr *Int) ProjValue(val float32) float32 { } // ClipVal clips given value within Min / Max rangee -func (mr *Int) ClipValue(val int) int { +func (mr *Int) Clamp(val int) int { if val < mr.Min { return mr.Min } diff --git a/math32/minmax/range.go b/math32/minmax/range.go index 2d73b938f5..84e7c159e4 100644 --- a/math32/minmax/range.go +++ b/math32/minmax/range.go @@ -6,8 +6,6 @@ package minmax // Range32 represents a range of values for plotting, where the min or max can optionally be fixed type Range32 struct { - - // Min and Max range values F32 // fix the minimum end of the range @@ -18,15 +16,17 @@ type Range32 struct { } // SetMin sets a fixed min value -func (rr *Range32) SetMin(mn float32) { +func (rr *Range32) SetMin(mn float32) *Range32 { rr.FixMin = true rr.Min = mn + return rr } // SetMax sets a fixed max value -func (rr *Range32) SetMax(mx float32) { +func (rr *Range32) SetMax(mx float32) *Range32 { rr.FixMax = true rr.Max = mx + return rr } // Range returns Max - Min @@ -34,13 +34,23 @@ func (rr *Range32) Range() float32 { return rr.Max - rr.Min } +// Clamp returns min, max values clamped according to Fixed min / max of range. +func (rr *Range32) Clamp(mnIn, mxIn float32) (mn, mx float32) { + mn, mx = mnIn, mxIn + if rr.FixMin && rr.Min < mn { + mn = rr.Min + } + if rr.FixMax && rr.Max > mx { + mx = rr.Max + } + return +} + /////////////////////////////////////////////////////////////////////// // 64 // Range64 represents a range of values for plotting, where the min or max can optionally be fixed type Range64 struct { - - // Min and Max range values F64 // fix the minimum end of the range @@ -51,18 +61,32 @@ type Range64 struct { } // SetMin sets a fixed min value -func (rr *Range64) SetMin(mn float64) { +func (rr *Range64) SetMin(mn float64) *Range64 { rr.FixMin = true rr.Min = mn + return rr } // SetMax sets a fixed max value -func (rr *Range64) SetMax(mx float64) { +func (rr *Range64) SetMax(mx float64) *Range64 { rr.FixMax = true rr.Max = mx + return rr } // Range returns Max - Min func (rr *Range64) Range() float64 { return rr.Max - rr.Min } + +// Clamp returns min, max values clamped according to Fixed min / max of range. +func (rr *Range64) Clamp(mnIn, mxIn float64) (mn, mx float64) { + mn, mx = mnIn, mxIn + if rr.FixMin && rr.Min < mn { + mn = rr.Min + } + if rr.FixMax && rr.Max > mx { + mx = rr.Max + } + return +} diff --git a/plot/README.md b/plot/README.md index c5fe23479b..a6bbeff852 100644 --- a/plot/README.md +++ b/plot/README.md @@ -1,8 +1,173 @@ # Plot The `plot` package generates 2D plots of data using the Cogent Core `paint` rendering system. The `plotcore` sub-package has Cogent Core Widgets that can be used in applications. -* `Plot` is just a wrapper around a `plot.Plot`, for manually-configured plots. -* `PlotEditor` is an interactive plot viewer that supports selection of which data to plot, and configuration of many plot parameters. +* `Plot` is just a wrapper around a `plot.Plot`, for code-generated plots. +* `PlotEditor` is an interactive plot viewer that supports selection of which data to plot, and GUI configuration of plot parameters. + +`plot` is designed to work in two potentially-conflicting ways: +* Code-based creation of a specific plot with specific data. +* GUI-based configuration of plots based on a `tensor.Table` of data columns (via `PlotEditor`). + +The GUI constraint requires a more systematic, factorial organization of the space of possible plot data and how it is organized to create a plot, so that it can be configured with a relatively simple set of GUI settings. The overall logic is as follows: + +* The overall plot has a single shared range of X, Y and optionally Z coordinate ranges (under the corresponding `Axis` field), that defines where a data value in any plot type is plotted. These ranges are set based on the DataRanger interface. + +* Plot content is driven by `Plotter` elements that each consume one or more sets of data, which is provided by a `Valuer` interface that maps onto a minimal subset of the `tensor.Tensor` interface, so a tensor directly satisfies the interface. + +* Each `Plotter` element can generally handle multiple different data elements, that are index-aligned. For example, the basic `XY` plotter requires `X` and `Y` Valuers, and optionally `Size` or `Color` Valuers that apply to the Point elements, while `Bar` gets at least a `Y` but also optionally a `High` Valuer for an error bar. The `plot.Data` = `map[Roles]Valuer` is used to create new Plotter elements, allowing an unordered and explicit way of specifying the `Roles` of each `Valuer` item. + +Here is a example for how a plotter element is created with the `plot.Data` map of roles to data: + +```Go +plt := plot.NewPlot() +plt.Add(plots.NewLine(plot.Data{plot.X: xd, plot.Y: yd, plot.Low: low, plot.High: high})) +``` + +The table-driven plotting case uses a `Group` name along with the `Roles` type (`X`, `Y` etc) and Plotter type names to organize different plots based on `Style` settings. Columns with the same Group name all provide data to the same plotter using their different Roles, making it easy to configure various statistical plots of multiple series of grouped data. + +Different plotter types (including custom ones) are registered along with their accepted input roles, to allow any type of plot to be generated. + +# Styling + +`plot.Style` contains the full set of styling parameters, which can be set using Styler functions that are attached to individual plot elements (e.g., lines, points etc) that drive the content of what is actually plotted (based on the `Plotter` interface). + +Each such plot element defines a `Styler` method, e.g.,: + +```Go +plt := plot.NewPlot() +ln := plots.NewLine(data).Styler(func(s *plot.Style) { + s.Plot.Title = "My Plot" // overall Plot styles + s.Line.Color = colors.Uniform(colors.Red) // line-specific styles +}) +plt.Add(ln) +``` + +The `Plot` field (of type `PlotStyle`) contains all the properties that apply to the plot as a whole. Each element can set these values, and they are applied in the order the elements are added, so the last one gets final say. Typically you want to just set these plot-level styles on one element only and avoid any conflicts. + +The rest of the style properties (e.g., `Line`, `Point`) apply to the element in question. There are also some default plot-level settings in `Plot` that apply to all elements, and the plot-level styles are updated first, so in this way it is possible to have plot-wide settings applied from one styler, that affect all plots (e.g., the line width, and whether lines and / or points are plotted or not). + +## Tensor metadata + +Styler functions can be attached directly to a `tensor.Tensor` via its metadata, and the `Plotter` elements will automatically grab these functions from any data source that has such metadata set. This allows the data generator to directly set default styling parameters, which can always be overridden later by adding more styler functions. Tying the plot styling directly to the source data allows all of the relevant logic to be put in one place, instead of spreading this logic across different places in the code. + +Here is an example of how this works: + +```Go + tx, ty := tensor.NewFloat64(21), tensor.NewFloat64(21) + for i := range tx.DimSize(0) { + tx.SetFloat1D(float64(i*5), i) + ty.SetFloat1D(50.0+40*math.Sin((float64(i)/8)*math.Pi), i) + } + // attach stylers to the Y axis data: that is where plotter looks for it + plot.SetStylersTo(ty, plot.Stylers{func(s *plot.Style) { + s.Plot.Title = "Test Line" + s.Plot.XAxis.Label = "X Axis" + s.Plot.YAxisLabel = "Y Axis" + s.Plot.Scale = 2 + s.Plot.XAxis.Range.SetMax(105) + s.Plot.SetLinesOn(plot.On).SetPointsOn(plot.On) + s.Line.Color = colors.Uniform(colors.Red) + s.Point.Color = colors.Uniform(colors.Blue) + s.Range.SetMin(0).SetMax(100) + }}) + + // somewhere else in the code: + + plt := plot.New() + // NewLine automatically gets stylers from ty tensor metadata + plt.Add(plots.NewLine(plot.Data{plot.X: tx, plot.Y: ty})) + plt.Draw() +``` + +# Plot Types + +The following are the builtin standard plot types, in the `plots` package: + +## 1D and 2D XY Data + +### XY + +`XY` is the workhorse standard Plotter, taking at least `X` and `Y` inputs, and plotting lines and / or points at each X, Y point. + +Optionally `Size` and / or `Color` inputs can be provided, which apply to the points. Thus, by using a `Point.Shape` of `Ring` or `Circle`, you can create a bubble plot by providing Size and Color data. + +### Bar + +`Bar` takes `Y` inputs, and draws bars of corresponding height. + +An optional `High` input can be provided to also plot error bars above each bar. + +To create a plot with multiple error bars, multiple Bar Plotters are created, with `Style.Width` parameters that have a shared `Stride = 1 / number of bars` and `Offset` that increments for each bar added. The `plots.NewBars` function handles this directly. + +### ErrorBar + +`XErrorBar` and `YErrorBar` take `X`, `Y`, `Low`, and `High` inputs, and draws an `I` shaped error bar at the X, Y coordinate with the error "handles" around it. + +### Labels + +`Labels` takes `X`, `Y` and `Labels` string inputs and plots labels at the given coordinates. + +### Box + +`Box` takes `X`, `Y` (median line), `U`, `V` (box first and 3rd quartile values), and `Low`, `High` (Min, Max) inputs, and renders a box plot with error bars. + +### XFill, YFill + +`XFill` and `YFill` are used to draw filled regions between pairs of X or Y points, using the `X`, `Y`, and `Low`, `High` values to specify the center point (X, Y) and the region below / left and above / right to fill around that central point. + +XFill along with an XY line can be used to draw the equivalent of the [matplotlib fill_between](https://matplotlib.org/stable/plot_types/basic/fill_between.html#sphx-glr-plot-types-basic-fill-between-py) plot. + +YFill can be used to draw the equivalent of the [matplotlib violin plot](https://matplotlib.org/stable/plot_types/stats/violin.html#sphx-glr-plot-types-stats-violin-py). + +### Pie + +`Pie` takes a list of `Y` values that are plotted as the size of segments of a circular pie plot. Y values are automatically normalized for plotting. + +TODO: implement, details on mapping, + +## 2D Grid-based + +### ColorGrid + +Input = Values and X, Y size + +### Contour + +?? + +### Vector + +X,Y,U,V + +Quiver? + +## 3D + +TODO: use math32 3D projection math and you can just take each 3d point and reduce to 2D. For stuff you want to actually be able to use in SVG, it needs to ultimately be 2D, so it makes sense to support basic versions here, including XYZ (points, lines), Bar3D, wireframe. + +Could also have a separate plot3d package based on `xyz` that is true 3D for interactive 3D plots of surfaces or things that don't make sense in this more limited 2D world. + +# Statistical plots + +The `statplot` package provides functions taking `tensor` data that produce statistical plots of the data, including Quartiles (Box with Median, Quartile, Min, Max), Histogram (Bar), Violin (YFill), Range (XFill), Cluster... + +TODO: add a Data scatter that plots points to overlay on top of Violin or Box. + +## LegendGroups + +* implements current legend grouping logic -- ends up being a multi-table output -- not sure how to interface. + +## Histogram + +## Quartiles + +## Violin + +## Range + +## Cluster + +# History The code is adapted from the [gonum plot](https://github.com/gonum/plot) package (which in turn was adapted from google's [plotinum](https://code.google.com/archive/p/plotinum/), to use the Cogent Core [styles](../styles) and [paint](../paint) rendering framework, which also supports SVG output of the rendering. @@ -13,3 +178,9 @@ Here is the copyright notice for that package: // license that can be found in the LICENSE file. ``` +# TODO + +* points size incorporated into UpdateRange in XY +* tensor index +* Grid? in styling. + diff --git a/plot/axis.go b/plot/axis.go index 91c1d81daa..2e648f0a1d 100644 --- a/plot/axis.go +++ b/plot/axis.go @@ -10,33 +10,39 @@ package plot import ( + "math" + "cogentcore.org/core/math32" + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/styles" "cogentcore.org/core/styles/units" ) -// Normalizer rescales values from the data coordinate system to the -// normalized coordinate system. -type Normalizer interface { - // Normalize transforms a value x in the data coordinate system to - // the normalized coordinate system. - Normalize(min, max, x float32) float32 -} +// AxisScales are the scaling options for how values are distributed +// along an axis: Linear, Log, etc. +type AxisScales int32 //enums:enum -// Axis represents either a horizontal or vertical -// axis of a plot. -type Axis struct { - // Min and Max are the minimum and maximum data - // values represented by the axis. - Min, Max float32 +const ( + // Linear is a linear axis scale. + Linear AxisScales = iota - // specifies which axis this is: X or Y - Axis math32.Dims + // Log is a Logarithmic axis scale. + Log - // Label for the axis - Label Text + // InverseLinear is an inverted linear axis scale. + InverseLinear + + // InverseLog is an inverted log axis scale. + InverseLog +) + +// AxisStyle has style properties for the axis. +type AxisStyle struct { //types:add -setters - // Line styling properties for the axis line. + // Text has the text style parameters for the text label. + Text TextStyle + + // Line has styling properties for the axis line. Line LineStyle // Padding between the axis line and the data. Having @@ -44,14 +50,54 @@ type Axis struct { // on the axis, thus making it easier to see. Padding units.Value - // has the text style for rendering tick labels, and is shared for actual rendering - TickText Text + // NTicks is the desired number of ticks (actual likely will be different). + NTicks int + + // Scale specifies how values are scaled along the axis: + // Linear, Log, Inverted + Scale AxisScales + + // TickText has the text style for rendering tick labels, + // and is shared for actual rendering. + TickText TextStyle - // line style for drawing tick lines + // TickLine has line style for drawing tick lines. TickLine LineStyle - // length of tick lines + // TickLength is the length of tick lines. TickLength units.Value +} + +func (ax *AxisStyle) Defaults() { + ax.Line.Defaults() + ax.Text.Defaults() + ax.Text.Size.Dp(20) + ax.Padding.Pt(5) + ax.NTicks = 5 + ax.TickText.Defaults() + ax.TickText.Size.Dp(16) + ax.TickText.Padding.Dp(2) + ax.TickLine.Defaults() + ax.TickLength.Pt(8) +} + +// Axis represents either a horizontal or vertical +// axis of a plot. +type Axis struct { + // Range has the Min, Max range of values for the axis (in raw data units.) + Range minmax.F64 + + // specifies which axis this is: X, Y or Z. + Axis math32.Dims + + // Label for the axis. + Label Text + + // Style has the style parameters for the Axis. + Style AxisStyle + + // TickText is used for rendering the tick text labels. + TickText Text // Ticker generates the tick marks. Any tick marks // returned by the Marker function that are not in @@ -74,51 +120,49 @@ type Axis struct { // Sets Defaults, range is (∞, ­∞), and thus any finite // value is less than Min and greater than Max. func (ax *Axis) Defaults(dim math32.Dims) { - ax.Min = math32.Inf(+1) - ax.Max = math32.Inf(-1) + ax.Style.Defaults() ax.Axis = dim - ax.Line.Defaults() - ax.Label.Defaults() - ax.Label.Style.Size.Dp(20) - ax.Padding.Pt(5) - ax.TickText.Defaults() - ax.TickText.Style.Size.Dp(16) - ax.TickText.Style.Padding.Dp(2) - ax.TickLine.Defaults() - ax.TickLength.Pt(8) if dim == math32.Y { ax.Label.Style.Rotation = -90 - ax.TickText.Style.Align = styles.End + ax.Style.TickText.Align = styles.End } ax.Scale = LinearScale{} ax.Ticker = DefaultTicks{} } -// SanitizeRange ensures that the range of the axis makes sense. -func (ax *Axis) SanitizeRange() { - if math32.IsInf(ax.Min, 0) { - ax.Min = 0 - } - if math32.IsInf(ax.Max, 0) { - ax.Max = 0 - } - if ax.Min > ax.Max { - ax.Min, ax.Max = ax.Max, ax.Min - } - if ax.Min == ax.Max { - ax.Min-- - ax.Max++ +// drawConfig configures for drawing. +func (ax *Axis) drawConfig() { + switch ax.Style.Scale { + case Linear: + ax.Scale = LinearScale{} + case Log: + ax.Scale = LogScale{} + case InverseLinear: + ax.Scale = InvertedScale{LinearScale{}} + case InverseLog: + ax.Scale = InvertedScale{LogScale{}} } +} +// SanitizeRange ensures that the range of the axis makes sense. +func (ax *Axis) SanitizeRange() { + ax.Range.Sanitize() if ax.AutoRescale { - marks := ax.Ticker.Ticks(ax.Min, ax.Max) + marks := ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks) for _, t := range marks { - ax.Min = math32.Min(ax.Min, t.Value) - ax.Max = math32.Max(ax.Max, t.Value) + ax.Range.FitValInRange(t.Value) } } } +// Normalizer rescales values from the data coordinate system to the +// normalized coordinate system. +type Normalizer interface { + // Normalize transforms a value x in the data coordinate system to + // the normalized coordinate system. + Normalize(min, max, x float64) float64 +} + // LinearScale an be used as the value of an Axis.Scale function to // set the axis to a standard linear scale. type LinearScale struct{} @@ -126,7 +170,7 @@ type LinearScale struct{} var _ Normalizer = LinearScale{} // Normalize returns the fractional distance of x between min and max. -func (LinearScale) Normalize(min, max, x float32) float32 { +func (LinearScale) Normalize(min, max, x float64) float64 { return (x - min) / (max - min) } @@ -138,12 +182,12 @@ var _ Normalizer = LogScale{} // Normalize returns the fractional logarithmic distance of // x between min and max. -func (LogScale) Normalize(min, max, x float32) float32 { +func (LogScale) Normalize(min, max, x float64) float64 { if min <= 0 || max <= 0 || x <= 0 { panic("Values must be greater than 0 for a log scale.") } - logMin := math32.Log(min) - return (math32.Log(x) - logMin) / (math32.Log(max) - logMin) + logMin := math.Log(min) + return (math.Log(x) - logMin) / (math.Log(max) - logMin) } // InvertedScale can be used as the value of an Axis.Scale function to @@ -153,7 +197,7 @@ type InvertedScale struct{ Normalizer } var _ Normalizer = InvertedScale{} // Normalize returns a normalized [0, 1] value for the position of x. -func (is InvertedScale) Normalize(min, max, x float32) float32 { +func (is InvertedScale) Normalize(min, max, x float64) float64 { return is.Normalizer.Normalize(max, min, x) } @@ -161,6 +205,6 @@ func (is InvertedScale) Normalize(min, max, x float32) float32 { // system, normalized to its distance as a fraction of the // range of this axis. For example, if x is a.Min then the return // value is 0, and if x is a.Max then the return value is 1. -func (ax *Axis) Norm(x float32) float32 { - return ax.Scale.Normalize(ax.Min, ax.Max, x) +func (ax *Axis) Norm(x float64) float64 { + return ax.Scale.Normalize(ax.Range.Min, ax.Range.Max, x) } diff --git a/plot/data.go b/plot/data.go index adf8b662ff..4778a03d8f 100644 --- a/plot/data.go +++ b/plot/data.go @@ -10,28 +10,91 @@ package plot import ( - "errors" + "log/slog" + "math" + "strconv" - "cogentcore.org/core/math32" + "cogentcore.org/core/base/errors" + "cogentcore.org/core/math32/minmax" ) -// data defines the main data interfaces for plotting. -// Other more specific types of plots define their own interfaces. -// unlike gonum/plot, NaN values are treated as missing data here. +// data defines the main data interfaces for plotting +// and the different Roles for data. var ( ErrInfinity = errors.New("plotter: infinite data point") ErrNoData = errors.New("plotter: no data points") ) +// Data is a map of Roles and Data for that Role, providing the +// primary way of passing data to a Plotter +type Data map[Roles]Valuer + +// Valuer is the data interface for plotting, supporting either +// float64 or string representations. It is satisfied by the tensor.Tensor +// interface, so a tensor can be used directly for plot Data. +type Valuer interface { + // Len returns the number of values. + Len() int + + // Float1D(i int) returns float64 value at given index. + Float1D(i int) float64 + + // String1D(i int) returns string value at given index. + String1D(i int) string +} + +// Roles are the roles that a given set of data values can play, +// designed to be sufficiently generalizable across all different +// types of plots, even if sometimes it is a bit of a stretch. +type Roles int32 //enums:enum + +const ( + // NoRole is the default no-role specified case. + NoRole Roles = iota + + // X axis + X + + // Y axis + Y + + // Z axis + Z + + // U is the X component of a vector or first quartile in Box plot, etc. + U + + // V is the Y component of a vector or third quartile in a Box plot, etc. + V + + // W is the Z component of a vector + W + + // Low is a lower error bar or region. + Low + + // High is an upper error bar or region. + High + + // Size controls the size of points etc. + Size + + // Color controls the color of points or other elements. + Color + + // Label renders a label, typically from string data, but can also be used for values. + Label +) + // CheckFloats returns an error if any of the arguments are Infinity. // or if there are no non-NaN data points available for plotting. -func CheckFloats(fs ...float32) error { +func CheckFloats(fs ...float64) error { n := 0 for _, f := range fs { switch { - case math32.IsNaN(f): - case math32.IsInf(f, 0): + case math.IsNaN(f): + case math.IsInf(f, 0): return ErrInfinity default: n++ @@ -44,65 +107,78 @@ func CheckFloats(fs ...float32) error { } // CheckNaNs returns true if any of the floats are NaN -func CheckNaNs(fs ...float32) bool { +func CheckNaNs(fs ...float64) bool { for _, f := range fs { - if math32.IsNaN(f) { + if math.IsNaN(f) { return true } } return false } -////////////////////////////////////////////////// -// Valuer - -// Valuer provides an interface for a list of scalar values -type Valuer interface { - // Len returns the number of values. - Len() int +// Range updates given Range with values from data. +func Range(data Valuer, rng *minmax.F64) { + for i := 0; i < data.Len(); i++ { + v := data.Float1D(i) + if math.IsNaN(v) { + continue + } + rng.FitValInRange(v) + } +} - // Value returns a value. - Value(i int) float32 +// RangeClamp updates the given axis Min, Max range values based +// on the range of values in the given [Data], and the given style range. +func RangeClamp(data Valuer, axisRng *minmax.F64, styleRng *minmax.Range64) { + Range(data, axisRng) + axisRng.Min, axisRng.Max = styleRng.Clamp(axisRng.Min, axisRng.Max) } -// Range returns the minimum and maximum values. -func Range(vs Valuer) (min, max float32) { - min = math32.Inf(1) - max = math32.Inf(-1) - for i := 0; i < vs.Len(); i++ { - v := vs.Value(i) - if math32.IsNaN(v) { - continue +// CheckLengths checks that all the data elements have the same length. +// Logs and returns an error if not. +func (dt Data) CheckLengths() error { + n := 0 + for _, v := range dt { + if n == 0 { + n = v.Len() + } else { + if v.Len() != n { + err := errors.New("plot.Data has inconsistent lengths -- all data elements must have the same length -- plotting aborted") + return errors.Log(err) + } } - min = math32.Min(min, v) - max = math32.Max(max, v) } - return + return nil } -// Values implements the Valuer interface. -type Values []float32 +// Values provides a minimal implementation of the Data interface +// using a slice of float64. +type Values []float64 func (vs Values) Len() int { return len(vs) } -func (vs Values) Value(i int) float32 { +func (vs Values) Float1D(i int) float64 { return vs[i] } +func (vs Values) String1D(i int) string { + return strconv.FormatFloat(vs[i], 'g', -1, 64) +} + // CopyValues returns a Values that is a copy of the values -// from a Valuer, or an error if there are no values, or if one of +// from Data, or an error if there are no values, or if one of // the copied values is a Infinity. // NaN values are skipped in the copying process. -func CopyValues(vs Valuer) (Values, error) { - if vs.Len() == 0 { +func CopyValues(data Valuer) (Values, error) { + if data == nil { return nil, ErrNoData } - cpy := make(Values, 0, vs.Len()) - for i := 0; i < vs.Len(); i++ { - v := vs.Value(i) - if math32.IsNaN(v) { + cpy := make(Values, 0, data.Len()) + for i := 0; i < data.Len(); i++ { + v := data.Float1D(i) + if math.IsNaN(v) { continue } if err := CheckFloats(v); err != nil { @@ -113,160 +189,60 @@ func CopyValues(vs Valuer) (Values, error) { return cpy, nil } -////////////////////////////////////////////////// -// XYer - -// XYer provides an interface for a list of X,Y data pairs -type XYer interface { - // Len returns the number of x, y pairs. - Len() int - - // XY returns an x, y pair. - XY(i int) (x, y float32) -} - -// XYRange returns the minimum and maximum -// x and y values. -func XYRange(xys XYer) (xmin, xmax, ymin, ymax float32) { - xmin, xmax = Range(XValues{xys}) - ymin, ymax = Range(YValues{xys}) - return -} - -// XYs implements the XYer interface. -type XYs []math32.Vector2 - -func (xys XYs) Len() int { - return len(xys) -} - -func (xys XYs) XY(i int) (float32, float32) { - return xys[i].X, xys[i].Y -} - -// CopyXYs returns an XYs that is a copy of the x and y values from -// an XYer, or an error if one of the data points contains a NaN or -// Infinity. -func CopyXYs(data XYer) (XYs, error) { - if data.Len() == 0 { - return nil, ErrNoData +// MustCopyRole returns Values copy of given role from given data map, +// logging an error and returning nil if not present. +func MustCopyRole(data Data, role Roles) Values { + d, ok := data[role] + if !ok { + slog.Error("plot Data role not present, but is required", "role:", role) + return nil } - cpy := make(XYs, 0, data.Len()) - for i := range data.Len() { - x, y := data.XY(i) - if CheckNaNs(x, y) { - continue - } - if err := CheckFloats(x, y); err != nil { - return nil, err - } - cpy = append(cpy, math32.Vec2(x, y)) - } - return cpy, nil + return errors.Log1(CopyValues(d)) } -// PlotXYs returns plot coordinates for given set of XYs -func PlotXYs(plt *Plot, data XYs) XYs { - ps := make(XYs, len(data)) - for i := range data { - ps[i].X, ps[i].Y = plt.PX(data[i].X), plt.PY(data[i].Y) +// CopyRole returns Values copy of given role from given data map, +// returning nil if role not present. +func CopyRole(data Data, role Roles) Values { + d, ok := data[role] + if !ok { + return nil } - return ps -} - -// XValues implements the Valuer interface, -// returning the x value from an XYer. -type XValues struct { - XYer -} - -func (xs XValues) Value(i int) float32 { - x, _ := xs.XY(i) - return x -} - -// YValues implements the Valuer interface, -// returning the y value from an XYer. -type YValues struct { - XYer -} - -func (ys YValues) Value(i int) float32 { - _, y := ys.XY(i) - return y + v, _ := CopyValues(d) + return v } -////////////////////////////////////////////////// -// XYer - -// XYZer provides an interface for a list of X,Y,Z data triples. -// It also satisfies the XYer interface for the X,Y pairs. -type XYZer interface { - // Len returns the number of x, y, z triples. - Len() int - - // XYZ returns an x, y, z triple. - XYZ(i int) (float32, float32, float32) - - // XY returns an x, y pair. - XY(i int) (float32, float32) +// PlotX returns plot pixel X coordinate values for given data. +func PlotX(plt *Plot, data Valuer) []float32 { + px := make([]float32, data.Len()) + for i := range px { + px[i] = plt.PX(data.Float1D(i)) + } + return px } -// XYZs implements the XYZer interface using a slice. -type XYZs []XYZ - -// XYZ is an x, y and z value. -type XYZ struct{ X, Y, Z float32 } - -// Len implements the Len method of the XYZer interface. -func (xyz XYZs) Len() int { - return len(xyz) +// PlotY returns plot pixel Y coordinate values for given data. +func PlotY(plt *Plot, data Valuer) []float32 { + py := make([]float32, data.Len()) + for i := range py { + py[i] = plt.PY(data.Float1D(i)) + } + return py } -// XYZ implements the XYZ method of the XYZer interface. -func (xyz XYZs) XYZ(i int) (float32, float32, float32) { - return xyz[i].X, xyz[i].Y, xyz[i].Z -} +//////// Labels -// XY implements the XY method of the XYer interface. -func (xyz XYZs) XY(i int) (float32, float32) { - return xyz[i].X, xyz[i].Y -} +// Labels provides a minimal implementation of the Data interface +// using a slice of string. It always returns 0 for Float1D. +type Labels []string -// CopyXYZs copies an XYZer. -func CopyXYZs(data XYZer) (XYZs, error) { - if data.Len() == 0 { - return nil, ErrNoData - } - cpy := make(XYZs, 0, data.Len()) - for i := range data.Len() { - x, y, z := data.XYZ(i) - if CheckNaNs(x, y, z) { - continue - } - if err := CheckFloats(x, y, z); err != nil { - return nil, err - } - cpy = append(cpy, XYZ{X: x, Y: y, Z: z}) - } - return cpy, nil +func (lb Labels) Len() int { + return len(lb) } -// XYValues implements the XYer interface, returning -// the x and y values from an XYZer. -type XYValues struct{ XYZer } - -// XY implements the XY method of the XYer interface. -func (xy XYValues) XY(i int) (float32, float32) { - x, y, _ := xy.XYZ(i) - return x, y +func (lb Labels) Float1D(i int) float64 { + return 0 } -////////////////////////////////////////////////// -// Labeler - -// Labeler provides an interface for a list of string labels -type Labeler interface { - // Label returns a label. - Label(i int) string +func (lb Labels) String1D(i int) string { + return lb[i] } diff --git a/plot/draw.go b/plot/draw.go index 5d73bb95c6..a572509fca 100644 --- a/plot/draw.go +++ b/plot/draw.go @@ -51,10 +51,13 @@ func (pt *Plot) SVGToFile(filename string) error { return bw.Flush() } -// drawConfig configures everything for drawing +// drawConfig configures everything for drawing, applying styles etc. func (pt *Plot) drawConfig() { + pt.applyStyle() pt.Resize(pt.Size) // ensure - pt.Legend.TextStyle.openFont(pt) + pt.X.drawConfig() + pt.Y.drawConfig() + pt.Z.drawConfig() pt.Paint.ToDots() } @@ -70,8 +73,8 @@ func (pt *Plot) Draw() { ptb := image.Rectangle{Max: pt.Size} pc.PushBounds(ptb) - if pt.Background != nil { - pc.BlitBox(math32.Vector2{}, math32.FromPoint(pt.Size), pt.Background) + if pt.Style.Background != nil { + pc.BlitBox(math32.Vector2{}, math32.FromPoint(pt.Size), pt.Style.Background) } if pt.Title.Text != "" { @@ -131,28 +134,28 @@ func (pt *Plot) Draw() { // drawTicks returns true if the tick marks should be drawn. func (ax *Axis) drawTicks() bool { - return ax.TickLine.Width.Value > 0 && ax.TickLength.Value > 0 + return ax.Style.TickLine.Width.Value > 0 && ax.Style.TickLength.Value > 0 } // sizeX returns the total height of the axis, left and right padding func (ax *Axis) sizeX(pt *Plot, axw float32) (ht, lpad, rpad int) { pc := pt.Paint uc := &pc.UnitContext - ax.TickLength.ToDots(uc) - ax.ticks = ax.Ticker.Ticks(ax.Min, ax.Max) + ax.Style.TickLength.ToDots(uc) + ax.ticks = ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks) h := float32(0) if ax.Label.Text != "" { // We assume that the label isn't rotated. ax.Label.Config(pt) h += ax.Label.PaintText.BBox.Size().Y h += ax.Label.Style.Padding.Dots } - lw := ax.Line.Width.Dots + lw := ax.Style.Line.Width.Dots lpad = int(math32.Ceil(lw)) + 2 rpad = int(math32.Ceil(lw)) + 10 tht := float32(0) if len(ax.ticks) > 0 { if ax.drawTicks() { - h += ax.TickLength.Dots + h += ax.Style.TickLength.Dots } ftk := ax.firstTickLabel() if ftk.Label != "" { @@ -177,7 +180,7 @@ func (ax *Axis) sizeX(pt *Plot, axw float32) (ht, lpad, rpad int) { } h += ax.TickText.Style.Padding.Dots } - h += tht + lw + ax.Padding.Dots + h += tht + lw + ax.Style.Padding.Dots ht = int(math32.Ceil(h)) return @@ -231,8 +234,8 @@ func (ax *Axis) longestTickLabel() string { func (ax *Axis) sizeY(pt *Plot) (ywidth, tickWidth, tpad, bpad int) { pc := pt.Paint uc := &pc.UnitContext - ax.ticks = ax.Ticker.Ticks(ax.Min, ax.Max) - ax.TickLength.ToDots(uc) + ax.ticks = ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks) + ax.Style.TickLength.ToDots(uc) w := float32(0) if ax.Label.Text != "" { @@ -241,13 +244,13 @@ func (ax *Axis) sizeY(pt *Plot) (ywidth, tickWidth, tpad, bpad int) { w += ax.Label.Style.Padding.Dots } - lw := ax.Line.Width.Dots + lw := ax.Style.Line.Width.Dots tpad = int(math32.Ceil(lw)) + 2 bpad = int(math32.Ceil(lw)) + 2 if len(ax.ticks) > 0 { if ax.drawTicks() { - w += ax.TickLength.Dots + w += ax.Style.TickLength.Dots } ax.TickText.Text = ax.longestTickLabel() if ax.TickText.Text != "" { @@ -261,7 +264,7 @@ func (ax *Axis) sizeY(pt *Plot) (ywidth, tickWidth, tpad, bpad int) { bpad += tht } } - w += lw + ax.Padding.Dots + w += lw + ax.Style.Padding.Dots ywidth = int(math32.Ceil(w)) return } @@ -305,7 +308,7 @@ func (ax *Axis) drawX(pt *Plot, lpad, rpad int) { } if len(ax.ticks) > 0 && ax.drawTicks() { - ln := ax.TickLength.Dots + ln := ax.Style.TickLength.Dots for _, t := range ax.ticks { yoff := float32(0) if t.IsMinor() { @@ -316,12 +319,12 @@ func (ax *Axis) drawX(pt *Plot, lpad, rpad int) { continue } x += float32(ab.Min.X) - ax.TickLine.Draw(pt, math32.Vec2(x, float32(ab.Max.Y)-yoff), math32.Vec2(x, float32(ab.Max.Y)-ln)) + ax.Style.TickLine.Draw(pt, math32.Vec2(x, float32(ab.Max.Y)-yoff), math32.Vec2(x, float32(ab.Max.Y)-ln)) } - ab.Max.Y -= int(ln - 0.5*ax.Line.Width.Dots) + ab.Max.Y -= int(ln - 0.5*ax.Style.Line.Width.Dots) } - ax.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y)), math32.Vec2(float32(ab.Min.X)+axw, float32(ab.Max.Y))) + ax.Style.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y)), math32.Vec2(float32(ab.Min.X)+axw, float32(ab.Max.Y))) } // drawY draws the Y axis along the left side @@ -362,7 +365,7 @@ func (ax *Axis) drawY(pt *Plot, tickWidth, tpad, bpad int) { } if len(ax.ticks) > 0 && ax.drawTicks() { - ln := ax.TickLength.Dots + ln := ax.Style.TickLength.Dots for _, t := range ax.ticks { xoff := float32(0) if t.IsMinor() { @@ -373,12 +376,12 @@ func (ax *Axis) drawY(pt *Plot, tickWidth, tpad, bpad int) { continue } y += float32(ab.Min.Y) - ax.TickLine.Draw(pt, math32.Vec2(float32(ab.Min.X)+xoff, y), math32.Vec2(float32(ab.Min.X)+ln, y)) + ax.Style.TickLine.Draw(pt, math32.Vec2(float32(ab.Min.X)+xoff, y), math32.Vec2(float32(ab.Min.X)+ln, y)) } - ab.Min.X += int(ln + 0.5*ax.Line.Width.Dots) + ab.Min.X += int(ln + 0.5*ax.Style.Line.Width.Dots) } - ax.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Min.Y)), math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y))) + ax.Style.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Min.Y)), math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y))) } //////////////////////////////////////////////// @@ -390,17 +393,17 @@ func (lg *Legend) draw(pt *Plot) { uc := &pc.UnitContext ptb := pc.Bounds - lg.ThumbnailWidth.ToDots(uc) - lg.TextStyle.ToDots(uc) - lg.Position.XOffs.ToDots(uc) - lg.Position.YOffs.ToDots(uc) - lg.TextStyle.openFont(pt) - - em := lg.TextStyle.Font.Face.Metrics.Em - pad := math32.Ceil(lg.TextStyle.Padding.Dots) + lg.Style.ThumbnailWidth.ToDots(uc) + lg.Style.Position.XOffs.ToDots(uc) + lg.Style.Position.YOffs.ToDots(uc) var ltxt Text - ltxt.Style = lg.TextStyle + ltxt.Defaults() + ltxt.Style = lg.Style.Text + ltxt.ToDots(uc) + pad := math32.Ceil(ltxt.Style.Padding.Dots) + ltxt.openFont(pt) + em := ltxt.font.Face.Metrics.Em var sz image.Point maxTht := 0 for _, e := range lg.Entries { @@ -413,25 +416,25 @@ func (lg *Legend) draw(pt *Plot) { sz.X += int(em) sz.Y = len(lg.Entries) * maxTht txsz := sz - sz.X += int(lg.ThumbnailWidth.Dots) + sz.X += int(lg.Style.ThumbnailWidth.Dots) pos := ptb.Min - if lg.Position.Left { - pos.X += int(lg.Position.XOffs.Dots) + if lg.Style.Position.Left { + pos.X += int(lg.Style.Position.XOffs.Dots) } else { - pos.X = ptb.Max.X - sz.X - int(lg.Position.XOffs.Dots) + pos.X = ptb.Max.X - sz.X - int(lg.Style.Position.XOffs.Dots) } - if lg.Position.Top { - pos.Y += int(lg.Position.YOffs.Dots) + if lg.Style.Position.Top { + pos.Y += int(lg.Style.Position.YOffs.Dots) } else { - pos.Y = ptb.Max.Y - sz.Y - int(lg.Position.YOffs.Dots) + pos.Y = ptb.Max.Y - sz.Y - int(lg.Style.Position.YOffs.Dots) } - if lg.Fill != nil { - pc.FillBox(math32.FromPoint(pos), math32.FromPoint(sz), lg.Fill) + if lg.Style.Fill != nil { + pc.FillBox(math32.FromPoint(pos), math32.FromPoint(sz), lg.Style.Fill) } cp := pos - thsz := image.Point{X: int(lg.ThumbnailWidth.Dots), Y: maxTht - 2*int(pad)} + thsz := image.Point{X: int(lg.Style.ThumbnailWidth.Dots), Y: maxTht - 2*int(pad)} for _, e := range lg.Entries { tp := cp tp.X += int(txsz.X) diff --git a/plot/enumgen.go b/plot/enumgen.go new file mode 100644 index 0000000000..ac6c7c0087 --- /dev/null +++ b/plot/enumgen.go @@ -0,0 +1,212 @@ +// Code generated by "core generate -add-types"; DO NOT EDIT. + +package plot + +import ( + "cogentcore.org/core/enums" +) + +var _AxisScalesValues = []AxisScales{0, 1, 2, 3} + +// AxisScalesN is the highest valid value for type AxisScales, plus one. +const AxisScalesN AxisScales = 4 + +var _AxisScalesValueMap = map[string]AxisScales{`Linear`: 0, `Log`: 1, `InverseLinear`: 2, `InverseLog`: 3} + +var _AxisScalesDescMap = map[AxisScales]string{0: `Linear is a linear axis scale.`, 1: `Log is a Logarithmic axis scale.`, 2: `InverseLinear is an inverted linear axis scale.`, 3: `InverseLog is an inverted log axis scale.`} + +var _AxisScalesMap = map[AxisScales]string{0: `Linear`, 1: `Log`, 2: `InverseLinear`, 3: `InverseLog`} + +// String returns the string representation of this AxisScales value. +func (i AxisScales) String() string { return enums.String(i, _AxisScalesMap) } + +// SetString sets the AxisScales value from its string representation, +// and returns an error if the string is invalid. +func (i *AxisScales) SetString(s string) error { + return enums.SetString(i, s, _AxisScalesValueMap, "AxisScales") +} + +// Int64 returns the AxisScales value as an int64. +func (i AxisScales) Int64() int64 { return int64(i) } + +// SetInt64 sets the AxisScales value from an int64. +func (i *AxisScales) SetInt64(in int64) { *i = AxisScales(in) } + +// Desc returns the description of the AxisScales value. +func (i AxisScales) Desc() string { return enums.Desc(i, _AxisScalesDescMap) } + +// AxisScalesValues returns all possible values for the type AxisScales. +func AxisScalesValues() []AxisScales { return _AxisScalesValues } + +// Values returns all possible values for the type AxisScales. +func (i AxisScales) Values() []enums.Enum { return enums.Values(_AxisScalesValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i AxisScales) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *AxisScales) UnmarshalText(text []byte) error { + return enums.UnmarshalText(i, text, "AxisScales") +} + +var _RolesValues = []Roles{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} + +// RolesN is the highest valid value for type Roles, plus one. +const RolesN Roles = 12 + +var _RolesValueMap = map[string]Roles{`NoRole`: 0, `X`: 1, `Y`: 2, `Z`: 3, `U`: 4, `V`: 5, `W`: 6, `Low`: 7, `High`: 8, `Size`: 9, `Color`: 10, `Label`: 11} + +var _RolesDescMap = map[Roles]string{0: `NoRole is the default no-role specified case.`, 1: `X axis`, 2: `Y axis`, 3: `Z axis`, 4: `U is the X component of a vector or first quartile in Box plot, etc.`, 5: `V is the Y component of a vector or third quartile in a Box plot, etc.`, 6: `W is the Z component of a vector`, 7: `Low is a lower error bar or region.`, 8: `High is an upper error bar or region.`, 9: `Size controls the size of points etc.`, 10: `Color controls the color of points or other elements.`, 11: `Label renders a label, typically from string data, but can also be used for values.`} + +var _RolesMap = map[Roles]string{0: `NoRole`, 1: `X`, 2: `Y`, 3: `Z`, 4: `U`, 5: `V`, 6: `W`, 7: `Low`, 8: `High`, 9: `Size`, 10: `Color`, 11: `Label`} + +// String returns the string representation of this Roles value. +func (i Roles) String() string { return enums.String(i, _RolesMap) } + +// SetString sets the Roles value from its string representation, +// and returns an error if the string is invalid. +func (i *Roles) SetString(s string) error { return enums.SetString(i, s, _RolesValueMap, "Roles") } + +// Int64 returns the Roles value as an int64. +func (i Roles) Int64() int64 { return int64(i) } + +// SetInt64 sets the Roles value from an int64. +func (i *Roles) SetInt64(in int64) { *i = Roles(in) } + +// Desc returns the description of the Roles value. +func (i Roles) Desc() string { return enums.Desc(i, _RolesDescMap) } + +// RolesValues returns all possible values for the type Roles. +func RolesValues() []Roles { return _RolesValues } + +// Values returns all possible values for the type Roles. +func (i Roles) Values() []enums.Enum { return enums.Values(_RolesValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i Roles) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *Roles) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Roles") } + +var _StepKindValues = []StepKind{0, 1, 2, 3} + +// StepKindN is the highest valid value for type StepKind, plus one. +const StepKindN StepKind = 4 + +var _StepKindValueMap = map[string]StepKind{`NoStep`: 0, `PreStep`: 1, `MidStep`: 2, `PostStep`: 3} + +var _StepKindDescMap = map[StepKind]string{0: `NoStep connects two points by simple line.`, 1: `PreStep connects two points by following lines: vertical, horizontal.`, 2: `MidStep connects two points by following lines: horizontal, vertical, horizontal. Vertical line is placed in the middle of the interval.`, 3: `PostStep connects two points by following lines: horizontal, vertical.`} + +var _StepKindMap = map[StepKind]string{0: `NoStep`, 1: `PreStep`, 2: `MidStep`, 3: `PostStep`} + +// String returns the string representation of this StepKind value. +func (i StepKind) String() string { return enums.String(i, _StepKindMap) } + +// SetString sets the StepKind value from its string representation, +// and returns an error if the string is invalid. +func (i *StepKind) SetString(s string) error { + return enums.SetString(i, s, _StepKindValueMap, "StepKind") +} + +// Int64 returns the StepKind value as an int64. +func (i StepKind) Int64() int64 { return int64(i) } + +// SetInt64 sets the StepKind value from an int64. +func (i *StepKind) SetInt64(in int64) { *i = StepKind(in) } + +// Desc returns the description of the StepKind value. +func (i StepKind) Desc() string { return enums.Desc(i, _StepKindDescMap) } + +// StepKindValues returns all possible values for the type StepKind. +func StepKindValues() []StepKind { return _StepKindValues } + +// Values returns all possible values for the type StepKind. +func (i StepKind) Values() []enums.Enum { return enums.Values(_StepKindValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i StepKind) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *StepKind) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "StepKind") } + +var _ShapesValues = []Shapes{0, 1, 2, 3, 4, 5, 6, 7} + +// ShapesN is the highest valid value for type Shapes, plus one. +const ShapesN Shapes = 8 + +var _ShapesValueMap = map[string]Shapes{`Ring`: 0, `Circle`: 1, `Square`: 2, `Box`: 3, `Triangle`: 4, `Pyramid`: 5, `Plus`: 6, `Cross`: 7} + +var _ShapesDescMap = map[Shapes]string{0: `Ring is the outline of a circle`, 1: `Circle is a solid circle`, 2: `Square is the outline of a square`, 3: `Box is a filled square`, 4: `Triangle is the outline of a triangle`, 5: `Pyramid is a filled triangle`, 6: `Plus is a plus sign`, 7: `Cross is a big X`} + +var _ShapesMap = map[Shapes]string{0: `Ring`, 1: `Circle`, 2: `Square`, 3: `Box`, 4: `Triangle`, 5: `Pyramid`, 6: `Plus`, 7: `Cross`} + +// String returns the string representation of this Shapes value. +func (i Shapes) String() string { return enums.String(i, _ShapesMap) } + +// SetString sets the Shapes value from its string representation, +// and returns an error if the string is invalid. +func (i *Shapes) SetString(s string) error { return enums.SetString(i, s, _ShapesValueMap, "Shapes") } + +// Int64 returns the Shapes value as an int64. +func (i Shapes) Int64() int64 { return int64(i) } + +// SetInt64 sets the Shapes value from an int64. +func (i *Shapes) SetInt64(in int64) { *i = Shapes(in) } + +// Desc returns the description of the Shapes value. +func (i Shapes) Desc() string { return enums.Desc(i, _ShapesDescMap) } + +// ShapesValues returns all possible values for the type Shapes. +func ShapesValues() []Shapes { return _ShapesValues } + +// Values returns all possible values for the type Shapes. +func (i Shapes) Values() []enums.Enum { return enums.Values(_ShapesValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i Shapes) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *Shapes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Shapes") } + +var _DefaultOffOnValues = []DefaultOffOn{0, 1, 2} + +// DefaultOffOnN is the highest valid value for type DefaultOffOn, plus one. +const DefaultOffOnN DefaultOffOn = 3 + +var _DefaultOffOnValueMap = map[string]DefaultOffOn{`Default`: 0, `Off`: 1, `On`: 2} + +var _DefaultOffOnDescMap = map[DefaultOffOn]string{0: `Default means use the default value.`, 1: `Off means to override the default and turn Off.`, 2: `On means to override the default and turn On.`} + +var _DefaultOffOnMap = map[DefaultOffOn]string{0: `Default`, 1: `Off`, 2: `On`} + +// String returns the string representation of this DefaultOffOn value. +func (i DefaultOffOn) String() string { return enums.String(i, _DefaultOffOnMap) } + +// SetString sets the DefaultOffOn value from its string representation, +// and returns an error if the string is invalid. +func (i *DefaultOffOn) SetString(s string) error { + return enums.SetString(i, s, _DefaultOffOnValueMap, "DefaultOffOn") +} + +// Int64 returns the DefaultOffOn value as an int64. +func (i DefaultOffOn) Int64() int64 { return int64(i) } + +// SetInt64 sets the DefaultOffOn value from an int64. +func (i *DefaultOffOn) SetInt64(in int64) { *i = DefaultOffOn(in) } + +// Desc returns the description of the DefaultOffOn value. +func (i DefaultOffOn) Desc() string { return enums.Desc(i, _DefaultOffOnDescMap) } + +// DefaultOffOnValues returns all possible values for the type DefaultOffOn. +func DefaultOffOnValues() []DefaultOffOn { return _DefaultOffOnValues } + +// Values returns all possible values for the type DefaultOffOn. +func (i DefaultOffOn) Values() []enums.Enum { return enums.Values(_DefaultOffOnValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i DefaultOffOn) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *DefaultOffOn) UnmarshalText(text []byte) error { + return enums.UnmarshalText(i, text, "DefaultOffOn") +} diff --git a/plot/labelling.go b/plot/labelling.go index 29b2b1cc64..606cac3a2c 100644 --- a/plot/labelling.go +++ b/plot/labelling.go @@ -14,7 +14,7 @@ package plot -import "cogentcore.org/core/math32" +import "math" const ( // dlamchE is the machine epsilon. For IEEE this is 2^{-53}. @@ -49,7 +49,7 @@ const ( // By default, when nil, legbility will set the legibility score for each candidate // labelling scheme to 1. // See the paper for an explanation of the function of Q, w and legibility. -func talbotLinHanrahan(dMin, dMax float32, want int, containment int, Q []float32, w *weights, legibility func(lMin, lMax, lStep float32) float32) (values []float32, step, q float32, magnitude int) { +func talbotLinHanrahan(dMin, dMax float64, want int, containment int, Q []float64, w *weights, legibility func(lMin, lMax, lStep float64) float64) (values []float64, step, q float64, magnitude int) { const eps = dlamchP * 100 if dMin > dMax { @@ -57,7 +57,7 @@ func talbotLinHanrahan(dMin, dMax float32, want int, containment int, Q []float3 } if Q == nil { - Q = []float32{1, 5, 2, 2.5, 4, 3} + Q = []float64{1, 5, 2, 2.5, 4, 3} } if w == nil { w = &weights{ @@ -72,10 +72,10 @@ func talbotLinHanrahan(dMin, dMax float32, want int, containment int, Q []float3 } if r := dMax - dMin; r < eps { - l := make([]float32, want) - step := r / float32(want-1) + l := make([]float64, want) + step := r / float64(want-1) for i := range l { - l[i] = dMin + float32(i)*step + l[i] = dMin + float64(i)*step } magnitude = minAbsMag(dMin, dMax) return l, step, 0, magnitude @@ -87,9 +87,9 @@ func talbotLinHanrahan(dMin, dMax float32, want int, containment int, Q []float3 // lMin and lMax are the selected min // and max label values. lq is the q // chosen. - lMin, lMax, lStep, lq float32 + lMin, lMax, lStep, lq float64 // score is the score for the selection. - score float32 + score float64 // magnitude is the magnitude of the // label step distance. magnitude int @@ -110,22 +110,22 @@ outer: break } - delta := (dMax - dMin) / float32(have+1) / float32(skip) / q + delta := (dMax - dMin) / float64(have+1) / float64(skip) / q const maxExp = 309 - for mag := int(math32.Ceil(math32.Log10(delta))); mag < maxExp; mag++ { - step := float32(skip) * q * math32.Pow10(mag) + for mag := int(math.Ceil(math.Log10(delta))); mag < maxExp; mag++ { + step := float64(skip) * q * math.Pow10(mag) - cm := maxCoverage(dMin, dMax, step*float32(have-1)) + cm := maxCoverage(dMin, dMax, step*float64(have-1)) if w.score(sm, cm, dm, 1) < best.score { break } - fracStep := step / float32(skip) - kStep := step * float32(have-1) + fracStep := step / float64(skip) + kStep := step * float64(have-1) - minStart := (math32.Floor(dMax/step) - float32(have-1)) * float32(skip) - maxStart := math32.Ceil(dMax/step) * float32(skip) + minStart := (math.Floor(dMax/step) - float64(have-1)) * float64(skip) + maxStart := math.Ceil(dMax/step) * float64(skip) for start := minStart; start <= maxStart && start != start-1; start++ { lMin := start * fracStep lMax := lMin + kStep @@ -154,7 +154,7 @@ outer: n: have, lMin: lMin, lMax: lMax, - lStep: float32(skip) * q, + lStep: float64(skip) * q, lq: q, score: score, magnitude: mag, @@ -167,51 +167,51 @@ outer: } if best.score == -2 { - l := make([]float32, want) - step := (dMax - dMin) / float32(want-1) + l := make([]float64, want) + step := (dMax - dMin) / float64(want-1) for i := range l { - l[i] = dMin + float32(i)*step + l[i] = dMin + float64(i)*step } magnitude = minAbsMag(dMin, dMax) return l, step, 0, magnitude } - l := make([]float32, best.n) - step = best.lStep * math32.Pow10(best.magnitude) + l := make([]float64, best.n) + step = best.lStep * math.Pow10(best.magnitude) for i := range l { - l[i] = best.lMin + float32(i)*step + l[i] = best.lMin + float64(i)*step } return l, best.lStep, best.lq, best.magnitude } // minAbsMag returns the minumum magnitude of the absolute values of a and b. -func minAbsMag(a, b float32) int { - return int(math32.Min(math32.Floor(math32.Log10(math32.Abs(a))), (math32.Floor(math32.Log10(math32.Abs(b)))))) +func minAbsMag(a, b float64) int { + return int(math.Min(math.Floor(math.Log10(math.Abs(a))), (math.Floor(math.Log10(math.Abs(b)))))) } // simplicity returns the simplicity score for how will the curent q, lMin, lMax, // lStep and skip match the given nice numbers, Q. -func simplicity(q float32, Q []float32, skip int, lMin, lMax, lStep float32) float32 { +func simplicity(q float64, Q []float64, skip int, lMin, lMax, lStep float64) float64 { const eps = dlamchP * 100 for i, v := range Q { if v == q { - m := math32.Mod(lMin, lStep) + m := math.Mod(lMin, lStep) v = 0 if (m < eps || lStep-m < eps) && lMin <= 0 && 0 <= lMax { v = 1 } - return 1 - float32(i)/(float32(len(Q))-1) - float32(skip) + v + return 1 - float64(i)/(float64(len(Q))-1) - float64(skip) + v } } panic("labelling: invalid q for Q") } // maxSimplicity returns the maximum simplicity for q, Q and skip. -func maxSimplicity(q float32, Q []float32, skip int) float32 { +func maxSimplicity(q float64, Q []float64, skip int) float64 { for i, v := range Q { if v == q { - return 1 - float32(i)/(float32(len(Q))-1) - float32(skip) + 1 + return 1 - float64(i)/(float64(len(Q))-1) - float64(skip) + 1 } } panic("labelling: invalid q for Q") @@ -220,7 +220,7 @@ func maxSimplicity(q float32, Q []float32, skip int) float32 { // coverage returns the coverage score for based on the average // squared distance between the extreme labels, lMin and lMax, and // the extreme data points, dMin and dMax. -func coverage(dMin, dMax, lMin, lMax float32) float32 { +func coverage(dMin, dMax, lMin, lMax float64) float64 { r := 0.1 * (dMax - dMin) max := dMax - lMax min := dMin - lMin @@ -229,7 +229,7 @@ func coverage(dMin, dMax, lMin, lMax float32) float32 { // maxCoverage returns the maximum coverage achievable for the data // range. -func maxCoverage(dMin, dMax, span float32) float32 { +func maxCoverage(dMin, dMax, span float64) float64 { r := dMax - dMin if span <= r { return 1 @@ -242,9 +242,9 @@ func maxCoverage(dMin, dMax, span float32) float32 { // density returns the density score which measures the goodness of // the labelling density compared to the user defined target // based on the want parameter given to talbotLinHanrahan. -func density(have, want int, dMin, dMax, lMin, lMax float32) float32 { - rho := float32(have-1) / (lMax - lMin) - rhot := float32(want-1) / (math32.Max(lMax, dMax) - math32.Min(dMin, lMin)) +func density(have, want int, dMin, dMax, lMin, lMax float64) float64 { + rho := float64(have-1) / (lMax - lMin) + rhot := float64(want-1) / (math.Max(lMax, dMax) - math.Min(dMin, lMin)) if d := rho / rhot; d >= 1 { return 2 - d } @@ -252,26 +252,26 @@ func density(have, want int, dMin, dMax, lMin, lMax float32) float32 { } // maxDensity returns the maximum density score achievable for have and want. -func maxDensity(have, want int) float32 { +func maxDensity(have, want int) float64 { if have < want { return 1 } - return 2 - float32(have-1)/float32(want-1) + return 2 - float64(have-1)/float64(want-1) } // unitLegibility returns a default legibility score ignoring label // spacing. -func unitLegibility(_, _, _ float32) float32 { +func unitLegibility(_, _, _ float64) float64 { return 1 } // weights is a helper type to calcuate the labelling scheme's total score. type weights struct { - simplicity, coverage, density, legibility float32 + simplicity, coverage, density, legibility float64 } // score returns the score for a labelling scheme with simplicity, s, // coverage, c, density, d and legibility l. -func (w *weights) score(s, c, d, l float32) float32 { +func (w *weights) score(s, c, d, l float64) float64 { return w.simplicity*s + w.coverage*c + w.density*d + w.legibility*l } diff --git a/plot/legend.go b/plot/legend.go index b7b4a2596a..bb0cacc049 100644 --- a/plot/legend.go +++ b/plot/legend.go @@ -12,6 +12,35 @@ import ( "cogentcore.org/core/styles/units" ) +// LegendStyle has the styling properties for the Legend. +type LegendStyle struct { //types:add -setters + + // Column is for table-based plotting, specifying the column with legend values. + Column string + + // Text is the style given to the legend entry texts. + Text TextStyle `display:"add-fields"` + + // position of the legend + Position LegendPosition `display:"inline"` + + // ThumbnailWidth is the width of legend thumbnails. + ThumbnailWidth units.Value `display:"inline"` + + // Fill specifies the background fill color for the legend box, + // if non-nil. + Fill image.Image +} + +func (ls *LegendStyle) Defaults() { + ls.Text.Defaults() + ls.Text.Padding.Dp(2) + ls.Text.Size.Dp(20) + ls.Position.Defaults() + ls.ThumbnailWidth.Pt(20) + ls.Fill = gradient.ApplyOpacity(colors.Scheme.Surface, 0.75) +} + // LegendPosition specifies where to put the legend type LegendPosition struct { // Top and Left specify the location of the legend. @@ -31,30 +60,16 @@ func (lg *LegendPosition) Defaults() { // and a thumbnail, where the thumbnail shows a small // sample of the display style of the corresponding data. type Legend struct { - // TextStyle is the style given to the legend entry texts. - TextStyle TextStyle - - // position of the legend - Position LegendPosition `display:"inline"` - // ThumbnailWidth is the width of legend thumbnails. - ThumbnailWidth units.Value - - // Fill specifies the background fill color for the legend box, - // if non-nil. - Fill image.Image + // Style has the legend styling parameters. + Style LegendStyle // Entries are all of the LegendEntries described by this legend. Entries []LegendEntry } func (lg *Legend) Defaults() { - lg.TextStyle.Defaults() - lg.TextStyle.Padding.Dp(2) - lg.TextStyle.Font.Size.Dp(20) - lg.Position.Defaults() - lg.ThumbnailWidth.Pt(20) - lg.Fill = gradient.ApplyOpacity(colors.Scheme.Surface, 0.75) + lg.Style.Defaults() } // Add adds an entry to the legend with the given name. @@ -78,14 +93,12 @@ func (lg *Legend) LegendForPlotter(plt Plotter) string { return "" } -// Thumbnailer wraps the Thumbnail method, which -// draws the small image in a legend representing the -// style of data. +// Thumbnailer wraps the Thumbnail method, which draws the small +// image in a legend representing the style of data. type Thumbnailer interface { - // Thumbnail draws an thumbnail representing - // a legend entry. The thumbnail will usually show - // a smaller representation of the style used - // to plot the corresponding data. + // Thumbnail draws an thumbnail representing a legend entry. + // The thumbnail will usually show a smaller representation + // of the style used to plot the corresponding data. Thumbnail(pt *Plot) } diff --git a/plot/line.go b/plot/line.go index 2e05073391..fba304e5ea 100644 --- a/plot/line.go +++ b/plot/line.go @@ -12,29 +12,48 @@ import ( "cogentcore.org/core/styles/units" ) -// LineStyle has style properties for line drawing -type LineStyle struct { +// LineStyle has style properties for drawing lines. +type LineStyle struct { //types:add -setters + // On indicates whether to plot lines. + On DefaultOffOn - // stroke color image specification; stroking is off if nil + // Color is the stroke color image specification. + // Setting to nil turns line off. Color image.Image - // line width + // Width is the line width, with a default of 1 Pt (point). + // Setting to 0 turns line off. Width units.Value // Dashes are the dashes of the stroke. Each pair of values specifies // the amount to paint and then the amount to skip. Dashes []float32 + + // Fill is the color to fill solid regions, in a plot-specific + // way (e.g., the area below a Line plot, the bar color). + // Use nil to disable filling. + Fill image.Image + + // NegativeX specifies whether to draw lines that connect points with a negative + // X-axis direction; otherwise there is a break in the line. + // default is false, so that repeated series of data across the X axis + // are plotted separately. + NegativeX bool + + // Step specifies how to step the line between points. + Step StepKind } func (ls *LineStyle) Defaults() { ls.Color = colors.Scheme.OnSurface + ls.Fill = colors.Uniform(colors.Transparent) ls.Width.Pt(1) } // SetStroke sets the stroke style in plot paint to current line style. // returns false if either the Width = 0 or Color is nil func (ls *LineStyle) SetStroke(pt *Plot) bool { - if ls.Color == nil { + if ls.On == Off || ls.Color == nil { return false } pc := pt.Paint @@ -49,6 +68,17 @@ func (ls *LineStyle) SetStroke(pt *Plot) bool { return true } +func (ls *LineStyle) HasFill() bool { + if ls.Fill == nil { + return false + } + clr := colors.ToUniform(ls.Fill) + if clr == colors.Transparent { + return false + } + return true +} + // Draw draws a line between given coordinates, setting the stroke style // to current parameters. Returns false if either Width = 0 or Color = nil func (ls *LineStyle) Draw(pt *Plot, start, end math32.Vector2) bool { @@ -61,3 +91,21 @@ func (ls *LineStyle) Draw(pt *Plot, start, end math32.Vector2) bool { pc.Stroke() return true } + +// StepKind specifies a form of a connection of two consecutive points. +type StepKind int32 //enums:enum + +const ( + // NoStep connects two points by simple line. + NoStep StepKind = iota + + // PreStep connects two points by following lines: vertical, horizontal. + PreStep + + // MidStep connects two points by following lines: horizontal, vertical, horizontal. + // Vertical line is placed in the middle of the interval. + MidStep + + // PostStep connects two points by following lines: horizontal, vertical. + PostStep +) diff --git a/plot/plot.go b/plot/plot.go index cdd40511b1..cec08d9224 100644 --- a/plot/plot.go +++ b/plot/plot.go @@ -17,39 +17,163 @@ import ( "cogentcore.org/core/base/iox/imagex" "cogentcore.org/core/colors" "cogentcore.org/core/math32" + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/paint" "cogentcore.org/core/styles" + "cogentcore.org/core/styles/units" ) +// XAxisStyle has overall plot level styling properties for the XAxis. +type XAxisStyle struct { //types:add -setters + // Column specifies the column to use for the common X axis, + // for [plot.NewTablePlot] [table.Table] driven plots. + // If empty, standard Group-based role binding is used: the last column + // within the same group with Role=X is used. + Column string + + // Rotation is the rotation of the X Axis labels, in degrees. + Rotation float32 + + // Label is the optional label to use for the XAxis instead of the default. + Label string + + // Range is the effective range of XAxis data to plot, where either end can be fixed. + Range minmax.Range64 `display:"inline"` + + // Scale specifies how values are scaled along the X axis: + // Linear, Log, Inverted + Scale AxisScales +} + +// PlotStyle has overall plot level styling properties. +// Some properties provide defaults for individual elements, which can +// then be overwritten by element-level properties. +type PlotStyle struct { //types:add -setters + + // Title is the overall title of the plot. + Title string + + // TitleStyle is the text styling parameters for the title. + TitleStyle TextStyle + + // Background is the background of the plot. + // The default is [colors.Scheme.Surface]. + Background image.Image + + // Scale multiplies the plot DPI value, to change the overall scale + // of the rendered plot. Larger numbers produce larger scaling. + // Typically use larger numbers when generating plots for inclusion in + // documents or other cases where the overall plot size will be small. + Scale float32 `default:"1,2"` + + // Legend has the styling properties for the Legend. + Legend LegendStyle `display:"add-fields"` + + // Axis has the styling properties for the Axes. + Axis AxisStyle `display:"add-fields"` + + // XAxis has plot-level XAxis style properties. + XAxis XAxisStyle `display:"add-fields"` + + // YAxisLabel is the optional label to use for the YAxis instead of the default. + YAxisLabel string + + // LinesOn determines whether lines are plotted by default, + // for elements that plot lines (e.g., plots.XY). + LinesOn DefaultOffOn + + // LineWidth sets the default line width for data plotting lines. + LineWidth units.Value + + // PointsOn determines whether points are plotted by default, + // for elements that plot points (e.g., plots.XY). + PointsOn DefaultOffOn + + // PointSize sets the default point size. + PointSize units.Value + + // LabelSize sets the default label text size. + LabelSize units.Value + + // BarWidth for Bar plot sets the default width of the bars, + // which should be less than the Stride (1 typically) to prevent + // bar overlap. Defaults to .8. + BarWidth float64 +} + +func (ps *PlotStyle) Defaults() { + ps.TitleStyle.Defaults() + ps.TitleStyle.Size.Dp(24) + ps.Background = colors.Scheme.Surface + ps.Scale = 1 + ps.Legend.Defaults() + ps.Axis.Defaults() + ps.LineWidth.Pt(1) + ps.PointSize.Pt(4) + ps.LabelSize.Dp(16) + ps.BarWidth = .8 +} + +// SetElementStyle sets the properties for given element's style +// based on the global default settings in this PlotStyle. +func (ps *PlotStyle) SetElementStyle(es *Style) { + if ps.LinesOn != Default { + es.Line.On = ps.LinesOn + } + if ps.PointsOn != Default { + es.Point.On = ps.PointsOn + } + es.Line.Width = ps.LineWidth + es.Point.Size = ps.PointSize + es.Width.Width = ps.BarWidth + es.Text.Size = ps.LabelSize +} + +// PanZoom provides post-styling pan and zoom range manipulation. +type PanZoom struct { + + // XOffset adds offset to X range (pan). + XOffset float64 + + // XScale multiplies X range (zoom). + XScale float64 + + // YOffset adds offset to Y range (pan). + YOffset float64 + + // YScale multiplies Y range (zoom). + YScale float64 +} + +func (pz *PanZoom) Defaults() { + pz.XScale = 1 + pz.YScale = 1 +} + // Plot is the basic type representing a plot. // It renders into its own image.RGBA Pixels image, // and can also save a corresponding SVG version. -// The Axis ranges are updated automatically when plots -// are added, so setting a fixed range should happen -// after that point. See [UpdateRange] method as well. type Plot struct { // Title of the plot Title Text - // Background is the background of the plot. - // The default is [colors.Scheme.Surface]. - Background image.Image + // Style has the styling properties for the plot. + Style PlotStyle // standard text style with default options StandardTextStyle styles.Text - // X and Y are the horizontal and vertical axes + // X, Y, and Z are the horizontal, vertical, and depth axes // of the plot respectively. - X, Y Axis + X, Y, Z Axis // Legend is the plot's legend. Legend Legend - // plotters are drawn by calling their Plot method - // after the axes are drawn. + // Plotters are drawn by calling their Plot method after the axes are drawn. Plotters []Plotter - // size is the target size of the image to render to + // Size is the target size of the image to render to. Size image.Point // DPI is the dots per inch for rendering the image. @@ -57,57 +181,102 @@ type Plot struct { // which is strongly recommended for print (e.g., use 300 for print) DPI float32 `default:"96,160,300"` - // painter for rendering - Paint *paint.Context + // PanZoom provides post-styling pan and zoom range factors. + PanZoom PanZoom + + // HighlightPlotter is the Plotter to highlight. Used for mouse hovering for example. + // It is the responsibility of the Plotter Plot function to implement highlighting. + HighlightPlotter Plotter + + // HighlightIndex is the index of the data point to highlight, for HighlightPlotter. + HighlightIndex int // pixels that we render into Pixels *image.RGBA `copier:"-" json:"-" xml:"-" edit:"-"` + // Paint is the painter for rendering + Paint *paint.Context + // Current plot bounding box in image coordinates, for plotting coordinates PlotBox math32.Box2 } +// New returns a new plot with some reasonable default settings. +func New() *Plot { + pt := &Plot{} + pt.Defaults() + return pt +} + // Defaults sets defaults func (pt *Plot) Defaults() { + pt.Style.Defaults() pt.Title.Defaults() pt.Title.Style.Size.Dp(24) - pt.Background = colors.Scheme.Surface pt.X.Defaults(math32.X) pt.Y.Defaults(math32.Y) pt.Legend.Defaults() pt.DPI = 96 + pt.PanZoom.Defaults() pt.Size = image.Point{1280, 1024} pt.StandardTextStyle.Defaults() pt.StandardTextStyle.WhiteSpace = styles.WhiteSpaceNowrap } -// New returns a new plot with some reasonable default settings. -func New() *Plot { - pt := &Plot{} - pt.Defaults() - return pt +// applyStyle applies all the style parameters +func (pt *Plot) applyStyle() { + // first update the global plot style settings + var st Style + st.Defaults() + st.Plot = pt.Style + for _, plt := range pt.Plotters { + stlr := plt.Stylers() + stlr.Run(&st) + } + pt.Style = st.Plot + // then apply to elements + for _, plt := range pt.Plotters { + plt.ApplyStyle(&pt.Style) + } + // now style plot: + pt.DPI *= pt.Style.Scale + pt.Title.Style = pt.Style.TitleStyle + if pt.Style.Title != "" { + pt.Title.Text = pt.Style.Title + } + pt.Legend.Style = pt.Style.Legend + pt.X.Style = pt.Style.Axis + pt.X.Style.Scale = pt.Style.XAxis.Scale + pt.Y.Style = pt.Style.Axis + if pt.Style.XAxis.Label != "" { + pt.X.Label.Text = pt.Style.XAxis.Label + } + if pt.Style.YAxisLabel != "" { + pt.Y.Label.Text = pt.Style.YAxisLabel + } + pt.X.Label.Style = pt.Style.Axis.Text + pt.Y.Label.Style = pt.Style.Axis.Text + pt.X.TickText.Style = pt.Style.Axis.TickText + pt.X.TickText.Style.Rotation = pt.Style.XAxis.Rotation + pt.Y.TickText.Style = pt.Style.Axis.TickText + pt.Y.Label.Style.Rotation = -90 + pt.Y.Style.TickText.Align = styles.End + pt.UpdateRange() } -// Add adds a Plotters to the plot. -// -// If the plotters implements DataRanger then the -// minimum and maximum values of the X and Y -// axes are changed if necessary to fit the range of -// the data. -// +// Add adds Plotter element(s) to the plot. // When drawing the plot, Plotters are drawn in the // order in which they were added to the plot. func (pt *Plot) Add(ps ...Plotter) { pt.Plotters = append(pt.Plotters, ps...) } -// SetPixels sets the backing pixels image to given image.RGBA +// SetPixels sets the backing pixels image to given image.RGBA. func (pt *Plot) SetPixels(img *image.RGBA) { pt.Pixels = img pt.Paint = paint.NewContextFromImage(pt.Pixels) pt.Paint.UnitContext.DPI = pt.DPI pt.Size = pt.Pixels.Bounds().Size() - pt.UpdateRange() // needs context, to automatically update for labels } // Resize sets the size of the output image to given size. @@ -136,28 +305,28 @@ func (pt *Plot) SaveImage(filename string) error { // that do not end up in range of the X axis will not have // tick marks. func (pt *Plot) NominalX(names ...string) { - pt.X.TickLine.Width.Pt(0) - pt.X.TickLength.Pt(0) - pt.X.Line.Width.Pt(0) - // pt.Y.Padding.Pt(pt.X.Tick.Label.Width(names[0]) / 2) + pt.X.Style.TickLine.Width.Pt(0) + pt.X.Style.TickLength.Pt(0) + pt.X.Style.Line.Width.Pt(0) + // pt.Y.Padding.Pt(pt.X.Style.Tick.Label.Width(names[0]) / 2) ticks := make([]Tick, len(names)) for i, name := range names { - ticks[i] = Tick{float32(i), name} + ticks[i] = Tick{float64(i), name} } pt.X.Ticker = ConstantTicks(ticks) } // HideX configures the X axis so that it will not be drawn. func (pt *Plot) HideX() { - pt.X.TickLength.Pt(0) - pt.X.Line.Width.Pt(0) + pt.X.Style.TickLength.Pt(0) + pt.X.Style.Line.Width.Pt(0) pt.X.Ticker = ConstantTicks([]Tick{}) } // HideY configures the Y axis so that it will not be drawn. func (pt *Plot) HideY() { - pt.Y.TickLength.Pt(0) - pt.Y.Line.Width.Pt(0) + pt.Y.Style.TickLength.Pt(0) + pt.Y.Style.Line.Width.Pt(0) pt.Y.Ticker = ConstantTicks([]Tick{}) } @@ -169,13 +338,13 @@ func (pt *Plot) HideAxes() { // NominalY is like NominalX, but for the Y axis. func (pt *Plot) NominalY(names ...string) { - pt.Y.TickLine.Width.Pt(0) - pt.Y.TickLength.Pt(0) - pt.Y.Line.Width.Pt(0) + pt.Y.Style.TickLine.Width.Pt(0) + pt.Y.Style.TickLength.Pt(0) + pt.Y.Style.Line.Width.Pt(0) // pt.X.Padding = pt.Y.Tick.Label.Height(names[0]) / 2 ticks := make([]Tick, len(names)) for i, name := range names { - ticks[i] = Tick{float32(i), name} + ticks[i] = Tick{float64(i), name} } pt.Y.Ticker = ConstantTicks(ticks) } @@ -184,55 +353,63 @@ func (pt *Plot) NominalY(names ...string) { // This first resets the range so any fixed additional range values should // be set after this point. func (pt *Plot) UpdateRange() { - pt.X.Min = math32.Inf(+1) - pt.X.Max = math32.Inf(-1) - pt.Y.Min = math32.Inf(+1) - pt.Y.Max = math32.Inf(-1) - for _, d := range pt.Plotters { - pt.UpdateRangeFromPlotter(d) + pt.X.Range.SetInfinity() + pt.Y.Range.SetInfinity() + pt.Z.Range.SetInfinity() + if pt.Style.XAxis.Range.FixMin { + pt.X.Range.Min = pt.Style.XAxis.Range.Min } -} - -func (pt *Plot) UpdateRangeFromPlotter(d Plotter) { - if x, ok := d.(DataRanger); ok { - xmin, xmax, ymin, ymax := x.DataRange(pt) - pt.X.Min = math32.Min(pt.X.Min, xmin) - pt.X.Max = math32.Max(pt.X.Max, xmax) - pt.Y.Min = math32.Min(pt.Y.Min, ymin) - pt.Y.Max = math32.Max(pt.Y.Max, ymax) + if pt.Style.XAxis.Range.FixMax { + pt.X.Range.Max = pt.Style.XAxis.Range.Max + } + for _, pl := range pt.Plotters { + pl.UpdateRange(pt, &pt.X.Range, &pt.Y.Range, &pt.Z.Range) } + pt.X.Range.Sanitize() + pt.Y.Range.Sanitize() + pt.Z.Range.Sanitize() + + pt.X.Range.Min *= pt.PanZoom.XScale + pt.X.Range.Max *= pt.PanZoom.XScale + pt.X.Range.Min += pt.PanZoom.XOffset + pt.X.Range.Max += pt.PanZoom.XOffset + + pt.Y.Range.Min *= pt.PanZoom.YScale + pt.Y.Range.Max *= pt.PanZoom.YScale + pt.Y.Range.Min += pt.PanZoom.YOffset + pt.Y.Range.Max += pt.PanZoom.YOffset } // PX returns the X-axis plotting coordinate for given raw data value // using the current plot bounding region -func (pt *Plot) PX(v float32) float32 { - return pt.PlotBox.ProjectX(pt.X.Norm(v)) +func (pt *Plot) PX(v float64) float32 { + return pt.PlotBox.ProjectX(float32(pt.X.Norm(v))) } // PY returns the Y-axis plotting coordinate for given raw data value -func (pt *Plot) PY(v float32) float32 { - return pt.PlotBox.ProjectY(1 - pt.Y.Norm(v)) +func (pt *Plot) PY(v float64) float32 { + return pt.PlotBox.ProjectY(float32(1 - pt.Y.Norm(v))) } // ClosestDataToPixel returns the Plotter data point closest to given pixel point, // in the Pixels image. -func (pt *Plot) ClosestDataToPixel(px, py int) (plt Plotter, idx int, dist float32, data, pixel math32.Vector2, legend string) { +func (pt *Plot) ClosestDataToPixel(px, py int) (plt Plotter, plotterIndex, pointIndex int, dist float32, pixel math32.Vector2, data Data, legend string) { tp := math32.Vec2(float32(px), float32(py)) dist = float32(math32.MaxFloat32) - for _, p := range pt.Plotters { - dts, pxls := p.XYData() - for i := range pxls.Len() { - ptx, pty := pxls.XY(i) + for pi, pl := range pt.Plotters { + dts, pxX, pxY := pl.Data() + for i, ptx := range pxX { + pty := pxY[i] pxy := math32.Vec2(ptx, pty) d := pxy.DistanceTo(tp) if d < dist { dist = d pixel = pxy - plt = p - idx = i - dx, dy := dts.XY(i) - data = math32.Vec2(dx, dy) - legend = pt.Legend.LegendForPlotter(p) + plt = pl + plotterIndex = pi + pointIndex = i + data = dts + legend = pt.Legend.LegendForPlotter(pl) } } } diff --git a/plot/plot_test.go b/plot/plot_test.go index 8332dee986..d3c0eff9bb 100644 --- a/plot/plot_test.go +++ b/plot/plot_test.go @@ -21,11 +21,9 @@ func TestMain(m *testing.M) { func TestPlot(t *testing.T) { pt := New() pt.Title.Text = "Test Plot" - pt.X.Min = 0 - pt.X.Max = 100 + pt.X.Range.Max = 100 pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 + pt.Y.Range.Max = 100 pt.Y.Label.Text = "Y Axis" pt.Resize(image.Point{640, 480}) diff --git a/plot/plotcore/barplot.go b/plot/plotcore/barplot.go deleted file mode 100644 index de74cd45ff..0000000000 --- a/plot/plotcore/barplot.go +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plotcore - -import ( - "fmt" - "log" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/colors" - "cogentcore.org/core/math32" - "cogentcore.org/core/math32/minmax" - "cogentcore.org/core/plot" - "cogentcore.org/core/plot/plots" - "cogentcore.org/core/tensor/stats/split" - "cogentcore.org/core/tensor/table" -) - -// bar plot is on integer positions, with different Y values and / or -// legend values interleaved - -// genPlotBar generates a Bar plot, setting GPlot variable -func (pl *PlotEditor) genPlotBar() { - plt := plot.New() // note: not clear how to re-use, due to newtablexynames - if pl.Options.BarWidth > 1 { - pl.Options.BarWidth = .8 - } - - // process xaxis first - xi, xview, err := pl.plotXAxis(plt, pl.table) - if err != nil { - return - } - xp := pl.Columns[xi] - - var lsplit *table.Splits - nleg := 1 - if pl.Options.Legend != "" { - _, err = pl.table.Table.ColumnIndex(pl.Options.Legend) - if err != nil { - log.Println("plot.Legend: " + err.Error()) - } else { - xview.SortColumnNames([]string{pl.Options.Legend, xp.Column}, table.Ascending) // make it fit! - lsplit = split.GroupBy(xview, pl.Options.Legend) - nleg = max(lsplit.Len(), 1) - } - } - - var firstXY *tableXY - var strCols []*ColumnOptions - nys := 0 - for _, cp := range pl.Columns { - if !cp.On { - continue - } - if cp.IsString { - strCols = append(strCols, cp) - continue - } - if cp.TensorIndex < 0 { - yc := errors.Log1(pl.table.Table.ColumnByName(cp.Column)) - _, sz := yc.RowCellSize() - nys += sz - } else { - nys++ - } - } - - if nys == 0 { - return - } - - stride := nys * nleg - if stride > 1 { - stride += 1 // extra gap - } - - yoff := 0 - yidx := 0 - maxx := 0 // max number of x values - for _, cp := range pl.Columns { - if !cp.On || cp == xp { - continue - } - if cp.IsString { - continue - } - start := yoff - for li := 0; li < nleg; li++ { - lview := xview - leg := "" - if lsplit != nil && len(lsplit.Values) > li { - leg = lsplit.Values[li][0] - lview = lsplit.Splits[li] - } - nidx := 1 - stidx := cp.TensorIndex - if cp.TensorIndex < 0 { // do all - yc := errors.Log1(pl.table.Table.ColumnByName(cp.Column)) - _, sz := yc.RowCellSize() - nidx = sz - stidx = 0 - } - for ii := 0; ii < nidx; ii++ { - idx := stidx + ii - xy, _ := newTableXYName(lview, xi, xp.TensorIndex, cp.Column, idx, cp.Range) - if xy == nil { - continue - } - maxx = max(maxx, lview.Len()) - if firstXY == nil { - firstXY = xy - } - lbl := cp.getLabel() - clr := cp.Color - if leg != "" { - lbl = leg + " " + lbl - } - if nleg > 1 { - cidx := yidx*nleg + li - clr = colors.Uniform(colors.Spaced(cidx)) - } - if nidx > 1 { - clr = colors.Uniform(colors.Spaced(idx)) - lbl = fmt.Sprintf("%s_%02d", lbl, idx) - } - ec := -1 - if cp.ErrColumn != "" { - ec, _ = pl.table.Table.ColumnIndex(cp.ErrColumn) - } - var bar *plots.BarChart - if ec >= 0 { - exy, _ := newTableXY(lview, ec, 0, ec, 0, minmax.Range32{}) - bar, err = plots.NewBarChart(xy, exy) - if err != nil { - // log.Println(err) - continue - } - } else { - bar, err = plots.NewBarChart(xy, nil) - if err != nil { - // log.Println(err) - continue - } - } - bar.Color = clr - bar.Stride = float32(stride) - bar.Offset = float32(start) - bar.Width = pl.Options.BarWidth - plt.Add(bar) - plt.Legend.Add(lbl, bar) - start++ - } - } - yidx++ - yoff += nleg - } - mid := (stride - 1) / 2 - if stride > 1 { - mid = (stride - 2) / 2 - } - if firstXY != nil && len(strCols) > 0 { - firstXY.table = xview - n := xview.Len() - for _, cp := range strCols { - xy, _ := newTableXY(xview, xi, xp.TensorIndex, firstXY.yColumn, cp.TensorIndex, firstXY.yRange) - xy.labelColumn, _ = xview.Table.ColumnIndex(cp.Column) - xy.yIndex = firstXY.yIndex - - xyl := plots.XYLabels{} - xyl.XYs = make(plot.XYs, n) - xyl.Labels = make([]string, n) - - for i := range xview.Indexes { - y := firstXY.Value(i) - x := float32(mid + (i%maxx)*stride) - xyl.XYs[i] = math32.Vec2(x, y) - xyl.Labels[i] = xy.Label(i) - } - lbls, _ := plots.NewLabels(xyl) - if lbls != nil { - plt.Add(lbls) - } - } - } - - netn := pl.table.Len() * stride - xc := pl.table.Table.Columns[xi] - vals := make([]string, netn) - for i, dx := range pl.table.Indexes { - pi := mid + i*stride - if pi < netn && dx < xc.Len() { - vals[pi] = xc.String1D(dx) - } - } - plt.NominalX(vals...) - - pl.configPlot(plt) - pl.plot = plt -} diff --git a/plot/plotcore/enumgen.go b/plot/plotcore/enumgen.go deleted file mode 100644 index 3ea13f9f1c..0000000000 --- a/plot/plotcore/enumgen.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by "core generate"; DO NOT EDIT. - -package plotcore - -import ( - "cogentcore.org/core/enums" -) - -var _PlotTypesValues = []PlotTypes{0, 1} - -// PlotTypesN is the highest valid value for type PlotTypes, plus one. -const PlotTypesN PlotTypes = 2 - -var _PlotTypesValueMap = map[string]PlotTypes{`XY`: 0, `Bar`: 1} - -var _PlotTypesDescMap = map[PlotTypes]string{0: `XY is a standard line / point plot.`, 1: `Bar plots vertical bars.`} - -var _PlotTypesMap = map[PlotTypes]string{0: `XY`, 1: `Bar`} - -// String returns the string representation of this PlotTypes value. -func (i PlotTypes) String() string { return enums.String(i, _PlotTypesMap) } - -// SetString sets the PlotTypes value from its string representation, -// and returns an error if the string is invalid. -func (i *PlotTypes) SetString(s string) error { - return enums.SetString(i, s, _PlotTypesValueMap, "PlotTypes") -} - -// Int64 returns the PlotTypes value as an int64. -func (i PlotTypes) Int64() int64 { return int64(i) } - -// SetInt64 sets the PlotTypes value from an int64. -func (i *PlotTypes) SetInt64(in int64) { *i = PlotTypes(in) } - -// Desc returns the description of the PlotTypes value. -func (i PlotTypes) Desc() string { return enums.Desc(i, _PlotTypesDescMap) } - -// PlotTypesValues returns all possible values for the type PlotTypes. -func PlotTypesValues() []PlotTypes { return _PlotTypesValues } - -// Values returns all possible values for the type PlotTypes. -func (i PlotTypes) Values() []enums.Enum { return enums.Values(_PlotTypesValues) } - -// MarshalText implements the [encoding.TextMarshaler] interface. -func (i PlotTypes) MarshalText() ([]byte, error) { return []byte(i.String()), nil } - -// UnmarshalText implements the [encoding.TextUnmarshaler] interface. -func (i *PlotTypes) UnmarshalText(text []byte) error { - return enums.UnmarshalText(i, text, "PlotTypes") -} diff --git a/plot/plotcore/options.go b/plot/plotcore/options.go deleted file mode 100644 index 5b91e3d865..0000000000 --- a/plot/plotcore/options.go +++ /dev/null @@ -1,301 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plotcore - -import ( - "image" - "strings" - - "cogentcore.org/core/base/option" - "cogentcore.org/core/base/reflectx" - "cogentcore.org/core/math32/minmax" - "cogentcore.org/core/plot" - "cogentcore.org/core/plot/plots" - "cogentcore.org/core/tensor/table" -) - -// PlotOptions are options for the overall plot. -type PlotOptions struct { //types:add - - // optional title at top of plot - Title string - - // type of plot to generate. For a Bar plot, items are plotted ordinally by row and the XAxis is optional - Type PlotTypes - - // whether to plot lines - Lines bool `default:"true"` - - // whether to plot points with symbols - Points bool - - // width of lines - LineWidth float32 `default:"1"` - - // size of points - PointSize float32 `default:"3"` - - // the shape used to draw points - PointShape plots.Shapes - - // width of bars for bar plot, as fraction of available space (1 = no gaps) - BarWidth float32 `min:"0.01" max:"1" default:"0.8"` - - // if true, draw lines that connect points with a negative X-axis direction; - // otherwise there is a break in the line. - // default is false, so that repeated series of data across the X axis - // are plotted separately. - NegativeXDraw bool - - // Scale multiplies the plot DPI value, to change the overall scale - // of the rendered plot. Larger numbers produce larger scaling. - // Typically use larger numbers when generating plots for inclusion in - // documents or other cases where the overall plot size will be small. - Scale float32 `default:"1,2"` - - // what column to use for the common X axis. if empty or not found, - // the row number is used. This optional for Bar plots, if present and - // Legend is also present, then an extra space will be put between X values. - XAxis string - - // optional column for adding a separate colored / styled line or bar - // according to this value, and acts just like a separate Y variable, - // crossed with Y variables. - Legend string - - // position of the Legend - LegendPosition plot.LegendPosition `display:"inline"` - - // rotation of the X Axis labels, in degrees - XAxisRotation float32 - - // optional label to use for XAxis instead of column name - XAxisLabel string - - // optional label to use for YAxis -- if empty, first column name is used - YAxisLabel string -} - -// defaults sets defaults if unset values are present. -func (po *PlotOptions) defaults() { - if po.LineWidth == 0 { - po.LineWidth = 1 - po.Lines = true - po.Points = false - po.PointSize = 3 - po.BarWidth = .8 - po.LegendPosition.Defaults() - } - if po.Scale == 0 { - po.Scale = 1 - } -} - -// fromMeta sets plot options from meta data. -func (po *PlotOptions) fromMeta(dt *table.Table) { - po.FromMetaMap(dt.MetaData) -} - -// metaMapLower tries meta data access by lower-case version of key too -func metaMapLower(meta map[string]string, key string) (string, bool) { - vl, has := meta[key] - if has { - return vl, has - } - vl, has = meta[strings.ToLower(key)] - return vl, has -} - -// FromMetaMap sets plot options from meta data map. -func (po *PlotOptions) FromMetaMap(meta map[string]string) { - if typ, has := metaMapLower(meta, "Type"); has { - po.Type.SetString(typ) - } - if op, has := metaMapLower(meta, "Lines"); has { - if op == "+" || op == "true" { - po.Lines = true - } else { - po.Lines = false - } - } - if op, has := metaMapLower(meta, "Points"); has { - if op == "+" || op == "true" { - po.Points = true - } else { - po.Points = false - } - } - if lw, has := metaMapLower(meta, "LineWidth"); has { - po.LineWidth, _ = reflectx.ToFloat32(lw) - } - if ps, has := metaMapLower(meta, "PointSize"); has { - po.PointSize, _ = reflectx.ToFloat32(ps) - } - if bw, has := metaMapLower(meta, "BarWidth"); has { - po.BarWidth, _ = reflectx.ToFloat32(bw) - } - if op, has := metaMapLower(meta, "NegativeXDraw"); has { - if op == "+" || op == "true" { - po.NegativeXDraw = true - } else { - po.NegativeXDraw = false - } - } - if scl, has := metaMapLower(meta, "Scale"); has { - po.Scale, _ = reflectx.ToFloat32(scl) - } - if xc, has := metaMapLower(meta, "XAxis"); has { - po.XAxis = xc - } - if lc, has := metaMapLower(meta, "Legend"); has { - po.Legend = lc - } - if xrot, has := metaMapLower(meta, "XAxisRotation"); has { - po.XAxisRotation, _ = reflectx.ToFloat32(xrot) - } - if lb, has := metaMapLower(meta, "XAxisLabel"); has { - po.XAxisLabel = lb - } - if lb, has := metaMapLower(meta, "YAxisLabel"); has { - po.YAxisLabel = lb - } -} - -// ColumnOptions are options for plotting one column of data. -type ColumnOptions struct { //types:add - - // whether to plot this column - On bool - - // name of column being plotting - Column string - - // whether to plot lines; uses the overall plot option if unset - Lines option.Option[bool] - - // whether to plot points with symbols; uses the overall plot option if unset - Points option.Option[bool] - - // the width of lines; uses the overall plot option if unset - LineWidth option.Option[float32] - - // the size of points; uses the overall plot option if unset - PointSize option.Option[float32] - - // the shape used to draw points; uses the overall plot option if unset - PointShape option.Option[plots.Shapes] - - // effective range of data to plot -- either end can be fixed - Range minmax.Range32 `display:"inline"` - - // full actual range of data -- only valid if specifically computed - FullRange minmax.F32 `display:"inline"` - - // color to use when plotting the line / column - Color image.Image - - // desired number of ticks - NTicks int - - // if specified, this is an alternative label to use when plotting - Label string - - // if column has n-dimensional tensor cells in each row, this is the index within each cell to plot -- use -1 to plot *all* indexes as separate lines - TensorIndex int - - // specifies a column containing error bars for this column - ErrColumn string - - // if true this is a string column -- plots as labels - IsString bool `edit:"-"` -} - -// defaults sets defaults if unset values are present. -func (co *ColumnOptions) defaults() { - if co.NTicks == 0 { - co.NTicks = 10 - } -} - -// getLabel returns the effective label of the column. -func (co *ColumnOptions) getLabel() string { - if co.Label != "" { - return co.Label - } - return co.Column -} - -// fromMetaMap sets column options from meta data map. -func (co *ColumnOptions) fromMetaMap(meta map[string]string) { - if op, has := metaMapLower(meta, co.Column+":On"); has { - if op == "+" || op == "true" || op == "" { - co.On = true - } else { - co.On = false - } - } - if op, has := metaMapLower(meta, co.Column+":Off"); has { - if op == "+" || op == "true" || op == "" { - co.On = false - } else { - co.On = true - } - } - if op, has := metaMapLower(meta, co.Column+":FixMin"); has { - if op == "+" || op == "true" { - co.Range.FixMin = true - } else { - co.Range.FixMin = false - } - } - if op, has := metaMapLower(meta, co.Column+":FixMax"); has { - if op == "+" || op == "true" { - co.Range.FixMax = true - } else { - co.Range.FixMax = false - } - } - if op, has := metaMapLower(meta, co.Column+":FloatMin"); has { - if op == "+" || op == "true" { - co.Range.FixMin = false - } else { - co.Range.FixMin = true - } - } - if op, has := metaMapLower(meta, co.Column+":FloatMax"); has { - if op == "+" || op == "true" { - co.Range.FixMax = false - } else { - co.Range.FixMax = true - } - } - if vl, has := metaMapLower(meta, co.Column+":Max"); has { - co.Range.Max, _ = reflectx.ToFloat32(vl) - } - if vl, has := metaMapLower(meta, co.Column+":Min"); has { - co.Range.Min, _ = reflectx.ToFloat32(vl) - } - if lb, has := metaMapLower(meta, co.Column+":Label"); has { - co.Label = lb - } - if lb, has := metaMapLower(meta, co.Column+":ErrColumn"); has { - co.ErrColumn = lb - } - if vl, has := metaMapLower(meta, co.Column+":TensorIndex"); has { - iv, _ := reflectx.ToInt(vl) - co.TensorIndex = int(iv) - } -} - -// PlotTypes are different types of plots. -type PlotTypes int32 //enums:enum - -const ( - // XY is a standard line / point plot. - XY PlotTypes = iota - - // Bar plots vertical bars. - Bar -) diff --git a/plot/plotcore/plot.go b/plot/plotcore/plot.go index 61b3e27fb2..286ae8f3a9 100644 --- a/plot/plotcore/plot.go +++ b/plot/plotcore/plot.go @@ -13,6 +13,7 @@ import ( "cogentcore.org/core/core" "cogentcore.org/core/cursors" "cogentcore.org/core/events" + "cogentcore.org/core/events/key" "cogentcore.org/core/plot" "cogentcore.org/core/styles" "cogentcore.org/core/styles/abilities" @@ -26,12 +27,6 @@ import ( type Plot struct { core.WidgetBase - // Scale multiplies the plot DPI value, to change the overall scale - // of the rendered plot. Larger numbers produce larger scaling. - // Typically use larger numbers when generating plots for inclusion in - // documents or other cases where the overall plot size will be small. - Scale float32 - // Plot is the Plot to display in this widget Plot *plot.Plot `set:"-"` @@ -44,7 +39,7 @@ type Plot struct { // drawn at the current size of this widget func (pt *Plot) SetPlot(pl *plot.Plot) { if pl != nil && pt.Plot != nil && pt.Plot.Pixels != nil { - pl.DPI = pt.Scale * pt.Styles.UnitContext.DPI + pl.DPI = pt.Styles.UnitContext.DPI pl.SetPixels(pt.Plot.Pixels) // re-use the image! } pt.Plot = pl @@ -62,7 +57,7 @@ func (pt *Plot) updatePlot() { if sz == (image.Point{}) { return } - pt.Plot.DPI = pt.Scale * pt.Styles.UnitContext.DPI + pt.Plot.DPI = pt.Styles.UnitContext.DPI pt.Plot.Resize(sz) if pt.SetRangesFunc != nil { pt.SetRangesFunc() @@ -73,7 +68,6 @@ func (pt *Plot) updatePlot() { func (pt *Plot) Init() { pt.WidgetBase.Init() - pt.Scale = 1 pt.Styler(func(s *styles.Style) { s.Min.Set(units.Dp(256)) ro := pt.IsReadOnly() @@ -93,15 +87,18 @@ func (pt *Plot) Init() { if pt.Plot == nil { return } + xf, yf := 1.0, 1.0 + if e.HasAnyModifier(key.Shift) { + yf = 0 + } else if e.HasAnyModifier(key.Alt) { + xf = 0 + } del := e.PrevDelta() - dx := -float32(del.X) * (pt.Plot.X.Max - pt.Plot.X.Min) * 0.0008 - dy := float32(del.Y) * (pt.Plot.Y.Max - pt.Plot.Y.Min) * 0.0008 - pt.Plot.X.Min += dx - pt.Plot.X.Max += dx - pt.Plot.Y.Min += dy - pt.Plot.Y.Max += dy + dx := -float64(del.X) * (pt.Plot.X.Range.Range()) * 0.0008 * xf + dy := float64(del.Y) * (pt.Plot.Y.Range.Range()) * 0.0008 * yf + pt.Plot.PanZoom.XOffset += dx + pt.Plot.PanZoom.YOffset += dy pt.updatePlot() - pt.NeedsRender() }) pt.On(events.Scroll, func(e events.Event) { @@ -110,13 +107,16 @@ func (pt *Plot) Init() { return } se := e.(*events.MouseScroll) - sc := 1 + (float32(se.Delta.Y) * 0.002) - pt.Plot.X.Min *= sc - pt.Plot.X.Max *= sc - pt.Plot.Y.Min *= sc - pt.Plot.Y.Max *= sc + sc := 1 + (float64(se.Delta.Y) * 0.002) + xsc, ysc := sc, sc + if e.HasAnyModifier(key.Shift) { + ysc = 1 + } else if e.HasAnyModifier(key.Alt) { + xsc = 1 + } + pt.Plot.PanZoom.XScale *= xsc + pt.Plot.PanZoom.YScale *= ysc pt.updatePlot() - pt.NeedsRender() }) } @@ -128,9 +128,25 @@ func (pt *Plot) WidgetTooltip(pos image.Point) (string, image.Point) { return pt.Tooltip, pt.DefaultTooltipPos() } wpos := pos.Sub(pt.Geom.ContentBBox.Min) - _, idx, dist, data, _, legend := pt.Plot.ClosestDataToPixel(wpos.X, wpos.Y) + plt, _, idx, dist, _, data, legend := pt.Plot.ClosestDataToPixel(wpos.X, wpos.Y) if dist <= 10 { - return fmt.Sprintf("%s[%d]: (%g, %g)", legend, idx, data.X, data.Y), pos + pt.Plot.HighlightPlotter = plt + pt.Plot.HighlightIndex = idx + pt.updatePlot() + dx := 0.0 + if data[plot.X] != nil { + dx = data[plot.X].Float1D(idx) + } + dy := 0.0 + if data[plot.Y] != nil { + dy = data[plot.Y].Float1D(idx) + } + return fmt.Sprintf("%s[%d]: (%g, %g)", legend, idx, dx, dy), pos + } else { + if pt.Plot.HighlightPlotter != nil { + pt.Plot.HighlightPlotter = nil + pt.updatePlot() + } } return pt.Tooltip, pt.DefaultTooltipPos() } diff --git a/plot/plotcore/ploteditor.go b/plot/plotcore/ploteditor.go index f0916e8952..ba8d7832e8 100644 --- a/plot/plotcore/ploteditor.go +++ b/plot/plotcore/ploteditor.go @@ -8,44 +8,47 @@ package plotcore //go:generate core generate import ( + "fmt" "io/fs" "log/slog" "path/filepath" - "reflect" + "slices" "strings" "time" "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/fsx" "cogentcore.org/core/base/iox/imagex" + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/base/reflectx" "cogentcore.org/core/colors" "cogentcore.org/core/core" "cogentcore.org/core/events" "cogentcore.org/core/icons" - "cogentcore.org/core/math32" "cogentcore.org/core/plot" + "cogentcore.org/core/plot/plots" "cogentcore.org/core/styles" "cogentcore.org/core/styles/states" "cogentcore.org/core/system" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/table" "cogentcore.org/core/tensor/tensorcore" "cogentcore.org/core/tree" + "golang.org/x/exp/maps" ) // PlotEditor is a widget that provides an interactive 2D plot -// of selected columns of tabular data, represented by a [table.IndexView] into +// of selected columns of tabular data, represented by a [table.Table] into // a [table.Table]. Other types of tabular data can be converted into this format. // The user can change various options for the plot and also modify the underlying data. type PlotEditor struct { //types:add core.Frame // table is the table of data being plotted. - table *table.IndexView + table *table.Table - // Options are the overall plot options. - Options PlotOptions - - // Columns are the options for each column of the table. - Columns []*ColumnOptions `set:"-"` + // PlotStyle has the overall plot style parameters. + PlotStyle plot.PlotStyle // plot is the plot object. plot *plot.Plot @@ -59,19 +62,16 @@ type PlotEditor struct { //types:add // currently doing a plot inPlot bool - columnsFrame *core.Frame - plotWidget *Plot + columnsFrame *core.Frame + plotWidget *Plot + plotStyleModified map[string]bool } func (pl *PlotEditor) CopyFieldsFrom(frm tree.Node) { fr := frm.(*PlotEditor) pl.Frame.CopyFieldsFrom(&fr.Frame) - pl.Options = fr.Options - pl.setIndexView(fr.table) - mx := min(len(pl.Columns), len(fr.Columns)) - for i := 0; i < mx; i++ { - *pl.Columns[i] = *fr.Columns[i] - } + pl.PlotStyle = fr.PlotStyle + pl.setTable(fr.table) } // NewSubPlot returns a [PlotEditor] with its own separate [core.Toolbar], @@ -91,7 +91,8 @@ func NewSubPlot(parent ...tree.Node) *PlotEditor { func (pl *PlotEditor) Init() { pl.Frame.Init() - pl.Options.defaults() + pl.PlotStyle.Defaults() + pl.Styler(func(s *styles.Style) { s.Grow.Set(1, 1) if pl.SizeClass() == core.SizeCompact { @@ -104,8 +105,8 @@ func (pl *PlotEditor) Init() { }) pl.Updater(func() { - if pl.table != nil && pl.table.Table != nil { - pl.Options.fromMeta(pl.table.Table) + if pl.table != nil { + pl.plotStyleFromTable(pl.table) } }) tree.AddChildAt(pl, "columns", func(w *core.Frame) { @@ -131,11 +132,11 @@ func (pl *PlotEditor) Init() { }) } -// setIndexView sets the table to view and does Update +// setTable sets the table to view and does Update // to update the Column list, which will also trigger a Layout // and updating of the plot on next render pass. // This is safe to call from a different goroutine. -func (pl *PlotEditor) setIndexView(tab *table.IndexView) *PlotEditor { +func (pl *PlotEditor) setTable(tab *table.Table) *PlotEditor { pl.table = tab pl.Update() return pl @@ -146,55 +147,25 @@ func (pl *PlotEditor) setIndexView(tab *table.IndexView) *PlotEditor { // and updating of the plot on next render pass. // This is safe to call from a different goroutine. func (pl *PlotEditor) SetTable(tab *table.Table) *PlotEditor { - pl.table = table.NewIndexView(tab) + pl.table = table.NewView(tab) pl.Update() return pl } -// SetSlice sets the table to a [table.NewSliceTable] -// from the given slice. -func (pl *PlotEditor) SetSlice(sl any) *PlotEditor { - return pl.SetTable(errors.Log1(table.NewSliceTable(sl))) -} - -// ColumnOptions returns the current column options by name -// (to access by index, just use Columns directly). -func (pl *PlotEditor) ColumnOptions(column string) *ColumnOptions { - for _, co := range pl.Columns { - if co.Column == column { - return co - } - } - return nil -} - -// Bool constants for [PlotEditor.SetColumnOptions]. -const ( - On = true - Off = false - FixMin = true - FloatMin = false - FixMax = true - FloatMax = false -) - -// SetColumnOptions sets the main parameters for one column. -func (pl *PlotEditor) SetColumnOptions(column string, on bool, fixMin bool, min float32, fixMax bool, max float32) *ColumnOptions { - co := pl.ColumnOptions(column) - if co == nil { - slog.Error("plotcore.PlotEditor.SetColumnOptions: column not found", "column", column) +// SetSlice sets the table to a [table.NewSliceTable] from the given slice. +// Optional styler functions are used for each struct field in sequence, +// and any can contain global plot style. +func (pl *PlotEditor) SetSlice(sl any, stylers ...func(s *plot.Style)) *PlotEditor { + dt, err := table.NewSliceTable(sl) + errors.Log(err) + if dt == nil { return nil } - co.On = on - co.Range.FixMin = fixMin - if fixMin { - co.Range.Min = min + mx := min(dt.NumColumns(), len(stylers)) + for i := range mx { + plot.SetStylersTo(dt.Columns.Values[i], plot.Stylers{stylers[i]}) } - co.Range.FixMax = fixMax - if fixMax { - co.Range.Max = max - } - return co + return pl.SetTable(dt) } // SaveSVG saves the plot to an svg -- first updates to ensure that plot is current @@ -213,8 +184,8 @@ func (pl *PlotEditor) SavePNG(fname core.Filename) { //types:add } // SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim) -func (pl *PlotEditor) SaveCSV(fname core.Filename, delim table.Delims) { //types:add - pl.table.SaveCSV(fname, delim, table.Headers) +func (pl *PlotEditor) SaveCSV(fname core.Filename, delim tensor.Delims) { //types:add + pl.table.SaveCSV(fsx.Filename(fname), delim, table.Headers) pl.dataFile = fname } @@ -223,57 +194,29 @@ func (pl *PlotEditor) SaveCSV(fname core.Filename, delim table.Delims) { //types func (pl *PlotEditor) SaveAll(fname core.Filename) { //types:add fn := string(fname) fn = strings.TrimSuffix(fn, filepath.Ext(fn)) - pl.SaveCSV(core.Filename(fn+".tsv"), table.Tab) + pl.SaveCSV(core.Filename(fn+".tsv"), tensor.Tab) pl.SavePNG(core.Filename(fn + ".png")) pl.SaveSVG(core.Filename(fn + ".svg")) } // OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim) -func (pl *PlotEditor) OpenCSV(filename core.Filename, delim table.Delims) { //types:add - pl.table.Table.OpenCSV(filename, delim) +func (pl *PlotEditor) OpenCSV(filename core.Filename, delim tensor.Delims) { //types:add + pl.table.OpenCSV(fsx.Filename(filename), delim) pl.dataFile = filename pl.UpdatePlot() } // OpenFS opens the Table data from a csv (comma-separated values) file (or any delim) // from the given filesystem. -func (pl *PlotEditor) OpenFS(fsys fs.FS, filename core.Filename, delim table.Delims) { - pl.table.Table.OpenFS(fsys, string(filename), delim) +func (pl *PlotEditor) OpenFS(fsys fs.FS, filename core.Filename, delim tensor.Delims) { + pl.table.OpenFS(fsys, string(filename), delim) pl.dataFile = filename pl.UpdatePlot() } -// yLabel returns the Y-axis label -func (pl *PlotEditor) yLabel() string { - if pl.Options.YAxisLabel != "" { - return pl.Options.YAxisLabel - } - for _, cp := range pl.Columns { - if cp.On { - return cp.getLabel() - } - } - return "Y" -} - -// xLabel returns the X-axis label -func (pl *PlotEditor) xLabel() string { - if pl.Options.XAxisLabel != "" { - return pl.Options.XAxisLabel - } - if pl.Options.XAxis != "" { - cp := pl.ColumnOptions(pl.Options.XAxis) - if cp != nil { - return cp.getLabel() - } - return pl.Options.XAxis - } - return "X" -} - -// GoUpdatePlot updates the display based on current IndexView into table. +// GoUpdatePlot updates the display based on current Indexed view into table. // This version can be called from goroutines. It does Sequential() on -// the [table.IndexView], under the assumption that it is used for tracking a +// the [table.Table], under the assumption that it is used for tracking a // the latest updates of a running process. func (pl *PlotEditor) GoUpdatePlot() { if pl == nil || pl.This == nil { @@ -282,7 +225,7 @@ func (pl *PlotEditor) GoUpdatePlot() { if core.TheApp.Platform() == system.Web { time.Sleep(time.Millisecond) // critical to prevent hanging! } - if !pl.IsVisible() || pl.table == nil || pl.table.Table == nil || pl.inPlot { + if !pl.IsVisible() || pl.table == nil || pl.inPlot { return } pl.Scene.AsyncLock() @@ -292,20 +235,20 @@ func (pl *PlotEditor) GoUpdatePlot() { pl.Scene.AsyncUnlock() } -// UpdatePlot updates the display based on current IndexView into table. -// It does not automatically update the [table.IndexView] unless it is +// UpdatePlot updates the display based on current Indexed view into table. +// It does not automatically update the [table.Table] unless it is // nil or out date. func (pl *PlotEditor) UpdatePlot() { if pl == nil || pl.This == nil { return } - if pl.table == nil || pl.table.Table == nil || pl.inPlot { + if pl.table == nil || pl.inPlot { return } - if len(pl.Children) != 2 || len(pl.Columns) != pl.table.Table.NumColumns() { + if len(pl.Children) != 2 { // || len(pl.Columns) != pl.table.NumColumns() { // todo: pl.Update() } - if pl.table.Len() == 0 { + if pl.table.NumRows() == 0 { pl.table.Sequential() } pl.genPlot() @@ -326,154 +269,35 @@ func (pl *PlotEditor) genPlot() { if len(pl.table.Indexes) == 0 { pl.table.Sequential() } else { - lsti := pl.table.Indexes[pl.table.Len()-1] - if lsti >= pl.table.Table.Rows { // out of date + lsti := pl.table.Indexes[pl.table.NumRows()-1] + if lsti >= pl.table.NumRows() { // out of date pl.table.Sequential() } } - pl.plot = nil - switch pl.Options.Type { - case XY: - pl.genPlotXY() - case Bar: - pl.genPlotBar() - } - pl.plotWidget.Scale = pl.Options.Scale - pl.plotWidget.SetRangesFunc = func() { - plt := pl.plotWidget.Plot - xi, err := pl.table.Table.ColumnIndex(pl.Options.XAxis) - if err == nil { - xp := pl.Columns[xi] - if xp.Range.FixMin { - plt.X.Min = math32.Min(plt.X.Min, float32(xp.Range.Min)) - } - if xp.Range.FixMax { - plt.X.Max = math32.Max(plt.X.Max, float32(xp.Range.Max)) - } - } - for _, cp := range pl.Columns { // key that this comes at the end, to actually stick - if !cp.On || cp.IsString { - continue - } - if cp.Range.FixMin { - plt.Y.Min = math32.Min(plt.Y.Min, float32(cp.Range.Min)) - } - if cp.Range.FixMax { - plt.Y.Max = math32.Max(plt.Y.Max, float32(cp.Range.Max)) - } - } + var err error + pl.plot, err = plot.NewTablePlot(pl.table) + if err != nil { + core.ErrorSnackbar(pl, fmt.Errorf("%s: %w", pl.PlotStyle.Title, err)) } pl.plotWidget.SetPlot(pl.plot) // redraws etc pl.inPlot = false } -// configPlot configures the given plot based on the plot options. -func (pl *PlotEditor) configPlot(plt *plot.Plot) { - plt.Title.Text = pl.Options.Title - plt.X.Label.Text = pl.xLabel() - plt.Y.Label.Text = pl.yLabel() - plt.Legend.Position = pl.Options.LegendPosition - plt.X.TickText.Style.Rotation = float32(pl.Options.XAxisRotation) -} - -// plotXAxis processes the XAxis and returns its index -func (pl *PlotEditor) plotXAxis(plt *plot.Plot, ixvw *table.IndexView) (xi int, xview *table.IndexView, err error) { - xi, err = ixvw.Table.ColumnIndex(pl.Options.XAxis) - if err != nil { - // log.Println("plot.PlotXAxis: " + err.Error()) - return - } - xview = ixvw - xc := ixvw.Table.Columns[xi] - xp := pl.Columns[xi] - sz := 1 - if xp.Range.FixMin { - plt.X.Min = math32.Min(plt.X.Min, float32(xp.Range.Min)) - } - if xp.Range.FixMax { - plt.X.Max = math32.Max(plt.X.Max, float32(xp.Range.Max)) - } - if xc.NumDims() > 1 { - sz = xc.Len() / xc.DimSize(0) - if xp.TensorIndex > sz || xp.TensorIndex < 0 { - slog.Error("plotcore.PlotEditor.plotXAxis: TensorIndex invalid -- reset to 0") - xp.TensorIndex = 0 - } - } - return -} - -const plotColumnsHeaderN = 2 - -// columnsListUpdate updates the list of columns -func (pl *PlotEditor) columnsListUpdate() { - if pl.table == nil || pl.table.Table == nil { - pl.Columns = nil - return - } - dt := pl.table.Table - nc := dt.NumColumns() - if nc == len(pl.Columns) { - return - } - pl.Columns = make([]*ColumnOptions, nc) - clri := 0 - hasOn := false - for ci := range dt.NumColumns() { - cn := dt.ColumnName(ci) - if pl.Options.XAxis == "" && ci == 0 { - pl.Options.XAxis = cn // x-axis defaults to the first column - } - cp := &ColumnOptions{Column: cn} - cp.defaults() - tcol := dt.Columns[ci] - if tcol.IsString() { - cp.IsString = true - } else { - cp.IsString = false - // we enable the first non-string, non-x-axis, non-first column by default - if !hasOn && cn != pl.Options.XAxis && ci != 0 { - cp.On = true - hasOn = true - } - } - cp.fromMetaMap(pl.table.Table.MetaData) - inc := 1 - if cn == pl.Options.XAxis || tcol.IsString() || tcol.DataType() == reflect.Int || tcol.DataType() == reflect.Int64 || tcol.DataType() == reflect.Int32 || tcol.DataType() == reflect.Uint8 { - inc = 0 - } - cp.Color = colors.Uniform(colors.Spaced(clri)) - pl.Columns[ci] = cp - clri += inc - } -} - -// ColumnsFromMetaMap updates all the column settings from given meta map -func (pl *PlotEditor) ColumnsFromMetaMap(meta map[string]string) { - for _, cp := range pl.Columns { - cp.fromMetaMap(meta) - } -} +const plotColumnsHeaderN = 3 -// setAllColumns turns all Columns on or off (except X axis) -func (pl *PlotEditor) setAllColumns(on bool) { +// allColumnsOff turns all columns off. +func (pl *PlotEditor) allColumnsOff() { fr := pl.columnsFrame for i, cli := range fr.Children { if i < plotColumnsHeaderN { continue } - ci := i - plotColumnsHeaderN - cp := pl.Columns[ci] - if cp.Column == pl.Options.XAxis { - continue - } - cp.On = on cl := cli.(*core.Frame) sw := cl.Child(0).(*core.Switch) - sw.SetChecked(cp.On) + sw.SetChecked(false) + sw.SendChange() } - pl.UpdatePlot() - pl.NeedsRender() + pl.Update() } // setColumnsByName turns columns on or off if their name contains @@ -484,32 +308,25 @@ func (pl *PlotEditor) setColumnsByName(nameContains string, on bool) { //types:a if i < plotColumnsHeaderN { continue } - ci := i - plotColumnsHeaderN - cp := pl.Columns[ci] - if cp.Column == pl.Options.XAxis { - continue - } - if !strings.Contains(cp.Column, nameContains) { + cl := cli.(*core.Frame) + if !strings.Contains(cl.Name, nameContains) { continue } - cp.On = on - cl := cli.(*core.Frame) sw := cl.Child(0).(*core.Switch) - sw.SetChecked(cp.On) + sw.SetChecked(on) + sw.SendChange() } - pl.UpdatePlot() - pl.NeedsRender() + pl.Update() } // makeColumns makes the Plans for columns func (pl *PlotEditor) makeColumns(p *tree.Plan) { - pl.columnsListUpdate() tree.Add(p, func(w *core.Frame) { tree.AddChild(w, func(w *core.Button) { w.SetText("Clear").SetIcon(icons.ClearAll).SetType(core.ButtonAction) w.SetTooltip("Turn all columns off") w.OnClick(func(e events.Event) { - pl.setAllColumns(false) + pl.allColumnsOff() }) }) tree.AddChild(w, func(w *core.Button) { @@ -521,29 +338,61 @@ func (pl *PlotEditor) makeColumns(p *tree.Plan) { }) }) tree.Add(p, func(w *core.Separator) {}) - for _, cp := range pl.Columns { - tree.AddAt(p, cp.Column, func(w *core.Frame) { + if pl.table == nil { + return + } + colorIdx := 0 // index for color sequence -- skips various types + for ci, cl := range pl.table.Columns.Values { + cnm := pl.table.Columns.Keys[ci] + tree.AddAt(p, cnm, func(w *core.Frame) { + psty := plot.GetStylersFrom(cl) + cst, mods := pl.defaultColumnStyle(cl, ci, &colorIdx, psty) + stys := psty + stys.Add(func(s *plot.Style) { + mf := modFields(mods) + errors.Log(reflectx.CopyFields(s, cst, mf...)) + errors.Log(reflectx.CopyFields(&s.Plot, &pl.PlotStyle, modFields(pl.plotStyleModified)...)) + }) + plot.SetStylersTo(cl, stys) + w.Styler(func(s *styles.Style) { s.CenterAll() }) tree.AddChild(w, func(w *core.Switch) { w.SetType(core.SwitchCheckbox).SetTooltip("Turn this column on or off") + w.Styler(func(s *styles.Style) { + s.Color = cst.Line.Color + }) + tree.AddChildInit(w, "stack", func(w *core.Frame) { + f := func(name string) { + tree.AddChildInit(w, name, func(w *core.Icon) { + w.Styler(func(s *styles.Style) { + s.Color = cst.Line.Color + }) + }) + } + f("icon-on") + f("icon-off") + f("icon-indeterminate") + }) w.OnChange(func(e events.Event) { - cp.On = w.IsChecked() + mods["On"] = true + cst.On = w.IsChecked() pl.UpdatePlot() }) w.Updater(func() { - xaxis := cp.Column == pl.Options.XAxis || cp.Column == pl.Options.Legend + xaxis := cst.Role == plot.X // || cp.Column == pl.Options.Legend w.SetState(xaxis, states.Disabled, states.Indeterminate) if xaxis { - cp.On = false + cst.On = false } else { - w.SetChecked(cp.On) + w.SetChecked(cst.On) } }) }) tree.AddChild(w, func(w *core.Button) { - w.SetText(cp.Column).SetType(core.ButtonAction).SetTooltip("Edit column options including setting it as the x-axis or legend") + tt := "[Edit all styling options for this column] " + metadata.Doc(cl) + w.SetText(cnm).SetType(core.ButtonAction).SetTooltip(tt) w.OnClick(func(e events.Event) { update := func() { if core.TheApp.Platform().IsMobile() { @@ -556,27 +405,28 @@ func (pl *PlotEditor) makeColumns(p *tree.Plan) { pl.Update() pl.AsyncUnlock() } - d := core.NewBody("Column options") - core.NewForm(d).SetStruct(cp). - OnChange(func(e events.Event) { - update() - }) - d.AddTopBar(func(bar *core.Frame) { - core.NewToolbar(bar).Maker(func(p *tree.Plan) { - tree.Add(p, func(w *core.Button) { - w.SetText("Set x-axis").OnClick(func(e events.Event) { - pl.Options.XAxis = cp.Column - update() - }) - }) - tree.Add(p, func(w *core.Button) { - w.SetText("Set legend").OnClick(func(e events.Event) { - pl.Options.Legend = cp.Column - update() - }) - }) - }) + d := core.NewBody(cnm + " style properties") + fm := core.NewForm(d).SetStruct(cst) + fm.Modified = mods + fm.OnChange(func(e events.Event) { + update() }) + // d.AddTopBar(func(bar *core.Frame) { + // core.NewToolbar(bar).Maker(func(p *tree.Plan) { + // tree.Add(p, func(w *core.Button) { + // w.SetText("Set x-axis").OnClick(func(e events.Event) { + // pl.Options.XAxis = cp.Column + // update() + // }) + // }) + // tree.Add(p, func(w *core.Button) { + // w.SetText("Set legend").OnClick(func(e events.Event) { + // pl.Options.Legend = cp.Column + // update() + // }) + // }) + // }) + // }) d.RunWindowDialog(pl) }) }) @@ -584,6 +434,97 @@ func (pl *PlotEditor) makeColumns(p *tree.Plan) { } } +// defaultColumnStyle initializes the column style with any existing stylers +// plus additional general defaults, returning the initially modified field names. +func (pl *PlotEditor) defaultColumnStyle(cl tensor.Values, ci int, colorIdx *int, psty plot.Stylers) (*plot.Style, map[string]bool) { + cst := &plot.Style{} + cst.Defaults() + if psty != nil { + psty.Run(cst) + } + mods := map[string]bool{} + isfloat := reflectx.KindIsFloat(cl.DataType()) + if cst.Plotter == "" { + if isfloat { + cst.Plotter = plot.PlotterName(plots.XYType) + mods["Plotter"] = true + } else if cl.IsString() { + cst.Plotter = plot.PlotterName(plots.LabelsType) + mods["Plotter"] = true + } + } + if cst.Role == plot.NoRole { + mods["Role"] = true + if isfloat { + cst.Role = plot.Y + } else if cl.IsString() { + cst.Role = plot.Label + } else { + cst.Role = plot.X + } + } + if cst.Line.Color == colors.Scheme.OnSurface { + if cst.Role == plot.Y && isfloat { + spclr := colors.Uniform(colors.Spaced(*colorIdx)) + cst.Line.Color = spclr + mods["Line.Color"] = true + cst.Point.Color = spclr + mods["Point.Color"] = true + if cst.Plotter == plots.BarType { + cst.Line.Fill = spclr + mods["Line.Fill"] = true + } + (*colorIdx)++ + } + } + return cst, mods +} + +func (pl *PlotEditor) plotStyleFromTable(dt *table.Table) { + if pl.plotStyleModified != nil { // already set + return + } + pst := &pl.PlotStyle + mods := map[string]bool{} + pl.plotStyleModified = mods + tst := &plot.Style{} + tst.Defaults() + tst.Plot.Defaults() + for _, cl := range pl.table.Columns.Values { + stl := plot.GetStylersFrom(cl) + if stl == nil { + continue + } + stl.Run(tst) + } + *pst = tst.Plot + if pst.PointsOn == plot.Default { + pst.PointsOn = plot.Off + mods["PointsOn"] = true + } + if pst.Title == "" { + pst.Title = metadata.Name(pl.table) + if pst.Title != "" { + mods["Title"] = true + } + } +} + +// modFields returns the modified fields as field paths using . separators +func modFields(mods map[string]bool) []string { + fns := maps.Keys(mods) + rf := make([]string, 0, len(fns)) + for _, f := range fns { + if mods[f] == false { + continue + } + fc := strings.ReplaceAll(f, " • ", ".") + rf = append(rf, fc) + } + slices.Sort(rf) + return rf +} + func (pl *PlotEditor) MakeToolbar(p *tree.Plan) { if pl.table == nil { return @@ -614,14 +555,15 @@ func (pl *PlotEditor) MakeToolbar(p *tree.Plan) { }) }) tree.Add(p, func(w *core.Button) { - w.SetText("Options").SetIcon(icons.Settings). - SetTooltip("Options for how the plot is rendered"). + w.SetText("Style").SetIcon(icons.Settings). + SetTooltip("Style for how the plot is rendered"). OnClick(func(e events.Event) { - d := core.NewBody("Plot options") - core.NewForm(d).SetStruct(&pl.Options). - OnChange(func(e events.Event) { - pl.GoUpdatePlot() - }) + d := core.NewBody("Plot style") + fm := core.NewForm(d).SetStruct(&pl.PlotStyle) + fm.Modified = pl.plotStyleModified + fm.OnChange(func(e events.Event) { + pl.GoUpdatePlot() + }) d.RunWindowDialog(pl) }) }) @@ -630,7 +572,7 @@ func (pl *PlotEditor) MakeToolbar(p *tree.Plan) { SetTooltip("open a Table window of the data"). OnClick(func(e events.Event) { d := core.NewBody(pl.Name + " Data") - tv := tensorcore.NewTable(d).SetTable(pl.table.Table) + tv := tensorcore.NewTable(d).SetTable(pl.table) d.AddTopBar(func(bar *core.Frame) { core.NewToolbar(bar).Maker(tv.MakeToolbar) }) @@ -653,7 +595,7 @@ func (pl *PlotEditor) MakeToolbar(p *tree.Plan) { }) tree.Add(p, func(w *core.Separator) {}) tree.Add(p, func(w *core.FuncButton) { - w.SetFunc(pl.table.FilterColumnName).SetText("Filter").SetIcon(icons.FilterAlt) + w.SetFunc(pl.table.FilterString).SetText("Filter").SetIcon(icons.FilterAlt) w.SetAfterFunc(pl.UpdatePlot) }) tree.Add(p, func(w *core.FuncButton) { diff --git a/plot/plotcore/ploteditor_test.go b/plot/plotcore/ploteditor_test.go index 2ecaa13bd8..7e89f8b0be 100644 --- a/plot/plotcore/ploteditor_test.go +++ b/plot/plotcore/ploteditor_test.go @@ -8,10 +8,13 @@ import ( "testing" "cogentcore.org/core/core" + "cogentcore.org/core/plot" + "cogentcore.org/core/plot/plots" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/table" ) -type Data struct { +type data struct { City string Population float32 Area float32 @@ -20,16 +23,20 @@ type Data struct { func TestTablePlotEditor(t *testing.T) { b := core.NewBody() - epc := table.NewTable("epc") - epc.OpenCSV("testdata/ra25epoch.tsv", table.Tab) + epc := table.New("epc") + epc.OpenCSV("testdata/ra25epoch.tsv", tensor.Tab) pl := NewPlotEditor(b) - pl.Options.Title = "RA25 Epoch Train" - pl.Options.XAxis = "Epoch" - // pl.Options.Scale = 2 - pl.Options.Points = true + pst := func(s *plot.Style) { + s.Plot.Title = "RA25 Epoch Train" + s.Plot.PointsOn = plot.On + } + perr := epc.Column("PctErr") + plot.SetStylersTo(perr, plot.Stylers{pst, func(s *plot.Style) { + s.On = true + s.Role = plot.Y + }}) pl.SetTable(epc) - pl.ColumnOptions("UnitErr").On = true b.AddTopBar(func(bar *core.Frame) { core.NewToolbar(bar).Maker(pl.MakeToolbar) }) @@ -37,18 +44,24 @@ func TestTablePlotEditor(t *testing.T) { } func TestSlicePlotEditor(t *testing.T) { - t.Skip("TODO: this test randomly hangs on CI") - data := []Data{ + dt := []data{ {"Davis", 62000, 500}, {"Boulder", 85000, 800}, } b := core.NewBody() - pl := NewPlotEditor(b) - pl.Options.Title = "Slice Data" - pl.Options.Points = true - pl.SetSlice(data) + pst := func(s *plot.Style) { + s.Plot.Title = "Test Data" + s.Plot.PointsOn = plot.On + } + onst := func(s *plot.Style) { + pst(s) + s.Plotter = plots.BarType + s.On = true + s.Role = plot.Y + } + pl.SetSlice(dt, pst, onst) b.AddTopBar(func(bar *core.Frame) { core.NewToolbar(bar).Maker(pl.MakeToolbar) }) diff --git a/plot/plotcore/plotterchooser.go b/plot/plotcore/plotterchooser.go new file mode 100644 index 0000000000..bbc121928a --- /dev/null +++ b/plot/plotcore/plotterchooser.go @@ -0,0 +1,31 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package plotcore + +import ( + "slices" + + "cogentcore.org/core/core" + "cogentcore.org/core/plot" + _ "cogentcore.org/core/plot/plots" + "golang.org/x/exp/maps" +) + +func init() { + core.AddValueType[plot.PlotterName, PlotterChooser]() +} + +// PlotterChooser represents a [Plottername] value with a [core.Chooser] +// for selecting a plotter. +type PlotterChooser struct { + core.Chooser +} + +func (fc *PlotterChooser) Init() { + fc.Chooser.Init() + pnms := maps.Keys(plot.Plotters) + slices.Sort(pnms) + fc.SetStrings(pnms...) +} diff --git a/plot/plotcore/tablexy.go b/plot/plotcore/tablexy.go deleted file mode 100644 index 93d22fdc3e..0000000000 --- a/plot/plotcore/tablexy.go +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plotcore - -import ( - "cogentcore.org/core/base/errors" - "cogentcore.org/core/math32" - "cogentcore.org/core/math32/minmax" - "cogentcore.org/core/plot" - "cogentcore.org/core/plot/plots" - "cogentcore.org/core/tensor/table" -) - -// tableXY selects two columns from a [table.Table] data table to plot in a [plot.Plot], -// satisfying the [plot.XYer], [plot.Valuer], [plot.Labeler], and [plots.YErrorer] interfaces. -// For Tensor-valued cells, Index's specify tensor cell. -// Also satisfies the plot/plots.Labeler interface for labels attached to a line, and -// plot/plots.YErrorer for error bars. -type tableXY struct { - - // the index view of data table to plot from - table *table.IndexView - - // the indexes of the tensor columns to use for the X and Y data, respectively - xColumn, yColumn int - - // numer of elements in each row of data -- 1 for scalar, > 1 for multi-dimensional - xRowSize, yRowSize int - - // the indexes of the element within each tensor cell if cells are n-dimensional, respectively - xIndex, yIndex int - - // the column to use for returning a label using Label interface -- for string cols - labelColumn int - - // the column to use for returning errorbars (+/- given value) -- if YColumn is tensor then this must also be a tensor and given YIndex used - errColumn int - - // range constraints on Y values - yRange minmax.Range32 -} - -var _ plot.XYer = &tableXY{} -var _ plot.Valuer = &tableXY{} -var _ plot.Labeler = &tableXY{} -var _ plots.YErrorer = &tableXY{} - -// newTableXY returns a new XY plot view onto the given IndexView of table.Table (makes a copy), -// from given column indexes, and tensor indexes within each cell. -// Column indexes are enforced to be valid, with an error message if they are not. -func newTableXY(dt *table.IndexView, xcol, xtsrIndex, ycol, ytsrIndex int, yrng minmax.Range32) (*tableXY, error) { - txy := &tableXY{table: dt.Clone(), xColumn: xcol, yColumn: ycol, xIndex: xtsrIndex, yIndex: ytsrIndex, yRange: yrng} - return txy, txy.validate() -} - -// newTableXYName returns a new XY plot view onto the given IndexView of table.Table (makes a copy), -// from given column name and tensor indexes within each cell. -// Column indexes are enforced to be valid, with an error message if they are not. -func newTableXYName(dt *table.IndexView, xi, xtsrIndex int, ycol string, ytsrIndex int, yrng minmax.Range32) (*tableXY, error) { - yi, err := dt.Table.ColumnIndex(ycol) - if errors.Log(err) != nil { - return nil, err - } - txy := &tableXY{table: dt.Clone(), xColumn: xi, yColumn: yi, xIndex: xtsrIndex, yIndex: ytsrIndex, yRange: yrng} - return txy, txy.validate() -} - -// validate returns error message if column indexes are invalid, else nil -// it also sets column indexes to 0 so nothing crashes. -func (txy *tableXY) validate() error { - if txy.table == nil { - return errors.New("eplot.TableXY table is nil") - } - nc := txy.table.Table.NumColumns() - if txy.xColumn >= nc || txy.xColumn < 0 { - txy.xColumn = 0 - return errors.New("eplot.TableXY XColumn index invalid -- reset to 0") - } - if txy.yColumn >= nc || txy.yColumn < 0 { - txy.yColumn = 0 - return errors.New("eplot.TableXY YColumn index invalid -- reset to 0") - } - xc := txy.table.Table.Columns[txy.xColumn] - yc := txy.table.Table.Columns[txy.yColumn] - if xc.NumDims() > 1 { - _, txy.xRowSize = xc.RowCellSize() - // note: index already validated - } - if yc.NumDims() > 1 { - _, txy.yRowSize = yc.RowCellSize() - if txy.yIndex >= txy.yRowSize || txy.yIndex < 0 { - txy.yIndex = 0 - return errors.New("eplot.TableXY Y TensorIndex invalid -- reset to 0") - } - } - txy.filterValues() - return nil -} - -// filterValues removes items with NaN values, and out of Y range -func (txy *tableXY) filterValues() { - txy.table.Filter(func(et *table.Table, row int) bool { - xv := txy.tRowXValue(row) - yv := txy.tRowValue(row) - if math32.IsNaN(yv) || math32.IsNaN(xv) { - return false - } - if txy.yRange.FixMin && yv < txy.yRange.Min { - return false - } - if txy.yRange.FixMax && yv > txy.yRange.Max { - return false - } - return true - }) -} - -// Len returns the number of rows in the view of table -func (txy *tableXY) Len() int { - if txy.table == nil || txy.table.Table == nil { - return 0 - } - return txy.table.Len() -} - -// tRowValue returns the y value at given true table row in table -func (txy *tableXY) tRowValue(row int) float32 { - yc := txy.table.Table.Columns[txy.yColumn] - y := float32(0.0) - switch { - case yc.IsString(): - y = float32(row) - case yc.NumDims() > 1: - _, sz := yc.RowCellSize() - if txy.yIndex < sz && txy.yIndex >= 0 { - y = float32(yc.FloatRowCell(row, txy.yIndex)) - } - default: - y = float32(yc.Float1D(row)) - } - return y -} - -// Value returns the y value at given row in table -func (txy *tableXY) Value(row int) float32 { - if txy.table == nil || txy.table.Table == nil || row >= txy.table.Len() { - return 0 - } - trow := txy.table.Indexes[row] // true table row - yc := txy.table.Table.Columns[txy.yColumn] - y := float32(0.0) - switch { - case yc.IsString(): - y = float32(row) - case yc.NumDims() > 1: - _, sz := yc.RowCellSize() - if txy.yIndex < sz && txy.yIndex >= 0 { - y = float32(yc.FloatRowCell(trow, txy.yIndex)) - } - default: - y = float32(yc.Float1D(trow)) - } - return y -} - -// tRowXValue returns an x value at given actual row in table -func (txy *tableXY) tRowXValue(row int) float32 { - if txy.table == nil || txy.table.Table == nil { - return 0 - } - xc := txy.table.Table.Columns[txy.xColumn] - x := float32(0.0) - switch { - case xc.IsString(): - x = float32(row) - case xc.NumDims() > 1: - _, sz := xc.RowCellSize() - if txy.xIndex < sz && txy.xIndex >= 0 { - x = float32(xc.FloatRowCell(row, txy.xIndex)) - } - default: - x = float32(xc.Float1D(row)) - } - return x -} - -// xValue returns an x value at given row in table -func (txy *tableXY) xValue(row int) float32 { - if txy.table == nil || txy.table.Table == nil || row >= txy.table.Len() { - return 0 - } - trow := txy.table.Indexes[row] // true table row - xc := txy.table.Table.Columns[txy.xColumn] - x := float32(0.0) - switch { - case xc.IsString(): - x = float32(row) - case xc.NumDims() > 1: - _, sz := xc.RowCellSize() - if txy.xIndex < sz && txy.xIndex >= 0 { - x = float32(xc.FloatRowCell(trow, txy.xIndex)) - } - default: - x = float32(xc.Float1D(trow)) - } - return x -} - -// XY returns an x, y pair at given row in table -func (txy *tableXY) XY(row int) (x, y float32) { - if txy.table == nil || txy.table.Table == nil { - return 0, 0 - } - x = txy.xValue(row) - y = txy.Value(row) - return -} - -// Label returns a label for given row in table, implementing [plot.Labeler] interface -func (txy *tableXY) Label(row int) string { - if txy.table == nil || txy.table.Table == nil || row >= txy.table.Len() { - return "" - } - trow := txy.table.Indexes[row] // true table row - return txy.table.Table.Columns[txy.labelColumn].String1D(trow) -} - -// YError returns error bars, implementing [plots.YErrorer] interface. -func (txy *tableXY) YError(row int) (float32, float32) { - if txy.table == nil || txy.table.Table == nil || row >= txy.table.Len() { - return 0, 0 - } - trow := txy.table.Indexes[row] // true table row - ec := txy.table.Table.Columns[txy.errColumn] - eval := float32(0.0) - switch { - case ec.IsString(): - eval = float32(row) - case ec.NumDims() > 1: - _, sz := ec.RowCellSize() - if txy.yIndex < sz && txy.yIndex >= 0 { - eval = float32(ec.FloatRowCell(trow, txy.yIndex)) - } - default: - eval = float32(ec.Float1D(trow)) - } - return -eval, eval -} diff --git a/plot/plotcore/typegen.go b/plot/plotcore/typegen.go index 9a925c9750..ae170910ed 100644 --- a/plot/plotcore/typegen.go +++ b/plot/plotcore/typegen.go @@ -3,15 +3,12 @@ package plotcore import ( + "cogentcore.org/core/plot" "cogentcore.org/core/tree" "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.PlotOptions", IDName: "plot-options", Doc: "PlotOptions are options for the overall plot.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Title", Doc: "optional title at top of plot"}, {Name: "Type", Doc: "type of plot to generate. For a Bar plot, items are plotted ordinally by row and the XAxis is optional"}, {Name: "Lines", Doc: "whether to plot lines"}, {Name: "Points", Doc: "whether to plot points with symbols"}, {Name: "LineWidth", Doc: "width of lines"}, {Name: "PointSize", Doc: "size of points"}, {Name: "PointShape", Doc: "the shape used to draw points"}, {Name: "BarWidth", Doc: "width of bars for bar plot, as fraction of available space (1 = no gaps)"}, {Name: "NegativeXDraw", Doc: "if true, draw lines that connect points with a negative X-axis direction;\notherwise there is a break in the line.\ndefault is false, so that repeated series of data across the X axis\nare plotted separately."}, {Name: "Scale", Doc: "Scale multiplies the plot DPI value, to change the overall scale\nof the rendered plot. Larger numbers produce larger scaling.\nTypically use larger numbers when generating plots for inclusion in\ndocuments or other cases where the overall plot size will be small."}, {Name: "XAxis", Doc: "what column to use for the common X axis. if empty or not found,\nthe row number is used. This optional for Bar plots, if present and\nLegend is also present, then an extra space will be put between X values."}, {Name: "Legend", Doc: "optional column for adding a separate colored / styled line or bar\naccording to this value, and acts just like a separate Y variable,\ncrossed with Y variables."}, {Name: "LegendPosition", Doc: "position of the Legend"}, {Name: "XAxisRotation", Doc: "rotation of the X Axis labels, in degrees"}, {Name: "XAxisLabel", Doc: "optional label to use for XAxis instead of column name"}, {Name: "YAxisLabel", Doc: "optional label to use for YAxis -- if empty, first column name is used"}}}) - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.ColumnOptions", IDName: "column-options", Doc: "ColumnOptions are options for plotting one column of data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "On", Doc: "whether to plot this column"}, {Name: "Column", Doc: "name of column being plotting"}, {Name: "Lines", Doc: "whether to plot lines; uses the overall plot option if unset"}, {Name: "Points", Doc: "whether to plot points with symbols; uses the overall plot option if unset"}, {Name: "LineWidth", Doc: "the width of lines; uses the overall plot option if unset"}, {Name: "PointSize", Doc: "the size of points; uses the overall plot option if unset"}, {Name: "PointShape", Doc: "the shape used to draw points; uses the overall plot option if unset"}, {Name: "Range", Doc: "effective range of data to plot -- either end can be fixed"}, {Name: "FullRange", Doc: "full actual range of data -- only valid if specifically computed"}, {Name: "Color", Doc: "color to use when plotting the line / column"}, {Name: "NTicks", Doc: "desired number of ticks"}, {Name: "Label", Doc: "if specified, this is an alternative label to use when plotting"}, {Name: "TensorIndex", Doc: "if column has n-dimensional tensor cells in each row, this is the index within each cell to plot -- use -1 to plot *all* indexes as separate lines"}, {Name: "ErrColumn", Doc: "specifies a column containing error bars for this column"}, {Name: "IsString", Doc: "if true this is a string column -- plots as labels"}}}) - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.Plot", IDName: "plot", Doc: "Plot is a widget that renders a [plot.Plot] object.\nIf it is not [states.ReadOnly], the user can pan and zoom the graph.\nSee [PlotEditor] for an interactive interface for selecting columns to view.", Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Scale", Doc: "Scale multiplies the plot DPI value, to change the overall scale\nof the rendered plot. Larger numbers produce larger scaling.\nTypically use larger numbers when generating plots for inclusion in\ndocuments or other cases where the overall plot size will be small."}, {Name: "Plot", Doc: "Plot is the Plot to display in this widget"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.Plot", IDName: "plot", Doc: "Plot is a widget that renders a [plot.Plot] object.\nIf it is not [states.ReadOnly], the user can pan and zoom the graph.\nSee [PlotEditor] for an interactive interface for selecting columns to view.", Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Plot", Doc: "Plot is the Plot to display in this widget"}, {Name: "SetRangesFunc", Doc: "SetRangesFunc, if set, is called to adjust the data ranges\nafter the point when these ranges are updated based on the plot data."}}}) // NewPlot returns a new [Plot] with the given optional parent: // Plot is a widget that renders a [plot.Plot] object. @@ -19,22 +16,29 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.Plot" // See [PlotEditor] for an interactive interface for selecting columns to view. func NewPlot(parent ...tree.Node) *Plot { return tree.New[Plot](parent...) } -// SetScale sets the [Plot.Scale]: -// Scale multiplies the plot DPI value, to change the overall scale -// of the rendered plot. Larger numbers produce larger scaling. -// Typically use larger numbers when generating plots for inclusion in -// documents or other cases where the overall plot size will be small. -func (t *Plot) SetScale(v float32) *Plot { t.Scale = v; return t } +// SetSetRangesFunc sets the [Plot.SetRangesFunc]: +// SetRangesFunc, if set, is called to adjust the data ranges +// after the point when these ranges are updated based on the plot data. +func (t *Plot) SetSetRangesFunc(v func()) *Plot { t.SetRangesFunc = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.PlotEditor", IDName: "plot-editor", Doc: "PlotEditor is a widget that provides an interactive 2D plot\nof selected columns of tabular data, represented by a [table.IndexView] into\na [table.Table]. Other types of tabular data can be converted into this format.\nThe user can change various options for the plot and also modify the underlying data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "SaveSVG", Doc: "SaveSVG saves the plot to an svg -- first updates to ensure that plot is current", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SavePNG", Doc: "SavePNG saves the current plot to a png, capturing current render", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SaveCSV", Doc: "SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname", "delim"}}, {Name: "SaveAll", Doc: "SaveAll saves the current plot to a png, svg, and the data to a tsv -- full save\nAny extension is removed and appropriate extensions are added", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "OpenCSV", Doc: "OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}}, {Name: "setColumnsByName", Doc: "setColumnsByName turns columns on or off if their name contains\nthe given string.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"nameContains", "on"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "table", Doc: "table is the table of data being plotted."}, {Name: "Options", Doc: "Options are the overall plot options."}, {Name: "Columns", Doc: "Columns are the options for each column of the table."}, {Name: "plot", Doc: "plot is the plot object."}, {Name: "svgFile", Doc: "current svg file"}, {Name: "dataFile", Doc: "current csv data file"}, {Name: "inPlot", Doc: "currently doing a plot"}, {Name: "columnsFrame"}, {Name: "plotWidget"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.PlotEditor", IDName: "plot-editor", Doc: "PlotEditor is a widget that provides an interactive 2D plot\nof selected columns of tabular data, represented by a [table.Table] into\na [table.Table]. Other types of tabular data can be converted into this format.\nThe user can change various options for the plot and also modify the underlying data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "SaveSVG", Doc: "SaveSVG saves the plot to an svg -- first updates to ensure that plot is current", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SavePNG", Doc: "SavePNG saves the current plot to a png, capturing current render", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SaveCSV", Doc: "SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname", "delim"}}, {Name: "SaveAll", Doc: "SaveAll saves the current plot to a png, svg, and the data to a tsv -- full save\nAny extension is removed and appropriate extensions are added", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "OpenCSV", Doc: "OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}}, {Name: "setColumnsByName", Doc: "setColumnsByName turns columns on or off if their name contains\nthe given string.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"nameContains", "on"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "table", Doc: "table is the table of data being plotted."}, {Name: "PlotStyle", Doc: "PlotStyle has the overall plot style parameters."}, {Name: "plot", Doc: "plot is the plot object."}, {Name: "svgFile", Doc: "current svg file"}, {Name: "dataFile", Doc: "current csv data file"}, {Name: "inPlot", Doc: "currently doing a plot"}, {Name: "columnsFrame"}, {Name: "plotWidget"}, {Name: "plotStyleModified"}}}) // NewPlotEditor returns a new [PlotEditor] with the given optional parent: // PlotEditor is a widget that provides an interactive 2D plot -// of selected columns of tabular data, represented by a [table.IndexView] into +// of selected columns of tabular data, represented by a [table.Table] into // a [table.Table]. Other types of tabular data can be converted into this format. // The user can change various options for the plot and also modify the underlying data. func NewPlotEditor(parent ...tree.Node) *PlotEditor { return tree.New[PlotEditor](parent...) } -// SetOptions sets the [PlotEditor.Options]: -// Options are the overall plot options. -func (t *PlotEditor) SetOptions(v PlotOptions) *PlotEditor { t.Options = v; return t } +// SetPlotStyle sets the [PlotEditor.PlotStyle]: +// PlotStyle has the overall plot style parameters. +func (t *PlotEditor) SetPlotStyle(v plot.PlotStyle) *PlotEditor { t.PlotStyle = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot/plotcore.PlotterChooser", IDName: "plotter-chooser", Doc: "PlotterChooser represents a [Plottername] value with a [core.Chooser]\nfor selecting a plotter.", Embeds: []types.Field{{Name: "Chooser"}}}) + +// NewPlotterChooser returns a new [PlotterChooser] with the given optional parent: +// PlotterChooser represents a [Plottername] value with a [core.Chooser] +// for selecting a plotter. +func NewPlotterChooser(parent ...tree.Node) *PlotterChooser { + return tree.New[PlotterChooser](parent...) +} diff --git a/plot/plotcore/xyplot.go b/plot/plotcore/xyplot.go deleted file mode 100644 index eb405ee337..0000000000 --- a/plot/plotcore/xyplot.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plotcore - -import ( - "fmt" - "log/slog" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/colors" - "cogentcore.org/core/plot" - "cogentcore.org/core/plot/plots" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/split" - "cogentcore.org/core/tensor/table" -) - -// genPlotXY generates an XY (lines, points) plot, setting Plot variable -func (pl *PlotEditor) genPlotXY() { - plt := plot.New() - - // process xaxis first - xi, xview, err := pl.plotXAxis(plt, pl.table) - if err != nil { - return - } - xp := pl.Columns[xi] - - var lsplit *table.Splits - nleg := 1 - if pl.Options.Legend != "" { - _, err = pl.table.Table.ColumnIndex(pl.Options.Legend) - if err != nil { - slog.Error("plot.Legend", "err", err.Error()) - } else { - errors.Log(xview.SortStableColumnNames([]string{pl.Options.Legend, xp.Column}, table.Ascending)) - lsplit = split.GroupBy(xview, pl.Options.Legend) - nleg = max(lsplit.Len(), 1) - } - } - - var firstXY *tableXY - var strCols []*ColumnOptions - nys := 0 - for _, cp := range pl.Columns { - if !cp.On { - continue - } - if cp.IsString { - strCols = append(strCols, cp) - continue - } - if cp.TensorIndex < 0 { - yc := errors.Log1(pl.table.Table.ColumnByName(cp.Column)) - _, sz := yc.RowCellSize() - nys += sz - } else { - nys++ - } - } - - if nys == 0 { - return - } - - firstXY = nil - yidx := 0 - for _, cp := range pl.Columns { - if !cp.On || cp == xp { - continue - } - if cp.IsString { - continue - } - for li := 0; li < nleg; li++ { - lview := xview - leg := "" - if lsplit != nil && len(lsplit.Values) > li { - leg = lsplit.Values[li][0] - lview = lsplit.Splits[li] - } - nidx := 1 - stidx := cp.TensorIndex - if cp.TensorIndex < 0 { // do all - yc := errors.Log1(pl.table.Table.ColumnByName(cp.Column)) - _, sz := yc.RowCellSize() - nidx = sz - stidx = 0 - } - for ii := 0; ii < nidx; ii++ { - idx := stidx + ii - tix := lview.Clone() - xy, _ := newTableXYName(tix, xi, xp.TensorIndex, cp.Column, idx, cp.Range) - if xy == nil { - continue - } - if firstXY == nil { - firstXY = xy - } - var pts *plots.Scatter - var lns *plots.Line - lbl := cp.getLabel() - clr := cp.Color - if leg != "" { - lbl = leg + " " + lbl - } - if nleg > 1 { - cidx := yidx*nleg + li - clr = colors.Uniform(colors.Spaced(cidx)) - } - if nidx > 1 { - clr = colors.Uniform(colors.Spaced(idx)) - lbl = fmt.Sprintf("%s_%02d", lbl, idx) - } - if cp.Lines.Or(pl.Options.Lines) && cp.Points.Or(pl.Options.Points) { - lns, pts, _ = plots.NewLinePoints(xy) - } else if cp.Points.Or(pl.Options.Points) { - pts, _ = plots.NewScatter(xy) - } else { - lns, _ = plots.NewLine(xy) - } - if lns != nil { - lns.LineStyle.Width.Pt(float32(cp.LineWidth.Or(pl.Options.LineWidth))) - lns.LineStyle.Color = clr - lns.NegativeXDraw = pl.Options.NegativeXDraw - plt.Add(lns) - if pts != nil { - plt.Legend.Add(lbl, lns, pts) - } else { - plt.Legend.Add(lbl, lns) - } - } - if pts != nil { - pts.LineStyle.Color = clr - pts.LineStyle.Width.Pt(float32(cp.LineWidth.Or(pl.Options.LineWidth))) - pts.PointSize.Pt(float32(cp.PointSize.Or(pl.Options.PointSize))) - pts.PointShape = cp.PointShape.Or(pl.Options.PointShape) - plt.Add(pts) - if lns == nil { - plt.Legend.Add(lbl, pts) - } - } - if cp.ErrColumn != "" { - ec := errors.Log1(pl.table.Table.ColumnIndex(cp.ErrColumn)) - if ec >= 0 { - xy.errColumn = ec - eb, _ := plots.NewYErrorBars(xy) - eb.LineStyle.Color = clr - plt.Add(eb) - } - } - } - } - yidx++ - } - if firstXY != nil && len(strCols) > 0 { - for _, cp := range strCols { - xy, _ := newTableXY(xview, xi, xp.TensorIndex, firstXY.yColumn, cp.TensorIndex, firstXY.yRange) - xy.labelColumn, _ = xview.Table.ColumnIndex(cp.Column) - xy.yIndex = firstXY.yIndex - lbls, _ := plots.NewLabels(xy) - if lbls != nil { - plt.Add(lbls) - } - } - } - - // Use string labels for X axis if X is a string - xc := pl.table.Table.Columns[xi] - if xc.IsString() { - xcs := xc.(*tensor.String) - vals := make([]string, pl.table.Len()) - for i, dx := range pl.table.Indexes { - vals[i] = xcs.Values[dx] - } - plt.NominalX(vals...) - } - - pl.configPlot(plt) - pl.plot = plt -} diff --git a/plot/plots/bar.go b/plot/plots/bar.go new file mode 100644 index 0000000000..4fe1fe8238 --- /dev/null +++ b/plot/plots/bar.go @@ -0,0 +1,247 @@ +// Copyright (c) 2019, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This is copied and modified directly from gonum to add better error-bar +// plotting for bar plots, along with multiple groups. + +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package plots + +import ( + "math" + + "cogentcore.org/core/math32" + "cogentcore.org/core/math32/minmax" + "cogentcore.org/core/plot" +) + +// BarType is be used for specifying the type name. +const BarType = "Bar" + +func init() { + plot.RegisterPlotter(BarType, "A Bar presents ordinally-organized data with rectangular bars with lengths proportional to the data values, and an optional error bar at the top of the bar using the High data role.", []plot.Roles{plot.Y}, []plot.Roles{plot.High}, func(data plot.Data) plot.Plotter { + return NewBar(data) + }) +} + +// A Bar presents ordinally-organized data with rectangular bars +// with lengths proportional to the data values, and an optional +// error bar ("handle") at the top of the bar using the High data role. +// +// Bars are plotted centered at integer multiples of Stride plus Start offset. +// Full data range also includes Pad value to extend range beyond edge bar centers. +// Bar Width is in data units, e.g., should be <= Stride. +// Defaults provide a unit-spaced plot. +type Bar struct { + // copies of data + Y, Err plot.Values + + // actual plotting X, Y values in data coordinates, taking into account stacking etc. + X, Yp plot.Values + + // PX, PY are the actual pixel plotting coordinates for each XY value. + PX, PY []float32 + + // Style has the properties used to render the bars. + Style plot.Style + + // Horizontal dictates whether the bars should be in the vertical + // (default) or horizontal direction. If Horizontal is true, all + // X locations and distances referred to here will actually be Y + // locations and distances. + Horizontal bool + + // stackedOn is the bar chart upon which this bar chart is stacked. + StackedOn *Bar + + stylers plot.Stylers +} + +// NewBar returns a new bar plotter with a single bar for each value. +// The bars heights correspond to the values and their x locations correspond +// to the index of their value in the Valuer. +// Optional error-bar values can be provided using the High data role. +// Styler functions are obtained from the Y metadata if present. +func NewBar(data plot.Data) *Bar { + if data.CheckLengths() != nil { + return nil + } + bc := &Bar{} + bc.Y = plot.MustCopyRole(data, plot.Y) + if bc.Y == nil { + return nil + } + bc.stylers = plot.GetStylersFromData(data, plot.Y) + bc.Err = plot.CopyRole(data, plot.High) + bc.Defaults() + return bc +} + +func (bc *Bar) Defaults() { + bc.Style.Defaults() +} + +func (bc *Bar) Styler(f func(s *plot.Style)) *Bar { + bc.stylers.Add(f) + return bc +} + +func (bc *Bar) ApplyStyle(ps *plot.PlotStyle) { + ps.SetElementStyle(&bc.Style) + bc.stylers.Run(&bc.Style) +} + +func (bc *Bar) Stylers() *plot.Stylers { return &bc.stylers } + +func (bc *Bar) Data() (data plot.Data, pixX, pixY []float32) { + pixX = bc.PX + pixY = bc.PY + data = plot.Data{} + data[plot.X] = bc.X + data[plot.Y] = bc.Y + if bc.Err != nil { + data[plot.High] = bc.Err + } + return +} + +// BarHeight returns the maximum y value of the +// ith bar, taking into account any bars upon +// which it is stacked. +func (bc *Bar) BarHeight(i int) float64 { + ht := float64(0.0) + if bc == nil { + return 0 + } + if i >= 0 && i < len(bc.Y) { + ht += bc.Y[i] + } + if bc.StackedOn != nil { + ht += bc.StackedOn.BarHeight(i) + } + return ht +} + +// StackOn stacks a bar chart on top of another, +// and sets the bar positioning options to that of the +// chart upon which it is being stacked. +func (bc *Bar) StackOn(on *Bar) { + bc.Style.Width = on.Style.Width + bc.StackedOn = on +} + +// Plot implements the plot.Plotter interface. +func (bc *Bar) Plot(plt *plot.Plot) { + pc := plt.Paint + bc.Style.Line.SetStroke(plt) + pc.FillStyle.Color = bc.Style.Line.Fill + bw := bc.Style.Width + + nv := len(bc.Y) + bc.X = make(plot.Values, nv) + bc.Yp = make(plot.Values, nv) + bc.PX = make([]float32, nv) + bc.PY = make([]float32, nv) + + hw := 0.5 * bw.Width + ew := bw.Width / 3 + for i, ht := range bc.Y { + cat := bw.Offset + float64(i)*bw.Stride + var bottom float64 + var catVal, catMin, catMax, valMin, valMax float32 + var box math32.Box2 + if bc.Horizontal { + catVal = plt.PY(cat) + catMin = plt.PY(cat - hw) + catMax = plt.PY(cat + hw) + bottom = bc.StackedOn.BarHeight(i) // nil safe + valMin = plt.PX(bottom) + valMax = plt.PX(bottom + ht) + bc.X[i] = bottom + ht + bc.Yp[i] = cat + bc.PX[i] = valMax + bc.PY[i] = catVal + box.Min.Set(valMin, catMin) + box.Max.Set(valMax, catMax) + } else { + catVal = plt.PX(cat) + catMin = plt.PX(cat - hw) + catMax = plt.PX(cat + hw) + bottom = bc.StackedOn.BarHeight(i) // nil safe + valMin = plt.PY(bottom) + valMax = plt.PY(bottom + ht) + bc.X[i] = cat + bc.Yp[i] = bottom + ht + bc.PX[i] = catVal + bc.PY[i] = valMax + box.Min.Set(catMin, valMin) + box.Max.Set(catMax, valMax) + } + + pc.DrawRectangle(box.Min.X, box.Min.Y, box.Size().X, box.Size().Y) + pc.FillStrokeClear() + + if i < len(bc.Err) { + errval := math.Abs(bc.Err[i]) + if bc.Horizontal { + eVal := plt.PX(bottom + ht + math.Abs(errval)) + pc.MoveTo(valMax, catVal) + pc.LineTo(eVal, catVal) + pc.MoveTo(eVal, plt.PY(cat-ew)) + pc.LineTo(eVal, plt.PY(cat+ew)) + } else { + eVal := plt.PY(bottom + ht + math.Abs(errval)) + pc.MoveTo(catVal, valMax) + pc.LineTo(catVal, eVal) + pc.MoveTo(plt.PX(cat-ew), eVal) + pc.LineTo(plt.PX(cat+ew), eVal) + } + pc.Stroke() + } + } + pc.FillStyle.Color = nil +} + +// UpdateRange updates the given ranges. +func (bc *Bar) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) { + bw := bc.Style.Width + catMin := bw.Offset - bw.Pad + catMax := bw.Offset + float64(len(bc.Y)-1)*bw.Stride + bw.Pad + + for i, val := range bc.Y { + valBot := bc.StackedOn.BarHeight(i) + valTop := valBot + val + if i < len(bc.Err) { + valTop += math.Abs(bc.Err[i]) + } + if bc.Horizontal { + xr.FitValInRange(valBot) + xr.FitValInRange(valTop) + } else { + yr.FitValInRange(valBot) + yr.FitValInRange(valTop) + } + } + if bc.Horizontal { + xr.Min, xr.Max = bc.Style.Range.Clamp(xr.Min, xr.Max) + yr.FitInRange(minmax.F64{catMin, catMax}) + } else { + yr.Min, yr.Max = bc.Style.Range.Clamp(yr.Min, yr.Max) + xr.FitInRange(minmax.F64{catMin, catMax}) + } +} + +// Thumbnail fulfills the plot.Thumbnailer interface. +func (bc *Bar) Thumbnail(plt *plot.Plot) { + pc := plt.Paint + bc.Style.Line.SetStroke(plt) + pc.FillStyle.Color = bc.Style.Line.Fill + ptb := pc.Bounds + pc.DrawRectangle(float32(ptb.Min.X), float32(ptb.Min.Y), float32(ptb.Size().X), float32(ptb.Size().Y)) + pc.FillStrokeClear() + pc.FillStyle.Color = nil +} diff --git a/plot/plots/barchart.go b/plot/plots/barchart.go deleted file mode 100644 index 0b20ce7b00..0000000000 --- a/plot/plots/barchart.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright (c) 2019, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// This is copied and modified directly from gonum to add better error-bar -// plotting for bar plots, along with multiple groups. - -// Copyright ©2015 The Gonum Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plots - -import ( - "image" - - "cogentcore.org/core/colors" - "cogentcore.org/core/math32" - "cogentcore.org/core/plot" -) - -// A BarChart presents ordinally-organized data with rectangular bars -// with lengths proportional to the data values, and an optional -// error bar ("handle") at the top of the bar using given error value -// (single value, like a standard deviation etc, not drawn below the bar). -// -// Bars are plotted centered at integer multiples of Stride plus Start offset. -// Full data range also includes Pad value to extend range beyond edge bar centers. -// Bar Width is in data units, e.g., should be <= Stride. -// Defaults provide a unit-spaced plot. -type BarChart struct { - // Values are the plotted values - Values plot.Values - - // YErrors is a copy of the Y errors for each point. - Errors plot.Values - - // XYs is the actual pixel plotting coordinates for each value. - XYs plot.XYs - - // PXYs is the actual pixel plotting coordinates for each value. - PXYs plot.XYs - - // Offset is offset added to each X axis value relative to the - // Stride computed value (X = offset + index * Stride) - // Defaults to 1. - Offset float32 - - // Stride is distance between bars. Defaults to 1. - Stride float32 - - // Width is the width of the bars, which should be less than - // the Stride to prevent bar overlap. - // Defaults to .8 - Width float32 - - // Pad is additional space at start / end of data range, to keep bars from - // overflowing ends. This amount is subtracted from Offset - // and added to (len(Values)-1)*Stride -- no other accommodation for bar - // width is provided, so that should be built into this value as well. - // Defaults to 1. - Pad float32 - - // Color is the fill color of the bars. - Color image.Image - - // LineStyle is the style of the line connecting the points. - // Use zero width to disable lines. - LineStyle plot.LineStyle - - // Horizontal dictates whether the bars should be in the vertical - // (default) or horizontal direction. If Horizontal is true, all - // X locations and distances referred to here will actually be Y - // locations and distances. - Horizontal bool - - // stackedOn is the bar chart upon which this bar chart is stacked. - StackedOn *BarChart -} - -// NewBarChart returns a new bar chart with a single bar for each value. -// The bars heights correspond to the values and their x locations correspond -// to the index of their value in the Valuer. Optional error-bar values can be -// provided. -func NewBarChart(vs, ers plot.Valuer) (*BarChart, error) { - values, err := plot.CopyValues(vs) - if err != nil { - return nil, err - } - var errs plot.Values - if ers != nil { - errs, err = plot.CopyValues(ers) - if err != nil { - return nil, err - } - } - b := &BarChart{ - Values: values, - Errors: errs, - } - b.Defaults() - return b, nil -} - -func (b *BarChart) Defaults() { - b.Offset = 1 - b.Stride = 1 - b.Width = .8 - b.Pad = 1 - b.Color = colors.Scheme.OnSurface - b.LineStyle.Defaults() -} - -func (b *BarChart) XYData() (data plot.XYer, pixels plot.XYer) { - data = b.XYs - pixels = b.PXYs - return -} - -// BarHeight returns the maximum y value of the -// ith bar, taking into account any bars upon -// which it is stacked. -func (b *BarChart) BarHeight(i int) float32 { - ht := float32(0.0) - if b == nil { - return 0 - } - if i >= 0 && i < len(b.Values) { - ht += b.Values[i] - } - if b.StackedOn != nil { - ht += b.StackedOn.BarHeight(i) - } - return ht -} - -// StackOn stacks a bar chart on top of another, -// and sets the bar positioning options to that of the -// chart upon which it is being stacked. -func (b *BarChart) StackOn(on *BarChart) { - b.Offset = on.Offset - b.Stride = on.Stride - b.Pad = on.Pad - b.StackedOn = on -} - -// Plot implements the plot.Plotter interface. -func (b *BarChart) Plot(plt *plot.Plot) { - pc := plt.Paint - pc.FillStyle.Color = b.Color - b.LineStyle.SetStroke(plt) - - nv := len(b.Values) - b.XYs = make(plot.XYs, nv) - b.PXYs = make(plot.XYs, nv) - - hw := 0.5 * b.Width - ew := b.Width / 3 - for i, ht := range b.Values { - cat := b.Offset + float32(i)*b.Stride - var bottom, catVal, catMin, catMax, valMin, valMax float32 - var box math32.Box2 - if b.Horizontal { - catVal = plt.PY(cat) - catMin = plt.PY(cat - hw) - catMax = plt.PY(cat + hw) - bottom = b.StackedOn.BarHeight(i) // nil safe - valMin = plt.PX(bottom) - valMax = plt.PX(bottom + ht) - b.XYs[i] = math32.Vec2(bottom+ht, cat) - b.PXYs[i] = math32.Vec2(valMax, catVal) - box.Min.Set(valMin, catMin) - box.Max.Set(valMax, catMax) - } else { - catVal = plt.PX(cat) - catMin = plt.PX(cat - hw) - catMax = plt.PX(cat + hw) - bottom = b.StackedOn.BarHeight(i) // nil safe - valMin = plt.PY(bottom) - valMax = plt.PY(bottom + ht) - b.XYs[i] = math32.Vec2(cat, bottom+ht) - b.PXYs[i] = math32.Vec2(catVal, valMax) - box.Min.Set(catMin, valMin) - box.Max.Set(catMax, valMax) - } - - pc.DrawRectangle(box.Min.X, box.Min.Y, box.Size().X, box.Size().Y) - pc.FillStrokeClear() - - if i < len(b.Errors) { - errval := b.Errors[i] - if b.Horizontal { - eVal := plt.PX(bottom + ht + math32.Abs(errval)) - pc.MoveTo(valMax, catVal) - pc.LineTo(eVal, catVal) - pc.MoveTo(eVal, plt.PY(cat-ew)) - pc.LineTo(eVal, plt.PY(cat+ew)) - } else { - eVal := plt.PY(bottom + ht + math32.Abs(errval)) - pc.MoveTo(catVal, valMax) - pc.LineTo(catVal, eVal) - pc.MoveTo(plt.PX(cat-ew), eVal) - pc.LineTo(plt.PX(cat+ew), eVal) - } - pc.Stroke() - } - } -} - -// DataRange implements the plot.DataRanger interface. -func (b *BarChart) DataRange(plt *plot.Plot) (xmin, xmax, ymin, ymax float32) { - catMin := b.Offset - b.Pad - catMax := b.Offset + float32(len(b.Values)-1)*b.Stride + b.Pad - - valMin := math32.Inf(1) - valMax := math32.Inf(-1) - for i, val := range b.Values { - valBot := b.StackedOn.BarHeight(i) - valTop := valBot + val - if i < len(b.Errors) { - valTop += math32.Abs(b.Errors[i]) - } - valMin = math32.Min(valMin, math32.Min(valBot, valTop)) - valMax = math32.Max(valMax, math32.Max(valBot, valTop)) - } - if !b.Horizontal { - return catMin, catMax, valMin, valMax - } - return valMin, valMax, catMin, catMax -} - -// Thumbnail fulfills the plot.Thumbnailer interface. -func (b *BarChart) Thumbnail(plt *plot.Plot) { - pc := plt.Paint - pc.FillStyle.Color = b.Color - b.LineStyle.SetStroke(plt) - ptb := pc.Bounds - pc.DrawRectangle(float32(ptb.Min.X), float32(ptb.Min.Y), float32(ptb.Size().X), float32(ptb.Size().Y)) - pc.FillStrokeClear() -} diff --git a/plot/plots/doc.go b/plot/plots/doc.go index 072b3aef99..f4b79d3f06 100644 --- a/plot/plots/doc.go +++ b/plot/plots/doc.go @@ -18,6 +18,5 @@ // data points, and are just skipped over. // // New* functions return an error if the data contains Inf or is -// empty. Some of the New* functions return other plotter-specific errors -// too. +// empty. Some of the New* functions return other plotter-specific errors too. package plots diff --git a/plot/plots/enumgen.go b/plot/plots/enumgen.go deleted file mode 100644 index b795361bc0..0000000000 --- a/plot/plots/enumgen.go +++ /dev/null @@ -1,87 +0,0 @@ -// Code generated by "core generate"; DO NOT EDIT. - -package plots - -import ( - "cogentcore.org/core/enums" -) - -var _StepKindValues = []StepKind{0, 1, 2, 3} - -// StepKindN is the highest valid value for type StepKind, plus one. -const StepKindN StepKind = 4 - -var _StepKindValueMap = map[string]StepKind{`NoStep`: 0, `PreStep`: 1, `MidStep`: 2, `PostStep`: 3} - -var _StepKindDescMap = map[StepKind]string{0: `NoStep connects two points by simple line`, 1: `PreStep connects two points by following lines: vertical, horizontal.`, 2: `MidStep connects two points by following lines: horizontal, vertical, horizontal. Vertical line is placed in the middle of the interval.`, 3: `PostStep connects two points by following lines: horizontal, vertical.`} - -var _StepKindMap = map[StepKind]string{0: `NoStep`, 1: `PreStep`, 2: `MidStep`, 3: `PostStep`} - -// String returns the string representation of this StepKind value. -func (i StepKind) String() string { return enums.String(i, _StepKindMap) } - -// SetString sets the StepKind value from its string representation, -// and returns an error if the string is invalid. -func (i *StepKind) SetString(s string) error { - return enums.SetString(i, s, _StepKindValueMap, "StepKind") -} - -// Int64 returns the StepKind value as an int64. -func (i StepKind) Int64() int64 { return int64(i) } - -// SetInt64 sets the StepKind value from an int64. -func (i *StepKind) SetInt64(in int64) { *i = StepKind(in) } - -// Desc returns the description of the StepKind value. -func (i StepKind) Desc() string { return enums.Desc(i, _StepKindDescMap) } - -// StepKindValues returns all possible values for the type StepKind. -func StepKindValues() []StepKind { return _StepKindValues } - -// Values returns all possible values for the type StepKind. -func (i StepKind) Values() []enums.Enum { return enums.Values(_StepKindValues) } - -// MarshalText implements the [encoding.TextMarshaler] interface. -func (i StepKind) MarshalText() ([]byte, error) { return []byte(i.String()), nil } - -// UnmarshalText implements the [encoding.TextUnmarshaler] interface. -func (i *StepKind) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "StepKind") } - -var _ShapesValues = []Shapes{0, 1, 2, 3, 4, 5, 6, 7} - -// ShapesN is the highest valid value for type Shapes, plus one. -const ShapesN Shapes = 8 - -var _ShapesValueMap = map[string]Shapes{`Ring`: 0, `Circle`: 1, `Square`: 2, `Box`: 3, `Triangle`: 4, `Pyramid`: 5, `Plus`: 6, `Cross`: 7} - -var _ShapesDescMap = map[Shapes]string{0: `Ring is the outline of a circle`, 1: `Circle is a solid circle`, 2: `Square is the outline of a square`, 3: `Box is a filled square`, 4: `Triangle is the outline of a triangle`, 5: `Pyramid is a filled triangle`, 6: `Plus is a plus sign`, 7: `Cross is a big X`} - -var _ShapesMap = map[Shapes]string{0: `Ring`, 1: `Circle`, 2: `Square`, 3: `Box`, 4: `Triangle`, 5: `Pyramid`, 6: `Plus`, 7: `Cross`} - -// String returns the string representation of this Shapes value. -func (i Shapes) String() string { return enums.String(i, _ShapesMap) } - -// SetString sets the Shapes value from its string representation, -// and returns an error if the string is invalid. -func (i *Shapes) SetString(s string) error { return enums.SetString(i, s, _ShapesValueMap, "Shapes") } - -// Int64 returns the Shapes value as an int64. -func (i Shapes) Int64() int64 { return int64(i) } - -// SetInt64 sets the Shapes value from an int64. -func (i *Shapes) SetInt64(in int64) { *i = Shapes(in) } - -// Desc returns the description of the Shapes value. -func (i Shapes) Desc() string { return enums.Desc(i, _ShapesDescMap) } - -// ShapesValues returns all possible values for the type Shapes. -func ShapesValues() []Shapes { return _ShapesValues } - -// Values returns all possible values for the type Shapes. -func (i Shapes) Values() []enums.Enum { return enums.Values(_ShapesValues) } - -// MarshalText implements the [encoding.TextMarshaler] interface. -func (i Shapes) MarshalText() ([]byte, error) { return []byte(i.String()), nil } - -// UnmarshalText implements the [encoding.TextUnmarshaler] interface. -func (i *Shapes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Shapes") } diff --git a/plot/plots/errbars.go b/plot/plots/errbars.go index 09e305e017..4d37a6279e 100644 --- a/plot/plots/errbars.go +++ b/plot/plots/errbars.go @@ -5,124 +5,118 @@ package plots import ( - "cogentcore.org/core/math32" + "math" + + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/plot" - "cogentcore.org/core/styles/units" ) -////////////////////////////////////////////////// -// XErrorer - -// XErrorer provides an interface for a list of Low, High error bar values. -// This is used in addition to an XYer interface, if implemented. -type XErrorer interface { - // XError returns Low, High error values for X data. - XError(i int) (low, high float32) -} - -// Errors is a slice of low and high error values. -type Errors []struct{ Low, High float32 } - -// XErrors implements the XErrorer interface. -type XErrors Errors - -func (xe XErrors) XError(i int) (low, high float32) { - return xe[i].Low, xe[i].High -} +const ( + // YErrorBarsType is be used for specifying the type name. + YErrorBarsType = "YErrorBars" -// YErrorer provides an interface for YError method. -// This is used in addition to an XYer interface, if implemented. -type YErrorer interface { - // YError returns two error values for Y data. - YError(i int) (float32, float32) -} - -// YErrors implements the YErrorer interface. -type YErrors Errors + // XErrorBarsType is be used for specifying the type name. + XErrorBarsType = "XErrorBars" +) -func (ye YErrors) YError(i int) (float32, float32) { - return ye[i].Low, ye[i].High +func init() { + plot.RegisterPlotter(YErrorBarsType, "draws draws vertical error bars, denoting error in Y values, using either High or Low & High data roles for error deviations around X, Y coordinates.", []plot.Roles{plot.X, plot.Y, plot.High}, []plot.Roles{plot.Low}, func(data plot.Data) plot.Plotter { + return NewYErrorBars(data) + }) + plot.RegisterPlotter(XErrorBarsType, "draws draws horizontal error bars, denoting error in X values, using either High or Low & High data roles for error deviations around X, Y coordinates.", []plot.Roles{plot.X, plot.Y, plot.High}, []plot.Roles{plot.Low}, func(data plot.Data) plot.Plotter { + return NewXErrorBars(data) + }) } -// YErrorBars implements the plot.Plotter, plot.DataRanger, -// and plot.GlyphBoxer interfaces, drawing vertical error -// bars, denoting error in Y values. +// YErrorBars draws vertical error bars, denoting error in Y values, +// using ether High or Low, High data roles for error deviations +// around X, Y coordinates. type YErrorBars struct { - // XYs is a copy of the points for this line. - plot.XYs + // copies of data for this line + X, Y, Low, High plot.Values - // YErrors is a copy of the Y errors for each point. - YErrors + // PX, PY are the actual pixel plotting coordinates for each XY value. + PX, PY []float32 - // PXYs is the actual pixel plotting coordinates for each XY value, - // representing the high, center value of the error bar. - PXYs plot.XYs + // Style is the style for plotting. + Style plot.Style - // LineStyle is the style used to draw the error bars. - LineStyle plot.LineStyle - - // CapWidth is the width of the caps drawn at the top of each error bar. - CapWidth units.Value + stylers plot.Stylers + ystylers plot.Stylers } func (eb *YErrorBars) Defaults() { - eb.LineStyle.Defaults() - eb.CapWidth.Dp(10) + eb.Style.Defaults() } -// NewYErrorBars returns a new YErrorBars plotter, or an error on failure. -// The error values from the YErrorer interface are interpreted as relative -// to the corresponding Y value. The errors for a given Y value are computed -// by taking the absolute value of the error returned by the YErrorer -// and subtracting the first and adding the second to the Y value. -func NewYErrorBars(yerrs interface { - plot.XYer - YErrorer -}) (*YErrorBars, error) { - - errors := make(YErrors, yerrs.Len()) - for i := range errors { - errors[i].Low, errors[i].High = yerrs.YError(i) - if err := plot.CheckFloats(errors[i].Low, errors[i].High); err != nil { - return nil, err - } +// NewYErrorBars returns a new YErrorBars plotter, +// using Low, High data roles for error deviations around X, Y coordinates. +// Styler functions are obtained from the High data if present. +func NewYErrorBars(data plot.Data) *YErrorBars { + if data.CheckLengths() != nil { + return nil } - xys, err := plot.CopyXYs(yerrs) - if err != nil { - return nil, err + eb := &YErrorBars{} + eb.X = plot.MustCopyRole(data, plot.X) + eb.Y = plot.MustCopyRole(data, plot.Y) + eb.Low = plot.CopyRole(data, plot.Low) + eb.High = plot.CopyRole(data, plot.High) + if eb.Low == nil && eb.High != nil { + eb.Low = eb.High } - - eb := &YErrorBars{ - XYs: xys, - YErrors: errors, + if eb.X == nil || eb.Y == nil || eb.Low == nil || eb.High == nil { + return nil } + eb.stylers = plot.GetStylersFromData(data, plot.High) + eb.ystylers = plot.GetStylersFromData(data, plot.Y) eb.Defaults() - return eb, nil + return eb +} + +// Styler adds a style function to set style parameters. +func (eb *YErrorBars) Styler(f func(s *plot.Style)) *YErrorBars { + eb.stylers.Add(f) + return eb +} + +func (eb *YErrorBars) ApplyStyle(ps *plot.PlotStyle) { + ps.SetElementStyle(&eb.Style) + yst := &plot.Style{} + eb.ystylers.Run(yst) + eb.Style.Range = yst.Range // get range from y + eb.stylers.Run(&eb.Style) } -func (e *YErrorBars) XYData() (data plot.XYer, pixels plot.XYer) { - data = e.XYs - pixels = e.PXYs +func (eb *YErrorBars) Stylers() *plot.Stylers { return &eb.stylers } + +func (eb *YErrorBars) Data() (data plot.Data, pixX, pixY []float32) { + pixX = eb.PX + pixY = eb.PY + data = plot.Data{} + data[plot.X] = eb.X + data[plot.Y] = eb.Y + data[plot.Low] = eb.Low + data[plot.High] = eb.High return } -// Plot implements the Plotter interface, drawing labels. -func (e *YErrorBars) Plot(plt *plot.Plot) { +func (eb *YErrorBars) Plot(plt *plot.Plot) { pc := plt.Paint uc := &pc.UnitContext - e.CapWidth.ToDots(uc) - cw := 0.5 * e.CapWidth.Dots - nv := len(e.YErrors) - e.PXYs = make(plot.XYs, nv) - e.LineStyle.SetStroke(plt) - for i, err := range e.YErrors { - x := plt.PX(e.XYs[i].X) - ylow := plt.PY(e.XYs[i].Y - math32.Abs(err.Low)) - yhigh := plt.PY(e.XYs[i].Y + math32.Abs(err.High)) + eb.Style.Width.Cap.ToDots(uc) + cw := 0.5 * eb.Style.Width.Cap.Dots + nv := len(eb.X) + eb.PX = make([]float32, nv) + eb.PY = make([]float32, nv) + eb.Style.Line.SetStroke(plt) + for i, y := range eb.Y { + x := plt.PX(eb.X.Float1D(i)) + ylow := plt.PY(y - math.Abs(eb.Low[i])) + yhigh := plt.PY(y + math.Abs(eb.High[i])) - e.PXYs[i].X = x - e.PXYs[i].Y = yhigh + eb.PX[i] = x + eb.PY[i] = yhigh pc.MoveTo(x, ylow) pc.LineTo(x, yhigh) @@ -136,102 +130,111 @@ func (e *YErrorBars) Plot(plt *plot.Plot) { } } -// DataRange implements the plot.DataRanger interface. -func (e *YErrorBars) DataRange(plt *plot.Plot) (xmin, xmax, ymin, ymax float32) { - xmin, xmax = plot.Range(plot.XValues{e}) - ymin = math32.Inf(1) - ymax = math32.Inf(-1) - for i, err := range e.YErrors { - y := e.XYs[i].Y - ylow := y - math32.Abs(err.Low) - yhigh := y + math32.Abs(err.High) - ymin = math32.Min(math32.Min(math32.Min(ymin, y), ylow), yhigh) - ymax = math32.Max(math32.Max(math32.Max(ymax, y), ylow), yhigh) +// UpdateRange updates the given ranges. +func (eb *YErrorBars) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) { + plot.Range(eb.X, xr) + plot.RangeClamp(eb.Y, yr, &eb.Style.Range) + for i, y := range eb.Y { + ylow := y - math.Abs(eb.Low[i]) + yhigh := y + math.Abs(eb.High[i]) + yr.FitInRange(minmax.F64{ylow, yhigh}) } return } -// XErrorBars implements the plot.Plotter, plot.DataRanger, -// and plot.GlyphBoxer interfaces, drawing horizontal error -// bars, denoting error in Y values. +//////// XErrorBars + +// XErrorBars draws horizontal error bars, denoting error in X values, +// using ether High or Low, High data roles for error deviations +// around X, Y coordinates. type XErrorBars struct { - // XYs is a copy of the points for this line. - plot.XYs + // copies of data for this line + X, Y, Low, High plot.Values - // XErrors is a copy of the X errors for each point. - XErrors + // PX, PY are the actual pixel plotting coordinates for each XY value. + PX, PY []float32 - // PXYs is the actual pixel plotting coordinates for each XY value, - // representing the high, center value of the error bar. - PXYs plot.XYs + // Style is the style for plotting. + Style plot.Style - // LineStyle is the style used to draw the error bars. - LineStyle plot.LineStyle + stylers plot.Stylers + ystylers plot.Stylers + yrange minmax.Range64 +} - // CapWidth is the width of the caps drawn at the top - // of each error bar. - CapWidth units.Value +func (eb *XErrorBars) Defaults() { + eb.Style.Defaults() } -// Returns a new XErrorBars plotter, or an error on failure. The error values -// from the XErrorer interface are interpreted as relative to the corresponding -// X value. The errors for a given X value are computed by taking the absolute -// value of the error returned by the XErrorer and subtracting the first and -// adding the second to the X value. -func NewXErrorBars(xerrs interface { - plot.XYer - XErrorer -}) (*XErrorBars, error) { - - errors := make(XErrors, xerrs.Len()) - for i := range errors { - errors[i].Low, errors[i].High = xerrs.XError(i) - if err := plot.CheckFloats(errors[i].Low, errors[i].High); err != nil { - return nil, err - } +// NewXErrorBars returns a new XErrorBars plotter, +// using Low, High data roles for error deviations around X, Y coordinates. +func NewXErrorBars(data plot.Data) *XErrorBars { + if data.CheckLengths() != nil { + return nil } - xys, err := plot.CopyXYs(xerrs) - if err != nil { - return nil, err + eb := &XErrorBars{} + eb.X = plot.MustCopyRole(data, plot.X) + eb.Y = plot.MustCopyRole(data, plot.Y) + eb.Low = plot.MustCopyRole(data, plot.Low) + eb.High = plot.MustCopyRole(data, plot.High) + eb.Low = plot.CopyRole(data, plot.Low) + eb.High = plot.CopyRole(data, plot.High) + if eb.Low == nil && eb.High != nil { + eb.Low = eb.High } - - eb := &XErrorBars{ - XYs: xys, - XErrors: errors, + if eb.X == nil || eb.Y == nil || eb.Low == nil || eb.High == nil { + return nil } + eb.stylers = plot.GetStylersFromData(data, plot.High) + eb.ystylers = plot.GetStylersFromData(data, plot.Y) eb.Defaults() - return eb, nil + return eb } -func (eb *XErrorBars) Defaults() { - eb.LineStyle.Defaults() - eb.CapWidth.Dp(10) +// Styler adds a style function to set style parameters. +func (eb *XErrorBars) Styler(f func(s *plot.Style)) *XErrorBars { + eb.stylers.Add(f) + return eb +} + +func (eb *XErrorBars) ApplyStyle(ps *plot.PlotStyle) { + ps.SetElementStyle(&eb.Style) + yst := &plot.Style{} + eb.ystylers.Run(yst) + eb.yrange = yst.Range // get range from y + eb.stylers.Run(&eb.Style) } -func (e *XErrorBars) XYData() (data plot.XYer, pixels plot.XYer) { - data = e.XYs - pixels = e.PXYs +func (eb *XErrorBars) Stylers() *plot.Stylers { return &eb.stylers } + +func (eb *XErrorBars) Data() (data plot.Data, pixX, pixY []float32) { + pixX = eb.PX + pixY = eb.PY + data = plot.Data{} + data[plot.X] = eb.X + data[plot.Y] = eb.Y + data[plot.Low] = eb.Low + data[plot.High] = eb.High return } -// Plot implements the Plotter interface, drawing labels. -func (e *XErrorBars) Plot(plt *plot.Plot) { +func (eb *XErrorBars) Plot(plt *plot.Plot) { pc := plt.Paint uc := &pc.UnitContext - e.CapWidth.ToDots(uc) - cw := 0.5 * e.CapWidth.Dots - - nv := len(e.XErrors) - e.PXYs = make(plot.XYs, nv) - e.LineStyle.SetStroke(plt) - for i, err := range e.XErrors { - y := plt.PY(e.XYs[i].Y) - xlow := plt.PX(e.XYs[i].X - math32.Abs(err.Low)) - xhigh := plt.PX(e.XYs[i].X + math32.Abs(err.High)) + eb.Style.Width.Cap.ToDots(uc) + cw := 0.5 * eb.Style.Width.Cap.Dots + nv := len(eb.X) + eb.PX = make([]float32, nv) + eb.PY = make([]float32, nv) + eb.Style.Line.SetStroke(plt) + for i, x := range eb.X { + y := plt.PY(eb.Y.Float1D(i)) + xlow := plt.PX(x - math.Abs(eb.Low[i])) + xhigh := plt.PX(x + math.Abs(eb.High[i])) - e.PXYs[i].X = xhigh - e.PXYs[i].Y = y + eb.PX[i] = xhigh + eb.PY[i] = y pc.MoveTo(xlow, y) pc.LineTo(xhigh, y) @@ -245,17 +248,14 @@ func (e *XErrorBars) Plot(plt *plot.Plot) { } } -// DataRange implements the plot.DataRanger interface. -func (e *XErrorBars) DataRange(plt *plot.Plot) (xmin, xmax, ymin, ymax float32) { - ymin, ymax = plot.Range(plot.YValues{e}) - xmin = math32.Inf(1) - xmax = math32.Inf(-1) - for i, err := range e.XErrors { - x := e.XYs[i].X - xlow := x - math32.Abs(err.Low) - xhigh := x + math32.Abs(err.High) - xmin = math32.Min(math32.Min(math32.Min(xmin, x), xlow), xhigh) - xmax = math32.Max(math32.Max(math32.Max(xmax, x), xlow), xhigh) +// UpdateRange updates the given ranges. +func (eb *XErrorBars) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) { + plot.RangeClamp(eb.X, xr, &eb.Style.Range) + plot.RangeClamp(eb.Y, yr, &eb.yrange) + for i, xv := range eb.X { + xlow := xv - math.Abs(eb.Low[i]) + xhigh := xv + math.Abs(eb.High[i]) + xr.FitInRange(minmax.F64{xlow, xhigh}) } return } diff --git a/plot/plots/labels.go b/plot/plots/labels.go index be431c2665..d377804e03 100644 --- a/plot/plots/labels.go +++ b/plot/plots/labels.go @@ -5,162 +5,144 @@ package plots import ( - "errors" "image" "cogentcore.org/core/math32" + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/plot" - "cogentcore.org/core/styles/units" ) -// Labels implements the Plotter interface, -// drawing a set of labels at specified points. -type Labels struct { - // XYs is a copy of the points for labels - plot.XYs +// LabelsType is be used for specifying the type name. +const LabelsType = "Labels" - // PXYs is the actual pixel plotting coordinates for each XY value. - PXYs plot.XYs +func init() { + plot.RegisterPlotter(LabelsType, "draws text labels at specified X, Y points.", []plot.Roles{plot.X, plot.Y, plot.Label}, []plot.Roles{}, func(data plot.Data) plot.Plotter { + return NewLabels(data) + }) +} - // Labels is the set of labels corresponding to each point. - Labels []string +// Labels draws text labels at specified X, Y points. +type Labels struct { + // copies of data for this line + X, Y plot.Values + Labels plot.Labels - // TextStyle is the style of the label text. - // Each label can have a different text style, but - // by default they share a common one (len = 1) - TextStyle []plot.TextStyle + // PX, PY are the actual pixel plotting coordinates for each XY value. + PX, PY []float32 - // Offset is added directly to the final label location. - Offset units.XY + // Style is the style of the label text. + Style plot.Style // plot size and number of TextStyle when styles last generated -- don't regen styleSize image.Point - styleN int + stylers plot.Stylers + ystylers plot.Stylers } // NewLabels returns a new Labels using defaults -func NewLabels(d XYLabeler) (*Labels, error) { - xys, err := plot.CopyXYs(d) - if err != nil { - return nil, err +// Styler functions are obtained from the Label metadata if present. +func NewLabels(data plot.Data) *Labels { + if data.CheckLengths() != nil { + return nil } - - if d.Len() != len(xys) { - return nil, errors.New("plotter: number of points does not match the number of labels") + lb := &Labels{} + lb.X = plot.MustCopyRole(data, plot.X) + lb.Y = plot.MustCopyRole(data, plot.Y) + if lb.X == nil || lb.Y == nil { + return nil } - - strs := make([]string, d.Len()) - for i := range strs { - strs[i] = d.Label(i) + ld := data[plot.Label] + if ld == nil { + return nil } - - styles := make([]plot.TextStyle, 1) - for i := range styles { - styles[i].Defaults() + lb.Labels = make(plot.Labels, lb.X.Len()) + for i := range ld.Len() { + lb.Labels[i] = ld.String1D(i) } - return &Labels{ - XYs: xys, - Labels: strs, - TextStyle: styles, - }, nil + lb.stylers = plot.GetStylersFromData(data, plot.Label) + lb.ystylers = plot.GetStylersFromData(data, plot.Y) + lb.Defaults() + return lb } -func (l *Labels) XYData() (data plot.XYer, pixels plot.XYer) { - data = l.XYs - pixels = l.PXYs - return +func (lb *Labels) Defaults() { + lb.Style.Defaults() } -// updateStyles updates the text styles and dots. -// returns true if custom styles are used per point -func (l *Labels) updateStyles(plt *plot.Plot) bool { - customStyles := len(l.TextStyle) == len(l.XYs) - if plt.Size == l.styleSize && len(l.TextStyle) == l.styleN { - return customStyles - } - l.styleSize = plt.Size - l.styleN = len(l.TextStyle) - pc := plt.Paint - uc := &pc.UnitContext - l.Offset.ToDots(uc) - for i := range l.TextStyle { - l.TextStyle[i].ToDots(uc) - } - return customStyles +// Styler adds a style function to set style parameters. +func (lb *Labels) Styler(f func(s *plot.Style)) *Labels { + lb.stylers.Add(f) + return lb +} + +func (lb *Labels) ApplyStyle(ps *plot.PlotStyle) { + ps.SetElementStyle(&lb.Style) + yst := &plot.Style{} + lb.ystylers.Run(yst) + lb.Style.Range = yst.Range // get range from y + lb.stylers.Run(&lb.Style) // can still override here +} + +func (lb *Labels) Stylers() *plot.Stylers { return &lb.stylers } + +func (lb *Labels) Data() (data plot.Data, pixX, pixY []float32) { + pixX = lb.PX + pixY = lb.PY + data = plot.Data{} + data[plot.X] = lb.X + data[plot.Y] = lb.Y + data[plot.Label] = lb.Labels + return } // Plot implements the Plotter interface, drawing labels. -func (l *Labels) Plot(plt *plot.Plot) { - ps := plot.PlotXYs(plt, l.XYs) - customStyles := l.updateStyles(plt) +func (lb *Labels) Plot(plt *plot.Plot) { + pc := plt.Paint + uc := &pc.UnitContext + lb.PX = plot.PlotX(plt, lb.X) + lb.PY = plot.PlotY(plt, lb.Y) + st := &lb.Style.Text + st.Offset.ToDots(uc) var ltxt plot.Text - for i, label := range l.Labels { + ltxt.Defaults() + ltxt.Style = *st + ltxt.ToDots(uc) + for i, label := range lb.Labels { if label == "" { continue } - if customStyles { - ltxt.Style = l.TextStyle[i] - } else { - ltxt.Style = l.TextStyle[0] - } ltxt.Text = label ltxt.Config(plt) tht := ltxt.PaintText.BBox.Size().Y - ltxt.Draw(plt, math32.Vec2(ps[i].X+l.Offset.X.Dots, ps[i].Y+l.Offset.Y.Dots-tht)) + ltxt.Draw(plt, math32.Vec2(lb.PX[i]+st.Offset.X.Dots, lb.PY[i]+st.Offset.Y.Dots-tht)) } } -// DataRange returns the minimum and maximum X and Y values -func (l *Labels) DataRange(plt *plot.Plot) (xmin, xmax, ymin, ymax float32) { - xmin, xmax, ymin, ymax = plot.XYRange(l) // first get basic numerical range +// UpdateRange updates the given ranges. +func (lb *Labels) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) { + // todo: include point sizes! + plot.Range(lb.X, xr) + plot.RangeClamp(lb.Y, yr, &lb.Style.Range) pxToData := math32.FromPoint(plt.Size) - pxToData.X = (xmax - xmin) / pxToData.X - pxToData.Y = (ymax - ymin) / pxToData.Y - customStyles := l.updateStyles(plt) + pxToData.X = float32(xr.Range()) / pxToData.X + pxToData.Y = float32(yr.Range()) / pxToData.Y + st := &lb.Style.Text var ltxt plot.Text - for i, label := range l.Labels { + ltxt.Style = *st + for i, label := range lb.Labels { if label == "" { continue } - if customStyles { - ltxt.Style = l.TextStyle[i] - } else { - ltxt.Style = l.TextStyle[0] - } ltxt.Text = label ltxt.Config(plt) tht := pxToData.Y * ltxt.PaintText.BBox.Size().Y twd := 1.1 * pxToData.X * ltxt.PaintText.BBox.Size().X - x, y := l.XY(i) - minx := x - maxx := x + pxToData.X*l.Offset.X.Dots + twd - miny := y - maxy := y + pxToData.Y*l.Offset.Y.Dots + tht // y is up here - xmin = min(xmin, minx) - xmax = max(xmax, maxx) - ymin = min(ymin, miny) - ymax = max(ymax, maxy) + x := lb.X[i] + y := lb.Y[i] + maxx := x + float64(pxToData.X*st.Offset.X.Dots+twd) + maxy := y + float64(pxToData.Y*st.Offset.Y.Dots+tht) // y is up here + xr.FitInRange(minmax.F64{x, maxx}) + yr.FitInRange(minmax.F64{y, maxy}) } - return } - -// XYLabeler combines the [plot.XYer] and [plot.Labeler] types. -type XYLabeler interface { - plot.XYer - plot.Labeler -} - -// XYLabels holds XY data with labels. -// The ith label corresponds to the ith XY. -type XYLabels struct { - plot.XYs - Labels []string -} - -// Label returns the label for point index i. -func (l XYLabels) Label(i int) string { - return l.Labels[i] -} - -var _ XYLabeler = (*XYLabels)(nil) diff --git a/plot/plots/line.go b/plot/plots/line.go deleted file mode 100644 index f1697e3691..0000000000 --- a/plot/plots/line.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Adapted from github.com/gonum/plot: -// Copyright ©2015 The Gonum Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plots - -//go:generate core generate - -import ( - "image" - - "cogentcore.org/core/math32" - "cogentcore.org/core/plot" -) - -// StepKind specifies a form of a connection of two consecutive points. -type StepKind int32 //enums:enum - -const ( - // NoStep connects two points by simple line - NoStep StepKind = iota - - // PreStep connects two points by following lines: vertical, horizontal. - PreStep - - // MidStep connects two points by following lines: horizontal, vertical, horizontal. - // Vertical line is placed in the middle of the interval. - MidStep - - // PostStep connects two points by following lines: horizontal, vertical. - PostStep -) - -// Line implements the Plotter interface, drawing a line using XYer data. -type Line struct { - // XYs is a copy of the points for this line. - plot.XYs - - // PXYs is the actual pixel plotting coordinates for each XY value. - PXYs plot.XYs - - // StepStyle is the kind of the step line. - StepStyle StepKind - - // LineStyle is the style of the line connecting the points. - // Use zero width to disable lines. - LineStyle plot.LineStyle - - // Fill is the color to fill the area below the plot. - // Use nil to disable filling, which is the default. - Fill image.Image - - // if true, draw lines that connect points with a negative X-axis direction; - // otherwise there is a break in the line. - // default is false, so that repeated series of data across the X axis - // are plotted separately. - NegativeXDraw bool -} - -// NewLine returns a Line that uses the default line style and -// does not draw glyphs. -func NewLine(xys plot.XYer) (*Line, error) { - data, err := plot.CopyXYs(xys) - if err != nil { - return nil, err - } - ln := &Line{XYs: data} - ln.Defaults() - return ln, nil -} - -// NewLinePoints returns both a Line and a -// Scatter plot for the given point data. -func NewLinePoints(xys plot.XYer) (*Line, *Scatter, error) { - sc, err := NewScatter(xys) - if err != nil { - return nil, nil, err - } - ln := &Line{XYs: sc.XYs} - ln.Defaults() - return ln, sc, nil -} - -func (pts *Line) Defaults() { - pts.LineStyle.Defaults() -} - -func (pts *Line) XYData() (data plot.XYer, pixels plot.XYer) { - data = pts.XYs - pixels = pts.PXYs - return -} - -// Plot draws the Line, implementing the plot.Plotter interface. -func (pts *Line) Plot(plt *plot.Plot) { - pc := plt.Paint - - ps := plot.PlotXYs(plt, pts.XYs) - np := len(ps) - pts.PXYs = ps - - if pts.Fill != nil { - pc.FillStyle.Color = pts.Fill - minY := plt.PY(plt.Y.Min) - prev := math32.Vec2(ps[0].X, minY) - pc.MoveTo(prev.X, prev.Y) - for i := range ps { - pt := ps[i] - switch pts.StepStyle { - case NoStep: - if pt.X < prev.X { - pc.LineTo(prev.X, minY) - pc.ClosePath() - pc.MoveTo(pt.X, minY) - } - pc.LineTo(pt.X, pt.Y) - case PreStep: - if i == 0 { - continue - } - if pt.X < prev.X { - pc.LineTo(prev.X, minY) - pc.ClosePath() - pc.MoveTo(pt.X, minY) - } else { - pc.LineTo(prev.X, pt.Y) - } - pc.LineTo(pt.X, pt.Y) - case MidStep: - if pt.X < prev.X { - pc.LineTo(prev.X, minY) - pc.ClosePath() - pc.MoveTo(pt.X, minY) - } else { - pc.LineTo(0.5*(prev.X+pt.X), prev.Y) - pc.LineTo(0.5*(prev.X+pt.X), pt.Y) - } - pc.LineTo(pt.X, pt.Y) - case PostStep: - if pt.X < prev.X { - pc.LineTo(prev.X, minY) - pc.ClosePath() - pc.MoveTo(pt.X, minY) - } else { - pc.LineTo(pt.X, prev.Y) - } - pc.LineTo(pt.X, pt.Y) - } - prev = pt - } - pc.LineTo(prev.X, minY) - pc.ClosePath() - pc.Fill() - } - pc.FillStyle.Color = nil - - if !pts.LineStyle.SetStroke(plt) { - return - } - prev := ps[0] - pc.MoveTo(prev.X, prev.Y) - for i := 1; i < np; i++ { - pt := ps[i] - if pts.StepStyle != NoStep { - if pt.X >= prev.X { - switch pts.StepStyle { - case PreStep: - pc.LineTo(prev.X, pt.Y) - case MidStep: - pc.LineTo(0.5*(prev.X+pt.X), prev.Y) - pc.LineTo(0.5*(prev.X+pt.X), pt.Y) - case PostStep: - pc.LineTo(pt.X, prev.Y) - } - } else { - pc.MoveTo(pt.X, pt.Y) - } - } - if !pts.NegativeXDraw && pt.X < prev.X { - pc.MoveTo(pt.X, pt.Y) - } else { - pc.LineTo(pt.X, pt.Y) - } - prev = pt - } - pc.Stroke() -} - -// DataRange returns the minimum and maximum -// x and y values, implementing the plot.DataRanger interface. -func (pts *Line) DataRange(plt *plot.Plot) (xmin, xmax, ymin, ymax float32) { - return plot.XYRange(pts) -} - -// Thumbnail returns the thumbnail for the LineTo, implementing the plot.Thumbnailer interface. -func (pts *Line) Thumbnail(plt *plot.Plot) { - pc := plt.Paint - ptb := pc.Bounds - midY := 0.5 * float32(ptb.Min.Y+ptb.Max.Y) - - if pts.Fill != nil { - tb := ptb - if pts.LineStyle.Width.Value > 0 { - tb.Min.Y = int(midY) - } - pc.FillBox(math32.FromPoint(tb.Min), math32.FromPoint(tb.Size()), pts.Fill) - } - - if pts.LineStyle.SetStroke(plt) { - pc.MoveTo(float32(ptb.Min.X), midY) - pc.LineTo(float32(ptb.Max.X), midY) - pc.Stroke() - } -} diff --git a/plot/plots/plot_test.go b/plot/plots/plot_test.go index c85f75020d..3026a5737e 100644 --- a/plot/plots/plot_test.go +++ b/plot/plots/plot_test.go @@ -7,321 +7,570 @@ package plots import ( "fmt" "image" + "math" + "math/rand" "os" + "slices" + "strconv" "testing" + "cogentcore.org/core/base/errors" "cogentcore.org/core/base/iox/imagex" "cogentcore.org/core/colors" - "cogentcore.org/core/math32" "cogentcore.org/core/paint" "cogentcore.org/core/plot" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/maps" ) +func ExampleLine() { + xd, yd := make(plot.Values, 21), make(plot.Values, 21) + for i := range xd { + xd[i] = float64(i * 5) + yd[i] = 50.0 + 40*math.Sin((float64(i)/8)*math.Pi) + } + data := plot.Data{plot.X: xd, plot.Y: yd} + plt := plot.New() + plt.Add(NewLine(data).Styler(func(s *plot.Style) { + s.Plot.Title = "Test Line" + s.Plot.XAxis.Label = "X Axis" + s.Plot.YAxisLabel = "Y Axis" + s.Plot.XAxis.Range.SetMax(105) + s.Plot.LineWidth.Pt(2) + s.Plot.SetLinesOn(plot.On).SetPointsOn(plot.On) + s.Plot.TitleStyle.Size.Dp(48) + s.Plot.Legend.Position.Left = true + s.Plot.Legend.Text.Size.Dp(24) + s.Plot.Axis.Text.Size.Dp(32) + s.Plot.Axis.TickText.Size.Dp(24) + s.Plot.XAxis.Rotation = -45 + s.Line.Color = colors.Uniform(colors.Red) + s.Point.Color = colors.Uniform(colors.Blue) + s.Range.SetMin(0).SetMax(100) + })) + plt.Draw() + imagex.Save(plt.Pixels, "testdata/ex_line_plot.png") + // Output: +} + +func ExampleStylerMetadata() { + tx, ty := tensor.NewFloat64(21), tensor.NewFloat64(21) + for i := range tx.DimSize(0) { + tx.SetFloat1D(float64(i*5), i) + ty.SetFloat1D(50.0+40*math.Sin((float64(i)/8)*math.Pi), i) + } + // attach stylers to the Y axis data: that is where plotter looks for it + plot.SetStylersTo(ty, plot.Stylers{func(s *plot.Style) { + s.Plot.Title = "Test Line" + s.Plot.XAxis.Label = "X Axis" + s.Plot.YAxisLabel = "Y Axis" + s.Plot.Scale = 2 + s.Plot.XAxis.Range.SetMax(105) + s.Plot.SetLinesOn(plot.On).SetPointsOn(plot.On) + s.Line.Color = colors.Uniform(colors.Red) + s.Point.Color = colors.Uniform(colors.Blue) + s.Range.SetMin(0).SetMax(100) + }}) + + // somewhere else in the code: + + plt := plot.New() + // NewLine automatically gets stylers from ty tensor metadata + plt.Add(NewLine(plot.Data{plot.X: tx, plot.Y: ty})) + plt.Draw() + imagex.Save(plt.Pixels, "testdata/ex_styler_metadata.png") + // Output: +} + +func ExampleTable() { + rand.Seed(1) + n := 21 + tx, ty, th := tensor.NewFloat64(n), tensor.NewFloat64(n), tensor.NewFloat64(n) + lbls := tensor.NewString(n) + for i := range n { + tx.SetFloat1D(float64(i*5), i) + ty.SetFloat1D(50.0+40*math.Sin((float64(i)/8)*math.Pi), i) + th.SetFloat1D(5*rand.Float64(), i) + lbls.SetString1D(strconv.Itoa(i), i) + } + genst := func(s *plot.Style) { + s.Plot.Title = "Test Table" + s.Plot.XAxis.Label = "X Axis" + s.Plot.YAxisLabel = "Y Axis" + s.Plot.Scale = 2 + s.Plot.SetLinesOn(plot.On).SetPointsOn(plot.Off) + } + plot.SetStylersTo(ty, plot.Stylers{genst, func(s *plot.Style) { + s.On = true + s.Plotter = "XY" + s.Role = plot.Y + s.Line.Color = colors.Uniform(colors.Red) + s.Range.SetMin(0).SetMax(100) + }}) + // others get basic styling + plot.SetStylersTo(tx, plot.Stylers{func(s *plot.Style) { + s.Role = plot.X + }}) + plot.SetStylersTo(th, plot.Stylers{func(s *plot.Style) { + s.On = true + s.Plotter = "YErrorBars" + s.Role = plot.High + }}) + plot.SetStylersTo(lbls, plot.Stylers{func(s *plot.Style) { + s.On = true + s.Plotter = "Labels" + s.Role = plot.Label + s.Text.Offset.X.Dp(6) + s.Text.Offset.Y.Dp(-6) + }}) + dt := table.New("Test Table") // todo: use Name by default for plot. + dt.AddColumn("X", tx) + dt.AddColumn("Y", ty) + dt.AddColumn("High", th) + dt.AddColumn("Labels", lbls) + + plt := errors.Log1(plot.NewTablePlot(dt)) + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Save(plt.Pixels, "testdata/ex_table.png") + // Output: +} + func TestMain(m *testing.M) { paint.FontLibrary.InitFontPaths(paint.FontPaths...) os.Exit(m.Run()) } -func TestLine(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Line" - pt.X.Min = 0 - pt.X.Max = 100 - pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 - pt.Y.Label.Text = "Y Axis" - - // note: making two overlapping series - data := make(plot.XYs, 42) - for i := range data { - x := float32(i % 21) - data[i].X = x * 5 +// sinCosWrapData returns overlapping sin / cos curves in one sequence. +func sinCosWrapData() plot.Data { + xd, yd := make(plot.Values, 42), make(plot.Values, 42) + for i := range xd { + x := float64(i % 21) + xd[i] = x * 5 if i < 21 { - data[i].Y = float32(50) + 40*math32.Sin((x/8)*math32.Pi) + yd[i] = float64(50) + 40*math.Sin((x/8)*math.Pi) } else { - data[i].Y = float32(50) + 40*math32.Cos((x/8)*math32.Pi) + yd[i] = float64(50) + 40*math.Cos((x/8)*math.Pi) } } + return plot.Data{plot.X: xd, plot.Y: yd} +} - l1, err := NewLine(data) - if err != nil { - t.Error(err.Error()) +func sinDataXY() plot.Data { + xd, yd := make(plot.Values, 21), make(plot.Values, 21) + for i := range xd { + xd[i] = float64(i * 5) + yd[i] = float64(50) + 40*math.Sin((float64(i)/8)*math.Pi) } - pt.Add(l1) - pt.Legend.Add("Sine", l1) - pt.Legend.Add("Cos", l1) - - pt.Resize(image.Point{640, 480}) - pt.Draw() - imagex.Assert(t, pt.Pixels, "line.png") - - l1.Fill = colors.Uniform(colors.Yellow) - pt.Draw() - imagex.Assert(t, pt.Pixels, "line-fill.png") - - l1.StepStyle = PreStep - pt.Draw() - imagex.Assert(t, pt.Pixels, "line-prestep.png") - - l1.StepStyle = MidStep - pt.Draw() - imagex.Assert(t, pt.Pixels, "line-midstep.png") - - l1.StepStyle = PostStep - pt.Draw() - imagex.Assert(t, pt.Pixels, "line-poststep.png") + return plot.Data{plot.X: xd, plot.Y: yd} +} - l1.StepStyle = NoStep - l1.Fill = nil - l1.NegativeXDraw = true - pt.Draw() - imagex.Assert(t, pt.Pixels, "line-negx.png") +func sinData() plot.Data { + yd := make(plot.Values, 21) + for i := range yd { + x := float64(i % 21) + yd[i] = float64(50) + 40*math.Sin((x/8)*math.Pi) + } + return plot.Data{plot.Y: yd} +} +func cosData() plot.Data { + yd := make(plot.Values, 21) + for i := range yd { + x := float64(i % 21) + yd[i] = float64(50) + 40*math.Cos((x/8)*math.Pi) + } + return plot.Data{plot.Y: yd} } -func TestScatter(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Scatter" - pt.X.Min = 0 - pt.X.Max = 100 - pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 - pt.Y.Label.Text = "Y Axis" - - data := make(plot.XYs, 21) - for i := range data { - data[i].X = float32(i * 5) - data[i].Y = float32(50) + 40*math32.Sin((float32(i)/8)*math32.Pi) +func TestLine(t *testing.T) { + data := sinCosWrapData() + + plt := plot.New() + plt.Title.Text = "Test Line" + plt.X.Range.Min = 0 + plt.X.Range.Max = 100 + plt.X.Label.Text = "X Axis" + plt.Y.Range.Min = 0 + plt.Y.Range.Max = 100 + plt.Y.Label.Text = "Y Axis" + + l1 := NewLine(data) + if l1 == nil { + t.Error("bad data") } + plt.Add(l1) + plt.Legend.Add("Sine", l1) + plt.Legend.Add("Cos", l1) + + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "line.png") + + l1.Style.Line.Fill = colors.Uniform(colors.Yellow) + plt.Draw() + imagex.Assert(t, plt.Pixels, "line-fill.png") + + l1.Style.Line.Step = plot.PreStep + plt.Draw() + imagex.Assert(t, plt.Pixels, "line-prestep.png") + + l1.Style.Line.Step = plot.MidStep + plt.Draw() + imagex.Assert(t, plt.Pixels, "line-midstep.png") + + l1.Style.Line.Step = plot.PostStep + plt.Draw() + imagex.Assert(t, plt.Pixels, "line-poststep.png") + + l1.Style.Line.Step = plot.NoStep + l1.Style.Line.Fill = nil + l1.Style.Line.NegativeX = true + plt.Draw() + imagex.Assert(t, plt.Pixels, "line-negx.png") +} - l1, err := NewScatter(data) - if err != nil { - t.Error(err.Error()) +func TestScatter(t *testing.T) { + data := sinDataXY() + + plt := plot.New() + plt.Title.Text = "Test Scatter" + plt.X.Range.Min = 0 + plt.X.Range.Max = 100 + plt.X.Label.Text = "X Axis" + plt.Y.Range.Min = 0 + plt.Y.Range.Max = 100 + plt.Y.Label.Text = "Y Axis" + + l1 := NewScatter(data) + if l1 == nil { + t.Error("bad data") } - pt.Add(l1) + plt.Add(l1) - pt.Resize(image.Point{640, 480}) + plt.Resize(image.Point{640, 480}) - shs := ShapesValues() + shs := plot.ShapesValues() for _, sh := range shs { - l1.PointShape = sh - pt.Draw() - imagex.Assert(t, pt.Pixels, "scatter-"+sh.String()+".png") + l1.Style.Point.Shape = sh + plt.Draw() + imagex.Assert(t, plt.Pixels, "scatter-"+sh.String()+".png") } } func TestLabels(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Labels" - pt.X.Label.Text = "X Axis" - pt.Y.Label.Text = "Y Axis" - - // note: making two overlapping series - data := make(plot.XYs, 12) - labels := make([]string, 12) - for i := range data { - x := float32(i % 21) - data[i].X = x * 5 - data[i].Y = float32(50) + 40*math32.Sin((x/8)*math32.Pi) - labels[i] = fmt.Sprintf("%7.4g", data[i].Y) + plt := plot.New() + plt.Title.Text = "Test Labels" + plt.X.Label.Text = "X Axis" + plt.Y.Label.Text = "Y Axis" + + xd, yd := make(plot.Values, 12), make(plot.Values, 12) + labels := make(plot.Labels, 12) + for i := range xd { + x := float64(i % 21) + xd[i] = x * 5 + yd[i] = float64(50) + 40*math.Sin((x/8)*math.Pi) + labels[i] = fmt.Sprintf("%7.4g", yd[i]) } - - l1, sc, err := NewLinePoints(data) - if err != nil { - t.Error(err.Error()) + data := plot.Data{} + data[plot.X] = xd + data[plot.Y] = yd + data[plot.Label] = labels + + l1 := NewLine(data) + if l1 == nil { + t.Error("bad data") } - pt.Add(l1) - pt.Add(sc) - pt.Legend.Add("Sine", l1, sc) + l1.Style.Point.On = plot.On + plt.Add(l1) + plt.Legend.Add("Sine", l1) - l2, err := NewLabels(XYLabels{XYs: data, Labels: labels}) - if err != nil { - t.Error(err.Error()) + l2 := NewLabels(data) + if l2 == nil { + t.Error("bad data") } - l2.Offset.X.Dp(6) - l2.Offset.Y.Dp(-6) - pt.Add(l2) + l2.Style.Text.Offset.X.Dp(6) + l2.Style.Text.Offset.Y.Dp(-6) + plt.Add(l2) - pt.Resize(image.Point{640, 480}) - pt.Draw() - imagex.Assert(t, pt.Pixels, "labels.png") + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "labels.png") } -func TestBarChart(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Bar Chart" - pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 - pt.Y.Label.Text = "Y Axis" - - data := make(plot.Values, 21) - for i := range data { - x := float32(i % 21) - data[i] = float32(50) + 40*math32.Sin((x/8)*math32.Pi) - } +func TestBar(t *testing.T) { + plt := plot.New() + plt.Title.Text = "Test Bar Chart" + plt.X.Label.Text = "X Axis" + plt.Y.Range.Min = 0 + plt.Y.Range.Max = 100 + plt.Y.Label.Text = "Y Axis" - cos := make(plot.Values, 21) - for i := range data { - x := float32(i % 21) - cos[i] = float32(50) + 40*math32.Cos((x/8)*math32.Pi) - } + data := sinData() + cos := cosData() - l1, err := NewBarChart(data, nil) - if err != nil { - t.Error(err.Error()) + l1 := NewBar(data) + if l1 == nil { + t.Error("bad data") } - l1.Color = colors.Uniform(colors.Red) - pt.Add(l1) - pt.Legend.Add("Sine", l1) + l1.Style.Line.Fill = colors.Uniform(colors.Red) + plt.Add(l1) + plt.Legend.Add("Sine", l1) - pt.Resize(image.Point{640, 480}) - pt.Draw() - imagex.Assert(t, pt.Pixels, "bar.png") + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "bar.png") - l2, err := NewBarChart(cos, nil) - if err != nil { - t.Error(err.Error()) + l2 := NewBar(cos) + if l2 == nil { + t.Error("bad data") } - l2.Color = colors.Uniform(colors.Blue) - pt.Legend.Add("Cosine", l2) + l2.Style.Line.Fill = colors.Uniform(colors.Blue) + plt.Legend.Add("Cosine", l2) - l1.Stride = 2 - l2.Stride = 2 - l2.Offset = 2 + l1.Style.Width.Stride = 2 + l2.Style.Width.Stride = 2 + l2.Style.Width.Offset = 2 - pt.Add(l2) // note: range updated when added! - pt.Draw() - imagex.Assert(t, pt.Pixels, "bar-cos.png") + plt.Add(l2) + plt.Draw() + imagex.Assert(t, plt.Pixels, "bar-cos.png") } -func TestBarChartErr(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Bar Chart Errors" - pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 - pt.Y.Label.Text = "Y Axis" - - data := make(plot.Values, 21) - for i := range data { - x := float32(i % 21) - data[i] = float32(50) + 40*math32.Sin((x/8)*math32.Pi) +func TestBarErr(t *testing.T) { + plt := plot.New() + plt.Title.Text = "Test Bar Chart Errors" + plt.X.Label.Text = "X Axis" + plt.Y.Range.Min = 0 + plt.Y.Range.Max = 100 + plt.Y.Label.Text = "Y Axis" + + data := sinData() + cos := cosData() + data[plot.High] = cos[plot.Y] + + l1 := NewBar(data) + if l1 == nil { + t.Error("bad data") } + l1.Style.Line.Fill = colors.Uniform(colors.Red) + plt.Add(l1) + plt.Legend.Add("Sine", l1) - cos := make(plot.Values, 21) - for i := range data { - x := float32(i % 21) - cos[i] = float32(5) + 4*math32.Cos((x/8)*math32.Pi) - } - - l1, err := NewBarChart(data, cos) - if err != nil { - t.Error(err.Error()) - } - l1.Color = colors.Uniform(colors.Red) - pt.Add(l1) - pt.Legend.Add("Sine", l1) - - pt.Resize(image.Point{640, 480}) - pt.Draw() - imagex.Assert(t, pt.Pixels, "bar-err.png") + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "bar-err.png") l1.Horizontal = true - pt.UpdateRange() - pt.X.Min = 0 - pt.X.Max = 100 - pt.Draw() - imagex.Assert(t, pt.Pixels, "bar-err-horiz.png") + plt.UpdateRange() + plt.X.Range.Min = 0 + plt.X.Range.Max = 100 + plt.Draw() + imagex.Assert(t, plt.Pixels, "bar-err-horiz.png") } -func TestBarChartStack(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Bar Chart Stacked" - pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 - pt.Y.Label.Text = "Y Axis" - - data := make(plot.Values, 21) - for i := range data { - x := float32(i % 21) - data[i] = float32(50) + 40*math32.Sin((x/8)*math32.Pi) - } +func TestBarStack(t *testing.T) { + plt := plot.New() + plt.Title.Text = "Test Bar Chart Stacked" + plt.X.Label.Text = "X Axis" + plt.Y.Range.Min = 0 + plt.Y.Range.Max = 100 + plt.Y.Label.Text = "Y Axis" - cos := make(plot.Values, 21) - for i := range data { - x := float32(i % 21) - cos[i] = float32(5) + 4*math32.Cos((x/8)*math32.Pi) - } + data := sinData() + cos := cosData() - l1, err := NewBarChart(data, nil) - if err != nil { - t.Error(err.Error()) + l1 := NewBar(data) + if l1 == nil { + t.Error("bad data") } - l1.Color = colors.Uniform(colors.Red) - pt.Add(l1) - pt.Legend.Add("Sine", l1) + l1.Style.Line.Fill = colors.Uniform(colors.Red) + plt.Add(l1) + plt.Legend.Add("Sine", l1) - l2, err := NewBarChart(cos, nil) - if err != nil { - t.Error(err.Error()) + l2 := NewBar(cos) + if l2 == nil { + t.Error("bad data") } - l2.Color = colors.Uniform(colors.Blue) + l2.Style.Line.Fill = colors.Uniform(colors.Blue) l2.StackedOn = l1 - pt.Add(l2) - pt.Legend.Add("Cos", l2) + plt.Add(l2) + plt.Legend.Add("Cos", l2) - pt.Resize(image.Point{640, 480}) - pt.Draw() - imagex.Assert(t, pt.Pixels, "bar-stacked.png") -} - -type XYErr struct { - plot.XYs - YErrors + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "bar-stacked.png") } func TestErrBar(t *testing.T) { - pt := plot.New() - pt.Title.Text = "Test Line Errors" - pt.X.Label.Text = "X Axis" - pt.Y.Min = 0 - pt.Y.Max = 100 - pt.Y.Label.Text = "Y Axis" - - data := make(plot.XYs, 21) - for i := range data { - x := float32(i % 21) - data[i].X = x * 5 - data[i].Y = float32(50) + 40*math32.Sin((x/8)*math32.Pi) + plt := plot.New() + plt.Title.Text = "Test Line Errors" + plt.X.Label.Text = "X Axis" + plt.Y.Range.Min = 0 + plt.Y.Range.Max = 100 + plt.Y.Label.Text = "Y Axis" + + xd, yd := make(plot.Values, 21), make(plot.Values, 21) + for i := range xd { + x := float64(i % 21) + xd[i] = x * 5 + yd[i] = float64(50) + 40*math.Sin((x/8)*math.Pi) } - yerr := make(YErrors, 21) - for i := range yerr { - x := float32(i % 21) - yerr[i].High = float32(5) + 4*math32.Cos((x/8)*math32.Pi) - yerr[i].Low = -yerr[i].High + low, high := make(plot.Values, 21), make(plot.Values, 21) + for i := range low { + x := float64(i % 21) + high[i] = float64(5) + 4*math.Cos((x/8)*math.Pi) + low[i] = -high[i] } - xyerr := XYErr{XYs: data, YErrors: yerr} + data := plot.Data{plot.X: xd, plot.Y: yd, plot.Low: low, plot.High: high} + + l1 := NewLine(data) + if l1 == nil { + t.Error("bad data") + } + plt.Add(l1) + plt.Legend.Add("Sine", l1) - l1, err := NewLine(data) - if err != nil { - t.Error(err.Error()) + l2 := NewYErrorBars(data) + if l2 == nil { + t.Error("bad data") } - pt.Add(l1) - pt.Legend.Add("Sine", l1) + plt.Add(l2) + + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "errbar.png") +} - l2, err := NewYErrorBars(xyerr) - if err != nil { - t.Error(err.Error()) +func TestStyle(t *testing.T) { + data := sinCosWrapData() + + stf := func(s *plot.Style) { + s.Plot.Title = "Test Line" + s.Plot.XAxis.Label = "X Axis" + s.Plot.YAxisLabel = "Y Axis" + s.Plot.XAxis.Range.SetMax(105) + s.Plot.LineWidth.Pt(2) + s.Plot.SetLinesOn(plot.On).SetPointsOn(plot.On) + s.Plot.TitleStyle.Size.Dp(48) + s.Plot.Legend.Position.Left = true + s.Plot.Legend.Text.Size.Dp(24) + s.Plot.Axis.Text.Size.Dp(32) + s.Plot.Axis.TickText.Size.Dp(24) + s.Plot.XAxis.Rotation = -45 + // s.Line.On = plot.Off + s.Line.Color = colors.Uniform(colors.Red) + s.Point.Color = colors.Uniform(colors.Blue) + s.Range.SetMax(100) } - pt.Add(l2) - pt.Resize(image.Point{640, 480}) - pt.Draw() - imagex.Assert(t, pt.Pixels, "errbar.png") + plt := plot.New() + l1 := NewLine(data).Styler(stf) + plt.Add(l1) + plt.Legend.Add("Sine", l1) // todo: auto-add! + plt.Legend.Add("Cos", l1) + + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "style_line_point.png") + + plt = plot.New() + tdy := tensor.NewFloat64FromValues(data[plot.Y].(plot.Values)...) + plot.SetStylersTo(tdy, plot.Stylers{stf}) // set metadata for tensor + tdx := tensor.NewFloat64FromValues(data[plot.X].(plot.Values)...) + // NewLine auto-grabs from Y metadata + l1 = NewLine(plot.Data{plot.X: tdx, plot.Y: tdy}) + plt.Add(l1) + plt.Legend.Add("Sine", l1) // todo: auto-add! + plt.Legend.Add("Cos", l1) + plt.Resize(image.Point{640, 480}) + plt.Draw() + imagex.Assert(t, plt.Pixels, "style_line_point_auto.png") +} + +// todo: move into statplot and test everything + +func TestTable(t *testing.T) { + rand.Seed(1) + n := 21 + tx, ty := tensor.NewFloat64(n), tensor.NewFloat64(n) + tl, th := tensor.NewFloat64(n), tensor.NewFloat64(n) + ts, tc := tensor.NewFloat64(n), tensor.NewFloat64(n) + lbls := tensor.NewString(n) + for i := range n { + tx.SetFloat1D(float64(i*5), i) + ty.SetFloat1D(50.0+40*math.Sin((float64(i)/8)*math.Pi), i) + tl.SetFloat1D(5*rand.Float64(), i) + th.SetFloat1D(5*rand.Float64(), i) + ts.SetFloat1D(1+5*rand.Float64(), i) + tc.SetFloat1D(float64(i), i) + lbls.SetString1D(strconv.Itoa(i), i) + } + ptyps := maps.Keys(plot.Plotters) + slices.Sort(ptyps) + for _, ttyp := range ptyps { + // attach stylers to the Y axis data: that is where plotter looks for it + genst := func(s *plot.Style) { + s.Plot.Title = "Test " + ttyp + s.Plot.XAxis.Label = "X Axis" + s.Plot.YAxisLabel = "Y Axis" + s.Plotter = plot.PlotterName(ttyp) + s.Plot.Scale = 2 + s.Plot.SetLinesOn(plot.On).SetPointsOn(plot.On) + s.Line.Color = colors.Uniform(colors.Red) + s.Point.Color = colors.Uniform(colors.Blue) + s.Range.SetMin(0).SetMax(100) + } + plot.SetStylersTo(ty, plot.Stylers{genst, func(s *plot.Style) { + s.On = true + s.Role = plot.Y + s.Group = "Y" + }}) + // others get basic styling + plot.SetStylersTo(tx, plot.Stylers{func(s *plot.Style) { + s.Role = plot.X + s.Group = "Y" + }}) + plot.SetStylersTo(tl, plot.Stylers{func(s *plot.Style) { + s.Role = plot.Low + s.Group = "Y" + }}) + plot.SetStylersTo(th, plot.Stylers{genst, func(s *plot.Style) { + s.On = true + s.Role = plot.High + s.Group = "Y" + }}) + plot.SetStylersTo(ts, plot.Stylers{func(s *plot.Style) { + s.Role = plot.Size + s.Group = "Y" + }}) + plot.SetStylersTo(tc, plot.Stylers{func(s *plot.Style) { + s.Role = plot.Color + s.Group = "Y" + }}) + plot.SetStylersTo(lbls, plot.Stylers{genst, func(s *plot.Style) { + s.On = true + s.Role = plot.Label + s.Group = "Y" + }}) + dt := table.New("Test Table") // todo: use Name by default for plot. + dt.AddColumn("X", tx) + dt.AddColumn("Y", ty) + dt.AddColumn("Low", tl) + dt.AddColumn("High", th) + dt.AddColumn("Size", ts) + dt.AddColumn("Color", tc) + dt.AddColumn("Labels", lbls) + + plt, err := plot.NewTablePlot(dt) + assert.NoError(t, err) + plt.Resize(image.Point{640, 480}) + plt.Draw() + fnm := "table_" + ttyp + ".png" + imagex.Assert(t, plt.Pixels, fnm) + } } diff --git a/plot/plots/scatter.go b/plot/plots/scatter.go deleted file mode 100644 index bfd5cdf754..0000000000 --- a/plot/plots/scatter.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Adapted from github.com/gonum/plot: -// Copyright ©2015 The Gonum Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plots - -import ( - "cogentcore.org/core/math32" - "cogentcore.org/core/plot" - "cogentcore.org/core/styles/units" -) - -// Scatter implements the Plotter interface, drawing -// a shape for each point. -type Scatter struct { - // XYs is a copy of the points for this scatter. - plot.XYs - - // PXYs is the actual plotting coordinates for each XY value. - PXYs plot.XYs - - // size of shape to draw for each point - PointSize units.Value - - // shape to draw for each point - PointShape Shapes - - // LineStyle is the style of the line connecting the points. - // Use zero width to disable lines. - LineStyle plot.LineStyle -} - -// NewScatter returns a Scatter that uses the -// default glyph style. -func NewScatter(xys plot.XYer) (*Scatter, error) { - data, err := plot.CopyXYs(xys) - if err != nil { - return nil, err - } - sc := &Scatter{XYs: data} - sc.LineStyle.Defaults() - sc.PointSize.Pt(4) - return sc, nil -} - -func (pts *Scatter) XYData() (data plot.XYer, pixels plot.XYer) { - data = pts.XYs - pixels = pts.PXYs - return -} - -// Plot draws the Line, implementing the plot.Plotter interface. -func (pts *Scatter) Plot(plt *plot.Plot) { - pc := plt.Paint - if !pts.LineStyle.SetStroke(plt) { - return - } - pts.PointSize.ToDots(&pc.UnitContext) - pc.FillStyle.Color = pts.LineStyle.Color - ps := plot.PlotXYs(plt, pts.XYs) - for i := range ps { - pt := ps[i] - DrawShape(pc, math32.Vec2(pt.X, pt.Y), pts.PointSize.Dots, pts.PointShape) - } - pc.FillStyle.Color = nil -} - -// DataRange returns the minimum and maximum -// x and y values, implementing the plot.DataRanger interface. -func (pts *Scatter) DataRange(plt *plot.Plot) (xmin, xmax, ymin, ymax float32) { - return plot.XYRange(pts) -} - -// Thumbnail the thumbnail for the Scatter, -// implementing the plot.Thumbnailer interface. -func (pts *Scatter) Thumbnail(plt *plot.Plot) { - if !pts.LineStyle.SetStroke(plt) { - return - } - pc := plt.Paint - pts.PointSize.ToDots(&pc.UnitContext) - pc.FillStyle.Color = pts.LineStyle.Color - ptb := pc.Bounds - midX := 0.5 * float32(ptb.Min.X+ptb.Max.X) - midY := 0.5 * float32(ptb.Min.Y+ptb.Max.Y) - - DrawShape(pc, math32.Vec2(midX, midY), pts.PointSize.Dots, pts.PointShape) -} diff --git a/plot/plots/table.go b/plot/plots/table.go deleted file mode 100644 index 41644ce7a6..0000000000 --- a/plot/plots/table.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package plots - -import "cogentcore.org/core/plot" - -// Table is an interface for tabular data for plotting, -// with columns of values. -type Table interface { - // number of columns of data - NumColumns() int - - // name of given column - ColumnName(i int) string - - // number of rows of data - NumRows() int - - // PlotData returns the data value at given column and row - PlotData(column, row int) float32 -} - -func TableColumnIndex(tab Table, name string) int { - for i := range tab.NumColumns() { - if tab.ColumnName(i) == name { - return i - } - } - return -1 -} - -// TableXYer is an interface for providing XY access to Table data -type TableXYer struct { - Table Table - - // the indexes of the tensor columns to use for the X and Y data, respectively - XColumn, YColumn int -} - -func NewTableXYer(tab Table, xcolumn, ycolumn int) *TableXYer { - txy := &TableXYer{Table: tab, XColumn: xcolumn, YColumn: ycolumn} - return txy -} - -func (dt *TableXYer) Len() int { - return dt.Table.NumRows() -} - -func (dt *TableXYer) XY(i int) (x, y float32) { - return dt.Table.PlotData(dt.XColumn, i), dt.Table.PlotData(dt.YColumn, i) -} - -// AddTableLine adds Line with given x, y columns from given tabular data -func AddTableLine(plt *plot.Plot, tab Table, xcolumn, ycolumn int) (*Line, error) { - txy := NewTableXYer(tab, xcolumn, ycolumn) - ln, err := NewLine(txy) - if err != nil { - return nil, err - } - plt.Add(ln) - return ln, nil -} - -// AddTableLinePoints adds Line w/ Points with given x, y columns from given tabular data -func AddTableLinePoints(plt *plot.Plot, tab Table, xcolumn, ycolumn int) (*Line, *Scatter, error) { - txy := &TableXYer{Table: tab, XColumn: xcolumn, YColumn: ycolumn} - ln, sc, err := NewLinePoints(txy) - if err != nil { - return nil, nil, err - } - plt.Add(ln) - plt.Add(sc) - return ln, sc, nil -} diff --git a/plot/plots/xy.go b/plot/plots/xy.go new file mode 100644 index 0000000000..06c8e15664 --- /dev/null +++ b/plot/plots/xy.go @@ -0,0 +1,271 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Adapted from github.com/gonum/plot: +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package plots + +//go:generate core generate + +import ( + "cogentcore.org/core/math32" + "cogentcore.org/core/math32/minmax" + "cogentcore.org/core/plot" +) + +// XYType is be used for specifying the type name. +const XYType = "XY" + +func init() { + plot.RegisterPlotter(XYType, "draws lines between and / or points for X,Y data values, using optional Size and Color data for the points, for a bubble plot.", []plot.Roles{plot.X, plot.Y}, []plot.Roles{plot.Size, plot.Color}, func(data plot.Data) plot.Plotter { + return NewXY(data) + }) +} + +// XY draws lines between and / or points for XY data values. +type XY struct { + // copies of data for this line + X, Y, Color, Size plot.Values + + // PX, PY are the actual pixel plotting coordinates for each XY value. + PX, PY []float32 + + // Style is the style for plotting. + Style plot.Style + + stylers plot.Stylers +} + +// NewXY returns an XY plotter for given X, Y data. +// data can also include Color and / or Size for the points. +// Styler functions are obtained from the Y metadata if present. +func NewXY(data plot.Data) *XY { + if data.CheckLengths() != nil { + return nil + } + ln := &XY{} + ln.X = plot.MustCopyRole(data, plot.X) + ln.Y = plot.MustCopyRole(data, plot.Y) + if ln.X == nil || ln.Y == nil { + return nil + } + ln.stylers = plot.GetStylersFromData(data, plot.Y) + ln.Color = plot.CopyRole(data, plot.Color) + ln.Size = plot.CopyRole(data, plot.Size) + ln.Defaults() + return ln +} + +// NewLine returns an XY plot drawing Lines by default. +func NewLine(data plot.Data) *XY { + ln := NewXY(data) + if ln == nil { + return ln + } + ln.Style.Line.On = plot.On + ln.Style.Point.On = plot.Off + return ln +} + +// NewScatter returns an XY scatter plot drawing Points by default. +func NewScatter(data plot.Data) *XY { + ln := NewXY(data) + if ln == nil { + return ln + } + ln.Style.Line.On = plot.Off + ln.Style.Point.On = plot.On + return ln +} + +func (ln *XY) Defaults() { + ln.Style.Defaults() +} + +// Styler adds a style function to set style parameters. +func (ln *XY) Styler(f func(s *plot.Style)) *XY { + ln.stylers.Add(f) + return ln +} + +func (ln *XY) Stylers() *plot.Stylers { return &ln.stylers } + +func (ln *XY) ApplyStyle(ps *plot.PlotStyle) { + ps.SetElementStyle(&ln.Style) + ln.stylers.Run(&ln.Style) +} + +func (ln *XY) Data() (data plot.Data, pixX, pixY []float32) { + pixX = ln.PX + pixY = ln.PY + data = plot.Data{} + data[plot.X] = ln.X + data[plot.Y] = ln.Y + if ln.Size != nil { + data[plot.Size] = ln.Size + } + if ln.Color != nil { + data[plot.Color] = ln.Color + } + return +} + +// Plot does the drawing, implementing the plot.Plotter interface. +func (ln *XY) Plot(plt *plot.Plot) { + ln.PX = plot.PlotX(plt, ln.X) + ln.PY = plot.PlotY(plt, ln.Y) + np := len(ln.PX) + if np == 0 || len(ln.PY) == 0 { + return + } + pc := plt.Paint + if ln.Style.Line.HasFill() { + pc.FillStyle.Color = ln.Style.Line.Fill + minY := plt.PY(plt.Y.Range.Min) + prevX := ln.PX[0] + prevY := minY + pc.MoveTo(prevX, prevY) + for i, ptx := range ln.PX { + pty := ln.PY[i] + switch ln.Style.Line.Step { + case plot.NoStep: + if ptx < prevX { + pc.LineTo(prevX, minY) + pc.ClosePath() + pc.MoveTo(ptx, minY) + } + pc.LineTo(ptx, pty) + case plot.PreStep: + if i == 0 { + continue + } + if ptx < prevX { + pc.LineTo(prevX, minY) + pc.ClosePath() + pc.MoveTo(ptx, minY) + } else { + pc.LineTo(prevX, pty) + } + pc.LineTo(ptx, pty) + case plot.MidStep: + if ptx < prevX { + pc.LineTo(prevX, minY) + pc.ClosePath() + pc.MoveTo(ptx, minY) + } else { + pc.LineTo(0.5*(prevX+ptx), prevY) + pc.LineTo(0.5*(prevX+ptx), pty) + } + pc.LineTo(ptx, pty) + case plot.PostStep: + if ptx < prevX { + pc.LineTo(prevX, minY) + pc.ClosePath() + pc.MoveTo(ptx, minY) + } else { + pc.LineTo(ptx, prevY) + } + pc.LineTo(ptx, pty) + } + prevX, prevY = ptx, pty + } + pc.LineTo(prevX, minY) + pc.ClosePath() + pc.Fill() + } + pc.FillStyle.Color = nil + + if ln.Style.Line.SetStroke(plt) { + if plt.HighlightPlotter == ln { + pc.StrokeStyle.Width.Dots *= 1.5 + } + prevX, prevY := ln.PX[0], ln.PY[0] + pc.MoveTo(prevX, prevY) + for i := 1; i < np; i++ { + ptx, pty := ln.PX[i], ln.PY[i] + if ln.Style.Line.Step != plot.NoStep { + if ptx >= prevX { + switch ln.Style.Line.Step { + case plot.PreStep: + pc.LineTo(prevX, pty) + case plot.MidStep: + pc.LineTo(0.5*(prevX+ptx), prevY) + pc.LineTo(0.5*(prevX+ptx), pty) + case plot.PostStep: + pc.LineTo(ptx, prevY) + } + } else { + pc.MoveTo(ptx, pty) + } + } + if !ln.Style.Line.NegativeX && ptx < prevX { + pc.MoveTo(ptx, pty) + } else { + pc.LineTo(ptx, pty) + } + prevX, prevY = ptx, pty + } + pc.Stroke() + } + if ln.Style.Point.SetStroke(plt) { + origWidth := pc.StrokeStyle.Width.Dots + for i, ptx := range ln.PX { + pty := ln.PY[i] + if plt.HighlightPlotter == ln { + if i == plt.HighlightIndex { + pc.StrokeStyle.Width.Dots *= 1.5 + } else { + pc.StrokeStyle.Width.Dots = origWidth + } + } + ln.Style.Point.DrawShape(pc, math32.Vec2(ptx, pty)) + } + } else if plt.HighlightPlotter == ln { + op := ln.Style.Point.On + ln.Style.Point.On = plot.On + ln.Style.Point.SetStroke(plt) + ptx := ln.PX[plt.HighlightIndex] + pty := ln.PY[plt.HighlightIndex] + ln.Style.Point.DrawShape(pc, math32.Vec2(ptx, pty)) + ln.Style.Point.On = op + } + pc.FillStyle.Color = nil +} + +// UpdateRange updates the given ranges. +func (ln *XY) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) { + // todo: include point sizes! + plot.Range(ln.X, xr) + plot.RangeClamp(ln.Y, yr, &ln.Style.Range) +} + +// Thumbnail returns the thumbnail, implementing the plot.Thumbnailer interface. +func (ln *XY) Thumbnail(plt *plot.Plot) { + pc := plt.Paint + ptb := pc.Bounds + midY := 0.5 * float32(ptb.Min.Y+ptb.Max.Y) + + if ln.Style.Line.Fill != nil { + tb := ptb + if ln.Style.Line.Width.Value > 0 { + tb.Min.Y = int(midY) + } + pc.FillBox(math32.FromPoint(tb.Min), math32.FromPoint(tb.Size()), ln.Style.Line.Fill) + } + + if ln.Style.Line.SetStroke(plt) { + pc.MoveTo(float32(ptb.Min.X), midY) + pc.LineTo(float32(ptb.Max.X), midY) + pc.Stroke() + } + + if ln.Style.Point.SetStroke(plt) { + midX := 0.5 * float32(ptb.Min.X+ptb.Max.X) + ln.Style.Point.DrawShape(pc, math32.Vec2(midX, midY)) + } + pc.FillStyle.Color = nil +} diff --git a/plot/plotter.go b/plot/plotter.go index a656a3f808..6a74c431ed 100644 --- a/plot/plotter.go +++ b/plot/plotter.go @@ -4,20 +4,84 @@ package plot +import ( + "fmt" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/math32/minmax" +) + // Plotter is an interface that wraps the Plot method. -// Some standard implementations of Plotter can be found in plotters. +// Standard implementations of Plotter are in the [plots] package. type Plotter interface { - // Plot draws the data to the Plot Paint + + // Plot draws the data to the Plot Paint. Plot(pt *Plot) - // returns the data for this plot as X,Y points, - // including corresponding pixel data. - // This allows gui interface to inspect data etc. - XYData() (data XYer, pixels XYer) + // UpdateRange updates the given ranges. + UpdateRange(plt *Plot, xr, yr, zr *minmax.F64) + + // Data returns the data by roles for this plot, for both the original + // data and the pixel-transformed X,Y coordinates for that data. + // This allows a GUI interface to inspect data etc. + Data() (data Data, pixX, pixY []float32) + + // Stylers returns the styler functions for this element. + Stylers() *Stylers + + // ApplyStyle applies any stylers to this element, + // first initializing from the given global plot style, which has + // already been styled with defaults and all the plot element stylers. + ApplyStyle(plotStyle *PlotStyle) +} + +// PlotterType registers a Plotter so that it can be created with appropriate data. +type PlotterType struct { + // Name of the plot type. + Name string + + // Doc is the documentation for this Plotter. + Doc string + + // Required Data roles for this plot. Data for these Roles must be provided. + Required []Roles + + // Optional Data roles for this plot. + Optional []Roles + + // New returns a new plotter of this type with given data in given roles. + New func(data Data) Plotter +} + +// PlotterName is the name of a specific plotter type. +type PlotterName string + +// Plotters is the registry of [Plotter] types. +var Plotters = map[string]PlotterType{} + +// RegisterPlotter registers a plotter type. +func RegisterPlotter(name, doc string, required, optional []Roles, newFun func(data Data) Plotter) { + Plotters[name] = PlotterType{Name: name, Doc: doc, Required: required, Optional: optional, New: newFun} +} + +// PlotterByType returns [PlotterType] info for a registered [Plotter] +// of given type name, e.g., "XY", "Bar" etc, +// Returns an error and nil if type name is not a registered type. +func PlotterByType(typeName string) (*PlotterType, error) { + pt, ok := Plotters[typeName] + if !ok { + return nil, fmt.Errorf("plot.PlotterByType type name is not registered: %s", typeName) + } + return &pt, nil } -// DataRanger wraps the DataRange method. -type DataRanger interface { - // DataRange returns the range of X and Y values. - DataRange(pt *Plot) (xmin, xmax, ymin, ymax float32) +// NewPlotter returns a new plotter of given type, e.g., "XY", "Bar" etc, +// for given data roles (which must include Required roles, and may include Optional ones). +// Logs an error and returns nil if type name is not a registered type. +func NewPlotter(typeName string, data Data) Plotter { + pt, err := PlotterByType(typeName) + if errors.Log(err) != nil { + return nil + } + return pt.New(data) } diff --git a/plot/plots/shapes.go b/plot/point.go similarity index 60% rename from plot/plots/shapes.go rename to plot/point.go index 3d6252343d..739aa439f2 100644 --- a/plot/plots/shapes.go +++ b/plot/point.go @@ -2,44 +2,77 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package plots +package plot import ( + "image" + + "cogentcore.org/core/colors" "cogentcore.org/core/math32" "cogentcore.org/core/paint" + "cogentcore.org/core/styles/units" ) -type Shapes int32 //enums:enum +// PointStyle has style properties for drawing points as different shapes. +type PointStyle struct { //types:add -setters + // On indicates whether to plot points. + On DefaultOffOn -const ( - // Ring is the outline of a circle - Ring Shapes = iota - - // Circle is a solid circle - Circle + // Shape to draw. + Shape Shapes - // Square is the outline of a square - Square + // Color is the stroke color image specification. + // Setting to nil turns line off. + Color image.Image - // Box is a filled square - Box + // Fill is the color to fill solid regions, in a plot-specific + // way (e.g., the area below a Line plot, the bar color). + // Use nil to disable filling. + Fill image.Image - // Triangle is the outline of a triangle - Triangle + // Width is the line width for point glyphs, with a default of 1 Pt (point). + // Setting to 0 turns line off. + Width units.Value - // Pyramid is a filled triangle - Pyramid + // Size of shape to draw for each point. + // Defaults to 4 Pt (point). + Size units.Value +} - // Plus is a plus sign - Plus +func (ps *PointStyle) Defaults() { + ps.Color = colors.Scheme.OnSurface + ps.Fill = colors.Scheme.OnSurface + ps.Width.Pt(1) + ps.Size.Pt(4) +} - // Cross is a big X - Cross -) +// SetStroke sets the stroke style in plot paint to current line style. +// returns false if either the Width = 0 or Color is nil +func (ps *PointStyle) SetStroke(pt *Plot) bool { + if ps.On == Off || ps.Color == nil { + return false + } + pc := pt.Paint + uc := &pc.UnitContext + ps.Width.ToDots(uc) + ps.Size.ToDots(uc) + if ps.Width.Dots == 0 || ps.Size.Dots == 0 { + return false + } + pc.StrokeStyle.Width = ps.Width + pc.StrokeStyle.Color = ps.Color + pc.StrokeStyle.ToDots(uc) + pc.FillStyle.Color = ps.Fill + return true +} // DrawShape draws the given shape -func DrawShape(pc *paint.Context, pos math32.Vector2, size float32, shape Shapes) { - switch shape { +func (ps *PointStyle) DrawShape(pc *paint.Context, pos math32.Vector2) { + size := ps.Size.Dots + if size == 0 { + return + } + switch ps.Shape { case Ring: DrawRing(pc, pos, size) case Circle: @@ -126,3 +159,32 @@ func DrawCross(pc *paint.Context, pos math32.Vector2, size float32) { pc.ClosePath() pc.Stroke() } + +// Shapes has the options for how to draw points in the plot. +type Shapes int32 //enums:enum + +const ( + // Ring is the outline of a circle + Ring Shapes = iota + + // Circle is a solid circle + Circle + + // Square is the outline of a square + Square + + // Box is a filled square + Box + + // Triangle is the outline of a triangle + Triangle + + // Pyramid is a filled triangle + Pyramid + + // Plus is a plus sign + Plus + + // Cross is a big X + Cross +) diff --git a/plot/style.go b/plot/style.go new file mode 100644 index 0000000000..d54a45343a --- /dev/null +++ b/plot/style.go @@ -0,0 +1,201 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package plot + +import ( + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/math32/minmax" + "cogentcore.org/core/styles/units" +) + +// Style contains the plot styling properties relevant across +// most plot types. These properties apply to individual plot elements +// while the Plot properties applies to the overall plot itself. +type Style struct { //types:add -setters + + // Plot has overall plot-level properties, which can be set by any + // plot element, and are updated first, before applying element-wise styles. + Plot PlotStyle `display:"-"` + + // On specifies whether to plot this item, for table-based plots. + On bool + + // Plotter is the type of plotter to use in plotting this data, + // for [plot.NewTablePlot] [table.Table] driven plots. + // Blank means use default ([plots.XY] is overall default). + Plotter PlotterName + + // Role specifies how a particular column of data should be used, + // for [plot.NewTablePlot] [table.Table] driven plots. + Role Roles + + // Group specifies a group of related data items, + // for [plot.NewTablePlot] [table.Table] driven plots, + // where different columns of data within the same Group play different Roles. + Group string + + // Range is the effective range of data to plot, where either end can be fixed. + Range minmax.Range64 `display:"inline"` + + // Label provides an alternative label to use for axis, if set. + Label string + + // NoLegend excludes this item from the legend when it otherwise would be included, + // for [plot.NewTablePlot] [table.Table] driven plots. + // Role = Y values are included in the Legend by default. + NoLegend bool + + // NTicks sets the desired number of ticks for the axis, if > 0. + NTicks int + + // Line has style properties for drawing lines. + Line LineStyle `display:"add-fields"` + + // Point has style properties for drawing points. + Point PointStyle `display:"add-fields"` + + // Text has style properties for rendering text. + Text TextStyle `display:"add-fields"` + + // Width has various plot width properties. + Width WidthStyle `display:"inline"` +} + +// NewStyle returns a new Style object with defaults applied. +func NewStyle() *Style { + st := &Style{} + st.Defaults() + return st +} + +func (st *Style) Defaults() { + st.Plot.Defaults() + st.Line.Defaults() + st.Point.Defaults() + st.Text.Defaults() + st.Width.Defaults() +} + +// WidthStyle contains various plot width properties relevant across +// different plot types. +type WidthStyle struct { //types:add -setters + // Cap is the width of the caps drawn at the top of error bars. + // The default is 10dp + Cap units.Value + + // Offset for Bar plot is the offset added to each X axis value + // relative to the Stride computed value (X = offset + index * Stride) + // Defaults to 0. + Offset float64 + + // Stride for Bar plot is distance between bars. Defaults to 1. + Stride float64 + + // Width for Bar plot is the width of the bars, as a fraction of the Stride, + // to prevent bar overlap. Defaults to .8. + Width float64 `min:"0.01" max:"1" default:"0.8"` + + // Pad for Bar plot is additional space at start / end of data range, + // to keep bars from overflowing ends. This amount is subtracted from Offset + // and added to (len(Values)-1)*Stride -- no other accommodation for bar + // width is provided, so that should be built into this value as well. + // Defaults to 1. + Pad float64 +} + +func (ws *WidthStyle) Defaults() { + ws.Cap.Dp(10) + ws.Offset = 1 + ws.Stride = 1 + ws.Width = .8 + ws.Pad = 1 +} + +// Stylers is a list of styling functions that set Style properties. +// These are called in the order added. +type Stylers []func(s *Style) + +// Add Adds a styling function to the list. +func (st *Stylers) Add(f func(s *Style)) { + *st = append(*st, f) +} + +// Run runs the list of styling functions on given [Style] object. +func (st *Stylers) Run(s *Style) { + for _, f := range *st { + f(s) + } +} + +// NewStyle returns a new Style object with styling functions applied +// on top of Style defaults. +func (st *Stylers) NewStyle(ps *PlotStyle) *Style { + s := NewStyle() + ps.SetElementStyle(s) + st.Run(s) + return s +} + +// SetStylersTo sets the [Stylers] into given object's [metadata]. +func SetStylersTo(obj any, st Stylers) { + metadata.SetTo(obj, "PlotStylers", st) +} + +// GetStylersFrom returns [Stylers] from given object's [metadata]. +// Returns nil if none or no metadata. +func GetStylersFrom(obj any) Stylers { + st, _ := metadata.GetFrom[Stylers](obj, "PlotStylers") + return st +} + +// SetStylerTo sets the [Styler] function into given object's [metadata], +// replacing anything that might have already been added. +func SetStylerTo(obj any, f func(s *Style)) { + metadata.SetTo(obj, "PlotStylers", Stylers{f}) +} + +// SetFirstStylerTo sets the [Styler] function into given object's [metadata], +// only if there are no other stylers present. +func SetFirstStylerTo(obj any, f func(s *Style)) { + st := GetStylersFrom(obj) + if len(st) > 0 { + return + } + metadata.SetTo(obj, "PlotStylers", Stylers{f}) +} + +// AddStylerTo adds the given [Styler] function into given object's [metadata]. +func AddStylerTo(obj any, f func(s *Style)) { + st := GetStylersFrom(obj) + st.Add(f) + SetStylersTo(obj, st) +} + +// GetStylersFromData returns [Stylers] from given role +// in given [Data]. nil if not present. +func GetStylersFromData(data Data, role Roles) Stylers { + vr, ok := data[role] + if !ok { + return nil + } + return GetStylersFrom(vr) +} + +//////// + +// DefaultOffOn specifies whether to use the default value for a bool option, +// or to override the default and set Off or On. +type DefaultOffOn int32 //enums:enum + +const ( + // Default means use the default value. + Default DefaultOffOn = iota + + // Off means to override the default and turn Off. + Off + + // On means to override the default and turn On. + On +) diff --git a/plot/table.go b/plot/table.go new file mode 100644 index 0000000000..91e8aa7c11 --- /dev/null +++ b/plot/table.go @@ -0,0 +1,234 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package plot + +import ( + "fmt" + "reflect" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/reflectx" + "cogentcore.org/core/tensor/table" + "golang.org/x/exp/maps" +) + +// NewTablePlot returns a new Plot with all configuration based on given +// [table.Table] set of columns and associated metadata, which must have +// [Stylers] functions set (e.g., [SetStylersTo]) that at least set basic +// table parameters, including: +// - On: Set the main (typically Role = Y) column On to include in plot. +// - Role: Set the appropriate [Roles] role for this column (Y, X, etc). +// - Group: Multiple columns used for a given Plotter type must be grouped +// together with a common name (typically the name of the main Y axis), +// e.g., for Low, High error bars, Size, Color, etc. If only one On column, +// then Group can be empty and all other such columns will be grouped. +// - Plotter: Determines the type of Plotter element to use, which in turn +// determines the additional Roles that can be used within a Group. +func NewTablePlot(dt *table.Table) (*Plot, error) { + nc := len(dt.Columns.Values) + if nc == 0 { + return nil, errors.New("plot.NewTablePlot: no columns in data table") + } + csty := make([]*Style, nc) + gps := make(map[string][]int, nc) + xi := -1 // get the _last_ role = X column -- most specific counter + var errs []error + var pstySt Style // overall PlotStyle accumulator + pstySt.Defaults() + for ci, cl := range dt.Columns.Values { + st := &Style{} + st.Defaults() + stl := GetStylersFrom(cl) + if stl != nil { + stl.Run(st) + } + csty[ci] = st + stl.Run(&pstySt) + gps[st.Group] = append(gps[st.Group], ci) + if st.Role == X { + xi = ci + } + } + psty := pstySt.Plot + globalX := false + xidxs := map[int]bool{} // map of all the _unique_ x indexes used + if psty.XAxis.Column != "" { + xc := dt.Columns.IndexByKey(psty.XAxis.Column) + if xc >= 0 { + xi = xc + globalX = true + xidxs[xi] = true + } else { + errs = append(errs, errors.New("XAxis.Column name not found: "+psty.XAxis.Column)) + } + } + doneGps := map[string]bool{} + plt := New() + var legends []Thumbnailer // candidates for legend adding -- only add if > 1 + var legLabels []string + var barCols []int // column indexes of bar plots + var barPlots []int // plotter indexes of bar plots + for ci, cl := range dt.Columns.Values { + cnm := dt.Columns.Keys[ci] + st := csty[ci] + if !st.On || st.Role == X { + continue + } + lbl := cnm + if st.Label != "" { + lbl = st.Label + } + gp := st.Group + if doneGps[gp] { + continue + } + if gp != "" { + doneGps[gp] = true + } + ptyp := "XY" + if st.Plotter != "" { + ptyp = string(st.Plotter) + } + pt, err := PlotterByType(ptyp) + if err != nil { + errs = append(errs, err) + continue + } + data := Data{st.Role: cl} + gcols := gps[gp] + gotReq := true + gotX := -1 + if globalX { + data[X] = dt.Columns.Values[xi] + gotX = xi + } + for _, rl := range pt.Required { + if rl == st.Role || (rl == X && globalX) { + continue + } + got := false + for _, gi := range gcols { + gst := csty[gi] + if gst.Role == rl { + if rl == Y { + if !gst.On { + continue + } + } + data[rl] = dt.Columns.Values[gi] + got = true + if rl == X { + gotX = gi // fallthrough so we get the last X + } else { + break + } + } + } + if !got { + if rl == X && xi >= 0 { + gotX = xi + data[rl] = dt.Columns.Values[xi] + } else { + err = fmt.Errorf("plot.NewTablePlot: Required Role %q not found in Group %q, Plotter %q not added for Column: %q", rl.String(), gp, ptyp, cnm) + errs = append(errs, err) + gotReq = false + } + } + } + if !gotReq { + continue + } + if gotX >= 0 { + xidxs[gotX] = true + } + for _, rl := range pt.Optional { + if rl == st.Role { // should not happen + continue + } + for _, gi := range gcols { + gst := csty[gi] + if gst.Role == rl { + data[rl] = dt.Columns.Values[gi] + break + } + } + } + pl := pt.New(data) + if reflectx.IsNil(reflect.ValueOf(pl)) { + err = fmt.Errorf("plot.NewTablePlot: error in creating plotter type: %q", ptyp) + errs = append(errs, err) + continue + } + plt.Add(pl) + if !st.NoLegend { + if tn, ok := pl.(Thumbnailer); ok { + legends = append(legends, tn) + legLabels = append(legLabels, lbl) + } + } + if ptyp == "Bar" { + barCols = append(barCols, ci) + barPlots = append(barPlots, len(plt.Plotters)-1) + } + } + if len(legends) > 1 { + for i, l := range legends { + plt.Legend.Add(legLabels[i], l) + } + } + if psty.XAxis.Label == "" && len(xidxs) == 1 { + xi := maps.Keys(xidxs)[0] + lbl := dt.Columns.Keys[xi] + if csty[xi].Label != "" { + lbl = csty[xi].Label + } + if len(plt.Plotters) > 0 { + pl0 := plt.Plotters[0] + if pl0 != nil { + pl0.Stylers().Add(func(s *Style) { + s.Plot.XAxis.Label = lbl + }) + } + } + } + nbar := len(barCols) + if nbar > 1 { + sz := 1.0 / (float64(nbar) + 0.5) + for bi, bp := range barPlots { + pl := plt.Plotters[bp] + pl.Stylers().Add(func(s *Style) { + s.Width.Stride = 1 + s.Width.Offset = float64(bi) * sz + s.Width.Width = psty.BarWidth * sz + }) + } + } + return plt, errors.Join(errs...) +} + +// todo: bar chart rows, if needed +// +// netn := pl.table.NumRows() * stride +// xc := pl.table.ColumnByIndex(xi) +// vals := make([]string, netn) +// for i, dx := range pl.table.Indexes { +// pi := mid + i*stride +// if pi < netn && dx < xc.Len() { +// vals[pi] = xc.String1D(dx) +// } +// } +// plt.NominalX(vals...) + +// todo: +// Use string labels for X axis if X is a string +// xc := pl.table.ColumnByIndex(xi) +// if xc.Tensor.IsString() { +// xcs := xc.Tensor.(*tensor.String) +// vals := make([]string, pl.table.NumRows()) +// for i, dx := range pl.table.Indexes { +// vals[i] = xcs.Values[dx] +// } +// plt.NominalX(vals...) +// } diff --git a/plot/text.go b/plot/text.go index cb86017a65..040e6f0cd1 100644 --- a/plot/text.go +++ b/plot/text.go @@ -5,6 +5,8 @@ package plot import ( + "image" + "cogentcore.org/core/colors" "cogentcore.org/core/math32" "cogentcore.org/core/paint" @@ -16,40 +18,42 @@ import ( // if not set, the standard Cogent Core default font is used. var DefaultFontFamily = "" -// TextStyle specifies styling parameters for Text elements -type TextStyle struct { - styles.FontRender +// TextStyle specifies styling parameters for Text elements. +type TextStyle struct { //types:add -setters + // Size of font to render. Default is 16dp + Size units.Value + + // Family name for font (inherited): ordered list of comma-separated names + // from more general to more specific to use. Use split on, to parse. + Family string - // how to align text along the relevant dimension for the text element + // Color of text. + Color image.Image + + // Align specifies how to align text along the relevant + // dimension for the text element. Align styles.Aligns - // Padding is used in a case-dependent manner to add space around text elements + // Padding is used in a case-dependent manner to add + // space around text elements. Padding units.Value - // rotation of the text, in Degrees + // Rotation of the text, in degrees. Rotation float32 + + // Offset is added directly to the final label location. + Offset units.XY } func (ts *TextStyle) Defaults() { - ts.FontRender.Defaults() + ts.Size.Dp(16) ts.Color = colors.Scheme.OnSurface ts.Align = styles.Center if DefaultFontFamily != "" { - ts.FontRender.Family = DefaultFontFamily + ts.Family = DefaultFontFamily } } -func (ts *TextStyle) openFont(pt *Plot) { - if ts.Font.Face == nil { - paint.OpenFont(&ts.FontRender, &pt.Paint.UnitContext) // calls SetUnContext after updating metrics - } -} - -func (ts *TextStyle) ToDots(uc *units.Context) { - ts.FontRender.ToDots(uc) - ts.Padding.ToDots(uc) -} - // Text specifies a single text element in a plot type Text struct { @@ -59,6 +63,9 @@ type Text struct { // styling for this text element Style TextStyle + // font has the full font rendering styles. + font styles.FontRender + // PaintText is the [paint.Text] for the text. PaintText paint.Text } @@ -70,7 +77,10 @@ func (tx *Text) Defaults() { // config is called during the layout of the plot, prior to drawing func (tx *Text) Config(pt *Plot) { uc := &pt.Paint.UnitContext - fs := &tx.Style.FontRender + fs := &tx.font + fs.Size = tx.Style.Size + fs.Family = tx.Style.Family + fs.Color = tx.Style.Color if math32.Abs(tx.Style.Rotation) > 10 { tx.Style.Align = styles.End } @@ -88,6 +98,17 @@ func (tx *Text) Config(pt *Plot) { } } +func (tx *Text) openFont(pt *Plot) { + if tx.font.Face == nil { + paint.OpenFont(&tx.font, &pt.Paint.UnitContext) // calls SetUnContext after updating metrics + } +} + +func (tx *Text) ToDots(uc *units.Context) { + tx.font.ToDots(uc) + tx.Style.Padding.ToDots(uc) +} + // PosX returns the starting position for a horizontally-aligned text element, // based on given width. Text must have been config'd already. func (tx *Text) PosX(width float32) math32.Vector2 { diff --git a/plot/tick.go b/plot/tick.go index 8045c2ac49..3be73dfc27 100644 --- a/plot/tick.go +++ b/plot/tick.go @@ -5,16 +5,15 @@ package plot import ( + "math" "strconv" "time" - - "cogentcore.org/core/math32" ) // A Tick is a single tick mark on an axis. type Tick struct { // Value is the data value marked by this Tick. - Value float32 + Value float64 // Label is the text to display at the tick mark. // If Label is an empty string then this is a minor tick mark. @@ -28,8 +27,9 @@ func (tk *Tick) IsMinor() bool { // Ticker creates Ticks in a specified range type Ticker interface { - // Ticks returns Ticks in a specified range - Ticks(min, max float32) []Tick + // Ticks returns Ticks in a specified range, with desired number of ticks, + // which can be ignored depending on the ticker type. + Ticks(min, max float64, nticks int) []Tick } // DefaultTicks is suitable for the Ticker field of an Axis, @@ -39,15 +39,13 @@ type DefaultTicks struct{} var _ Ticker = DefaultTicks{} // Ticks returns Ticks in the specified range. -func (DefaultTicks) Ticks(min, max float32) []Tick { +func (DefaultTicks) Ticks(min, max float64, nticks int) []Tick { if max <= min { panic("illegal range") } - const suggestedTicks = 3 - - labels, step, q, mag := talbotLinHanrahan(min, max, suggestedTicks, withinData, nil, nil, nil) - majorDelta := step * math32.Pow10(mag) + labels, step, q, mag := talbotLinHanrahan(min, max, nticks, withinData, nil, nil, nil) + majorDelta := step * math.Pow10(mag) if q == 0 { // Simple fall back was chosen, so // majorDelta is the label distance. @@ -62,16 +60,16 @@ func (DefaultTicks) Ticks(min, max float32) []Tick { off = 1 fc = 'g' } - if math32.Trunc(q) != q { + if math.Trunc(q) != q { off += 2 } prec := minInt(6, maxInt(off, -mag)) ticks := make([]Tick, len(labels)) for i, v := range labels { - ticks[i] = Tick{Value: v, Label: strconv.FormatFloat(float64(v), fc, prec, 32)} + ticks[i] = Tick{Value: v, Label: strconv.FormatFloat(float64(v), fc, prec, 64)} } - var minorDelta float32 + var minorDelta float64 // See talbotLinHanrahan for the values used here. switch step { case 1, 2.5: @@ -87,7 +85,7 @@ func (DefaultTicks) Ticks(min, max float32) []Tick { // Find the first minor tick not greater // than the lowest data value. - var i float32 + var i float64 for labels[0]+(i-1)*minorDelta > min { i-- } @@ -101,7 +99,7 @@ func (DefaultTicks) Ticks(min, max float32) []Tick { } found := false for _, t := range ticks { - if math32.Abs(t.Value-val) < minorDelta/2 { + if math.Abs(t.Value-val) < minorDelta/2 { found = true } } @@ -139,20 +137,20 @@ type LogTicks struct { var _ Ticker = LogTicks{} // Ticks returns Ticks in a specified range -func (t LogTicks) Ticks(min, max float32) []Tick { +func (t LogTicks) Ticks(min, max float64, nticks int) []Tick { if min <= 0 || max <= 0 { panic("Values must be greater than 0 for a log scale.") } - val := math32.Pow10(int(math32.Log10(min))) - max = math32.Pow10(int(math32.Ceil(math32.Log10(max)))) + val := math.Pow10(int(math.Log10(min))) + max = math.Pow10(int(math.Ceil(math.Log10(max)))) var ticks []Tick for val < max { for i := 1; i < 10; i++ { if i == 1 { ticks = append(ticks, Tick{Value: val, Label: formatFloatTick(val, t.Prec)}) } - ticks = append(ticks, Tick{Value: val * float32(i)}) + ticks = append(ticks, Tick{Value: val * float64(i)}) } val *= 10 } @@ -168,13 +166,13 @@ type ConstantTicks []Tick var _ Ticker = ConstantTicks{} // Ticks returns Ticks in a specified range -func (ts ConstantTicks) Ticks(float32, float32) []Tick { +func (ts ConstantTicks) Ticks(float64, float64, int) []Tick { return ts } // UnixTimeIn returns a time conversion function for the given location. -func UnixTimeIn(loc *time.Location) func(t float32) time.Time { - return func(t float32) time.Time { +func UnixTimeIn(loc *time.Location) func(t float64) time.Time { + return func(t float64) time.Time { return time.Unix(int64(t), 0).In(loc) } } @@ -194,13 +192,13 @@ type TimeTicks struct { // Time takes a float32 value and converts it into a time.Time. // If nil, UTCUnixTime is used. - Time func(t float32) time.Time + Time func(t float64) time.Time } var _ Ticker = TimeTicks{} // Ticks implements plot.Ticker. -func (t TimeTicks) Ticks(min, max float32) []Tick { +func (t TimeTicks) Ticks(min, max float64, nticks int) []Tick { if t.Ticker == nil { t.Ticker = DefaultTicks{} } @@ -211,7 +209,7 @@ func (t TimeTicks) Ticks(min, max float32) []Tick { t.Time = UTCUnixTime } - ticks := t.Ticker.Ticks(min, max) + ticks := t.Ticker.Ticks(min, max, nticks) for i := range ticks { tick := &ticks[i] if tick.Label == "" { @@ -269,17 +267,17 @@ func tickLabelWidth(sty text.Style, ticks []Tick) vg.Length { // formatFloatTick returns a g-formated string representation of v // to the specified precision. -func formatFloatTick(v float32, prec int) string { - return strconv.FormatFloat(float64(v), 'g', prec, 32) +func formatFloatTick(v float64, prec int) string { + return strconv.FormatFloat(float64(v), 'g', prec, 64) } -// TickerFunc is suitable for the Ticker field of an Axis. -// It is an adapter which allows to quickly setup a Ticker using a function with an appropriate signature. -type TickerFunc func(min, max float32) []Tick - -var _ Ticker = TickerFunc(nil) - -// Ticks implements plot.Ticker. -func (f TickerFunc) Ticks(min, max float32) []Tick { - return f(min, max) -} +// // TickerFunc is suitable for the Ticker field of an Axis. +// // It is an adapter which allows to quickly setup a Ticker using a function with an appropriate signature. +// type TickerFunc func(min, max float64) []Tick +// +// var _ Ticker = TickerFunc(nil) +// +// // Ticks implements plot.Ticker. +// func (f TickerFunc) Ticks(min, max float64) []Tick { +// return f(min, max) +// } diff --git a/plot/typegen.go b/plot/typegen.go index cc812fd88e..f7020e7b3c 100644 --- a/plot/typegen.go +++ b/plot/typegen.go @@ -3,12 +3,57 @@ package plot import ( + "image" + + "cogentcore.org/core/math32/minmax" + "cogentcore.org/core/styles" + "cogentcore.org/core/styles/units" "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Normalizer", IDName: "normalizer", Doc: "Normalizer rescales values from the data coordinate system to the\nnormalized coordinate system.", Methods: []types.Method{{Name: "Normalize", Doc: "Normalize transforms a value x in the data coordinate system to\nthe normalized coordinate system.", Args: []string{"min", "max", "x"}, Returns: []string{"float32"}}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.AxisScales", IDName: "axis-scales", Doc: "AxisScales are the scaling options for how values are distributed\nalong an axis: Linear, Log, etc."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.AxisStyle", IDName: "axis-style", Doc: "AxisStyle has style properties for the axis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Text", Doc: "Text has the text style parameters for the text label."}, {Name: "Line", Doc: "Line has styling properties for the axis line."}, {Name: "Padding", Doc: "Padding between the axis line and the data. Having\nnon-zero padding ensures that the data is never drawn\non the axis, thus making it easier to see."}, {Name: "NTicks", Doc: "NTicks is the desired number of ticks (actual likely will be different)."}, {Name: "Scale", Doc: "Scale specifies how values are scaled along the axis:\nLinear, Log, Inverted"}, {Name: "TickText", Doc: "TickText has the text style for rendering tick labels,\nand is shared for actual rendering."}, {Name: "TickLine", Doc: "TickLine has line style for drawing tick lines."}, {Name: "TickLength", Doc: "TickLength is the length of tick lines."}}}) + +// SetText sets the [AxisStyle.Text]: +// Text has the text style parameters for the text label. +func (t *AxisStyle) SetText(v TextStyle) *AxisStyle { t.Text = v; return t } + +// SetLine sets the [AxisStyle.Line]: +// Line has styling properties for the axis line. +func (t *AxisStyle) SetLine(v LineStyle) *AxisStyle { t.Line = v; return t } + +// SetPadding sets the [AxisStyle.Padding]: +// Padding between the axis line and the data. Having +// non-zero padding ensures that the data is never drawn +// on the axis, thus making it easier to see. +func (t *AxisStyle) SetPadding(v units.Value) *AxisStyle { t.Padding = v; return t } + +// SetNTicks sets the [AxisStyle.NTicks]: +// NTicks is the desired number of ticks (actual likely will be different). +func (t *AxisStyle) SetNTicks(v int) *AxisStyle { t.NTicks = v; return t } + +// SetScale sets the [AxisStyle.Scale]: +// Scale specifies how values are scaled along the axis: +// Linear, Log, Inverted +func (t *AxisStyle) SetScale(v AxisScales) *AxisStyle { t.Scale = v; return t } + +// SetTickText sets the [AxisStyle.TickText]: +// TickText has the text style for rendering tick labels, +// and is shared for actual rendering. +func (t *AxisStyle) SetTickText(v TextStyle) *AxisStyle { t.TickText = v; return t } + +// SetTickLine sets the [AxisStyle.TickLine]: +// TickLine has line style for drawing tick lines. +func (t *AxisStyle) SetTickLine(v LineStyle) *AxisStyle { t.TickLine = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Axis", IDName: "axis", Doc: "Axis represents either a horizontal or vertical\naxis of a plot.", Fields: []types.Field{{Name: "Min", Doc: "Min and Max are the minimum and maximum data\nvalues represented by the axis."}, {Name: "Max", Doc: "Min and Max are the minimum and maximum data\nvalues represented by the axis."}, {Name: "Axis", Doc: "specifies which axis this is: X or Y"}, {Name: "Label", Doc: "Label for the axis"}, {Name: "Line", Doc: "Line styling properties for the axis line."}, {Name: "Padding", Doc: "Padding between the axis line and the data. Having\nnon-zero padding ensures that the data is never drawn\non the axis, thus making it easier to see."}, {Name: "TickText", Doc: "has the text style for rendering tick labels, and is shared for actual rendering"}, {Name: "TickLine", Doc: "line style for drawing tick lines"}, {Name: "TickLength", Doc: "length of tick lines"}, {Name: "Ticker", Doc: "Ticker generates the tick marks. Any tick marks\nreturned by the Marker function that are not in\nrange of the axis are not drawn."}, {Name: "Scale", Doc: "Scale transforms a value given in the data coordinate system\nto the normalized coordinate system of the axis—its distance\nalong the axis as a fraction of the axis range."}, {Name: "AutoRescale", Doc: "AutoRescale enables an axis to automatically adapt its minimum\nand maximum boundaries, according to its underlying Ticker."}, {Name: "ticks", Doc: "cached list of ticks, set in size"}}}) +// SetTickLength sets the [AxisStyle.TickLength]: +// TickLength is the length of tick lines. +func (t *AxisStyle) SetTickLength(v units.Value) *AxisStyle { t.TickLength = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Axis", IDName: "axis", Doc: "Axis represents either a horizontal or vertical\naxis of a plot.", Fields: []types.Field{{Name: "Range", Doc: "Range has the Min, Max range of values for the axis (in raw data units.)"}, {Name: "Axis", Doc: "specifies which axis this is: X, Y or Z."}, {Name: "Label", Doc: "Label for the axis."}, {Name: "Style", Doc: "Style has the style parameters for the Axis."}, {Name: "TickText", Doc: "TickText is used for rendering the tick text labels."}, {Name: "Ticker", Doc: "Ticker generates the tick marks. Any tick marks\nreturned by the Marker function that are not in\nrange of the axis are not drawn."}, {Name: "Scale", Doc: "Scale transforms a value given in the data coordinate system\nto the normalized coordinate system of the axis—its distance\nalong the axis as a fraction of the axis range."}, {Name: "AutoRescale", Doc: "AutoRescale enables an axis to automatically adapt its minimum\nand maximum boundaries, according to its underlying Ticker."}, {Name: "ticks", Doc: "cached list of ticks, set in size"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Normalizer", IDName: "normalizer", Doc: "Normalizer rescales values from the data coordinate system to the\nnormalized coordinate system.", Methods: []types.Method{{Name: "Normalize", Doc: "Normalize transforms a value x in the data coordinate system to\nthe normalized coordinate system.", Args: []string{"min", "max", "x"}, Returns: []string{"float64"}}}}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LinearScale", IDName: "linear-scale", Doc: "LinearScale an be used as the value of an Axis.Scale function to\nset the axis to a standard linear scale."}) @@ -16,55 +61,362 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LogScale", IDN var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.InvertedScale", IDName: "inverted-scale", Doc: "InvertedScale can be used as the value of an Axis.Scale function to\ninvert the axis using any Normalizer.", Embeds: []types.Field{{Name: "Normalizer"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Valuer", IDName: "valuer", Doc: "Valuer provides an interface for a list of scalar values", Methods: []types.Method{{Name: "Len", Doc: "Len returns the number of values.", Returns: []string{"int"}}, {Name: "Value", Doc: "Value returns a value.", Args: []string{"i"}, Returns: []string{"float32"}}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Data", IDName: "data", Doc: "Data is a map of Roles and Data for that Role, providing the\nprimary way of passing data to a Plotter"}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Values", IDName: "values", Doc: "Values implements the Valuer interface."}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Valuer", IDName: "valuer", Doc: "Valuer is the data interface for plotting, supporting either\nfloat64 or string representations. It is satisfied by the tensor.Tensor\ninterface, so a tensor can be used directly for plot Data.", Methods: []types.Method{{Name: "Len", Doc: "Len returns the number of values.", Returns: []string{"int"}}, {Name: "Float1D", Doc: "Float1D(i int) returns float64 value at given index.", Args: []string{"i"}, Returns: []string{"float64"}}, {Name: "String1D", Doc: "String1D(i int) returns string value at given index.", Args: []string{"i"}, Returns: []string{"string"}}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XYer", IDName: "x-yer", Doc: "XYer provides an interface for a list of X,Y data pairs", Methods: []types.Method{{Name: "Len", Doc: "Len returns the number of x, y pairs.", Returns: []string{"int"}}, {Name: "XY", Doc: "XY returns an x, y pair.", Args: []string{"i"}, Returns: []string{"x", "y"}}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Roles", IDName: "roles", Doc: "Roles are the roles that a given set of data values can play,\ndesigned to be sufficiently generalizable across all different\ntypes of plots, even if sometimes it is a bit of a stretch."}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XYs", IDName: "x-ys", Doc: "XYs implements the XYer interface."}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Values", IDName: "values", Doc: "Values provides a minimal implementation of the Data interface\nusing a slice of float64."}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XValues", IDName: "x-values", Doc: "XValues implements the Valuer interface,\nreturning the x value from an XYer.", Embeds: []types.Field{{Name: "XYer"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Labels", IDName: "labels", Doc: "Labels provides a minimal implementation of the Data interface\nusing a slice of string. It always returns 0 for Float1D."}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.YValues", IDName: "y-values", Doc: "YValues implements the Valuer interface,\nreturning the y value from an XYer.", Embeds: []types.Field{{Name: "XYer"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.selection", IDName: "selection", Fields: []types.Field{{Name: "n", Doc: "n is the number of labels selected."}, {Name: "lMin", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lMax", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lStep", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lq", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "score", Doc: "score is the score for the selection."}, {Name: "magnitude", Doc: "magnitude is the magnitude of the\nlabel step distance."}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XYZer", IDName: "xy-zer", Doc: "XYZer provides an interface for a list of X,Y,Z data triples.\nIt also satisfies the XYer interface for the X,Y pairs.", Methods: []types.Method{{Name: "Len", Doc: "Len returns the number of x, y, z triples.", Returns: []string{"int"}}, {Name: "XYZ", Doc: "XYZ returns an x, y, z triple.", Args: []string{"i"}, Returns: []string{"float32", "float32", "float32"}}, {Name: "XY", Doc: "XY returns an x, y pair.", Args: []string{"i"}, Returns: []string{"float32", "float32"}}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.weights", IDName: "weights", Doc: "weights is a helper type to calcuate the labelling scheme's total score.", Fields: []types.Field{{Name: "simplicity"}, {Name: "coverage"}, {Name: "density"}, {Name: "legibility"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XYZs", IDName: "xy-zs", Doc: "XYZs implements the XYZer interface using a slice."}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LegendStyle", IDName: "legend-style", Doc: "LegendStyle has the styling properties for the Legend.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Column", Doc: "Column is for table-based plotting, specifying the column with legend values."}, {Name: "Text", Doc: "Text is the style given to the legend entry texts."}, {Name: "Position", Doc: "position of the legend"}, {Name: "ThumbnailWidth", Doc: "ThumbnailWidth is the width of legend thumbnails."}, {Name: "Fill", Doc: "Fill specifies the background fill color for the legend box,\nif non-nil."}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XYZ", IDName: "xyz", Doc: "XYZ is an x, y and z value.", Fields: []types.Field{{Name: "X"}, {Name: "Y"}, {Name: "Z"}}}) +// SetColumn sets the [LegendStyle.Column]: +// Column is for table-based plotting, specifying the column with legend values. +func (t *LegendStyle) SetColumn(v string) *LegendStyle { t.Column = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XYValues", IDName: "xy-values", Doc: "XYValues implements the XYer interface, returning\nthe x and y values from an XYZer.", Embeds: []types.Field{{Name: "XYZer"}}}) +// SetText sets the [LegendStyle.Text]: +// Text is the style given to the legend entry texts. +func (t *LegendStyle) SetText(v TextStyle) *LegendStyle { t.Text = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Labeler", IDName: "labeler", Doc: "Labeler provides an interface for a list of string labels", Methods: []types.Method{{Name: "Label", Doc: "Label returns a label.", Args: []string{"i"}, Returns: []string{"string"}}}}) +// SetPosition sets the [LegendStyle.Position]: +// position of the legend +func (t *LegendStyle) SetPosition(v LegendPosition) *LegendStyle { t.Position = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.selection", IDName: "selection", Fields: []types.Field{{Name: "n", Doc: "n is the number of labels selected."}, {Name: "lMin", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lMax", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lStep", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lq", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "score", Doc: "score is the score for the selection."}, {Name: "magnitude", Doc: "magnitude is the magnitude of the\nlabel step distance."}}}) +// SetThumbnailWidth sets the [LegendStyle.ThumbnailWidth]: +// ThumbnailWidth is the width of legend thumbnails. +func (t *LegendStyle) SetThumbnailWidth(v units.Value) *LegendStyle { t.ThumbnailWidth = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.weights", IDName: "weights", Doc: "weights is a helper type to calcuate the labelling scheme's total score.", Fields: []types.Field{{Name: "simplicity"}, {Name: "coverage"}, {Name: "density"}, {Name: "legibility"}}}) +// SetFill sets the [LegendStyle.Fill]: +// Fill specifies the background fill color for the legend box, +// if non-nil. +func (t *LegendStyle) SetFill(v image.Image) *LegendStyle { t.Fill = v; return t } var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LegendPosition", IDName: "legend-position", Doc: "LegendPosition specifies where to put the legend", Fields: []types.Field{{Name: "Top", Doc: "Top and Left specify the location of the legend."}, {Name: "Left", Doc: "Top and Left specify the location of the legend."}, {Name: "XOffs", Doc: "XOffs and YOffs are added to the legend's final position,\nrelative to the relevant anchor position"}, {Name: "YOffs", Doc: "XOffs and YOffs are added to the legend's final position,\nrelative to the relevant anchor position"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Legend", IDName: "legend", Doc: "A Legend gives a description of the meaning of different\ndata elements of the plot. Each legend entry has a name\nand a thumbnail, where the thumbnail shows a small\nsample of the display style of the corresponding data.", Fields: []types.Field{{Name: "TextStyle", Doc: "TextStyle is the style given to the legend entry texts."}, {Name: "Position", Doc: "position of the legend"}, {Name: "ThumbnailWidth", Doc: "ThumbnailWidth is the width of legend thumbnails."}, {Name: "Fill", Doc: "Fill specifies the background fill color for the legend box,\nif non-nil."}, {Name: "Entries", Doc: "Entries are all of the LegendEntries described by this legend."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Legend", IDName: "legend", Doc: "A Legend gives a description of the meaning of different\ndata elements of the plot. Each legend entry has a name\nand a thumbnail, where the thumbnail shows a small\nsample of the display style of the corresponding data.", Fields: []types.Field{{Name: "Style", Doc: "Style has the legend styling parameters."}, {Name: "Entries", Doc: "Entries are all of the LegendEntries described by this legend."}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Thumbnailer", IDName: "thumbnailer", Doc: "Thumbnailer wraps the Thumbnail method, which\ndraws the small image in a legend representing the\nstyle of data.", Methods: []types.Method{{Name: "Thumbnail", Doc: "Thumbnail draws an thumbnail representing\na legend entry. The thumbnail will usually show\na smaller representation of the style used\nto plot the corresponding data.", Args: []string{"pt"}}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Thumbnailer", IDName: "thumbnailer", Doc: "Thumbnailer wraps the Thumbnail method, which draws the small\nimage in a legend representing the style of data.", Methods: []types.Method{{Name: "Thumbnail", Doc: "Thumbnail draws an thumbnail representing a legend entry.\nThe thumbnail will usually show a smaller representation\nof the style used to plot the corresponding data.", Args: []string{"pt"}}}}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LegendEntry", IDName: "legend-entry", Doc: "A LegendEntry represents a single line of a legend, it\nhas a name and an icon.", Fields: []types.Field{{Name: "Text", Doc: "text is the text associated with this entry."}, {Name: "Thumbs", Doc: "thumbs is a slice of all of the thumbnails styles"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LineStyle", IDName: "line-style", Doc: "LineStyle has style properties for line drawing", Fields: []types.Field{{Name: "Color", Doc: "stroke color image specification; stroking is off if nil"}, {Name: "Width", Doc: "line width"}, {Name: "Dashes", Doc: "Dashes are the dashes of the stroke. Each pair of values specifies\nthe amount to paint and then the amount to skip."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LineStyle", IDName: "line-style", Doc: "LineStyle has style properties for drawing lines.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On indicates whether to plot lines."}, {Name: "Color", Doc: "Color is the stroke color image specification.\nSetting to nil turns line off."}, {Name: "Width", Doc: "Width is the line width, with a default of 1 Pt (point).\nSetting to 0 turns line off."}, {Name: "Dashes", Doc: "Dashes are the dashes of the stroke. Each pair of values specifies\nthe amount to paint and then the amount to skip."}, {Name: "Fill", Doc: "Fill is the color to fill solid regions, in a plot-specific\nway (e.g., the area below a Line plot, the bar color).\nUse nil to disable filling."}, {Name: "NegativeX", Doc: "NegativeX specifies whether to draw lines that connect points with a negative\nX-axis direction; otherwise there is a break in the line.\ndefault is false, so that repeated series of data across the X axis\nare plotted separately."}, {Name: "Step", Doc: "Step specifies how to step the line between points."}}}) + +// SetOn sets the [LineStyle.On]: +// On indicates whether to plot lines. +func (t *LineStyle) SetOn(v DefaultOffOn) *LineStyle { t.On = v; return t } + +// SetColor sets the [LineStyle.Color]: +// Color is the stroke color image specification. +// Setting to nil turns line off. +func (t *LineStyle) SetColor(v image.Image) *LineStyle { t.Color = v; return t } + +// SetWidth sets the [LineStyle.Width]: +// Width is the line width, with a default of 1 Pt (point). +// Setting to 0 turns line off. +func (t *LineStyle) SetWidth(v units.Value) *LineStyle { t.Width = v; return t } + +// SetDashes sets the [LineStyle.Dashes]: +// Dashes are the dashes of the stroke. Each pair of values specifies +// the amount to paint and then the amount to skip. +func (t *LineStyle) SetDashes(v ...float32) *LineStyle { t.Dashes = v; return t } + +// SetFill sets the [LineStyle.Fill]: +// Fill is the color to fill solid regions, in a plot-specific +// way (e.g., the area below a Line plot, the bar color). +// Use nil to disable filling. +func (t *LineStyle) SetFill(v image.Image) *LineStyle { t.Fill = v; return t } + +// SetNegativeX sets the [LineStyle.NegativeX]: +// NegativeX specifies whether to draw lines that connect points with a negative +// X-axis direction; otherwise there is a break in the line. +// default is false, so that repeated series of data across the X axis +// are plotted separately. +func (t *LineStyle) SetNegativeX(v bool) *LineStyle { t.NegativeX = v; return t } + +// SetStep sets the [LineStyle.Step]: +// Step specifies how to step the line between points. +func (t *LineStyle) SetStep(v StepKind) *LineStyle { t.Step = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.StepKind", IDName: "step-kind", Doc: "StepKind specifies a form of a connection of two consecutive points."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.XAxisStyle", IDName: "x-axis-style", Doc: "XAxisStyle has overall plot level styling properties for the XAxis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Column", Doc: "Column specifies the column to use for the common X axis,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nIf empty, standard Group-based role binding is used: the last column\nwithin the same group with Role=X is used."}, {Name: "Rotation", Doc: "Rotation is the rotation of the X Axis labels, in degrees."}, {Name: "Label", Doc: "Label is the optional label to use for the XAxis instead of the default."}, {Name: "Range", Doc: "Range is the effective range of XAxis data to plot, where either end can be fixed."}, {Name: "Scale", Doc: "Scale specifies how values are scaled along the X axis:\nLinear, Log, Inverted"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Plot", IDName: "plot", Doc: "Plot is the basic type representing a plot.\nIt renders into its own image.RGBA Pixels image,\nand can also save a corresponding SVG version.\nThe Axis ranges are updated automatically when plots\nare added, so setting a fixed range should happen\nafter that point. See [UpdateRange] method as well.", Fields: []types.Field{{Name: "Title", Doc: "Title of the plot"}, {Name: "Background", Doc: "Background is the background of the plot.\nThe default is [colors.Scheme.Surface]."}, {Name: "StandardTextStyle", Doc: "standard text style with default options"}, {Name: "X", Doc: "X and Y are the horizontal and vertical axes\nof the plot respectively."}, {Name: "Y", Doc: "X and Y are the horizontal and vertical axes\nof the plot respectively."}, {Name: "Legend", Doc: "Legend is the plot's legend."}, {Name: "Plotters", Doc: "plotters are drawn by calling their Plot method\nafter the axes are drawn."}, {Name: "Size", Doc: "size is the target size of the image to render to"}, {Name: "DPI", Doc: "DPI is the dots per inch for rendering the image.\nLarger numbers result in larger scaling of the plot contents\nwhich is strongly recommended for print (e.g., use 300 for print)"}, {Name: "Paint", Doc: "painter for rendering"}, {Name: "Pixels", Doc: "pixels that we render into"}, {Name: "PlotBox", Doc: "Current plot bounding box in image coordinates, for plotting coordinates"}}}) +// SetColumn sets the [XAxisStyle.Column]: +// Column specifies the column to use for the common X axis, +// for [plot.NewTablePlot] [table.Table] driven plots. +// If empty, standard Group-based role binding is used: the last column +// within the same group with Role=X is used. +func (t *XAxisStyle) SetColumn(v string) *XAxisStyle { t.Column = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Plotter", IDName: "plotter", Doc: "Plotter is an interface that wraps the Plot method.\nSome standard implementations of Plotter can be found in plotters.", Methods: []types.Method{{Name: "Plot", Doc: "Plot draws the data to the Plot Paint", Args: []string{"pt"}}, {Name: "XYData", Doc: "returns the data for this plot as X,Y points,\nincluding corresponding pixel data.\nThis allows gui interface to inspect data etc.", Returns: []string{"data", "pixels"}}}}) +// SetRotation sets the [XAxisStyle.Rotation]: +// Rotation is the rotation of the X Axis labels, in degrees. +func (t *XAxisStyle) SetRotation(v float32) *XAxisStyle { t.Rotation = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.DataRanger", IDName: "data-ranger", Doc: "DataRanger wraps the DataRange method.", Methods: []types.Method{{Name: "DataRange", Doc: "DataRange returns the range of X and Y values.", Returns: []string{"xmin", "xmax", "ymin", "ymax"}}}}) +// SetLabel sets the [XAxisStyle.Label]: +// Label is the optional label to use for the XAxis instead of the default. +func (t *XAxisStyle) SetLabel(v string) *XAxisStyle { t.Label = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.TextStyle", IDName: "text-style", Doc: "TextStyle specifies styling parameters for Text elements", Embeds: []types.Field{{Name: "FontRender"}}, Fields: []types.Field{{Name: "Align", Doc: "how to align text along the relevant dimension for the text element"}, {Name: "Padding", Doc: "Padding is used in a case-dependent manner to add space around text elements"}, {Name: "Rotation", Doc: "rotation of the text, in Degrees"}}}) +// SetRange sets the [XAxisStyle.Range]: +// Range is the effective range of XAxis data to plot, where either end can be fixed. +func (t *XAxisStyle) SetRange(v minmax.Range64) *XAxisStyle { t.Range = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Text", IDName: "text", Doc: "Text specifies a single text element in a plot", Fields: []types.Field{{Name: "Text", Doc: "text string, which can use HTML formatting"}, {Name: "Style", Doc: "styling for this text element"}, {Name: "PaintText", Doc: "PaintText is the [paint.Text] for the text."}}}) +// SetScale sets the [XAxisStyle.Scale]: +// Scale specifies how values are scaled along the X axis: +// Linear, Log, Inverted +func (t *XAxisStyle) SetScale(v AxisScales) *XAxisStyle { t.Scale = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.PlotStyle", IDName: "plot-style", Doc: "PlotStyle has overall plot level styling properties.\nSome properties provide defaults for individual elements, which can\nthen be overwritten by element-level properties.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Title", Doc: "Title is the overall title of the plot."}, {Name: "TitleStyle", Doc: "TitleStyle is the text styling parameters for the title."}, {Name: "Background", Doc: "Background is the background of the plot.\nThe default is [colors.Scheme.Surface]."}, {Name: "Scale", Doc: "Scale multiplies the plot DPI value, to change the overall scale\nof the rendered plot. Larger numbers produce larger scaling.\nTypically use larger numbers when generating plots for inclusion in\ndocuments or other cases where the overall plot size will be small."}, {Name: "Legend", Doc: "Legend has the styling properties for the Legend."}, {Name: "Axis", Doc: "Axis has the styling properties for the Axes."}, {Name: "XAxis", Doc: "XAxis has plot-level XAxis style properties."}, {Name: "YAxisLabel", Doc: "YAxisLabel is the optional label to use for the YAxis instead of the default."}, {Name: "LinesOn", Doc: "LinesOn determines whether lines are plotted by default,\nfor elements that plot lines (e.g., plots.XY)."}, {Name: "LineWidth", Doc: "LineWidth sets the default line width for data plotting lines."}, {Name: "PointsOn", Doc: "PointsOn determines whether points are plotted by default,\nfor elements that plot points (e.g., plots.XY)."}, {Name: "PointSize", Doc: "PointSize sets the default point size."}, {Name: "LabelSize", Doc: "LabelSize sets the default label text size."}, {Name: "BarWidth", Doc: "BarWidth for Bar plot sets the default width of the bars,\nwhich should be less than the Stride (1 typically) to prevent\nbar overlap. Defaults to .8."}}}) + +// SetTitle sets the [PlotStyle.Title]: +// Title is the overall title of the plot. +func (t *PlotStyle) SetTitle(v string) *PlotStyle { t.Title = v; return t } + +// SetTitleStyle sets the [PlotStyle.TitleStyle]: +// TitleStyle is the text styling parameters for the title. +func (t *PlotStyle) SetTitleStyle(v TextStyle) *PlotStyle { t.TitleStyle = v; return t } + +// SetBackground sets the [PlotStyle.Background]: +// Background is the background of the plot. +// The default is [colors.Scheme.Surface]. +func (t *PlotStyle) SetBackground(v image.Image) *PlotStyle { t.Background = v; return t } + +// SetScale sets the [PlotStyle.Scale]: +// Scale multiplies the plot DPI value, to change the overall scale +// of the rendered plot. Larger numbers produce larger scaling. +// Typically use larger numbers when generating plots for inclusion in +// documents or other cases where the overall plot size will be small. +func (t *PlotStyle) SetScale(v float32) *PlotStyle { t.Scale = v; return t } + +// SetLegend sets the [PlotStyle.Legend]: +// Legend has the styling properties for the Legend. +func (t *PlotStyle) SetLegend(v LegendStyle) *PlotStyle { t.Legend = v; return t } + +// SetAxis sets the [PlotStyle.Axis]: +// Axis has the styling properties for the Axes. +func (t *PlotStyle) SetAxis(v AxisStyle) *PlotStyle { t.Axis = v; return t } + +// SetXAxis sets the [PlotStyle.XAxis]: +// XAxis has plot-level XAxis style properties. +func (t *PlotStyle) SetXAxis(v XAxisStyle) *PlotStyle { t.XAxis = v; return t } + +// SetYAxisLabel sets the [PlotStyle.YAxisLabel]: +// YAxisLabel is the optional label to use for the YAxis instead of the default. +func (t *PlotStyle) SetYAxisLabel(v string) *PlotStyle { t.YAxisLabel = v; return t } + +// SetLinesOn sets the [PlotStyle.LinesOn]: +// LinesOn determines whether lines are plotted by default, +// for elements that plot lines (e.g., plots.XY). +func (t *PlotStyle) SetLinesOn(v DefaultOffOn) *PlotStyle { t.LinesOn = v; return t } + +// SetLineWidth sets the [PlotStyle.LineWidth]: +// LineWidth sets the default line width for data plotting lines. +func (t *PlotStyle) SetLineWidth(v units.Value) *PlotStyle { t.LineWidth = v; return t } + +// SetPointsOn sets the [PlotStyle.PointsOn]: +// PointsOn determines whether points are plotted by default, +// for elements that plot points (e.g., plots.XY). +func (t *PlotStyle) SetPointsOn(v DefaultOffOn) *PlotStyle { t.PointsOn = v; return t } + +// SetPointSize sets the [PlotStyle.PointSize]: +// PointSize sets the default point size. +func (t *PlotStyle) SetPointSize(v units.Value) *PlotStyle { t.PointSize = v; return t } + +// SetLabelSize sets the [PlotStyle.LabelSize]: +// LabelSize sets the default label text size. +func (t *PlotStyle) SetLabelSize(v units.Value) *PlotStyle { t.LabelSize = v; return t } + +// SetBarWidth sets the [PlotStyle.BarWidth]: +// BarWidth for Bar plot sets the default width of the bars, +// which should be less than the Stride (1 typically) to prevent +// bar overlap. Defaults to .8. +func (t *PlotStyle) SetBarWidth(v float64) *PlotStyle { t.BarWidth = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.PanZoom", IDName: "pan-zoom", Doc: "PanZoom provides post-styling pan and zoom range manipulation.", Fields: []types.Field{{Name: "XOffset", Doc: "XOffset adds offset to X range (pan)."}, {Name: "XScale", Doc: "XScale multiplies X range (zoom)."}, {Name: "YOffset", Doc: "YOffset adds offset to Y range (pan)."}, {Name: "YScale", Doc: "YScale multiplies Y range (zoom)."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Plot", IDName: "plot", Doc: "Plot is the basic type representing a plot.\nIt renders into its own image.RGBA Pixels image,\nand can also save a corresponding SVG version.", Fields: []types.Field{{Name: "Title", Doc: "Title of the plot"}, {Name: "Style", Doc: "Style has the styling properties for the plot."}, {Name: "StandardTextStyle", Doc: "standard text style with default options"}, {Name: "X", Doc: "X, Y, and Z are the horizontal, vertical, and depth axes\nof the plot respectively."}, {Name: "Y", Doc: "X, Y, and Z are the horizontal, vertical, and depth axes\nof the plot respectively."}, {Name: "Z", Doc: "X, Y, and Z are the horizontal, vertical, and depth axes\nof the plot respectively."}, {Name: "Legend", Doc: "Legend is the plot's legend."}, {Name: "Plotters", Doc: "Plotters are drawn by calling their Plot method after the axes are drawn."}, {Name: "Size", Doc: "Size is the target size of the image to render to."}, {Name: "DPI", Doc: "DPI is the dots per inch for rendering the image.\nLarger numbers result in larger scaling of the plot contents\nwhich is strongly recommended for print (e.g., use 300 for print)"}, {Name: "PanZoom", Doc: "PanZoom provides post-styling pan and zoom range factors."}, {Name: "HighlightPlotter", Doc: "\tHighlightPlotter is the Plotter to highlight. Used for mouse hovering for example.\nIt is the responsibility of the Plotter Plot function to implement highlighting."}, {Name: "HighlightIndex", Doc: "HighlightIndex is the index of the data point to highlight, for HighlightPlotter."}, {Name: "Pixels", Doc: "pixels that we render into"}, {Name: "Paint", Doc: "Paint is the painter for rendering"}, {Name: "PlotBox", Doc: "Current plot bounding box in image coordinates, for plotting coordinates"}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Plotter", IDName: "plotter", Doc: "Plotter is an interface that wraps the Plot method.\nStandard implementations of Plotter are in the [plots] package.", Methods: []types.Method{{Name: "Plot", Doc: "Plot draws the data to the Plot Paint.", Args: []string{"pt"}}, {Name: "UpdateRange", Doc: "UpdateRange updates the given ranges.", Args: []string{"plt", "xr", "yr", "zr"}}, {Name: "Data", Doc: "Data returns the data by roles for this plot, for both the original\ndata and the pixel-transformed X,Y coordinates for that data.\nThis allows a GUI interface to inspect data etc.", Returns: []string{"data", "pixX", "pixY"}}, {Name: "Stylers", Doc: "Stylers returns the styler functions for this element.", Returns: []string{"Stylers"}}, {Name: "ApplyStyle", Doc: "ApplyStyle applies any stylers to this element,\nfirst initializing from the given global plot style, which has\nalready been styled with defaults and all the plot element stylers.", Args: []string{"plotStyle"}}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.PlotterType", IDName: "plotter-type", Doc: "PlotterType registers a Plotter so that it can be created with appropriate data.", Fields: []types.Field{{Name: "Name", Doc: "Name of the plot type."}, {Name: "Doc", Doc: "Doc is the documentation for this Plotter."}, {Name: "Required", Doc: "Required Data roles for this plot. Data for these Roles must be provided."}, {Name: "Optional", Doc: "Optional Data roles for this plot."}, {Name: "New", Doc: "New returns a new plotter of this type with given data in given roles."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.PlotterName", IDName: "plotter-name", Doc: "PlotterName is the name of a specific plotter type."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.PointStyle", IDName: "point-style", Doc: "PointStyle has style properties for drawing points as different shapes.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On indicates whether to plot points."}, {Name: "Shape", Doc: "Shape to draw."}, {Name: "Color", Doc: "Color is the stroke color image specification.\nSetting to nil turns line off."}, {Name: "Fill", Doc: "Fill is the color to fill solid regions, in a plot-specific\nway (e.g., the area below a Line plot, the bar color).\nUse nil to disable filling."}, {Name: "Width", Doc: "Width is the line width for point glyphs, with a default of 1 Pt (point).\nSetting to 0 turns line off."}, {Name: "Size", Doc: "Size of shape to draw for each point.\nDefaults to 4 Pt (point)."}}}) + +// SetOn sets the [PointStyle.On]: +// On indicates whether to plot points. +func (t *PointStyle) SetOn(v DefaultOffOn) *PointStyle { t.On = v; return t } + +// SetShape sets the [PointStyle.Shape]: +// Shape to draw. +func (t *PointStyle) SetShape(v Shapes) *PointStyle { t.Shape = v; return t } + +// SetColor sets the [PointStyle.Color]: +// Color is the stroke color image specification. +// Setting to nil turns line off. +func (t *PointStyle) SetColor(v image.Image) *PointStyle { t.Color = v; return t } + +// SetFill sets the [PointStyle.Fill]: +// Fill is the color to fill solid regions, in a plot-specific +// way (e.g., the area below a Line plot, the bar color). +// Use nil to disable filling. +func (t *PointStyle) SetFill(v image.Image) *PointStyle { t.Fill = v; return t } + +// SetWidth sets the [PointStyle.Width]: +// Width is the line width for point glyphs, with a default of 1 Pt (point). +// Setting to 0 turns line off. +func (t *PointStyle) SetWidth(v units.Value) *PointStyle { t.Width = v; return t } + +// SetSize sets the [PointStyle.Size]: +// Size of shape to draw for each point. +// Defaults to 4 Pt (point). +func (t *PointStyle) SetSize(v units.Value) *PointStyle { t.Size = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Shapes", IDName: "shapes", Doc: "Shapes has the options for how to draw points in the plot."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Style", IDName: "style", Doc: "Style contains the plot styling properties relevant across\nmost plot types. These properties apply to individual plot elements\nwhile the Plot properties applies to the overall plot itself.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Plot", Doc: "\tPlot has overall plot-level properties, which can be set by any\nplot element, and are updated first, before applying element-wise styles."}, {Name: "On", Doc: "On specifies whether to plot this item, for table-based plots."}, {Name: "Plotter", Doc: "Plotter is the type of plotter to use in plotting this data,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nBlank means use default ([plots.XY] is overall default)."}, {Name: "Role", Doc: "Role specifies how a particular column of data should be used,\nfor [plot.NewTablePlot] [table.Table] driven plots."}, {Name: "Group", Doc: "Group specifies a group of related data items,\nfor [plot.NewTablePlot] [table.Table] driven plots,\nwhere different columns of data within the same Group play different Roles."}, {Name: "Range", Doc: "Range is the effective range of data to plot, where either end can be fixed."}, {Name: "Label", Doc: "Label provides an alternative label to use for axis, if set."}, {Name: "NoLegend", Doc: "NoLegend excludes this item from the legend when it otherwise would be included,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nRole = Y values are included in the Legend by default."}, {Name: "NTicks", Doc: "NTicks sets the desired number of ticks for the axis, if > 0."}, {Name: "Line", Doc: "Line has style properties for drawing lines."}, {Name: "Point", Doc: "Point has style properties for drawing points."}, {Name: "Text", Doc: "Text has style properties for rendering text."}, {Name: "Width", Doc: "Width has various plot width properties."}}}) + +// SetPlot sets the [Style.Plot]: +// +// Plot has overall plot-level properties, which can be set by any +// +// plot element, and are updated first, before applying element-wise styles. +func (t *Style) SetPlot(v PlotStyle) *Style { t.Plot = v; return t } + +// SetOn sets the [Style.On]: +// On specifies whether to plot this item, for table-based plots. +func (t *Style) SetOn(v bool) *Style { t.On = v; return t } + +// SetPlotter sets the [Style.Plotter]: +// Plotter is the type of plotter to use in plotting this data, +// for [plot.NewTablePlot] [table.Table] driven plots. +// Blank means use default ([plots.XY] is overall default). +func (t *Style) SetPlotter(v PlotterName) *Style { t.Plotter = v; return t } + +// SetRole sets the [Style.Role]: +// Role specifies how a particular column of data should be used, +// for [plot.NewTablePlot] [table.Table] driven plots. +func (t *Style) SetRole(v Roles) *Style { t.Role = v; return t } + +// SetGroup sets the [Style.Group]: +// Group specifies a group of related data items, +// for [plot.NewTablePlot] [table.Table] driven plots, +// where different columns of data within the same Group play different Roles. +func (t *Style) SetGroup(v string) *Style { t.Group = v; return t } + +// SetRange sets the [Style.Range]: +// Range is the effective range of data to plot, where either end can be fixed. +func (t *Style) SetRange(v minmax.Range64) *Style { t.Range = v; return t } + +// SetLabel sets the [Style.Label]: +// Label provides an alternative label to use for axis, if set. +func (t *Style) SetLabel(v string) *Style { t.Label = v; return t } + +// SetNoLegend sets the [Style.NoLegend]: +// NoLegend excludes this item from the legend when it otherwise would be included, +// for [plot.NewTablePlot] [table.Table] driven plots. +// Role = Y values are included in the Legend by default. +func (t *Style) SetNoLegend(v bool) *Style { t.NoLegend = v; return t } + +// SetNTicks sets the [Style.NTicks]: +// NTicks sets the desired number of ticks for the axis, if > 0. +func (t *Style) SetNTicks(v int) *Style { t.NTicks = v; return t } + +// SetLine sets the [Style.Line]: +// Line has style properties for drawing lines. +func (t *Style) SetLine(v LineStyle) *Style { t.Line = v; return t } + +// SetPoint sets the [Style.Point]: +// Point has style properties for drawing points. +func (t *Style) SetPoint(v PointStyle) *Style { t.Point = v; return t } + +// SetText sets the [Style.Text]: +// Text has style properties for rendering text. +func (t *Style) SetText(v TextStyle) *Style { t.Text = v; return t } + +// SetWidth sets the [Style.Width]: +// Width has various plot width properties. +func (t *Style) SetWidth(v WidthStyle) *Style { t.Width = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.WidthStyle", IDName: "width-style", Doc: "WidthStyle contains various plot width properties relevant across\ndifferent plot types.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Cap", Doc: "Cap is the width of the caps drawn at the top of error bars.\nThe default is 10dp"}, {Name: "Offset", Doc: "Offset for Bar plot is the offset added to each X axis value\nrelative to the Stride computed value (X = offset + index * Stride)\nDefaults to 0."}, {Name: "Stride", Doc: "Stride for Bar plot is distance between bars. Defaults to 1."}, {Name: "Width", Doc: "Width for Bar plot is the width of the bars, as a fraction of the Stride,\nto prevent bar overlap. Defaults to .8."}, {Name: "Pad", Doc: "Pad for Bar plot is additional space at start / end of data range,\nto keep bars from overflowing ends. This amount is subtracted from Offset\nand added to (len(Values)-1)*Stride -- no other accommodation for bar\nwidth is provided, so that should be built into this value as well.\nDefaults to 1."}}}) + +// SetCap sets the [WidthStyle.Cap]: +// Cap is the width of the caps drawn at the top of error bars. +// The default is 10dp +func (t *WidthStyle) SetCap(v units.Value) *WidthStyle { t.Cap = v; return t } + +// SetOffset sets the [WidthStyle.Offset]: +// Offset for Bar plot is the offset added to each X axis value +// relative to the Stride computed value (X = offset + index * Stride) +// Defaults to 0. +func (t *WidthStyle) SetOffset(v float64) *WidthStyle { t.Offset = v; return t } + +// SetStride sets the [WidthStyle.Stride]: +// Stride for Bar plot is distance between bars. Defaults to 1. +func (t *WidthStyle) SetStride(v float64) *WidthStyle { t.Stride = v; return t } + +// SetWidth sets the [WidthStyle.Width]: +// Width for Bar plot is the width of the bars, as a fraction of the Stride, +// to prevent bar overlap. Defaults to .8. +func (t *WidthStyle) SetWidth(v float64) *WidthStyle { t.Width = v; return t } + +// SetPad sets the [WidthStyle.Pad]: +// Pad for Bar plot is additional space at start / end of data range, +// to keep bars from overflowing ends. This amount is subtracted from Offset +// and added to (len(Values)-1)*Stride -- no other accommodation for bar +// width is provided, so that should be built into this value as well. +// Defaults to 1. +func (t *WidthStyle) SetPad(v float64) *WidthStyle { t.Pad = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Stylers", IDName: "stylers", Doc: "Stylers is a list of styling functions that set Style properties.\nThese are called in the order added."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.DefaultOffOn", IDName: "default-off-on", Doc: "DefaultOffOn specifies whether to use the default value for a bool option,\nor to override the default and set Off or On."}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.TextStyle", IDName: "text-style", Doc: "TextStyle specifies styling parameters for Text elements.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Size", Doc: "Size of font to render. Default is 16dp"}, {Name: "Family", Doc: "Family name for font (inherited): ordered list of comma-separated names\nfrom more general to more specific to use. Use split on, to parse."}, {Name: "Color", Doc: "Color of text."}, {Name: "Align", Doc: "Align specifies how to align text along the relevant\ndimension for the text element."}, {Name: "Padding", Doc: "Padding is used in a case-dependent manner to add\nspace around text elements."}, {Name: "Rotation", Doc: "Rotation of the text, in degrees."}, {Name: "Offset", Doc: "Offset is added directly to the final label location."}}}) + +// SetSize sets the [TextStyle.Size]: +// Size of font to render. Default is 16dp +func (t *TextStyle) SetSize(v units.Value) *TextStyle { t.Size = v; return t } + +// SetFamily sets the [TextStyle.Family]: +// Family name for font (inherited): ordered list of comma-separated names +// from more general to more specific to use. Use split on, to parse. +func (t *TextStyle) SetFamily(v string) *TextStyle { t.Family = v; return t } + +// SetColor sets the [TextStyle.Color]: +// Color of text. +func (t *TextStyle) SetColor(v image.Image) *TextStyle { t.Color = v; return t } + +// SetAlign sets the [TextStyle.Align]: +// Align specifies how to align text along the relevant +// dimension for the text element. +func (t *TextStyle) SetAlign(v styles.Aligns) *TextStyle { t.Align = v; return t } + +// SetPadding sets the [TextStyle.Padding]: +// Padding is used in a case-dependent manner to add +// space around text elements. +func (t *TextStyle) SetPadding(v units.Value) *TextStyle { t.Padding = v; return t } + +// SetRotation sets the [TextStyle.Rotation]: +// Rotation of the text, in degrees. +func (t *TextStyle) SetRotation(v float32) *TextStyle { t.Rotation = v; return t } + +// SetOffset sets the [TextStyle.Offset]: +// Offset is added directly to the final label location. +func (t *TextStyle) SetOffset(v units.XY) *TextStyle { t.Offset = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Text", IDName: "text", Doc: "Text specifies a single text element in a plot", Fields: []types.Field{{Name: "Text", Doc: "text string, which can use HTML formatting"}, {Name: "Style", Doc: "styling for this text element"}, {Name: "font", Doc: "font has the full font rendering styles."}, {Name: "PaintText", Doc: "PaintText is the [paint.Text] for the text."}}}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Tick", IDName: "tick", Doc: "A Tick is a single tick mark on an axis.", Fields: []types.Field{{Name: "Value", Doc: "Value is the data value marked by this Tick."}, {Name: "Label", Doc: "Label is the text to display at the tick mark.\nIf Label is an empty string then this is a minor tick mark."}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Ticker", IDName: "ticker", Doc: "Ticker creates Ticks in a specified range", Methods: []types.Method{{Name: "Ticks", Doc: "Ticks returns Ticks in a specified range", Args: []string{"min", "max"}, Returns: []string{"Tick"}}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.Ticker", IDName: "ticker", Doc: "Ticker creates Ticks in a specified range", Methods: []types.Method{{Name: "Ticks", Doc: "Ticks returns Ticks in a specified range, with desired number of ticks,\nwhich can be ignored depending on the ticker type.", Args: []string{"min", "max", "nticks"}, Returns: []string{"Tick"}}}}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.DefaultTicks", IDName: "default-ticks", Doc: "DefaultTicks is suitable for the Ticker field of an Axis,\nit returns a reasonable default set of tick marks."}) @@ -73,5 +425,3 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.LogTicks", IDN var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.ConstantTicks", IDName: "constant-ticks", Doc: "ConstantTicks is suitable for the Ticker field of an Axis.\nThis function returns the given set of ticks."}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.TimeTicks", IDName: "time-ticks", Doc: "TimeTicks is suitable for axes representing time values.", Fields: []types.Field{{Name: "Ticker", Doc: "Ticker is used to generate a set of ticks.\nIf nil, DefaultTicks will be used."}, {Name: "Format", Doc: "Format is the textual representation of the time value.\nIf empty, time.RFC3339 will be used"}, {Name: "Time", Doc: "Time takes a float32 value and converts it into a time.Time.\nIf nil, UTCUnixTime is used."}}}) - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/plot.TickerFunc", IDName: "ticker-func", Doc: "TickerFunc is suitable for the Ticker field of an Axis.\nIt is an adapter which allows to quickly setup a Ticker using a function with an appropriate signature."}) diff --git a/shell/README.md b/shell/README.md deleted file mode 100644 index c66364d59c..0000000000 --- a/shell/README.md +++ /dev/null @@ -1,190 +0,0 @@ -# Cogent Shell (cosh) - -Cogent Shell (cosh) is a shell that combines the best parts of Go and command-based shell languages like `bash` to provide an integrated shell experience that allows you to easily run terminal commands while using Go for complicated logic. This allows you to write concise, elegant, and readable shell code that runs quickly on all platforms, and transpiles to Go (i.e, can be compiled by `go build`). - -The simple idea is that each line is either Go or shell commands, determined in a fairly intuitive way mostly by the content at the start of the line (formal rules below), and they can be intermixed by wrapping Go within `{ }` and shell code from within backticks (\`). We henceforth refer to shell code as `exec` code (in reference to the Go & Cogent `exec` package that we use to execute programs), given the potential ambituity of the entire `cosh` language being the shell. There are different syntactic formatting rules for these two domains of Go and Exec, within cosh: - -* Go code is processed and formatted as usual (e.g., white space is irrelevant, etc). -* Exec code is space separated, like normal command-line invocations. - -Examples: - -```go -for i, f := range cosh.SplitLines(`ls -la`) { // `ls` executes returns string - echo {i} {strings.ToLower(f)} // {} surrounds go within shell -} -``` - -`shell.SplitLines` is a function that runs `strings.Split(arg, "\n")`, defined in the cosh standard library of such frequently-used helper functions. - -You can easily perform handy duration and data size formatting: - -```go -22010706 * time.Nanosecond // 22.010706ms -datasize.Size(44610930) // 42.5 MB -``` - -# Special syntax - -## Multiple statements per line - -* Multiple statements can be combined on one line, separated by `;` as in regular Go and shell languages. Critically, the language determination for the first statement determines the language for the remaining statements; you cannot intermix the two on one line, when using `;` -# Exec mode - -## Environment variables - -* `set ` (space delimited as in all exec mode, no equals) - -## Output redirction - -* Standard output redirect: `>` and `>&` (and `|`, `|&` if needed) - -## Control flow - -* Any error stops the script execution, except for statements wrapped in `[ ]`, indicating an "optional" statement, e.g.: - -```sh -cd some; [mkdir sub]; cd sub -``` - -* `&` at the end of a statement runs in the background (as in bash) -- otherwise it waits until it completes before it continues. - -* `jobs`, `fg`, `bg`, and `kill` builtin commands function as in usual bash. - -## Exec functions (aliases) - -Use the `command` keyword to define new functions for Exec mode execution, which can then be used like any other command, for example: - -```sh -command list { - ls -la args... -} -``` - -```sh -cd data -list *.tsv -``` - -The `command` is transpiled into a Go function that takes `args ...string`. In the command function body, you can use the `args...` expression to pass all of the args, or `args[1]` etc to refer to specific positional indexes, as usual. - -The command function name is registered so that the standard shell execution code can run the function, passing the args. You can also call it directly from Go code using the standard parentheses expression. - -# Script Files and Makefile-like functionality - -As with most scripting languages, a file of cosh code can be made directly executable by appending a "shebang" expression at the start of the file: - -```sh -#!/usr/bin/env cosh -``` - -When executed this way, any additional args are available via an `args []any` variable, which can be passed to a command as follows: -```go -install {args...} -``` -or by referring to specific arg indexes etc. - -To make a script behave like a standard Makefile, you can define different `command`s for each of the make commands, and then add the following at the end of the file to use the args to run commands: - -```go -shell.RunCommands(args) -``` - -See [make](cmd/cosh/testdata/make) for an example, in `cmd/cosh/testdata/make`, which can be run for example using: - -```sh -./make build -``` - -Note that there is nothing special about the name `make` here, so this can be done with any file. - -The `make` package defines a number of useful utility functions that accomplish the standard dependency and file timestamp checking functionality from the standard `make` command, as in the [magefile](https://magefile.org/dependencies/) system. Note that the cosh direct exec command syntax makes the resulting make files much closer to a standard bash-like Makefile, while still having all the benefits of Go control and expressions, compared to magefile. - -TODO: implement and document above. - -# SSH connections to remote hosts - -Any number of active SSH connections can be maintained and used dynamically within a script, including simple ways of copying data among the different hosts (including the local host). The Go mode execution is always on the local host in one running process, and only the shell commands are executed remotely, enabling a unique ability to easily coordinate and distribute processing and data across various hosts. - -Each host maintains its own working directory and environment variables, which can be configured and re-used by default whenever using a given host. - -* `cossh hostname.org [name]` establishes a connection, using given optional name to refer to this connection. If the name is not provided, a sequential number will be used, starting with 1, with 0 referring always to the local host. - -* `@name` then refers to the given host in all subsequent commands, with `@0` referring to the local host where the cosh script is running. - -### Explicit per-command specification of host - -```sh -@name cd subdir; ls -``` - -### Default host - -```sh -@name // or: -cossh @name -``` - -uses the given host for all subsequent commands (unless explicitly specified), until the default is changed. Use `cossh @0` to return to localhost. - -### Redirect input / output among hosts - -The output of a remote host command can be sent to a file on the local host: -```sh -@name cat hostfile.tsv > @0:localfile.tsv -``` -Note the use of the `:` colon delimiter after the host name here. TODO: You cannot send output to a remote host file (e.g., `> @host:remotefile.tsv`) -- maybe with sftp? - -The output of any command can also be piped to a remote host as its standard input: -```sh -ls *.tsv | @host cat > files.txt -``` - -### scp to copy files easily - -The builtin `scp` function allows easy copying of files across hosts, using the persistent connections established with `cossh` instead of creating new connections as in the standard scp command. - -`scp` is _always_ run from the local host, with the remote host filename specified as `@name:remotefile` - -```sh -scp @name:hostfile.tsv localfile.tsv -``` - -TODO: Importantly, file wildcard globbing works as expected: -```sh -scp @name:*.tsv @0:data/ -``` - -and entire directories can be copied, as in `cp -a` or `cp -r` (this behavior is automatic and does not require a flag). - -### Close connections - -```sh -cossh close -``` - -Will close all active connections and return the default host to @0. All active connections are also automatically closed when the shell terminates. - -# Other Utilties - -** need a replacement for findnm -- very powerful but garbage.. - -# Rules for Go vs. Shell determination - -The critical extension from standard Go syntax is for lines that are processed by the `Exec` functions, used for running arbitrary programs on the user's executable path. Here are the rules (word = IDENT token): - -* Backticks "``" anywhere: Exec. Returns a `string`. -* Within Exec, `{}`: Go -* Line starts with `Go` Keyword: Go -* Line is one word: Exec -* Line starts with `path`: Exec -* Line starts with `"string"`: Exec -* Line starts with `word word`: Exec -* Line starts with `word {`: Exec -* Otherwise: Go - -# TODO: - -* likewise, need to run everything effectively as a bg job with our own explicit Wait, which we can then communicate with to move from fg to bg. - - diff --git a/shell/cmd/cosh/test.cosh b/shell/cmd/cosh/test.cosh deleted file mode 100644 index dda7a55b21..0000000000 --- a/shell/cmd/cosh/test.cosh +++ /dev/null @@ -1,8 +0,0 @@ -// test file for cosh cli - -// todo: doesn't work: #1152 -echo {args} - -for i, fn := range cosh.SplitLines(`/bin/ls -1`) { - fmt.Println(i, fn) -} diff --git a/shell/cmd/cosh/typegen.go b/shell/cmd/cosh/typegen.go deleted file mode 100644 index 9dcf995c36..0000000000 --- a/shell/cmd/cosh/typegen.go +++ /dev/null @@ -1,15 +0,0 @@ -// Code generated by "core generate -add-types -add-funcs"; DO NOT EDIT. - -package main - -import ( - "cogentcore.org/core/types" -) - -var _ = types.AddType(&types.Type{Name: "main.Config", IDName: "config", Doc: "Config is the configuration information for the cosh cli.", Directives: []types.Directive{{Tool: "go", Directive: "generate", Args: []string{"core", "generate", "-add-types", "-add-funcs"}}}, Fields: []types.Field{{Name: "Input", Doc: "Input is the input file to run/compile.\nIf this is provided as the first argument,\nthen the program will exit after running,\nunless the Interactive mode is flagged."}, {Name: "Expr", Doc: "Expr is an optional expression to evaluate, which can be used\nin addition to the Input file to run, to execute commands\ndefined within that file for example, or as a command to run\nprior to starting interactive mode if no Input is specified."}, {Name: "Args", Doc: "Args is an optional list of arguments to pass in the run command.\nThese arguments will be turned into an \"args\" local variable in the shell.\nThese are automatically processed from any leftover arguments passed, so\nyou should not need to specify this flag manually."}, {Name: "Interactive", Doc: "Interactive runs the interactive command line after processing any input file.\nInteractive mode is the default mode for the run command unless an input file\nis specified."}}}) - -var _ = types.AddFunc(&types.Func{Name: "main.Run", Doc: "Run runs the specified cosh file. If no file is specified,\nit runs an interactive shell that allows the user to input cosh.", Directives: []types.Directive{{Tool: "cli", Directive: "cmd", Args: []string{"-root"}}}, Args: []string{"c"}, Returns: []string{"error"}}) - -var _ = types.AddFunc(&types.Func{Name: "main.Interactive", Doc: "Interactive runs an interactive shell that allows the user to input cosh.", Args: []string{"c", "in"}, Returns: []string{"error"}}) - -var _ = types.AddFunc(&types.Func{Name: "main.Build", Doc: "Build builds the specified input cosh file, or all .cosh files in the current\ndirectory if no input is specified, to corresponding .go file name(s).\nIf the file does not already contain a \"package\" specification, then\n\"package main; func main()...\" wrappers are added, which allows the same\ncode to be used in interactive and Go compiled modes.", Args: []string{"c"}, Returns: []string{"error"}}) diff --git a/shell/interpreter/cogentcore_org-core-base-datasize.go b/shell/interpreter/cogentcore_org-core-base-datasize.go deleted file mode 100644 index cafb889317..0000000000 --- a/shell/interpreter/cogentcore_org-core-base-datasize.go +++ /dev/null @@ -1,29 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/base/datasize'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/base/datasize" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/base/datasize/datasize"] = map[string]reflect.Value{ - // function, constant and variable definitions - "B": reflect.ValueOf(datasize.B), - "EB": reflect.ValueOf(datasize.EB), - "ErrBits": reflect.ValueOf(&datasize.ErrBits).Elem(), - "GB": reflect.ValueOf(datasize.GB), - "KB": reflect.ValueOf(datasize.KB), - "MB": reflect.ValueOf(datasize.MB), - "MustParse": reflect.ValueOf(datasize.MustParse), - "MustParseString": reflect.ValueOf(datasize.MustParseString), - "PB": reflect.ValueOf(datasize.PB), - "Parse": reflect.ValueOf(datasize.Parse), - "ParseString": reflect.ValueOf(datasize.ParseString), - "TB": reflect.ValueOf(datasize.TB), - - // type definitions - "Size": reflect.ValueOf((*datasize.Size)(nil)), - } -} diff --git a/shell/interpreter/cogentcore_org-core-base-elide.go b/shell/interpreter/cogentcore_org-core-base-elide.go deleted file mode 100644 index c2cff7e20f..0000000000 --- a/shell/interpreter/cogentcore_org-core-base-elide.go +++ /dev/null @@ -1,17 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/base/elide'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/base/elide" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/base/elide/elide"] = map[string]reflect.Value{ - // function, constant and variable definitions - "AppName": reflect.ValueOf(elide.AppName), - "End": reflect.ValueOf(elide.End), - "Middle": reflect.ValueOf(elide.Middle), - } -} diff --git a/shell/interpreter/cogentcore_org-core-base-errors.go b/shell/interpreter/cogentcore_org-core-base-errors.go deleted file mode 100644 index f595569dde..0000000000 --- a/shell/interpreter/cogentcore_org-core-base-errors.go +++ /dev/null @@ -1,25 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/base/errors'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/base/errors" - "github.com/cogentcore/yaegi/interp" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/base/errors/errors"] = map[string]reflect.Value{ - // function, constant and variable definitions - "As": reflect.ValueOf(errors.As), - "CallerInfo": reflect.ValueOf(errors.CallerInfo), - "ErrUnsupported": reflect.ValueOf(&errors.ErrUnsupported).Elem(), - "Is": reflect.ValueOf(errors.Is), - "Join": reflect.ValueOf(errors.Join), - "Log": reflect.ValueOf(errors.Log), - "Log1": reflect.ValueOf(interp.GenericFunc("func Log1[T any](v T, err error) T { //yaegi:add\n\tif err != nil {\n\t\tslog.Error(err.Error() + \" | \" + CallerInfo())\n\t}\n\treturn v\n}")), - "Must": reflect.ValueOf(errors.Must), - "New": reflect.ValueOf(errors.New), - "Unwrap": reflect.ValueOf(errors.Unwrap), - } -} diff --git a/shell/interpreter/cogentcore_org-core-base-fsx.go b/shell/interpreter/cogentcore_org-core-base-fsx.go deleted file mode 100644 index abc1208cb4..0000000000 --- a/shell/interpreter/cogentcore_org-core-base-fsx.go +++ /dev/null @@ -1,27 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/base/fsx'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/base/fsx" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/base/fsx/fsx"] = map[string]reflect.Value{ - // function, constant and variable definitions - "DirAndFile": reflect.ValueOf(fsx.DirAndFile), - "DirFS": reflect.ValueOf(fsx.DirFS), - "Dirs": reflect.ValueOf(fsx.Dirs), - "FileExists": reflect.ValueOf(fsx.FileExists), - "FileExistsFS": reflect.ValueOf(fsx.FileExistsFS), - "Filenames": reflect.ValueOf(fsx.Filenames), - "Files": reflect.ValueOf(fsx.Files), - "FindFilesOnPaths": reflect.ValueOf(fsx.FindFilesOnPaths), - "GoSrcDir": reflect.ValueOf(fsx.GoSrcDir), - "HasFile": reflect.ValueOf(fsx.HasFile), - "LatestMod": reflect.ValueOf(fsx.LatestMod), - "RelativeFilePath": reflect.ValueOf(fsx.RelativeFilePath), - "Sub": reflect.ValueOf(fsx.Sub), - } -} diff --git a/shell/interpreter/cogentcore_org-core-base-strcase.go b/shell/interpreter/cogentcore_org-core-base-strcase.go deleted file mode 100644 index 4c65387f68..0000000000 --- a/shell/interpreter/cogentcore_org-core-base-strcase.go +++ /dev/null @@ -1,54 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/base/strcase'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/base/strcase" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/base/strcase/strcase"] = map[string]reflect.Value{ - // function, constant and variable definitions - "CamelCase": reflect.ValueOf(strcase.CamelCase), - "CasesN": reflect.ValueOf(strcase.CasesN), - "CasesValues": reflect.ValueOf(strcase.CasesValues), - "FormatList": reflect.ValueOf(strcase.FormatList), - "KEBABCase": reflect.ValueOf(strcase.KEBABCase), - "KebabCase": reflect.ValueOf(strcase.KebabCase), - "LowerCamelCase": reflect.ValueOf(strcase.LowerCamelCase), - "LowerCase": reflect.ValueOf(strcase.LowerCase), - "Noop": reflect.ValueOf(strcase.Noop), - "SNAKECase": reflect.ValueOf(strcase.SNAKECase), - "SentenceCase": reflect.ValueOf(strcase.SentenceCase), - "Skip": reflect.ValueOf(strcase.Skip), - "SkipSplit": reflect.ValueOf(strcase.SkipSplit), - "SnakeCase": reflect.ValueOf(strcase.SnakeCase), - "Split": reflect.ValueOf(strcase.Split), - "TitleCase": reflect.ValueOf(strcase.TitleCase), - "To": reflect.ValueOf(strcase.To), - "ToCamel": reflect.ValueOf(strcase.ToCamel), - "ToKEBAB": reflect.ValueOf(strcase.ToKEBAB), - "ToKebab": reflect.ValueOf(strcase.ToKebab), - "ToLowerCamel": reflect.ValueOf(strcase.ToLowerCamel), - "ToSNAKE": reflect.ValueOf(strcase.ToSNAKE), - "ToSentence": reflect.ValueOf(strcase.ToSentence), - "ToSnake": reflect.ValueOf(strcase.ToSnake), - "ToTitle": reflect.ValueOf(strcase.ToTitle), - "ToWordCase": reflect.ValueOf(strcase.ToWordCase), - "UpperCase": reflect.ValueOf(strcase.UpperCase), - "WordCamelCase": reflect.ValueOf(strcase.WordCamelCase), - "WordCasesN": reflect.ValueOf(strcase.WordCasesN), - "WordCasesValues": reflect.ValueOf(strcase.WordCasesValues), - "WordLowerCase": reflect.ValueOf(strcase.WordLowerCase), - "WordOriginal": reflect.ValueOf(strcase.WordOriginal), - "WordSentenceCase": reflect.ValueOf(strcase.WordSentenceCase), - "WordTitleCase": reflect.ValueOf(strcase.WordTitleCase), - "WordUpperCase": reflect.ValueOf(strcase.WordUpperCase), - - // type definitions - "Cases": reflect.ValueOf((*strcase.Cases)(nil)), - "SplitAction": reflect.ValueOf((*strcase.SplitAction)(nil)), - "WordCases": reflect.ValueOf((*strcase.WordCases)(nil)), - } -} diff --git a/shell/interpreter/cogentcore_org-core-base-stringsx.go b/shell/interpreter/cogentcore_org-core-base-stringsx.go deleted file mode 100644 index f4f77cf0e5..0000000000 --- a/shell/interpreter/cogentcore_org-core-base-stringsx.go +++ /dev/null @@ -1,19 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/base/stringsx'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/base/stringsx" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/base/stringsx/stringsx"] = map[string]reflect.Value{ - // function, constant and variable definitions - "ByteSplitLines": reflect.ValueOf(stringsx.ByteSplitLines), - "ByteTrimCR": reflect.ValueOf(stringsx.ByteTrimCR), - "InsertFirstUnique": reflect.ValueOf(stringsx.InsertFirstUnique), - "SplitLines": reflect.ValueOf(stringsx.SplitLines), - "TrimCR": reflect.ValueOf(stringsx.TrimCR), - } -} diff --git a/shell/interpreter/cogentcore_org-core-shell-cosh.go b/shell/interpreter/cogentcore_org-core-shell-cosh.go deleted file mode 100644 index 68bfce6a75..0000000000 --- a/shell/interpreter/cogentcore_org-core-shell-cosh.go +++ /dev/null @@ -1,21 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/shell/cosh'. DO NOT EDIT. - -package interpreter - -import ( - "cogentcore.org/core/shell/cosh" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/shell/cosh/cosh"] = map[string]reflect.Value{ - // function, constant and variable definitions - "AllFiles": reflect.ValueOf(cosh.AllFiles), - "FileExists": reflect.ValueOf(cosh.FileExists), - "ReadFile": reflect.ValueOf(cosh.ReadFile), - "ReplaceInFile": reflect.ValueOf(cosh.ReplaceInFile), - "SplitLines": reflect.ValueOf(cosh.SplitLines), - "StringsToAnys": reflect.ValueOf(cosh.StringsToAnys), - "WriteFile": reflect.ValueOf(cosh.WriteFile), - } -} diff --git a/shell/interpreter/imports.go b/shell/interpreter/imports.go deleted file mode 100644 index 27c79398fd..0000000000 --- a/shell/interpreter/imports.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package interpreter - -//go:generate ./make - -import ( - "reflect" - - "github.com/cogentcore/yaegi/interp" -) - -var Symbols = map[string]map[string]reflect.Value{} - -// ImportShell imports special symbols from the shell package. -func (in *Interpreter) ImportShell() { - in.Interp.Use(interp.Exports{ - "cogentcore.org/core/shell/shell": map[string]reflect.Value{ - "Run": reflect.ValueOf(in.Shell.Run), - "RunErrOK": reflect.ValueOf(in.Shell.RunErrOK), - "Output": reflect.ValueOf(in.Shell.Output), - "OutputErrOK": reflect.ValueOf(in.Shell.OutputErrOK), - "Start": reflect.ValueOf(in.Shell.Start), - "AddCommand": reflect.ValueOf(in.Shell.AddCommand), - "RunCommands": reflect.ValueOf(in.Shell.RunCommands), - }, - }) -} diff --git a/shell/interpreter/make b/shell/interpreter/make deleted file mode 100755 index da045d6d28..0000000000 --- a/shell/interpreter/make +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env cosh -// add standard imports here; mostly base - -command base { - println("extracting base packages") - yaegi extract cogentcore.org/core/base/fsx cogentcore.org/core/base/errors cogentcore.org/core/base/strcase cogentcore.org/core/base/elide cogentcore.org/core/base/stringsx cogentcore.org/core/base/datasize -} - -command cosh { - println("extracting cosh packages") - yaegi extract cogentcore.org/core/shell/cosh -} - -// shell.RunCommands(args) -base -cosh - diff --git a/shell/shell.go b/shell/shell.go deleted file mode 100644 index 655262a17e..0000000000 --- a/shell/shell.go +++ /dev/null @@ -1,494 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package shell provides the Cogent Shell (cosh), which combines the best parts -// of Go and bash to provide an integrated shell experience that allows you to -// easily run terminal commands while using Go for complicated logic. -package shell - -import ( - "context" - "fmt" - "io/fs" - "log/slog" - "os" - "path/filepath" - "slices" - "strconv" - "strings" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/base/exec" - "cogentcore.org/core/base/logx" - "cogentcore.org/core/base/num" - "cogentcore.org/core/base/reflectx" - "cogentcore.org/core/base/sshclient" - "cogentcore.org/core/base/stack" - "cogentcore.org/core/base/stringsx" - "github.com/mitchellh/go-homedir" - "golang.org/x/tools/imports" -) - -// Shell represents one running shell context. -type Shell struct { - - // Config is the [exec.Config] used to run commands. - Config exec.Config - - // StdIOWrappers are IO wrappers sent to the interpreter, so we can - // control the IO streams used within the interpreter. - // Call SetWrappers on this with another StdIO object to update settings. - StdIOWrappers exec.StdIO - - // ssh connection, configuration - SSH *sshclient.Config - - // collection of ssh clients - SSHClients map[string]*sshclient.Client - - // SSHActive is the name of the active SSH client - SSHActive string - - // depth of delim at the end of the current line. if 0, was complete. - ParenDepth, BraceDepth, BrackDepth, TypeDepth, DeclDepth int - - // Chunks of code lines that are accumulated during Transpile, - // each of which should be evaluated separately, to avoid - // issues with contextual effects from import, package etc. - Chunks []string - - // current stack of transpiled lines, that are accumulated into - // code Chunks - Lines []string - - // stack of runtime errors - Errors []error - - // Builtins are all the builtin shell commands - Builtins map[string]func(cmdIO *exec.CmdIO, args ...string) error - - // commands that have been defined, which can be run in Exec mode. - Commands map[string]func(args ...string) - - // Jobs is a stack of commands running in the background - // (via Start instead of Run) - Jobs stack.Stack[*exec.CmdIO] - - // Cancel, while the interpreter is running, can be called - // to stop the code interpreting. - // It is connected to the Ctx context, by StartContext() - // Both can be nil. - Cancel func() - - // Ctx is the context used for cancelling current shell running - // a single chunk of code, typically from the interpreter. - // We are not able to pass the context around so it is set here, - // in the StartContext function. Clear when done with ClearContext. - Ctx context.Context - - // original standard IO setings, to restore - OrigStdIO exec.StdIO - - // Hist is the accumulated list of command-line input, - // which is displayed with the history builtin command, - // and saved / restored from ~/.coshhist file - Hist []string - - // FuncToVar translates function definitions into variable definitions, - // which is the default for interactive use of random code fragments - // without the complete go formatting. - // For pure transpiling of a complete codebase with full proper Go formatting - // this should be turned off. - FuncToVar bool - - // commandArgs is a stack of args passed to a command, used for simplified - // processing of args expressions. - commandArgs stack.Stack[[]string] - - // isCommand is a stack of bools indicating whether the _immediate_ run context - // is a command, which affects the way that args are processed. - isCommand stack.Stack[bool] - - // if this is non-empty, it is the name of the last command defined. - // triggers insertion of the AddCommand call to add to list of defined commands. - lastCommand string -} - -// NewShell returns a new [Shell] with default options. -func NewShell() *Shell { - sh := &Shell{ - Config: exec.Config{ - Dir: errors.Log1(os.Getwd()), - Env: map[string]string{}, - Buffer: false, - }, - } - sh.FuncToVar = true - sh.Config.StdIO.SetFromOS() - sh.SSH = sshclient.NewConfig(&sh.Config) - sh.SSHClients = make(map[string]*sshclient.Client) - sh.Commands = make(map[string]func(args ...string)) - sh.InstallBuiltins() - return sh -} - -// StartContext starts a processing context, -// setting the Ctx and Cancel Fields. -// Call EndContext when current operation finishes. -func (sh *Shell) StartContext() context.Context { - sh.Ctx, sh.Cancel = context.WithCancel(context.Background()) - return sh.Ctx -} - -// EndContext ends a processing context, clearing the -// Ctx and Cancel fields. -func (sh *Shell) EndContext() { - sh.Ctx = nil - sh.Cancel = nil -} - -// SaveOrigStdIO saves the current Config.StdIO as the original to revert to -// after an error, and sets the StdIOWrappers to use them. -func (sh *Shell) SaveOrigStdIO() { - sh.OrigStdIO = sh.Config.StdIO - sh.StdIOWrappers.NewWrappers(&sh.OrigStdIO) -} - -// RestoreOrigStdIO reverts to using the saved OrigStdIO -func (sh *Shell) RestoreOrigStdIO() { - sh.Config.StdIO = sh.OrigStdIO - sh.OrigStdIO.SetToOS() - sh.StdIOWrappers.SetWrappers(&sh.OrigStdIO) -} - -// Close closes any resources associated with the shell, -// including terminating any commands that are not running "nohup" -// in the background. -func (sh *Shell) Close() { - sh.CloseSSH() - // todo: kill jobs etc -} - -// CloseSSH closes all open ssh client connections -func (sh *Shell) CloseSSH() { - sh.SSHActive = "" - for _, cl := range sh.SSHClients { - cl.Close() - } - sh.SSHClients = make(map[string]*sshclient.Client) -} - -// ActiveSSH returns the active ssh client -func (sh *Shell) ActiveSSH() *sshclient.Client { - if sh.SSHActive == "" { - return nil - } - return sh.SSHClients[sh.SSHActive] -} - -// Host returns the name we're running commands on, -// which is empty if localhost (default). -func (sh *Shell) Host() string { - cl := sh.ActiveSSH() - if cl == nil { - return "" - } - return "@" + sh.SSHActive + ":" + cl.Host -} - -// HostAndDir returns the name we're running commands on, -// which is empty if localhost (default), -// and the current directory on that host. -func (sh *Shell) HostAndDir() string { - host := "" - dir := sh.Config.Dir - home := errors.Log1(homedir.Dir()) - cl := sh.ActiveSSH() - if cl != nil { - host = "@" + sh.SSHActive + ":" + cl.Host + ":" - dir = cl.Dir - home = cl.HomeDir - } - rel := errors.Log1(filepath.Rel(home, dir)) - // if it has to go back, then it is not in home dir, so no ~ - if strings.Contains(rel, "..") { - return host + dir + string(filepath.Separator) - } - return host + filepath.Join("~", rel) + string(filepath.Separator) -} - -// SSHByHost returns the SSH client for given host name, with err if not found -func (sh *Shell) SSHByHost(host string) (*sshclient.Client, error) { - if scl, ok := sh.SSHClients[host]; ok { - return scl, nil - } - return nil, fmt.Errorf("ssh connection named: %q not found", host) -} - -// TotalDepth returns the sum of any unresolved paren, brace, or bracket depths. -func (sh *Shell) TotalDepth() int { - return num.Abs(sh.ParenDepth) + num.Abs(sh.BraceDepth) + num.Abs(sh.BrackDepth) -} - -// ResetCode resets the stack of transpiled code -func (sh *Shell) ResetCode() { - sh.Chunks = nil - sh.Lines = nil -} - -// ResetDepth resets the current depths to 0 -func (sh *Shell) ResetDepth() { - sh.ParenDepth, sh.BraceDepth, sh.BrackDepth, sh.TypeDepth, sh.DeclDepth = 0, 0, 0, 0, 0 -} - -// DepthError reports an error if any of the parsing depths are not zero, -// to be called at the end of transpiling a complete block of code. -func (sh *Shell) DepthError() error { - if sh.TotalDepth() == 0 { - return nil - } - str := "" - if sh.ParenDepth != 0 { - str += fmt.Sprintf("Incomplete parentheses (), remaining depth: %d\n", sh.ParenDepth) - } - if sh.BraceDepth != 0 { - str += fmt.Sprintf("Incomplete braces [], remaining depth: %d\n", sh.BraceDepth) - } - if sh.BrackDepth != 0 { - str += fmt.Sprintf("Incomplete brackets {}, remaining depth: %d\n", sh.BrackDepth) - } - if str != "" { - slog.Error(str) - return errors.New(str) - } - return nil -} - -// AddLine adds line on the stack -func (sh *Shell) AddLine(ln string) { - sh.Lines = append(sh.Lines, ln) -} - -// Code returns the current transpiled lines, -// split into chunks that should be compiled separately. -func (sh *Shell) Code() string { - sh.AddChunk() - if len(sh.Chunks) == 0 { - return "" - } - return strings.Join(sh.Chunks, "\n") -} - -// AddChunk adds current lines into a chunk of code -// that should be compiled separately. -func (sh *Shell) AddChunk() { - if len(sh.Lines) == 0 { - return - } - sh.Chunks = append(sh.Chunks, strings.Join(sh.Lines, "\n")) - sh.Lines = nil -} - -// TranspileCode processes each line of given code, -// adding the results to the LineStack -func (sh *Shell) TranspileCode(code string) { - lns := strings.Split(code, "\n") - n := len(lns) - if n == 0 { - return - } - for _, ln := range lns { - hasDecl := sh.DeclDepth > 0 - tl := sh.TranspileLine(ln) - sh.AddLine(tl) - if sh.BraceDepth == 0 && sh.BrackDepth == 0 && sh.ParenDepth == 1 && sh.lastCommand != "" { - sh.lastCommand = "" - nl := len(sh.Lines) - sh.Lines[nl-1] = sh.Lines[nl-1] + ")" - sh.ParenDepth-- - } - if hasDecl && sh.DeclDepth == 0 { // break at decl - sh.AddChunk() - } - } -} - -// TranspileCodeFromFile transpiles the code in given file -func (sh *Shell) TranspileCodeFromFile(file string) error { - b, err := os.ReadFile(file) - if err != nil { - return err - } - sh.TranspileCode(string(b)) - return nil -} - -// TranspileFile transpiles the given input cosh file to the -// given output Go file. If no existing package declaration -// is found, then package main and func main declarations are -// added. This also affects how functions are interpreted. -func (sh *Shell) TranspileFile(in string, out string) error { - b, err := os.ReadFile(in) - if err != nil { - return err - } - code := string(b) - lns := stringsx.SplitLines(code) - hasPackage := false - for _, ln := range lns { - if strings.HasPrefix(ln, "package ") { - hasPackage = true - break - } - } - if hasPackage { - sh.FuncToVar = false // use raw functions - } - sh.TranspileCode(code) - sh.FuncToVar = true - if err != nil { - return err - } - gen := "// Code generated by \"cosh build\"; DO NOT EDIT.\n\n" - if hasPackage { - sh.Lines = slices.Insert(sh.Lines, 0, gen) - } else { - sh.Lines = slices.Insert(sh.Lines, 0, gen, "package main", "", "func main() {", "shell := shell.NewShell()") - sh.Lines = append(sh.Lines, "}") - } - src := []byte(sh.Code()) - res, err := imports.Process(out, src, nil) - if err != nil { - res = src - slog.Error(err.Error()) - } else { - err = sh.DepthError() - } - werr := os.WriteFile(out, res, 0666) - return errors.Join(err, werr) -} - -// AddError adds the given error to the error stack if it is non-nil, -// and calls the Cancel function if set, to stop execution. -// This is the main way that shell errors are handled. -// It also prints the error. -func (sh *Shell) AddError(err error) error { - if err == nil { - return nil - } - sh.Errors = append(sh.Errors, err) - logx.PrintlnError(err) - sh.CancelExecution() - return err -} - -// TranspileConfig transpiles the .cosh startup config file in the user's -// home directory if it exists. -func (sh *Shell) TranspileConfig() error { - path, err := homedir.Expand("~/.cosh") - if err != nil { - return err - } - b, err := os.ReadFile(path) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - return nil - } - return err - } - sh.TranspileCode(string(b)) - return nil -} - -// AddHistory adds given line to the Hist record of commands -func (sh *Shell) AddHistory(line string) { - sh.Hist = append(sh.Hist, line) -} - -// SaveHistory saves up to the given number of lines of current history -// to given file, e.g., ~/.coshhist for the default cosh program. -// If n is <= 0 all lines are saved. n is typically 500 by default. -func (sh *Shell) SaveHistory(n int, file string) error { - path, err := homedir.Expand(file) - if err != nil { - return err - } - hn := len(sh.Hist) - sn := hn - if n > 0 { - sn = min(n, hn) - } - lh := strings.Join(sh.Hist[hn-sn:hn], "\n") - err = os.WriteFile(path, []byte(lh), 0666) - if err != nil { - return err - } - return nil -} - -// OpenHistory opens Hist history lines from given file, -// e.g., ~/.coshhist -func (sh *Shell) OpenHistory(file string) error { - path, err := homedir.Expand(file) - if err != nil { - return err - } - b, err := os.ReadFile(path) - if err != nil { - return err - } - sh.Hist = strings.Split(string(b), "\n") - return nil -} - -// AddCommand adds given command to list of available commands -func (sh *Shell) AddCommand(name string, cmd func(args ...string)) { - sh.Commands[name] = cmd -} - -// RunCommands runs the given command(s). This is typically called -// from a Makefile-style cosh script. -func (sh *Shell) RunCommands(cmds []any) error { - for _, cmd := range cmds { - if cmdFun, hasCmd := sh.Commands[reflectx.ToString(cmd)]; hasCmd { - cmdFun() - } else { - return errors.Log(fmt.Errorf("command %q not found", cmd)) - } - } - return nil -} - -// DeleteJob deletes the given job and returns true if successful, -func (sh *Shell) DeleteJob(cmdIO *exec.CmdIO) bool { - idx := slices.Index(sh.Jobs, cmdIO) - if idx >= 0 { - sh.Jobs = slices.Delete(sh.Jobs, idx, idx+1) - return true - } - return false -} - -// JobIDExpand expands %n job id values in args with the full PID -// returns number of PIDs expanded -func (sh *Shell) JobIDExpand(args []string) int { - exp := 0 - for i, id := range args { - if id[0] == '%' { - idx, err := strconv.Atoi(id[1:]) - if err == nil { - if idx > 0 && idx <= len(sh.Jobs) { - jb := sh.Jobs[idx-1] - if jb.Cmd != nil && jb.Cmd.Process != nil { - args[i] = fmt.Sprintf("%d", jb.Cmd.Process.Pid) - exp++ - } - } else { - sh.AddError(fmt.Errorf("cosh: job number out of range: %d", idx)) - } - } - } - } - return exp -} diff --git a/shell/transpile.go b/shell/transpile.go deleted file mode 100644 index 305e832255..0000000000 --- a/shell/transpile.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package shell - -import ( - "fmt" - "go/token" - "strings" - - "cogentcore.org/core/base/logx" -) - -// TranspileLine is the main function for parsing a single line of shell input, -// returning a new transpiled line of code that converts Exec code into corresponding -// Go function calls. -func (sh *Shell) TranspileLine(ln string) string { - if len(ln) == 0 { - return ln - } - if strings.HasPrefix(ln, "#!") { - return "" - } - toks := sh.TranspileLineTokens(ln) - paren, brace, brack := toks.BracketDepths() - sh.ParenDepth += paren - sh.BraceDepth += brace - sh.BrackDepth += brack - if sh.TypeDepth > 0 && sh.BraceDepth == 0 { - sh.TypeDepth = 0 - } - if sh.DeclDepth > 0 && sh.ParenDepth == 0 { - sh.DeclDepth = 0 - } - // logx.PrintlnDebug("depths: ", sh.ParenDepth, sh.BraceDepth, sh.BrackDepth) - return toks.Code() -} - -// TranspileLineTokens returns the tokens for the full line -func (sh *Shell) TranspileLineTokens(ln string) Tokens { - if ln == "" { - return nil - } - toks := sh.Tokens(ln) - n := len(toks) - if n == 0 { - return toks - } - ewords, err := ExecWords(ln) - if err != nil { - sh.AddError(err) - return nil - } - logx.PrintlnDebug("\n########## line:\n", ln, "\nTokens:\n", toks.String(), "\nWords:\n", ewords) - - if toks[0].Tok == token.TYPE { - sh.TypeDepth++ - } - if toks[0].Tok == token.IMPORT || toks[0].Tok == token.VAR || toks[0].Tok == token.CONST { - sh.DeclDepth++ - } - - if sh.TypeDepth > 0 || sh.DeclDepth > 0 { - logx.PrintlnDebug("go: type / decl defn") - return sh.TranspileGo(toks) - } - - t0 := toks[0] - _, t0pn := toks.Path(true) // true = first position - en := len(ewords) - - f0exec := (t0.Tok == token.IDENT && ExecWordIsCommand(ewords[0])) - - switch { - case t0.Tok == token.LBRACE: - logx.PrintlnDebug("go: { } line") - return sh.TranspileGo(toks[1 : n-1]) - case t0.Tok == token.LBRACK: - logx.PrintlnDebug("exec: [ ] line") - return sh.TranspileExec(ewords, false) // it processes the [ ] - case t0.Tok == token.ILLEGAL: - logx.PrintlnDebug("exec: illegal") - return sh.TranspileExec(ewords, false) - case t0.IsBacktickString(): - logx.PrintlnDebug("exec: backquoted string") - exe := sh.TranspileExecString(t0.Str, false) - if n > 1 { // todo: is this an error? - exe.AddTokens(sh.TranspileGo(toks[1:])) - } - return exe - case t0.Tok == token.IDENT && t0.Str == "command": - sh.lastCommand = toks[1].Str // 1 is the name -- triggers AddCommand - toks = toks[2:] // get rid of first - toks.Insert(0, token.IDENT, "shell.AddCommand") - toks.Insert(1, token.LPAREN) - toks.Insert(2, token.STRING, `"`+sh.lastCommand+`"`) - toks.Insert(3, token.COMMA) - toks.Insert(4, token.FUNC) - toks.Insert(5, token.LPAREN) - toks.Insert(6, token.IDENT, "args") - toks.Insert(7, token.ELLIPSIS) - toks.Insert(8, token.IDENT, "string") - toks.Insert(9, token.RPAREN) - toks.AddTokens(sh.TranspileGo(toks[11:])) - case t0.IsGo(): - if t0.Tok == token.GO { - if !toks.Contains(token.LPAREN) { - logx.PrintlnDebug("exec: go command") - return sh.TranspileExec(ewords, false) - } - } - logx.PrintlnDebug("go keyword") - return sh.TranspileGo(toks) - case toks[n-1].Tok == token.INC: - return sh.TranspileGo(toks) - case t0pn > 0: // path expr - logx.PrintlnDebug("exec: path...") - return sh.TranspileExec(ewords, false) - case t0.Tok == token.STRING: - logx.PrintlnDebug("exec: string...") - return sh.TranspileExec(ewords, false) - case f0exec && en == 1: - logx.PrintlnDebug("exec: 1 word") - return sh.TranspileExec(ewords, false) - case !f0exec: // exec must be IDENT - logx.PrintlnDebug("go: not ident") - return sh.TranspileGo(toks) - case f0exec && en > 1 && (ewords[1][0] == '=' || ewords[1][0] == ':' || ewords[1][0] == '+' || toks[1].Tok == token.COMMA): - logx.PrintlnDebug("go: assignment or defn") - return sh.TranspileGo(toks) - case f0exec: // now any ident - logx.PrintlnDebug("exec: ident..") - return sh.TranspileExec(ewords, false) - default: - logx.PrintlnDebug("go: default") - return sh.TranspileGo(toks) - } - return toks -} - -// TranspileGo returns transpiled tokens assuming Go code. -// Unpacks any backtick encapsulated shell commands. -func (sh *Shell) TranspileGo(toks Tokens) Tokens { - n := len(toks) - if n == 0 { - return toks - } - if sh.FuncToVar && toks[0].Tok == token.FUNC { // reorder as an assignment - if len(toks) > 1 && toks[1].Tok == token.IDENT { - toks[0] = toks[1] - toks.Insert(1, token.DEFINE) - toks[2] = &Token{Tok: token.FUNC} - } - } - gtoks := make(Tokens, 0, len(toks)) // return tokens - for _, tok := range toks { - if sh.TypeDepth == 0 && tok.IsBacktickString() { - gtoks = append(gtoks, sh.TranspileExecString(tok.Str, true)...) - } else { - gtoks = append(gtoks, tok) - } - } - return gtoks -} - -// TranspileExecString returns transpiled tokens assuming Exec code, -// from a backtick-encoded string, with the given bool indicating -// whether [Output] is needed. -func (sh *Shell) TranspileExecString(str string, output bool) Tokens { - if len(str) <= 1 { - return nil - } - ewords, err := ExecWords(str[1 : len(str)-1]) // enclosed string - if err != nil { - sh.AddError(err) - } - return sh.TranspileExec(ewords, output) -} - -// TranspileExec returns transpiled tokens assuming Exec code, -// with the given bools indicating the type of run to execute. -func (sh *Shell) TranspileExec(ewords []string, output bool) Tokens { - n := len(ewords) - if n == 0 { - return nil - } - etoks := make(Tokens, 0, n+5) // return tokens - var execTok *Token - bgJob := false - noStop := false - if ewords[0] == "[" { - ewords = ewords[1:] - n-- - noStop = true - } - startExec := func() { - bgJob = false - etoks.Add(token.IDENT, "shell") - etoks.Add(token.PERIOD) - switch { - case output && noStop: - execTok = etoks.Add(token.IDENT, "OutputErrOK") - case output && !noStop: - execTok = etoks.Add(token.IDENT, "Output") - case !output && noStop: - execTok = etoks.Add(token.IDENT, "RunErrOK") - case !output && !noStop: - execTok = etoks.Add(token.IDENT, "Run") - } - etoks.Add(token.LPAREN) - } - endExec := func() { - if bgJob { - execTok.Str = "Start" - } - etoks.DeleteLastComma() - etoks.Add(token.RPAREN) - } - - startExec() - - for i := 0; i < n; i++ { - f := ewords[i] - switch { - case f == "{": // embedded go - if n < i+3 { - sh.AddError(fmt.Errorf("cosh: no matching right brace } found in exec command line")) - } else { - gstr := ewords[i+1] - etoks.AddTokens(sh.TranspileGo(sh.Tokens(gstr))) - etoks.Add(token.COMMA) - i += 2 - } - case f == "[": - noStop = true - case f == "]": // solo is def end - // just skip - noStop = false - case f == "&": - bgJob = true - case f[0] == '|': - execTok.Str = "Start" - etoks.Add(token.IDENT, AddQuotes(f)) - etoks.Add(token.COMMA) - endExec() - etoks.Add(token.SEMICOLON) - etoks.AddTokens(sh.TranspileExec(ewords[i+1:], output)) - return etoks - case f == ";": - endExec() - etoks.Add(token.SEMICOLON) - etoks.AddTokens(sh.TranspileExec(ewords[i+1:], output)) - return etoks - default: - if f[0] == '"' || f[0] == '`' { - etoks.Add(token.STRING, f) - } else { - etoks.Add(token.IDENT, AddQuotes(f)) // mark as an IDENT but add quotes! - } - etoks.Add(token.COMMA) - } - } - endExec() - return etoks -} diff --git a/shell/transpile_test.go b/shell/transpile_test.go deleted file mode 100644 index 94d07d470f..0000000000 --- a/shell/transpile_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package shell - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -type exIn struct { - i string - e string -} - -type wexIn struct { - i string - isErr bool - e []string -} - -// these are more general tests of full-line statements of various forms -func TestExecWords(t *testing.T) { - tests := []wexIn{ - {`ls`, false, []string{`ls`}}, - {`cat "be"`, false, []string{`cat`, `"be"`}}, - {`cat "be`, true, []string{`cat`, `"be`}}, - {`cat "be a thing"`, false, []string{`cat`, `"be a thing"`}}, - {`cat "{be \"a\" thing}"`, false, []string{`cat`, `"{be \"a\" thing}"`}}, - {`cat {vals[1:10]}`, false, []string{`cat`, `{`, `vals[1:10]`, `}`}}, - {`cat {myfunc(vals[1:10], "test", false)}`, false, []string{`cat`, `{`, `myfunc(vals[1:10],"test",false)`, `}`}}, - {`cat vals[1:10]`, false, []string{`cat`, `vals[1:10]`}}, - {`cat vals...`, false, []string{`cat`, `vals...`}}, - {`[cat vals...]`, false, []string{`[`, `cat`, `vals...`, `]`}}, - {`[cat vals...]; ls *.tsv`, false, []string{`[`, `cat`, `vals...`, `]`, `;`, `ls`, `*.tsv`}}, - {`cat vals... | grep -v "b"`, false, []string{`cat`, `vals...`, `|`, `grep`, `-v`, `"b"`}}, - {`cat vals...>&file.out`, false, []string{`cat`, `vals...`, `>&`, `file.out`}}, - {`cat vals...>&@0:file.out`, false, []string{`cat`, `vals...`, `>&`, `@0:file.out`}}, - {`./"Cogent Code"`, false, []string{`./"Cogent Code"`}}, - {`Cogent\ Code`, false, []string{`Cogent Code`}}, - {`./Cogent\ Code`, false, []string{`./Cogent Code`}}, - } - for _, test := range tests { - o, err := ExecWords(test.i) - assert.Equal(t, test.e, o) - if err != nil { - if !test.isErr { - t.Error("should not have been an error:", test.i) - } - } else if test.isErr { - t.Error("was supposed to be an error:", test.i) - } - } -} - -// Paths tests the Path() code -func TestPaths(t *testing.T) { - // logx.UserLevel = slog.LevelDebug - tests := []exIn{ - {`fmt.Println("hi")`, `fmt.Println`}, - {"./cosh -i", `./cosh`}, - {"main.go", `main.go`}, - {"cogent/", `cogent/`}, - {`./"Cogent Code"`, `./\"Cogent Code\"`}, - {`Cogent\ Code`, ``}, - {`./Cogent\ Code`, `./Cogent Code`}, - {"./ios-deploy", `./ios-deploy`}, - {"ios_deploy/sub", `ios_deploy/sub`}, - {"C:/ios_deploy/sub", `C:/ios_deploy/sub`}, - {"..", `..`}, - {"../another/dir/to/go_to", `../another/dir/to/go_to`}, - {"../an-other/dir/", `../an-other/dir/`}, - {"https://google.com/search?q=hello%20world#body", `https://google.com/search?q=hello%20world#body`}, - } - sh := NewShell() - for _, test := range tests { - toks := sh.Tokens(test.i) - p, _ := toks.Path(false) - assert.Equal(t, test.e, p) - } -} - -// these are more general tests of full-line statements of various forms -func TestTranspile(t *testing.T) { - // logx.UserLevel = slog.LevelDebug - tests := []exIn{ - {"ls", `shell.Run("ls")`}, - {"`ls -la`", `shell.Run("ls", "-la")`}, - {"ls -la", `shell.Run("ls", "-la")`}, - {"ls --help", `shell.Run("ls", "--help")`}, - {"ls go", `shell.Run("ls", "go")`}, - {"cd go", `shell.Run("cd", "go")`}, - {`var name string`, `var name string`}, - {`name = "test"`, `name = "test"`}, - {`echo {name}`, `shell.Run("echo", name)`}, - {`echo "testing"`, `shell.Run("echo", "testing")`}, - {`number := 1.23`, `number := 1.23`}, - {`res1, res2 := FunTwoRet()`, `res1, res2 := FunTwoRet()`}, - {`res1, res2, res3 := FunThreeRet()`, `res1, res2, res3 := FunThreeRet()`}, - {`println("hi")`, `println("hi")`}, - {`fmt.Println("hi")`, `fmt.Println("hi")`}, - {`for i := 0; i < 3; i++ { fmt.Println(i, "\n")`, `for i := 0; i < 3; i++ { fmt.Println(i, "\n")`}, - {"for i, v := range `ls -la` {", `for i, v := range shell.Output("ls", "-la") {`}, - {`// todo: fixit`, `// todo: fixit`}, - {"`go build`", `shell.Run("go", "build")`}, - {"{go build()}", `go build()`}, - {"go build", `shell.Run("go", "build")`}, - {"go build()", `go build()`}, - {"go build &", `shell.Start("go", "build")`}, - {"[mkdir subdir]", `shell.RunErrOK("mkdir", "subdir")`}, - {"set something hello-1", `shell.Run("set", "something", "hello-1")`}, - {"set something = hello", `shell.Run("set", "something", "=", "hello")`}, - {`set something = "hello"`, `shell.Run("set", "something", "=", "hello")`}, - {`set something=hello`, `shell.Run("set", "something=hello")`}, - {`set "something=hello"`, `shell.Run("set", "something=hello")`}, - {`set something="hello"`, `shell.Run("set", "something=\"hello\"")`}, - {`add-path /opt/sbin /opt/homebrew/bin`, `shell.Run("add-path", "/opt/sbin", "/opt/homebrew/bin")`}, - {`cat file > test.out`, `shell.Run("cat", "file", ">", "test.out")`}, - {`cat file | grep -v exe > test.out`, `shell.Start("cat", "file", "|"); shell.Run("grep", "-v", "exe", ">", "test.out")`}, - {`cd sub; pwd; ls -la`, `shell.Run("cd", "sub"); shell.Run("pwd"); shell.Run("ls", "-la")`}, - {`cd sub; [mkdir sub]; ls -la`, `shell.Run("cd", "sub"); shell.RunErrOK("mkdir", "sub"); shell.Run("ls", "-la")`}, - {`cd sub; mkdir names[4]`, `shell.Run("cd", "sub"); shell.Run("mkdir", "names[4]")`}, - {"ls -la > test.out", `shell.Run("ls", "-la", ">", "test.out")`}, - {"ls > test.out", `shell.Run("ls", ">", "test.out")`}, - {"ls -la >test.out", `shell.Run("ls", "-la", ">", "test.out")`}, - {"ls -la >> test.out", `shell.Run("ls", "-la", ">>", "test.out")`}, - {"ls -la >& test.out", `shell.Run("ls", "-la", ">&", "test.out")`}, - {"ls -la >>& test.out", `shell.Run("ls", "-la", ">>&", "test.out")`}, - {"@1 ls -la", `shell.Run("@1", "ls", "-la")`}, - {"git switch main", `shell.Run("git", "switch", "main")`}, - {"git checkout 123abc", `shell.Run("git", "checkout", "123abc")`}, - {"go get cogentcore.org/core@main", `shell.Run("go", "get", "cogentcore.org/core@main")`}, - {"ls *.go", `shell.Run("ls", "*.go")`}, - {"ls ??.go", `shell.Run("ls", "??.go")`}, - {`fmt.Println("hi")`, `fmt.Println("hi")`}, - {"cosh -i", `shell.Run("cosh", "-i")`}, - {"./cosh -i", `shell.Run("./cosh", "-i")`}, - {"cat main.go", `shell.Run("cat", "main.go")`}, - {"cd cogent", `shell.Run("cd", "cogent")`}, - {"cd cogent/", `shell.Run("cd", "cogent/")`}, - {"echo $PATH", `shell.Run("echo", "$PATH")`}, - {`"./Cogent Code"`, `shell.Run("./Cogent Code")`}, - {`./"Cogent Code"`, `shell.Run("./\"Cogent Code\"")`}, - {`Cogent\ Code`, `shell.Run("Cogent Code")`}, - {`./Cogent\ Code`, `shell.Run("./Cogent Code")`}, - {`ios\ deploy -i`, `shell.Run("ios deploy", "-i")`}, - {"./ios-deploy -i", `shell.Run("./ios-deploy", "-i")`}, - {"ios_deploy -i tree_file", `shell.Run("ios_deploy", "-i", "tree_file")`}, - {"ios_deploy/sub -i tree_file", `shell.Run("ios_deploy/sub", "-i", "tree_file")`}, - {"C:/ios_deploy/sub -i tree_file", `shell.Run("C:/ios_deploy/sub", "-i", "tree_file")`}, - {"ios_deploy -i tree_file/path", `shell.Run("ios_deploy", "-i", "tree_file/path")`}, - {"ios-deploy -i", `shell.Run("ios-deploy", "-i")`}, - {"ios-deploy -i tree-file", `shell.Run("ios-deploy", "-i", "tree-file")`}, - {"ios-deploy -i tree-file/path/here", `shell.Run("ios-deploy", "-i", "tree-file/path/here")`}, - {"cd ..", `shell.Run("cd", "..")`}, - {"cd ../another/dir/to/go_to", `shell.Run("cd", "../another/dir/to/go_to")`}, - {"cd ../an-other/dir/", `shell.Run("cd", "../an-other/dir/")`}, - {"curl https://google.com/search?q=hello%20world#body", `shell.Run("curl", "https://google.com/search?q=hello%20world#body")`}, - {"func splitLines(str string) []string {", `splitLines := func(str string)[]string {`}, - {"type Result struct {", `type Result struct {`}, - {"var Jobs *table.Table", `var Jobs *table.Table`}, - {"type Result struct { JobID string", `type Result struct { JobID string`}, - {"type Result struct { JobID string `width:\"60\"`", "type Result struct { JobID string `width:\"60\"`"}, - {"func RunInExamples(fun func()) {", "RunInExamples := func(fun func()) {"}, - {"ctr++", "ctr++"}, - {"stru.ctr++", "stru.ctr++"}, - {"meta += ln", "meta += ln"}, - {"var data map[string]any", "var data map[string]any"}, - } - - sh := NewShell() - for _, test := range tests { - o := sh.TranspileLine(test.i) - assert.Equal(t, test.e, o) - } -} - -// tests command generation -func TestCommand(t *testing.T) { - // logx.UserLevel = slog.LevelDebug - tests := []exIn{ - { - `command list { -ls -la args... -}`, - `shell.AddCommand("list", func(args ...string) { -shell.Run("ls", "-la", "args...") -})`}, - } - - sh := NewShell() - for _, test := range tests { - sh.TranspileCode(test.i) - o := sh.Code() - assert.Equal(t, test.e, o) - } -} diff --git a/spell/dict/dtool.go b/spell/dict/dtool.go index 73d01cee20..84c66102d1 100644 --- a/spell/dict/dtool.go +++ b/spell/dict/dtool.go @@ -14,7 +14,7 @@ import ( //go:generate core generate -add-types -add-funcs -// Config is the configuration information for the cosh cli. +// Config is the configuration information for the dict cli. type Config struct { // InputA is the first input dictionary file diff --git a/spell/dict/typegen.go b/spell/dict/typegen.go index 7fcd3f4f0a..af1751e4be 100644 --- a/spell/dict/typegen.go +++ b/spell/dict/typegen.go @@ -6,7 +6,7 @@ import ( "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "main.Config", IDName: "config", Doc: "Config is the configuration information for the cosh cli.", Directives: []types.Directive{{Tool: "go", Directive: "generate", Args: []string{"core", "generate", "-add-types", "-add-funcs"}}}, Fields: []types.Field{{Name: "InputA", Doc: "InputA is the first input dictionary file"}, {Name: "InputB", Doc: "InputB is the second input dictionary file"}, {Name: "Output", Doc: "Output is the output file for merge command"}}}) +var _ = types.AddType(&types.Type{Name: "main.Config", IDName: "config", Doc: "Config is the configuration information for the dict cli.", Directives: []types.Directive{{Tool: "go", Directive: "generate", Args: []string{"core", "generate", "-add-types", "-add-funcs"}}}, Fields: []types.Field{{Name: "InputA", Doc: "InputA is the first input dictionary file"}, {Name: "InputB", Doc: "InputB is the second input dictionary file"}, {Name: "Output", Doc: "Output is the output file for merge command"}}}) var _ = types.AddFunc(&types.Func{Name: "main.Compare", Doc: "Compare compares two dictionaries", Directives: []types.Directive{{Tool: "cli", Directive: "cmd", Args: []string{"-root"}}}, Args: []string{"c"}, Returns: []string{"error"}}) diff --git a/styles/font.go b/styles/font.go index 34bc3cd804..3712adc600 100644 --- a/styles/font.go +++ b/styles/font.go @@ -21,22 +21,24 @@ import ( // for rendering -- see [FontRender] for that. type Font struct { //types:add - // size of font to render (inherited); converted to points when getting font to use + // Size of font to render (inherited). + // Converted to points when getting font to use. Size units.Value - // font family (inherited): ordered list of comma-separated names from more general to more specific to use; use split on , to parse + // Family name for font (inherited): ordered list of comma-separated names + // from more general to more specific to use. Use split on, to parse. Family string - // style (inherited): normal, italic, etc + // Style (inherited): normal, italic, etc. Style FontStyles - // weight (inherited): normal, bold, etc + // Weight (inherited): normal, bold, etc. Weight FontWeights - // font stretch / condense options (inherited) + // Stretch / condense options (inherited). Stretch FontStretch - // normal or small caps (inherited) + // Variant specifies normal or small caps (inherited). Variant FontVariants // Decoration contains the bit flag [TextDecorations] @@ -45,15 +47,16 @@ type Font struct { //types:add // It is not inherited. Decoration TextDecorations - // super / sub script (not inherited) + // Shift is the super / sub script (not inherited). Shift BaselineShifts - // full font information including enhanced metrics and actual font codes for drawing text; this is a pointer into FontLibrary of loaded fonts + // Face has full font information including enhanced metrics and actual + // font codes for drawing text; this is a pointer into FontLibrary of loaded fonts. Face *FontFace `display:"-"` } func (fs *Font) Defaults() { - fs.Size = units.Dp(16) + fs.Size.Dp(16) } // InheritFields from parent diff --git a/styles/typegen.go b/styles/typegen.go index d0d3de6ca0..1264874c8c 100644 --- a/styles/typegen.go +++ b/styles/typegen.go @@ -10,7 +10,7 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/styles.Border", IDN var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/styles.Shadow", IDName: "shadow", Doc: "style parameters for shadows", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "OffsetX", Doc: "OffsetX is th horizontal offset of the shadow.\nPositive moves it right, negative moves it left."}, {Name: "OffsetY", Doc: "OffsetY is the vertical offset of the shadow.\nPositive moves it down, negative moves it up."}, {Name: "Blur", Doc: "Blur specifies the blur radius of the shadow.\nHigher numbers make it more blurry."}, {Name: "Spread", Doc: "Spread specifies the spread radius of the shadow.\nPositive numbers increase the size of the shadow,\nand negative numbers decrease the size."}, {Name: "Color", Doc: "Color specifies the color of the shadow."}, {Name: "Inset", Doc: "Inset specifies whether the shadow is inset within the\nbox instead of outset outside of the box.\nTODO: implement."}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/styles.Font", IDName: "font", Doc: "Font contains all font styling information.\nMost of font information is inherited.\nFont does not include all information needed\nfor rendering -- see [FontRender] for that.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Size", Doc: "size of font to render (inherited); converted to points when getting font to use"}, {Name: "Family", Doc: "font family (inherited): ordered list of comma-separated names from more general to more specific to use; use split on , to parse"}, {Name: "Style", Doc: "style (inherited): normal, italic, etc"}, {Name: "Weight", Doc: "weight (inherited): normal, bold, etc"}, {Name: "Stretch", Doc: "font stretch / condense options (inherited)"}, {Name: "Variant", Doc: "normal or small caps (inherited)"}, {Name: "Decoration", Doc: "Decoration contains the bit flag [TextDecorations]\n(underline, line-through, etc). It must be set using\n[Font.SetDecoration] since it contains bit flags.\nIt is not inherited."}, {Name: "Shift", Doc: "super / sub script (not inherited)"}, {Name: "Face", Doc: "full font information including enhanced metrics and actual font codes for drawing text; this is a pointer into FontLibrary of loaded fonts"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/styles.Font", IDName: "font", Doc: "Font contains all font styling information.\nMost of font information is inherited.\nFont does not include all information needed\nfor rendering -- see [FontRender] for that.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Size", Doc: "Size of font to render (inherited).\nConverted to points when getting font to use."}, {Name: "Family", Doc: "Family name for font (inherited): ordered list of comma-separated names\nfrom more general to more specific to use. Use split on, to parse."}, {Name: "Style", Doc: "Style (inherited): normal, italic, etc."}, {Name: "Weight", Doc: "Weight (inherited): normal, bold, etc."}, {Name: "Stretch", Doc: "Stretch / condense options (inherited)."}, {Name: "Variant", Doc: "Variant specifies normal or small caps (inherited)."}, {Name: "Decoration", Doc: "Decoration contains the bit flag [TextDecorations]\n(underline, line-through, etc). It must be set using\n[Font.SetDecoration] since it contains bit flags.\nIt is not inherited."}, {Name: "Shift", Doc: "Shift is the super / sub script (not inherited)."}, {Name: "Face", Doc: "Face has full font information including enhanced metrics and actual\nfont codes for drawing text; this is a pointer into FontLibrary of loaded fonts."}}}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/styles.FontRender", IDName: "font-render", Doc: "FontRender contains all font styling information\nthat is needed for SVG text rendering. It is passed to\nPaint and Style functions. It should typically not be\nused by end-user code -- see [Font] for that.\nIt stores all values as pointers so that they correspond\nto the values of the style object it was derived from.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Embeds: []types.Field{{Name: "Font"}}, Fields: []types.Field{{Name: "Color", Doc: "text color (inherited)"}, {Name: "Background", Doc: "background color (not inherited, transparent by default)"}, {Name: "Opacity", Doc: "alpha value between 0 and 1 to apply to the foreground and background of this element and all of its children"}}}) diff --git a/system/app.go b/system/app.go index 215b93cf5c..6633448e94 100644 --- a/system/app.go +++ b/system/app.go @@ -192,6 +192,9 @@ type App interface { // IsDark returns whether the system color theme is dark (as oppposed to light). IsDark() bool + + // GPUDevice returns the gpu.GPU device if it is present (else nil). + GPUDevice() any } // OnSystemWindowCreated is a channel used to communicate that the underlying diff --git a/system/driver/base/app.go b/system/driver/base/app.go index 11969dcdbf..05a6bd2806 100644 --- a/system/driver/base/app.go +++ b/system/driver/base/app.go @@ -182,6 +182,10 @@ func (a *App) IsDark() bool { return a.Dark } +func (a *App) GPUDevice() any { + return nil +} + func (a *App) GetScreens() { // no-op by default } diff --git a/system/driver/desktop/app.go b/system/driver/desktop/app.go index 8a184b8f8b..6deb0ecba8 100644 --- a/system/driver/desktop/app.go +++ b/system/driver/desktop/app.go @@ -51,6 +51,10 @@ func (a *App) SendEmptyEvent() { glfw.PostEmptyEvent() } +func (a *App) GPUDevice() any { + return a.GPU +} + // MainLoop starts running event loop on main thread (must be called // from the main thread). func (a *App) MainLoop() { diff --git a/system/driver/web/app.go b/system/driver/web/app.go index d87bcc4127..47c2a13f53 100644 --- a/system/driver/web/app.go +++ b/system/driver/web/app.go @@ -158,6 +158,10 @@ func (a *App) Resize() { a.Event.WindowResize() } +func (a *App) GPUDevice() any { + return a.Draw.wgpu +} + func (a *App) DataDir() string { return "/home/me/.data" } diff --git a/system/driver/web/drawer.go b/system/driver/web/drawer.go index 147a1ae593..f2f6be12d1 100644 --- a/system/driver/web/drawer.go +++ b/system/driver/web/drawer.go @@ -9,7 +9,6 @@ package web import ( "image" "image/draw" - "strings" "syscall/js" "cogentcore.org/core/gpu" @@ -44,7 +43,8 @@ func (a *App) InitDrawer() { // TODO(wgpu): various mobile and Linux browsers do not fully support WebGPU yet. isMobile := a.SystemPlatform().IsMobile() || a.SystemPlatform() == system.Linux // TODO(wgpu): Firefox currently does not support WebGPU in release builds. - isFirefox := strings.Contains(js.Global().Get("navigator").Get("userAgent").String(), "Firefox") + // isFirefox := strings.Contains(js.Global().Get("navigator").Get("userAgent").String(), "Firefox") + isFirefox := false if isMobile || isFirefox || !js.Global().Get("navigator").Get("gpu").Truthy() { a.Draw.context2D = js.Global().Get("document").Call("querySelector", "canvas").Call("getContext", "2d") return diff --git a/system/window.go b/system/window.go index 1bc84d796b..949d979a79 100644 --- a/system/window.go +++ b/system/window.go @@ -10,6 +10,7 @@ package system import ( + "fmt" "image" "unicode/utf8" @@ -329,6 +330,14 @@ func (o *NewWindowOptions) Fixup() { sc := TheApp.Screen(o.Screen) scsz := sc.Geometry.Size() // window coords size + if o.Flags.HasFlag(Fullscreen) { + o.Size.X = int(float32(scsz.X) * sc.DevicePixelRatio) + o.Size.Y = int(float32(scsz.Y) * sc.DevicePixelRatio) + o.Pos = image.Point{} + fmt.Println("fullscreen start:", o.Size, o.Pos) + return + } + if o.Size.X <= 0 { o.StdPixels = false o.Size.X = int(0.8 * float32(scsz.X) * sc.DevicePixelRatio) diff --git a/tensor/README.md b/tensor/README.md index 4dadca823b..9da1d04489 100644 --- a/tensor/README.md +++ b/tensor/README.md @@ -1,19 +1,292 @@ # Tensor -Tensor and related sub-packages provide a simple yet powerful framework for representing n-dimensional data of various types, providing similar functionality to the widely used `numpy` and `pandas` libraries in python, and the commercial MATLAB framework. +Tensor and related sub-packages provide a simple yet powerful framework for representing n-dimensional data of various types, providing similar functionality to the widely used [NumPy](https://numpy.org/doc/stable/index.html) and [pandas](https://pandas.pydata.org/) libraries in Python, and the commercial MATLAB framework. -* [table](table) organizes multiple Tensors as columns in a data `Table`, aligned by a common row dimension as the outer-most dimension of each tensor. Because the columns are tensors, each cell (value associated with a given row) can also be n-dimensional, allowing efficient representation of patterns and other high-dimensional data. Furthermore, the entire column is organized as a single contiguous slice of data, so it can be efficiently processed. The `table` package also has an `IndexView` that provides an indexed view into the rows of the table for highly efficient filtering and sorting of data. +The [Goal](../goal) augmented version of the _Go_ language directly supports NumPy-like operations on tensors. A `Tensor` is comparable to the NumPy `ndarray` type, and it provides the universal representation of a homogenous data type throughout all the packages here, from scalar to vector, matrix and beyond. All functions take and return `Tensor` arguments. - Data that is encoded as a slice of `struct`s can be bidirectionally converted to / from a Table, which then provides more powerful sorting, filtering and other functionality, including the plotcore. +The `Tensor` interface is implemented at the basic level with n-dimensional indexing into flat Go slices of any numeric data type (by `Number`), along with `String`, and `Bool` (which uses [bitslice](bitslice) for maximum efficiency). These implementations satisfy the `Values` sub-interface of Tensor, which supports the most direct and efficient operations on contiguous memory data. The `Shape` type provides all the n-dimensional indexing with arbitrary strides to allow any ordering, although _row major_ is the default and other orders have to be manually imposed. -* [tensorcore](tensorcore) provides core widgets for the `Tensor` and `Table` data. +In addition, there are five important "view" implementations of `Tensor` that wrap another "source" Tensor to provide more flexible and efficient access to the data, consistent with the NumPy functionality. See [Basic and Advanced Indexing](#basic-and-advanced-indexing) below for more info. -* [stats](stats) implements a number of different ways of analyzing tensor and table data. +* `Sliced` provides a sub-sliced view into the wrapped `Tensor` source, using an indexed list along each dimension. Thus, it can provide a reordered and filtered view onto the raw data, and it has a well-defined shape in terms of the number of indexes per dimension. This corresponds to the NumPy basic sliced indexing model. + +* `Masked` provides a `Bool` masked view onto each element in the wrapped `Tensor`, where the two maintain the same shape). Any cell with a `false` value in the bool mask returns a `NaN` (missing data), and `Set` functions are no-ops. The [stats](stats) packages treat `NaN` as missing data, but [tmath](tmath), [vector](vector), and [matrix](matrix) packages do not, so it is best to call `.AsValues()` on masked data prior to operating on it, in a basic math context (i.e., `copy` in Goal). + +* `Indexed` has a tensor of indexes into the source data, where the final, innermost dimension of the indexes is the same size as the number of dimensions in the wrapped source tensor. The overall shape of this view is that of the remaining outer dimensions of the Indexes tensor, and like other views, assignment and return values are taken from the corresponding indexed value in the wrapped source tensor. + +* `Reshaped` applies a different `Shape` to the source tensor, with the constraint that the new shape has the same length of total elements as the source tensor. It is particularly useful for aligning different tensors binary operation between them produces the desired results, for example by adding a new axis or collapsing multiple dimensions into one. + +* `Rows` is a specialized version of `Sliced` that provides a row index-based view, with the `Indexes` applying to the outermost _row_ dimension, which allows sorting and filtering to operate only on the indexes, leaving the underlying Tensor unchanged. This view is returned by the [table](table) data table, which organizes multiple heterogenous Tensor columns along a common outer row dimension, and provides similar functionality to pandas and particularly [xarray](http://xarray.pydata.org/en/stable/) in Python. + +Note that any view can be "stacked" on top of another, to produce more complex net views. + +Each view type implements the `AsValues` method to create a concrete "rendered" version of the view (as a `Values` tensor) where the actual underlying data is organized as it appears in the view. This is like the `copy` function in NumPy, disconnecting the view from the original source data. Note that unlike NumPy, `Masked` and `Indexed` remain views into the underlying source data -- see [Basic and Advanced Indexing](#basic-and-advanced-indexing) below. + +The `float64` ("Float"), `int` ("Int"), and `string` ("String") types are used as universal input / output types, and for intermediate computation in the math functions. Any performance-critical code can be optimized for a specific data type, but these universal interfaces are suitable for misc ad-hoc data analysis. + +There is also a `RowMajor` sub-interface for tensors (implemented by the `Values` and `Rows` types), which supports `[Set]FloatRow[Cell]` methods that provide optimized access to row major data. See [Standard shapes](#standard-shapes) for more info. + +The `Vectorize` function and its variants provide a universal "apply function to tensor data" mechanism (often called a "map" function, but that name is already taken in Go). It takes an `N` function that determines how many indexes to iterate over (and this function can also do any initialization prior to iterating), a compute function that gets the current index value, and a varargs list of tensors. In general it is completely up to the compute function how to interpret the index, although we also support the "broadcasting" principles from NumPy for binary functions operating on two tensors, as discussed below. There is a Threaded version of this for parallelizable functions, and a GPU version in the [gosl](../gpu/gosl) Go-as-a-shading-language package. + +To support the best possible performance in compute-intensive code, we have written all the core tensor functions in an `Out` suffixed version that takes the output tensor as an additional input argument (it must be a `Values` type), which allows an appropriately sized tensor to be used to hold the outputs on repeated function calls, instead of requiring new memory allocations every time. These versions are used in other calls where appropriate. The function without the `Out` suffix just wraps the `Out` version, and is what is called directly by Goal, where the output return value is essential for proper chaining of operations. + +To support proper argument handling for tensor functions, the [goal](../goal) transpiler registers all tensor package functions into the global name-to-function map (`tensor.Funcs`), which is used to retrieve the function by name, along with relevant arg metadata. This registry is also key for enum sets of functions, in the `stats` and `metrics` packages, for example, to be able to call the corresponding function. Goal uses symbols collected in the [yaegicore](../yaegicore) package to populate the Funcs, but enums should directly add themselves to ensure they are always available even outside of Goal. + +* [table](table) organizes multiple Tensors as columns in a data `Table`, aligned by a common outer row dimension. Because the columns are tensors, each cell (value associated with a given row) can also be n-dimensional, allowing efficient representation of patterns and other high-dimensional data. Furthermore, the entire column is organized as a single contiguous slice of data, so it can be efficiently processed. A `Table` automatically supplies a shared list of row Indexes for its `Indexed` columns, efficiently allowing all the heterogeneous data columns to be sorted and filtered together. + + Data that is encoded as a slice of `struct`s can be bidirectionally converted to / from a Table, which then provides more powerful sorting, filtering and other functionality, including [plot/plotcore](../plot/plotcore). + +* [tensorfs](tensorfs) provides a virtual filesystem (FS) for organizing arbitrary collections of data, supporting interactive, ad-hoc (notebook style) as well as systematic data processing. Interactive [goal](../goal) shell commands (`cd`, `ls`, `mkdir` etc) can be used to navigate the data space, with numerical expressions immediately available to operate on the data and save results back to the filesystem. Furthermore, the data can be directly copied to / from the OS filesystem to persist it, and `goal` can transparently access data on remote systems through ssh. Furthermore, the [databrowser](databrowser) provides a fully interactive GUI for inspecting and plotting data. + +* [tensorcore](tensorcore) provides core widgets for graphically displaying the `Tensor` and `Table` data, which are used in `tensorfs`. + +* [tmath](tmath) implements all standard math functions on `tensor.Indexed` data, including the standard `+, -, *, /` operators. `goal` then calls these functions. * [plot/plotcore](../plot/plotcore) supports interactive plotting of `Table` data. +* [bitslice](bitslice) is a Go slice of bytes `[]byte` that has methods for setting individual bits, as if it was a slice of bools, while being 8x more memory efficient. This is used for encoding null entries in `etensor`, and as a Tensor of bool / bits there as well, and is generally very useful for binary (boolean) data. + +* [stats](stats) implements a number of different ways of analyzing tensor and table data, including: + - [cluster](cluster) implements agglomerative clustering of items based on [metric](metric) distance / similarity matrix data. + - [convolve](convolve) convolves data (e.g., for smoothing). + - [glm](glm) fits a general linear model for one or more dependent variables as a function of one or more independent variables. This encompasses all forms of regression. + - [histogram](histogram) bins data into groups and reports the frequency of elements in the bins. + - [metric](metric) computes similarity / distance metrics for comparing two tensors, and associated distance / similarity matrix functions, including PCA and SVD analysis functions that operate on a covariance matrix. + - [stats](stats) provides a set of standard summary statistics on a range of different data types, including basic slices of floats, to tensor and table data. It also includes the ability to extract Groups of values and generate statistics for each group, as in a "pivot table" in a spreadsheet. + +# Standard shapes + +There are various standard shapes of tensor data that different functions expect, listed below. The two most general-purpose functions for shaping and slicing any tensor to get it into the right shape for a given computation are: + +* `Reshape` returns a `Reshaped` view with the same total length as the source tensor, functioning like the NumPy `reshape` function. + +* `Reslice` returns a re-sliced view of a tensor, extracting or rearranging dimenstions. It supports the full NumPy [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) syntax. It also does reshaping as needed, including processing the `NewAxis` option. + +* **Flat, 1D**: this is the simplest data shape, and any tensor can be turned into a flat 1D list using `NewReshaped(-1)` or the `As1D` function, which either returns the tensor itself it is already 1D, or a `Reshaped` 1D view. The [stats](stats) functions for example report summary statistics across the outermost row dimension, so converting data to this 1D view gives stats across all the data. + +* **Row, Cell 2D**: This is the natural shape for tabular data, and the `RowMajor` type and `Rows` view provide methods for efficiently accessing data in this way. In addition, the [stats](stats) and [metric](metric) packages automatically compute statistics across the outermost row dimension, aggregating results across rows for each cell. Thus, you end up with the "average cell-wise pattern" when you do `stats.Mean` for example. The `NewRowCellsView` function returns a `Reshaped` view of any tensor organized into this 2D shape, with the row vs. cell split specified at any point in the list of dimensions, which can be useful in obtaining the desired results. + +* **Matrix 2D**: For matrix algebra functions, a 2D tensor is treated as a standard row-major 2D matrix, which can be processed using `gonum` based matrix and vector operations, as in the [matrix](matrix) package. + +* **Matrix 3+D**: For functions that specifically process 2D matricies, a 3+D shape can be used as well, which iterates over the outer dimensions to process the inner 2D matricies. + +## Dynamic row sizing (e.g., for logs) + +The `SetNumRows` function can be used to progressively increase the number of rows to fit more data, as is typically the case when logging data (often using a [table](table)). You can set the row dimension to 0 to start -- that is (now) safe. However, for greatest efficiency, it is best to set the number of rows to the largest expected size first, and _then_ set it back to 0. The underlying slice of data retains its capacity when sized back down. During incremental increasing of the slice size, if it runs out of capacity, all the elements need to be copied, so it is more efficient to establish the capacity up front instead of having multiple incremental re-allocations. + +# Cheat Sheet + +TODO: update + +`ix` is the `Rows` tensor for these examples: + +## Tensor Access + +### 1D + +```Go +// 5th element in tensor regardless of shape: +val := ix.Float1D(5) +``` + +```Go +// value as a string regardless of underlying data type; numbers converted to strings. +str := ix.String1D(2) +``` + +### 2D Row, Cell + +```Go +// value at row 3, cell 2 (flat index into entire `SubSpace` tensor for this row) +// The row index will be indirected through any `Indexes` present on the Rows view. +val := ix.FloatRow(3, 2) +// string value at row 2, cell 0. this is safe for 1D and 2D+ shapes +// and is a robust way to get 1D data from tensors of unknown shapes. +str := ix.FloatRow(2, 0) +``` + +```Go +// get the whole n-dimensional tensor of data cells at given row. +// row is indirected through indexes. +// the resulting tensor is a "subslice" view into the underlying data +// so changes to it will automatically update the parent tensor. +tsr := ix.RowTensor(4) +.... +// set all n-dimensional tensor values at given row from given tensor. +ix.SetRowTensor(tsr, 4) +``` + +```Go +// returns a flat, 1D Rows view into n-dimensional tensor values at +// given row. This is used in compute routines that operate generically +// on the entire row as a flat pattern. +ci := tensor.Cells1D(ix, 5) +``` + +### Full N-dimensional Indexes + +```Go +// for 3D data +val := ix.Float(3,2,1) +``` + +# `Tensor` vs. Python NumPy + +The [Goal](../goal) language provides a reasonably faithful translation of NumPy `ndarray` syntax into the corresponding Go tensor package implementations. For those already familiar with NumPy, it should mostly "just work", but the following provides a more in-depth explanation for how the two relate, and when you might get different results. + +## Basic and Advanced Indexing + +NumPy distinguishes between _basic indexing_ (using a single index or sliced ranges of indexes along each dimension) versus _advanced indexing_ (using an array of indexes or bools). Basic indexing returns a **view** into the original data (where changes to the view directly affect the underlying type), while advanced indexing returns a **copy**. + +However, rather confusingly (per this [stack overflow question](https://stackoverflow.com/questions/15691740/does-assignment-with-advanced-indexing-copy-array-data)), you can do direct assignment through advanced indexing (more on this below): +```Python +a[np.array([1,2])] = 5 # or: +a[a > 0.5] = 1 # boolean advanced indexing +``` + +Although powerful, the semantics of all of this is a bit confusing. In the `tensor` package, we provide what are hopefully more clear and concrete _view_ types that have well-defined semantics, and cover the relevant functionality, while perhaps being a bit easier to reason with. These were described at the start of this README. The correspondence to NumPy indexing is as follows: + +* Basic indexing by individual integer index coordinate values is supported by the `Number`, `String`, `Bool` `Values` Tensors. For example, `Float(3,1,2)` returns the value at the given coordinates. The `Sliced` (and `Rows`) and `Reshaped` views then complete the basic indexing with arbitrary reordering and filtering along entire dimension values, and reshaping dimensions. As noted above, `Reslice` supports the full NumPy basic indexing syntax, and `Reshape` implements the NumPy `reshape` function. + +* The `Masked` view corresponds to the NumPy _advanced_ indexing using a same-shape boolean mask, although in the NumPy case it makes a copy (although practically it is widely used for direct assignment as shown above.) Critically, you can always extract just the `true` values from a Masked view by using the `AsValues` method on the view, which returns a 1D tensor of those values, similar to what the boolean advanced indexing produces in NumPy. In addition, the `SourceIndexes` method returns a 1D list of indexes of the `true` (or `false`) values, which can be used for the `Indexed` view. + +* The `Indexed` view corresponds to the array-based advanced indexing case in NumPy, but again it is a view, not a copy, so the assignment semantics are as expected from a view (and how NumPy behaves some of the time). Note that the NumPy version uses `n` separate index tensors, where each such tensor specifies the value of a corresponding dimension index, and all such tensors _must have the same shape_; that form can be converted into the single Indexes form with a utility function. Also, NumPy advanced indexing has a somewhat confusing property where it de-duplicates index references during some operations, such that `a+=1` only increments +1 even when there are multiple elements in the view. The tensor version does not implement that special case, due to its direct view semantics. + +To reiterate, all view tensors have a `AsValues` function, equivalent to the `copy` function in NumPy, which turns the view into a corresponding basic concrete value Tensor, so the copy semantics of advanced indexing (modulo the direct assignment behavior) can be achieved when assigning to a new variable. + +## Alignment of shapes for computations ("broadcasting") + +The NumPy concept of [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) is critical for flexibly defining the semantics for how functions taking two n-dimensional Tensor arguments behave when they have different shapes. Ultimately, the computation operates by iterating over the length of the longest tensor, and the question is how to _align_ the shapes so that a meaningful computation results from this. + +If both tensors are 1D and the same length, then a simple matched iteration over both can take place. However, the broadcasting logic defines what happens when there is a systematic relationship between the two, enabling powerful (but sometimes difficult to understand) computations to be specified. + +The following examples demonstrate the logic: + +Innermost dimensions that match in dimension are iterated over as you'd expect: +``` +Image (3d array): 256 x 256 x 3 +Scale (1d array): 3 +Result (3d array): 256 x 256 x 3 +``` + +Anything with a dimension size of 1 (a "singleton") will match against any other sized dimension: +``` +A (4d array): 8 x 1 x 6 x 1 +B (3d array): 7 x 1 x 5 +Result (4d array): 8 x 7 x 6 x 5 +``` +In the innermost dimension here, the single value in A acts like a "scalar" in relationship to the 5 values in B along that same dimension, operating on each one in turn. Likewise for the singleton second-to-last dimension in B. + +Any non-1 mismatch represents an error: +``` +A (2d array): 2 x 1 +B (3d array): 8 x 4 x 3 # second from last dimensions mismatched +``` + +The `AlignShapes` function performs this shape alignment logic, and the `WrapIndex1D` function is used to compute a 1D index into a given shape, based on the total output shape sizes, wrapping any singleton dimensions around as needed. These are used in the [tmath](tmath) package for example to implement the basic binary math operators. + +# Printing format + +The following are examples of tensor printing via the `Sprintf` function, which is used with default values for the `String()` stringer method on tensors. It does a 2D projection of higher-dimensional tensors, using the `Projection2D` set of functions, which assume a row-wise outermost dimension in general, and pack even sets of inner dimensions into 2D row x col shapes (see examples below). + +1D (vector): goes column-wise, and wraps around as needed, e.g., length = 4: +``` +[4] 0 1 2 3 +``` +and 40: +``` +[40] 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 + 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +``` + +2D matrix: +``` +[4 3] + [0] [1] [2] +[0] 0 1 2 +[1] 10 11 12 +[2] 20 21 22 +[3] 30 31 32 +``` +and a column vector (2nd dimension is 1): +``` +[4 1] +[0] 0 +[1] 1 +[2] 2 +[3] 3 +``` + +3D tensor, shape = `[4 3 2]` -- note the `[r r c]` legend below the shape which indicates which dimensions are shown on the row (`r`) vs column (`c`) axis, so you know how to interpret the indexes: +``` +[4 3 2] +[r r c] [0] [1] +[0 0] 0 1 +[0 1] 10 11 +[0 2] 20 21 +[1 0] 100 101 +[1 1] 110 111 +[1 2] 120 121 +[2 0] 200 201 +[2 1] 210 211 +[2 2] 220 221 +[3 0] 300 301 +[3 1] 310 311 +[3 2] 320 321 +``` + +4D tensor: note how the row, column dimensions alternate, resulting in a 2D layout of the outer 2 dimensions, with another 2D layout of the inner 2D dimensions inlaid within that: +``` +[5 4 3 2] +[r c r c] [0 0] [0 1] [1 0] [1 1] [2 0] [2 1] [3 0] [3 1] +[0 0] 0 1 100 101 200 201 300 301 +[0 1] 10 11 110 111 210 211 310 311 +[0 2] 20 21 120 121 220 221 320 321 +[1 0] 1000 1001 1100 1101 1200 1201 1300 1301 +[1 1] 1010 1011 1110 1111 1210 1211 1310 1311 +[1 2] 1020 1021 1120 1121 1220 1221 1320 1321 +[2 0] 2000 2001 2100 2101 2200 2201 2300 2301 +[2 1] 2010 2011 2110 2111 2210 2211 2310 2311 +[2 2] 2020 2021 2120 2121 2220 2221 2320 2321 +[3 0] 3000 3001 3100 3101 3200 3201 3300 3301 +[3 1] 3010 3011 3110 3111 3210 3211 3310 3311 +[3 2] 3020 3021 3120 3121 3220 3221 3320 3321 +[4 0] 4000 4001 4100 4101 4200 4201 4300 4301 +[4 1] 4010 4011 4110 4111 4210 4211 4310 4311 +[4 2] 4020 4021 4120 4121 4220 4221 4320 4321 +``` + +5D tensor: is treated like a 4D with the outermost dimension being an additional row dimension: +``` +[6 5 4 3 2] +[r r c r c] [0 0] [0 1] [1 0] [1 1] [2 0] [2 1] [3 0] [3 1] +[0 0 0] 0 1 100 101 200 201 300 301 +[0 0 1] 10 11 110 111 210 211 310 311 +[0 0 2] 20 21 120 121 220 221 320 321 +[0 1 0] 1000 1001 1100 1101 1200 1201 1300 1301 +[0 1 1] 1010 1011 1110 1111 1210 1211 1310 1311 +[0 1 2] 1020 1021 1120 1121 1220 1221 1320 1321 +[0 2 0] 2000 2001 2100 2101 2200 2201 2300 2301 +[0 2 1] 2010 2011 2110 2111 2210 2211 2310 2311 +[0 2 2] 2020 2021 2120 2121 2220 2221 2320 2321 +[0 3 0] 3000 3001 3100 3101 3200 3201 3300 3301 +[0 3 1] 3010 3011 3110 3111 3210 3211 3310 3311 +[0 3 2] 3020 3021 3120 3121 3220 3221 3320 3321 +[0 4 0] 4000 4001 4100 4101 4200 4201 4300 4301 +[0 4 1] 4010 4011 4110 4111 4210 4211 4310 4311 +[0 4 2] 4020 4021 4120 4121 4220 4221 4320 4321 +[1 0 0] 10000 10001 10100 10101 10200 10201 10300 10301 +[1 0 1] 10010 10011 10110 10111 10210 10211 10310 10311 +[1 0 2] 10020 10021 10120 10121 10220 10221 10320 10321 +[1 1 0] 11000 11001 11100 11101 11200 11201 11300 11301 +[1 1 1] 11010 11011 11110 11111 11210 11211 11310 11311 +... +``` # History -This package was originally developed as [etable](https://github.com/emer/etable) as part of the _emergent_ software framework. It always depended on the GUI framework that became Cogent Core, and having it integrated within the Core monorepo makes it easier to integrate updates, and also makes it easier to build advanced data management and visualization applications. For example, the [plot/plotcore](../plot/plotcore) package uses the `Table` to support flexible and powerful plotting functionality. +This package was originally developed as [etable](https://github.com/emer/etable) as part of the _emergent_ software framework. It always depended on the GUI framework that became Cogent Core, and having it integrated within the Core monorepo makes it easier to integrate updates, and also makes it easier to build advanced data management and visualization applications. For example, the [plot/plotcore](../plot/plotcore) package uses the `Table` to support flexible and powerful plotting functionality. + +It was completely rewritten in Sept 2024 to use a single data type (`tensor.Indexed`) and call signature for compute functions taking these args, to provide a simple and efficient data processing framework that greatly simplified the code and enables the [goal](../goal) language to directly transpile simplified math expressions into corresponding tensor compute code. + diff --git a/tensor/align.go b/tensor/align.go new file mode 100644 index 0000000000..a0c23b6509 --- /dev/null +++ b/tensor/align.go @@ -0,0 +1,331 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "fmt" + "slices" + + "cogentcore.org/core/base/errors" +) + +// AlignShapes aligns the shapes of two tensors, a and b for a binary +// computation producing an output, returning the effective aligned shapes +// for a, b, and the output, all with the same number of dimensions. +// Alignment proceeds from the innermost dimension out, with 1s provided +// beyond the number of dimensions for a or b. +// The output has the max of the dimension sizes for each dimension. +// An error is returned if the rules of alignment are violated: +// each dimension size must be either the same, or one of them +// is equal to 1. This corresponds to the "broadcasting" logic of NumPy. +func AlignShapes(a, b Tensor) (as, bs, os *Shape, err error) { + asz := a.ShapeSizes() + bsz := b.ShapeSizes() + an := len(asz) + bn := len(bsz) + n := max(an, bn) + osizes := make([]int, n) + asizes := make([]int, n) + bsizes := make([]int, n) + for d := range n { + ai := an - 1 - d + bi := bn - 1 - d + oi := n - 1 - d + ad := 1 + bd := 1 + if ai >= 0 { + ad = asz[ai] + } + if bi >= 0 { + bd = bsz[bi] + } + if ad != bd && !(ad == 1 || bd == 1) { + err = fmt.Errorf("tensor.AlignShapes: output dimension %d does not align for a=%d b=%d: must be either the same or one of them is a 1", oi, ad, bd) + return + } + od := max(ad, bd) + osizes[oi] = od + asizes[oi] = ad + bsizes[oi] = bd + } + as = NewShape(asizes...) + bs = NewShape(bsizes...) + os = NewShape(osizes...) + return +} + +// WrapIndex1D returns the 1d flat index for given n-dimensional index +// based on given shape, where any singleton dimension sizes cause the +// resulting index value to remain at 0, effectively causing that dimension +// to wrap around, while the other tensor is presumably using the full range +// of the values along this dimension. See [AlignShapes] for more info. +func WrapIndex1D(sh *Shape, i ...int) int { + nd := sh.NumDims() + ai := slices.Clone(i) + for d := range nd { + if sh.DimSize(d) == 1 { + ai[d] = 0 + } + } + return sh.IndexTo1D(ai...) +} + +// AlignForAssign ensures that the shapes of two tensors, a and b +// have the proper alignment for assigning b into a. +// Alignment proceeds from the innermost dimension out, with 1s provided +// beyond the number of dimensions for a or b. +// An error is returned if the rules of alignment are violated: +// each dimension size must be either the same, or b is equal to 1. +// This corresponds to the "broadcasting" logic of NumPy. +func AlignForAssign(a, b Tensor) (as, bs *Shape, err error) { + asz := a.ShapeSizes() + bsz := b.ShapeSizes() + an := len(asz) + bn := len(bsz) + n := max(an, bn) + asizes := make([]int, n) + bsizes := make([]int, n) + for d := range n { + ai := an - 1 - d + bi := bn - 1 - d + oi := n - 1 - d + ad := 1 + bd := 1 + if ai >= 0 { + ad = asz[ai] + } + if bi >= 0 { + bd = bsz[bi] + } + if ad != bd && bd != 1 { + err = fmt.Errorf("tensor.AlignShapes: dimension %d does not align for a=%d b=%d: must be either the same or b is a 1", oi, ad, bd) + return + } + asizes[oi] = ad + bsizes[oi] = bd + } + as = NewShape(asizes...) + bs = NewShape(bsizes...) + return +} + +// SplitAtInnerDims returns the sizes of the given tensor's shape +// with the given number of inner-most dimensions retained as is, +// and those above collapsed to a single dimension. +// If the total number of dimensions is < nInner the result is nil. +func SplitAtInnerDims(tsr Tensor, nInner int) []int { + sizes := tsr.ShapeSizes() + nd := len(sizes) + if nd < nInner { + return nil + } + rsz := make([]int, nInner+1) + split := nd - nInner + rows := sizes[:split] + copy(rsz[1:], sizes[split:]) + nr := 1 + for _, r := range rows { + nr *= r + } + rsz[0] = nr + return rsz +} + +// FloatAssignFunc sets a to a binary function of a and b float64 values. +func FloatAssignFunc(fun func(a, b float64) float64, a, b Tensor) error { + as, bs, err := AlignForAssign(a, b) + if err != nil { + return err + } + alen := as.Len() + VectorizeThreaded(1, func(tsr ...Tensor) int { return alen }, + func(idx int, tsr ...Tensor) { + ai := as.IndexFrom1D(idx) + bi := WrapIndex1D(bs, ai...) + tsr[0].SetFloat1D(fun(tsr[0].Float1D(idx), tsr[1].Float1D(bi)), idx) + }, a, b) + return nil +} + +// StringAssignFunc sets a to a binary function of a and b string values. +func StringAssignFunc(fun func(a, b string) string, a, b Tensor) error { + as, bs, err := AlignForAssign(a, b) + if err != nil { + return err + } + alen := as.Len() + VectorizeThreaded(1, func(tsr ...Tensor) int { return alen }, + func(idx int, tsr ...Tensor) { + ai := as.IndexFrom1D(idx) + bi := WrapIndex1D(bs, ai...) + tsr[0].SetString1D(fun(tsr[0].String1D(idx), tsr[1].String1D(bi)), idx) + }, a, b) + return nil +} + +// FloatBinaryFunc sets output to a binary function of a, b float64 values. +// The flops (floating point operations) estimate is used to control parallel +// threading using goroutines, and should reflect number of flops in the function. +// See [VectorizeThreaded] for more information. +func FloatBinaryFunc(flops int, fun func(a, b float64) float64, a, b Tensor) Tensor { + return CallOut2Gen2(FloatBinaryFuncOut, flops, fun, a, b) +} + +// FloatBinaryFuncOut sets output to a binary function of a, b float64 values. +func FloatBinaryFuncOut(flops int, fun func(a, b float64) float64, a, b Tensor, out Values) error { + as, bs, os, err := AlignShapes(a, b) + if err != nil { + return err + } + out.SetShapeSizes(os.Sizes...) + olen := os.Len() + VectorizeThreaded(flops, func(tsr ...Tensor) int { return olen }, + func(idx int, tsr ...Tensor) { + oi := os.IndexFrom1D(idx) + ai := WrapIndex1D(as, oi...) + bi := WrapIndex1D(bs, oi...) + out.SetFloat1D(fun(tsr[0].Float1D(ai), tsr[1].Float1D(bi)), idx) + }, a, b, out) + return nil +} + +// StringBinaryFunc sets output to a binary function of a, b string values. +func StringBinaryFunc(fun func(a, b string) string, a, b Tensor) Tensor { + return CallOut2Gen1(StringBinaryFuncOut, fun, a, b) +} + +// StringBinaryFuncOut sets output to a binary function of a, b string values. +func StringBinaryFuncOut(fun func(a, b string) string, a, b Tensor, out Values) error { + as, bs, os, err := AlignShapes(a, b) + if err != nil { + return err + } + out.SetShapeSizes(os.Sizes...) + olen := os.Len() + VectorizeThreaded(1, func(tsr ...Tensor) int { return olen }, + func(idx int, tsr ...Tensor) { + oi := os.IndexFrom1D(idx) + ai := WrapIndex1D(as, oi...) + bi := WrapIndex1D(bs, oi...) + out.SetString1D(fun(tsr[0].String1D(ai), tsr[1].String1D(bi)), idx) + }, a, b, out) + return nil +} + +// FloatFunc sets output to a function of tensor float64 values. +// The flops (floating point operations) estimate is used to control parallel +// threading using goroutines, and should reflect number of flops in the function. +// See [VectorizeThreaded] for more information. +func FloatFunc(flops int, fun func(in float64) float64, in Tensor) Values { + return CallOut1Gen2(FloatFuncOut, flops, fun, in) +} + +// FloatFuncOut sets output to a function of tensor float64 values. +func FloatFuncOut(flops int, fun func(in float64) float64, in Tensor, out Values) error { + SetShapeFrom(out, in) + n := in.Len() + VectorizeThreaded(flops, func(tsr ...Tensor) int { return n }, + func(idx int, tsr ...Tensor) { + tsr[1].SetFloat1D(fun(tsr[0].Float1D(idx)), idx) + }, in, out) + return nil +} + +// FloatSetFunc sets tensor float64 values from a function, +// which gets the index. Must be parallel threadsafe. +// The flops (floating point operations) estimate is used to control parallel +// threading using goroutines, and should reflect number of flops in the function. +// See [VectorizeThreaded] for more information. +func FloatSetFunc(flops int, fun func(idx int) float64, a Tensor) error { + n := a.Len() + VectorizeThreaded(flops, func(tsr ...Tensor) int { return n }, + func(idx int, tsr ...Tensor) { + tsr[0].SetFloat1D(fun(idx), idx) + }, a) + return nil +} + +//////// Bool + +// BoolStringsFunc sets boolean output value based on a function involving +// string values from the two tensors. +func BoolStringsFunc(fun func(a, b string) bool, a, b Tensor) *Bool { + out := NewBool() + errors.Log(BoolStringsFuncOut(fun, a, b, out)) + return out +} + +// BoolStringsFuncOut sets boolean output value based on a function involving +// string values from the two tensors. +func BoolStringsFuncOut(fun func(a, b string) bool, a, b Tensor, out *Bool) error { + as, bs, os, err := AlignShapes(a, b) + if err != nil { + return err + } + out.SetShapeSizes(os.Sizes...) + olen := os.Len() + VectorizeThreaded(5, func(tsr ...Tensor) int { return olen }, + func(idx int, tsr ...Tensor) { + oi := os.IndexFrom1D(idx) + ai := WrapIndex1D(as, oi...) + bi := WrapIndex1D(bs, oi...) + out.SetBool1D(fun(tsr[0].String1D(ai), tsr[1].String1D(bi)), idx) + }, a, b, out) + return nil +} + +// BoolFloatsFunc sets boolean output value based on a function involving +// float64 values from the two tensors. +func BoolFloatsFunc(fun func(a, b float64) bool, a, b Tensor) *Bool { + out := NewBool() + errors.Log(BoolFloatsFuncOut(fun, a, b, out)) + return out +} + +// BoolFloatsFuncOut sets boolean output value based on a function involving +// float64 values from the two tensors. +func BoolFloatsFuncOut(fun func(a, b float64) bool, a, b Tensor, out *Bool) error { + as, bs, os, err := AlignShapes(a, b) + if err != nil { + return err + } + out.SetShapeSizes(os.Sizes...) + olen := os.Len() + VectorizeThreaded(5, func(tsr ...Tensor) int { return olen }, + func(idx int, tsr ...Tensor) { + oi := os.IndexFrom1D(idx) + ai := WrapIndex1D(as, oi...) + bi := WrapIndex1D(bs, oi...) + out.SetBool1D(fun(tsr[0].Float1D(ai), tsr[1].Float1D(bi)), idx) + }, a, b, out) + return nil +} + +// BoolIntsFunc sets boolean output value based on a function involving +// int values from the two tensors. +func BoolIntsFunc(fun func(a, b int) bool, a, b Tensor) *Bool { + out := NewBool() + errors.Log(BoolIntsFuncOut(fun, a, b, out)) + return out +} + +// BoolIntsFuncOut sets boolean output value based on a function involving +// int values from the two tensors. +func BoolIntsFuncOut(fun func(a, b int) bool, a, b Tensor, out *Bool) error { + as, bs, os, err := AlignShapes(a, b) + if err != nil { + return err + } + out.SetShapeSizes(os.Sizes...) + olen := os.Len() + VectorizeThreaded(5, func(tsr ...Tensor) int { return olen }, + func(idx int, tsr ...Tensor) { + oi := os.IndexFrom1D(idx) + ai := WrapIndex1D(as, oi...) + bi := WrapIndex1D(bs, oi...) + out.SetBool1D(fun(tsr[0].Int1D(ai), tsr[1].Int1D(bi)), idx) + }, a, b, out) + return nil +} diff --git a/tensor/base.go b/tensor/base.go index 8bad3301f1..fe7e20b97d 100644 --- a/tensor/base.go +++ b/tensor/base.go @@ -5,43 +5,51 @@ package tensor import ( - "fmt" - "log" "reflect" + "slices" "unsafe" + "cogentcore.org/core/base/metadata" "cogentcore.org/core/base/reflectx" "cogentcore.org/core/base/slicesx" ) -// Base is an n-dim array of float64s. +// Base is the base Tensor implementation for given type. type Base[T any] struct { - Shp Shape + shape Shape Values []T - Meta map[string]string + Meta metadata.Data } -// Shape returns a pointer to the shape that fully parametrizes the tensor shape -func (tsr *Base[T]) Shape() *Shape { return &tsr.Shp } +// Metadata returns the metadata for this tensor, which can be used +// to encode plotting options, etc. +func (tsr *Base[T]) Metadata() *metadata.Data { return &tsr.Meta } + +func (tsr *Base[T]) Shape() *Shape { return &tsr.shape } + +// ShapeSizes returns the sizes of each dimension as a slice of ints. +// This is the preferred access for Go code. +func (tsr *Base[T]) ShapeSizes() []int { return slices.Clone(tsr.shape.Sizes) } + +// SetShapeSizes sets the dimension sizes of the tensor, and resizes +// backing storage appropriately, retaining all existing data that fits. +func (tsr *Base[T]) SetShapeSizes(sizes ...int) { + tsr.shape.SetShapeSizes(sizes...) + nln := tsr.shape.Len() + tsr.Values = slicesx.SetLength(tsr.Values, nln) +} // Len returns the number of elements in the tensor (product of shape dimensions). -func (tsr *Base[T]) Len() int { return tsr.Shp.Len() } +func (tsr *Base[T]) Len() int { return tsr.shape.Len() } // NumDims returns the total number of dimensions. -func (tsr *Base[T]) NumDims() int { return tsr.Shp.NumDims() } +func (tsr *Base[T]) NumDims() int { return tsr.shape.NumDims() } -// DimSize returns size of given dimension -func (tsr *Base[T]) DimSize(dim int) int { return tsr.Shp.DimSize(dim) } - -// RowCellSize returns the size of the outer-most Row shape dimension, -// and the size of all the remaining inner dimensions (the "cell" size). -// Used for Tensors that are columns in a data table. -func (tsr *Base[T]) RowCellSize() (rows, cells int) { - return tsr.Shp.RowCellSize() -} +// DimSize returns size of given dimension. +func (tsr *Base[T]) DimSize(dim int) int { return tsr.shape.DimSize(dim) } // DataType returns the type of the data elements in the tensor. -// Bool is returned for the Bits tensor type. +// Bool is returned for the Bool tensor type. func (tsr *Base[T]) DataType() reflect.Kind { var v T return reflect.TypeOf(v).Kind() @@ -56,24 +64,33 @@ func (tsr *Base[T]) Bytes() []byte { return slicesx.ToBytes(tsr.Values) } -func (tsr *Base[T]) Value(i []int) T { j := tsr.Shp.Offset(i); return tsr.Values[j] } -func (tsr *Base[T]) Value1D(i int) T { return tsr.Values[i] } -func (tsr *Base[T]) Set(i []int, val T) { j := tsr.Shp.Offset(i); tsr.Values[j] = val } -func (tsr *Base[T]) Set1D(i int, val T) { tsr.Values[i] = val } +func (tsr *Base[T]) Value(i ...int) T { + return tsr.Values[tsr.shape.IndexTo1D(i...)] +} -// SetShape sets the shape params, resizing backing storage appropriately -func (tsr *Base[T]) SetShape(sizes []int, names ...string) { - tsr.Shp.SetShape(sizes, names...) - nln := tsr.Len() - tsr.Values = slicesx.SetLength(tsr.Values, nln) +func (tsr *Base[T]) ValuePtr(i ...int) *T { + return &tsr.Values[tsr.shape.IndexTo1D(i...)] +} + +func (tsr *Base[T]) Set(val T, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] = val } -// SetNumRows sets the number of rows (outer-most dimension) in a RowMajor organized tensor. +func (tsr *Base[T]) Value1D(i int) T { return tsr.Values[i] } + +func (tsr *Base[T]) Set1D(val T, i int) { tsr.Values[i] = val } + +// SetNumRows sets the number of rows (outermost dimension) in a RowMajor organized tensor. +// It is safe to set this to 0. For incrementally growing tensors (e.g., a log) +// it is best to first set the anticipated full size, which allocates the +// full amount of memory, and then set to 0 and grow incrementally. func (tsr *Base[T]) SetNumRows(rows int) { - rows = max(1, rows) // must be > 0 - _, cells := tsr.Shp.RowCellSize() + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + _, cells := tsr.shape.RowCellSize() nln := rows * cells - tsr.Shp.Sizes[0] = rows + tsr.shape.Sizes[0] = rows tsr.Values = slicesx.SetLength(tsr.Values, nln) } @@ -82,113 +99,45 @@ func (tsr *Base[T]) SetNumRows(rows int) { // The new tensor points to the values of the this tensor (i.e., modifications // will affect both), as its Values slice is a view onto the original (which // is why only inner-most contiguous supsaces are supported). -// Use Clone() method to separate the two. -func (tsr *Base[T]) subSpaceImpl(offs []int) *Base[T] { +// Use AsValues() method to separate the two. +func (tsr *Base[T]) subSpaceImpl(offs ...int) *Base[T] { nd := tsr.NumDims() od := len(offs) - if od >= nd { + if od > nd { return nil } + var ssz []int + if od == nd { // scalar subspace + ssz = []int{1} + } else { + ssz = tsr.shape.Sizes[od:] + } stsr := &Base[T]{} - stsr.SetShape(tsr.Shp.Sizes[od:], tsr.Shp.Names[od:]...) + stsr.SetShapeSizes(ssz...) sti := make([]int, nd) copy(sti, offs) - stoff := tsr.Shp.Offset(sti) + stoff := tsr.shape.IndexTo1D(sti...) sln := stsr.Len() stsr.Values = tsr.Values[stoff : stoff+sln] return stsr } -func (tsr *Base[T]) StringValue(i []int) string { - j := tsr.Shp.Offset(i) - return reflectx.ToString(tsr.Values[j]) -} -func (tsr *Base[T]) String1D(off int) string { return reflectx.ToString(tsr.Values[off]) } +//////// Strings -func (tsr *Base[T]) StringRowCell(row, cell int) string { - _, sz := tsr.Shp.RowCellSize() - return reflectx.ToString(tsr.Values[row*sz+cell]) +func (tsr *Base[T]) StringValue(i ...int) string { + return reflectx.ToString(tsr.Values[tsr.shape.IndexTo1D(i...)]) } -// Label satisfies the core.Labeler interface for a summary description of the tensor -func (tsr *Base[T]) Label() string { - return fmt.Sprintf("Tensor: %s", tsr.Shp.String()) -} - -// Dims is the gonum/mat.Matrix interface method for returning the dimensionality of the -// 2D Matrix. Assumes Row-major ordering and logs an error if NumDims < 2. -func (tsr *Base[T]) Dims() (r, c int) { - nd := tsr.NumDims() - if nd < 2 { - log.Println("tensor Dims gonum Matrix call made on Tensor with dims < 2") - return 0, 0 - } - return tsr.Shp.DimSize(nd - 2), tsr.Shp.DimSize(nd - 1) +func (tsr *Base[T]) String1D(i int) string { + return reflectx.ToString(tsr.Values[i]) } -// Symmetric is the gonum/mat.Matrix interface method for returning the dimensionality of a symmetric -// 2D Matrix. -func (tsr *Base[T]) Symmetric() (r int) { - nd := tsr.NumDims() - if nd < 2 { - log.Println("tensor Symmetric gonum Matrix call made on Tensor with dims < 2") - return 0 - } - if tsr.Shp.DimSize(nd-2) != tsr.Shp.DimSize(nd-1) { - log.Println("tensor Symmetric gonum Matrix call made on Tensor that is not symmetric") - return 0 - } - return tsr.Shp.DimSize(nd - 1) -} - -// SymmetricDim returns the number of rows/columns in the matrix. -func (tsr *Base[T]) SymmetricDim() int { - nd := tsr.NumDims() - if nd < 2 { - log.Println("tensor Symmetric gonum Matrix call made on Tensor with dims < 2") - return 0 - } - if tsr.Shp.DimSize(nd-2) != tsr.Shp.DimSize(nd-1) { - log.Println("tensor Symmetric gonum Matrix call made on Tensor that is not symmetric") - return 0 - } - return tsr.Shp.DimSize(nd - 1) -} - -// SetMetaData sets a key=value meta data (stored as a map[string]string). -// For TensorGrid display: top-zero=+/-, odd-row=+/-, image=+/-, -// min, max set fixed min / max values, background=color -func (tsr *Base[T]) SetMetaData(key, val string) { - if tsr.Meta == nil { - tsr.Meta = make(map[string]string) - } - tsr.Meta[key] = val -} - -// MetaData retrieves value of given key, bool = false if not set -func (tsr *Base[T]) MetaData(key string) (string, bool) { - if tsr.Meta == nil { - return "", false - } - val, ok := tsr.Meta[key] - return val, ok -} - -// MetaDataMap returns the underlying map used for meta data -func (tsr *Base[T]) MetaDataMap() map[string]string { - return tsr.Meta +func (tsr *Base[T]) StringRow(row, cell int) string { + _, sz := tsr.shape.RowCellSize() + return reflectx.ToString(tsr.Values[row*sz+cell]) } -// CopyMetaData copies meta data from given source tensor -func (tsr *Base[T]) CopyMetaData(frm Tensor) { - fmap := frm.MetaDataMap() - if len(fmap) == 0 { - return - } - if tsr.Meta == nil { - tsr.Meta = make(map[string]string) - } - for k, v := range fmap { - tsr.Meta[k] = v - } +// Label satisfies the core.Labeler interface for a summary description of the tensor. +func (tsr *Base[T]) Label() string { + return label(metadata.Name(tsr), &tsr.shape) } diff --git a/tensor/bits.go b/tensor/bits.go deleted file mode 100644 index 76a15bfd39..0000000000 --- a/tensor/bits.go +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tensor - -import ( - "fmt" - "log/slog" - "reflect" - "strings" - - "cogentcore.org/core/base/reflectx" - "cogentcore.org/core/base/slicesx" - "cogentcore.org/core/tensor/bitslice" - "gonum.org/v1/gonum/mat" -) - -// Bits is a tensor of bits backed by a bitslice.Slice for efficient storage -// of binary data -type Bits struct { - Shp Shape - Values bitslice.Slice - Meta map[string]string -} - -// NewBits returns a new n-dimensional tensor of bit values -// with the given sizes per dimension (shape), and optional dimension names. -func NewBits(sizes []int, names ...string) *Bits { - tsr := &Bits{} - tsr.SetShape(sizes, names...) - tsr.Values = bitslice.Make(tsr.Len(), 0) - return tsr -} - -// NewBitsShape returns a new n-dimensional tensor of bit values -// using given shape. -func NewBitsShape(shape *Shape) *Bits { - tsr := &Bits{} - tsr.Shp.CopyShape(shape) - tsr.Values = bitslice.Make(tsr.Len(), 0) - return tsr -} - -func Float64ToBool(val float64) bool { - bv := true - if val == 0 { - bv = false - } - return bv -} - -func BoolToFloat64(bv bool) float64 { - if bv { - return 1 - } else { - return 0 - } -} - -func (tsr *Bits) IsString() bool { - return false -} - -// DataType returns the type of the data elements in the tensor. -// Bool is returned for the Bits tensor type. -func (tsr *Bits) DataType() reflect.Kind { - return reflect.Bool -} - -func (tsr *Bits) Sizeof() int64 { - return int64(len(tsr.Values)) -} - -func (tsr *Bits) Bytes() []byte { - return slicesx.ToBytes(tsr.Values) -} - -// Shape returns a pointer to the shape that fully parametrizes the tensor shape -func (tsr *Bits) Shape() *Shape { return &tsr.Shp } - -// Len returns the number of elements in the tensor (product of shape dimensions). -func (tsr *Bits) Len() int { return tsr.Shp.Len() } - -// NumDims returns the total number of dimensions. -func (tsr *Bits) NumDims() int { return tsr.Shp.NumDims() } - -// DimSize returns size of given dimension -func (tsr *Bits) DimSize(dim int) int { return tsr.Shp.DimSize(dim) } - -// RowCellSize returns the size of the outer-most Row shape dimension, -// and the size of all the remaining inner dimensions (the "cell" size). -// Used for Tensors that are columns in a data table. -func (tsr *Bits) RowCellSize() (rows, cells int) { - return tsr.Shp.RowCellSize() -} - -// Value returns value at given tensor index -func (tsr *Bits) Value(i []int) bool { j := int(tsr.Shp.Offset(i)); return tsr.Values.Index(j) } - -// Value1D returns value at given tensor 1D (flat) index -func (tsr *Bits) Value1D(i int) bool { return tsr.Values.Index(i) } - -func (tsr *Bits) Set(i []int, val bool) { j := int(tsr.Shp.Offset(i)); tsr.Values.Set(j, val) } -func (tsr *Bits) Set1D(i int, val bool) { tsr.Values.Set(i, val) } - -// SetShape sets the shape params, resizing backing storage appropriately -func (tsr *Bits) SetShape(sizes []int, names ...string) { - tsr.Shp.SetShape(sizes, names...) - nln := tsr.Len() - tsr.Values.SetLen(nln) -} - -// SetNumRows sets the number of rows (outer-most dimension) in a RowMajor organized tensor. -func (tsr *Bits) SetNumRows(rows int) { - rows = max(1, rows) // must be > 0 - _, cells := tsr.Shp.RowCellSize() - nln := rows * cells - tsr.Shp.Sizes[0] = rows - tsr.Values.SetLen(nln) -} - -// SubSpace is not possible with Bits -func (tsr *Bits) SubSpace(offs []int) Tensor { - return nil -} - -func (tsr *Bits) Float(i []int) float64 { - j := tsr.Shp.Offset(i) - return BoolToFloat64(tsr.Values.Index(j)) -} - -func (tsr *Bits) SetFloat(i []int, val float64) { - j := tsr.Shp.Offset(i) - tsr.Values.Set(j, Float64ToBool(val)) -} - -func (tsr *Bits) StringValue(i []int) string { - j := tsr.Shp.Offset(i) - return reflectx.ToString(tsr.Values.Index(j)) -} - -func (tsr *Bits) SetString(i []int, val string) { - if bv, err := reflectx.ToBool(val); err == nil { - j := tsr.Shp.Offset(i) - tsr.Values.Set(j, bv) - } -} - -func (tsr *Bits) Float1D(off int) float64 { - return BoolToFloat64(tsr.Values.Index(off)) -} -func (tsr *Bits) SetFloat1D(off int, val float64) { - tsr.Values.Set(off, Float64ToBool(val)) -} - -func (tsr *Bits) FloatRowCell(row, cell int) float64 { - _, sz := tsr.RowCellSize() - return BoolToFloat64(tsr.Values.Index(row*sz + cell)) -} -func (tsr *Bits) SetFloatRowCell(row, cell int, val float64) { - _, sz := tsr.RowCellSize() - tsr.Values.Set(row*sz+cell, Float64ToBool(val)) -} - -func (tsr *Bits) Floats(flt *[]float64) { - sz := tsr.Len() - *flt = slicesx.SetLength(*flt, sz) - for j := 0; j < sz; j++ { - (*flt)[j] = BoolToFloat64(tsr.Values.Index(j)) - } -} - -// SetFloats sets tensor values from a []float64 slice (copies values). -func (tsr *Bits) SetFloats(vals []float64) { - sz := min(tsr.Len(), len(vals)) - for j := 0; j < sz; j++ { - tsr.Values.Set(j, Float64ToBool(vals[j])) - } -} - -func (tsr *Bits) String1D(off int) string { - return reflectx.ToString(tsr.Values.Index(off)) -} - -func (tsr *Bits) SetString1D(off int, val string) { - if bv, err := reflectx.ToBool(val); err == nil { - tsr.Values.Set(off, bv) - } -} - -func (tsr *Bits) StringRowCell(row, cell int) string { - _, sz := tsr.RowCellSize() - return reflectx.ToString(tsr.Values.Index(row*sz + cell)) -} - -func (tsr *Bits) SetStringRowCell(row, cell int, val string) { - if bv, err := reflectx.ToBool(val); err == nil { - _, sz := tsr.RowCellSize() - tsr.Values.Set(row*sz+cell, bv) - } -} - -// Label satisfies the core.Labeler interface for a summary description of the tensor -func (tsr *Bits) Label() string { - return fmt.Sprintf("tensor.Bits: %s", tsr.Shp.String()) -} - -// SetMetaData sets a key=value meta data (stored as a map[string]string). -// For TensorGrid display: top-zero=+/-, odd-row=+/-, image=+/-, -// min, max set fixed min / max values, background=color -func (tsr *Bits) SetMetaData(key, val string) { - if tsr.Meta == nil { - tsr.Meta = make(map[string]string) - } - tsr.Meta[key] = val -} - -// MetaData retrieves value of given key, bool = false if not set -func (tsr *Bits) MetaData(key string) (string, bool) { - if tsr.Meta == nil { - return "", false - } - val, ok := tsr.Meta[key] - return val, ok -} - -// MetaDataMap returns the underlying map used for meta data -func (tsr *Bits) MetaDataMap() map[string]string { - return tsr.Meta -} - -// CopyMetaData copies meta data from given source tensor -func (tsr *Bits) CopyMetaData(frm Tensor) { - fmap := frm.MetaDataMap() - if len(fmap) == 0 { - return - } - if tsr.Meta == nil { - tsr.Meta = make(map[string]string) - } - for k, v := range fmap { - tsr.Meta[k] = v - } -} - -// Range is not applicable to Bits tensor -func (tsr *Bits) Range() (min, max float64, minIndex, maxIndex int) { - minIndex = -1 - maxIndex = -1 - return -} - -// SetZeros is simple convenience function initialize all values to 0 -func (tsr *Bits) SetZeros() { - ln := tsr.Len() - for j := 0; j < ln; j++ { - tsr.Values.Set(j, false) - } -} - -// Clone clones this tensor, creating a duplicate copy of itself with its -// own separate memory representation of all the values, and returns -// that as a Tensor (which can be converted into the known type as needed). -func (tsr *Bits) Clone() Tensor { - csr := NewBitsShape(&tsr.Shp) - csr.Values = tsr.Values.Clone() - return csr -} - -// CopyFrom copies all avail values from other tensor into this tensor, with an -// optimized implementation if the other tensor is of the same type, and -// otherwise it goes through appropriate standard type. -func (tsr *Bits) CopyFrom(frm Tensor) { - if fsm, ok := frm.(*Bits); ok { - copy(tsr.Values, fsm.Values) - return - } - sz := min(len(tsr.Values), frm.Len()) - for i := 0; i < sz; i++ { - tsr.Values.Set(i, Float64ToBool(frm.Float1D(i))) - } -} - -// CopyShapeFrom copies just the shape from given source tensor -// calling SetShape with the shape params from source (see for more docs). -func (tsr *Bits) CopyShapeFrom(frm Tensor) { - tsr.SetShape(frm.Shape().Sizes, frm.Shape().Names...) -} - -// CopyCellsFrom copies given range of values from other tensor into this tensor, -// using flat 1D indexes: to = starting index in this Tensor to start copying into, -// start = starting index on from Tensor to start copying from, and n = number of -// values to copy. Uses an optimized implementation if the other tensor is -// of the same type, and otherwise it goes through appropriate standard type. -func (tsr *Bits) CopyCellsFrom(frm Tensor, to, start, n int) { - if fsm, ok := frm.(*Bits); ok { - for i := 0; i < n; i++ { - tsr.Values.Set(to+i, fsm.Values.Index(start+i)) - } - return - } - for i := 0; i < n; i++ { - tsr.Values.Set(to+i, Float64ToBool(frm.Float1D(start+i))) - } -} - -// Dims is the gonum/mat.Matrix interface method for returning the dimensionality of the -// 2D Matrix. Not supported for Bits -- do not call! -func (tsr *Bits) Dims() (r, c int) { - slog.Error("tensor Dims gonum Matrix call made on Bits Tensor; not supported") - return 0, 0 -} - -// At is the gonum/mat.Matrix interface method for returning 2D matrix element at given -// row, column index. Not supported for Bits -- do not call! -func (tsr *Bits) At(i, j int) float64 { - slog.Error("tensor At gonum Matrix call made on Bits Tensor; not supported") - return 0 -} - -// T is the gonum/mat.Matrix transpose method. -// Not supported for Bits -- do not call! -func (tsr *Bits) T() mat.Matrix { - slog.Error("tensor T gonum Matrix call made on Bits Tensor; not supported") - return mat.Transpose{tsr} -} - -// String satisfies the fmt.Stringer interface for string of tensor data -func (tsr *Bits) String() string { - str := tsr.Label() - sz := tsr.Len() - if sz > 1000 { - return str - } - var b strings.Builder - b.WriteString(str) - b.WriteString("\n") - oddRow := true - rows, cols, _, _ := Projection2DShape(&tsr.Shp, oddRow) - for r := 0; r < rows; r++ { - rc, _ := Projection2DCoords(&tsr.Shp, oddRow, r, 0) - b.WriteString(fmt.Sprintf("%v: ", rc)) - for c := 0; c < cols; c++ { - vl := Projection2DValue(tsr, oddRow, r, c) - b.WriteString(fmt.Sprintf("%g ", vl)) - } - b.WriteString("\n") - } - return b.String() -} diff --git a/tensor/bitslice/bitslice.go b/tensor/bitslice/bitslice.go index eaefa05d9b..69f885779a 100644 --- a/tensor/bitslice/bitslice.go +++ b/tensor/bitslice/bitslice.go @@ -76,7 +76,7 @@ func (bs *Slice) SetLen(ln int) { } // Set sets value of given bit index -- no extra range checking is performed -- will panic if out of range -func (bs *Slice) Set(idx int, val bool) { +func (bs *Slice) Set(val bool, idx int) { by, bi := BitIndex(idx) if val { (*bs)[by+1] |= 1 << bi @@ -95,7 +95,7 @@ func (bs *Slice) Index(idx int) bool { func (bs *Slice) Append(val bool) Slice { if len(*bs) == 0 { *bs = Make(1, 0) - bs.Set(0, val) + bs.Set(val, 0) return *bs } ln := bs.Len() @@ -108,7 +108,7 @@ func (bs *Slice) Append(val bool) Slice { } else { (*bs)[0] = 0 } - bs.Set(ln, val) + bs.Set(val, ln) return *bs } @@ -161,7 +161,7 @@ func (bs *Slice) SubSlice(start, end int) Slice { } ss := Make(nln, 0) for i := 0; i < nln; i++ { - ss.Set(i, bs.Index(i+start)) + ss.Set(bs.Index(i+start), i) } return ss } @@ -186,10 +186,10 @@ func (bs *Slice) Delete(start, n int) Slice { } ss := Make(nln, 0) for i := 0; i < start; i++ { - ss.Set(i, bs.Index(i)) + ss.Set(bs.Index(i), i) } for i := end; i < ln; i++ { - ss.Set(i-n, bs.Index(i)) + ss.Set(bs.Index(i), i-n) } return ss } @@ -207,10 +207,10 @@ func (bs *Slice) Insert(start, n int) Slice { nln := ln + n ss := Make(nln, 0) for i := 0; i < start; i++ { - ss.Set(i, bs.Index(i)) + ss.Set(bs.Index(i), i) } for i := start; i < ln; i++ { - ss.Set(i+n, bs.Index(i)) + ss.Set(bs.Index(i), i+n) } return ss } diff --git a/tensor/bitslice/bitslice_test.go b/tensor/bitslice/bitslice_test.go index ceadc96cc6..322e66a5dd 100644 --- a/tensor/bitslice/bitslice_test.go +++ b/tensor/bitslice/bitslice_test.go @@ -24,7 +24,7 @@ func TestBitSlice10(t *testing.T) { t.Errorf("empty != %v", out) } - bs.Set(2, true) + bs.Set(true, 2) // fmt.Printf("2=true: %v\n", bs.String()) ex = "[0 0 1 0 0 0 0 0 0 0 ]" out = bs.String() @@ -32,7 +32,7 @@ func TestBitSlice10(t *testing.T) { t.Errorf("2=true != %v", out) } - bs.Set(4, true) + bs.Set(true, 4) // fmt.Printf("4=true: %v\n", bs.String()) ex = "[0 0 1 0 1 0 0 0 0 0 ]" out = bs.String() @@ -40,7 +40,7 @@ func TestBitSlice10(t *testing.T) { t.Errorf("4=true != %v", out) } - bs.Set(9, true) + bs.Set(true, 9) // fmt.Printf("9=true: %v\n", bs.String()) ex = "[0 0 1 0 1 0 0 0 0 1 ]" out = bs.String() diff --git a/tensor/bool.go b/tensor/bool.go new file mode 100644 index 0000000000..b80558d295 --- /dev/null +++ b/tensor/bool.go @@ -0,0 +1,342 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "fmt" + "reflect" + + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/base/num" + "cogentcore.org/core/base/reflectx" + "cogentcore.org/core/base/slicesx" + "cogentcore.org/core/tensor/bitslice" +) + +// Bool is a tensor of bits backed by a [bitslice.Slice] for efficient storage +// of binary, boolean data. Bool does not support [RowMajor.SubSpace] access +// and related methods due to the nature of the underlying data representation. +type Bool struct { + shape Shape + Values bitslice.Slice + Meta metadata.Data +} + +// NewBool returns a new n-dimensional tensor of bit values +// with the given sizes per dimension (shape). +func NewBool(sizes ...int) *Bool { + tsr := &Bool{} + tsr.SetShapeSizes(sizes...) + tsr.Values = bitslice.Make(tsr.Len(), 0) + return tsr +} + +// NewBoolShape returns a new n-dimensional tensor of bit values +// using given shape. +func NewBoolShape(shape *Shape) *Bool { + tsr := &Bool{} + tsr.shape.CopyFrom(shape) + tsr.Values = bitslice.Make(tsr.Len(), 0) + return tsr +} + +// Float64ToBool converts float64 value to bool. +func Float64ToBool(val float64) bool { + return num.ToBool(val) +} + +// BoolToFloat64 converts bool to float64 value. +func BoolToFloat64(bv bool) float64 { + return num.FromBool[float64](bv) +} + +// IntToBool converts int value to bool. +func IntToBool(val int) bool { + return num.ToBool(val) +} + +// BoolToInt converts bool to int value. +func BoolToInt(bv bool) int { + return num.FromBool[int](bv) +} + +// String satisfies the fmt.Stringer interface for string of tensor data. +func (tsr *Bool) String() string { return Sprintf("", tsr, 0) } + +// Label satisfies the core.Labeler interface for a summary description of the tensor +func (tsr *Bool) Label() string { + return label(metadata.Name(tsr), tsr.Shape()) +} + +func (tsr *Bool) IsString() bool { return false } + +func (tsr *Bool) AsValues() Values { return tsr } + +// DataType returns the type of the data elements in the tensor. +// Bool is returned for the Bool tensor type. +func (tsr *Bool) DataType() reflect.Kind { return reflect.Bool } + +func (tsr *Bool) Sizeof() int64 { return int64(len(tsr.Values)) } + +func (tsr *Bool) Bytes() []byte { return slicesx.ToBytes(tsr.Values) } + +func (tsr *Bool) Shape() *Shape { return &tsr.shape } + +// ShapeSizes returns the sizes of each dimension as a slice of ints. +// This is the preferred access for Go code. +func (tsr *Bool) ShapeSizes() []int { return tsr.shape.Sizes } + +// Metadata returns the metadata for this tensor, which can be used +// to encode plotting options, etc. +func (tsr *Bool) Metadata() *metadata.Data { return &tsr.Meta } + +// Len returns the number of elements in the tensor (product of shape dimensions). +func (tsr *Bool) Len() int { return tsr.shape.Len() } + +// NumDims returns the total number of dimensions. +func (tsr *Bool) NumDims() int { return tsr.shape.NumDims() } + +// DimSize returns size of given dimension +func (tsr *Bool) DimSize(dim int) int { return tsr.shape.DimSize(dim) } + +func (tsr *Bool) SetShapeSizes(sizes ...int) { + tsr.shape.SetShapeSizes(sizes...) + nln := tsr.Len() + tsr.Values.SetLen(nln) +} + +// SetNumRows sets the number of rows (outermost dimension) in a RowMajor organized tensor. +// It is safe to set this to 0. For incrementally growing tensors (e.g., a log) +// it is best to first set the anticipated full size, which allocates the +// full amount of memory, and then set to 0 and grow incrementally. +func (tsr *Bool) SetNumRows(rows int) { + _, cells := tsr.shape.RowCellSize() + nln := rows * cells + tsr.shape.Sizes[0] = rows + tsr.Values.SetLen(nln) +} + +// SubSpace is not possible with Bool. +func (tsr *Bool) SubSpace(offs ...int) Values { return nil } + +// RowTensor not possible with Bool. +func (tsr *Bool) RowTensor(row int) Values { return nil } + +// SetRowTensor not possible with Bool. +func (tsr *Bool) SetRowTensor(val Values, row int) {} + +// AppendRow not possible with Bool. +func (tsr *Bool) AppendRow(val Values) {} + +/////// Bool + +func (tsr *Bool) Value(i ...int) bool { + return tsr.Values.Index(tsr.shape.IndexTo1D(i...)) +} + +func (tsr *Bool) Set(val bool, i ...int) { + tsr.Values.Set(val, tsr.shape.IndexTo1D(i...)) +} + +func (tsr *Bool) Value1D(i int) bool { return tsr.Values.Index(i) } + +func (tsr *Bool) Set1D(val bool, i int) { tsr.Values.Set(val, i) } + +/////// Strings + +func (tsr *Bool) String1D(off int) string { + return reflectx.ToString(tsr.Values.Index(off)) +} + +func (tsr *Bool) SetString1D(val string, off int) { + if bv, err := reflectx.ToBool(val); err == nil { + tsr.Values.Set(bv, off) + } +} + +func (tsr *Bool) StringValue(i ...int) string { + return reflectx.ToString(tsr.Values.Index(tsr.shape.IndexTo1D(i...))) +} + +func (tsr *Bool) SetString(val string, i ...int) { + if bv, err := reflectx.ToBool(val); err == nil { + tsr.Values.Set(bv, tsr.shape.IndexTo1D(i...)) + } +} + +func (tsr *Bool) StringRow(row, cell int) string { + _, sz := tsr.shape.RowCellSize() + return reflectx.ToString(tsr.Values.Index(row*sz + cell)) +} + +func (tsr *Bool) SetStringRow(val string, row, cell int) { + if bv, err := reflectx.ToBool(val); err == nil { + _, sz := tsr.shape.RowCellSize() + tsr.Values.Set(bv, row*sz+cell) + } +} + +// AppendRowString not possible with Bool. +func (tsr *Bool) AppendRowString(val ...string) {} + +/////// Floats + +func (tsr *Bool) Float(i ...int) float64 { + return BoolToFloat64(tsr.Values.Index(tsr.shape.IndexTo1D(i...))) +} + +func (tsr *Bool) SetFloat(val float64, i ...int) { + tsr.Values.Set(Float64ToBool(val), tsr.shape.IndexTo1D(i...)) +} + +func (tsr *Bool) Float1D(off int) float64 { + return BoolToFloat64(tsr.Values.Index(off)) +} + +func (tsr *Bool) SetFloat1D(val float64, off int) { + tsr.Values.Set(Float64ToBool(val), off) +} + +func (tsr *Bool) FloatRow(row, cell int) float64 { + _, sz := tsr.shape.RowCellSize() + return BoolToFloat64(tsr.Values.Index(row*sz + cell)) +} + +func (tsr *Bool) SetFloatRow(val float64, row, cell int) { + _, sz := tsr.shape.RowCellSize() + tsr.Values.Set(Float64ToBool(val), row*sz+cell) +} + +// AppendRowFloat not possible with Bool. +func (tsr *Bool) AppendRowFloat(val ...float64) {} + +/////// Ints + +func (tsr *Bool) Int(i ...int) int { + return BoolToInt(tsr.Values.Index(tsr.shape.IndexTo1D(i...))) +} + +func (tsr *Bool) SetInt(val int, i ...int) { + tsr.Values.Set(IntToBool(val), tsr.shape.IndexTo1D(i...)) +} + +func (tsr *Bool) Int1D(off int) int { + return BoolToInt(tsr.Values.Index(off)) +} + +func (tsr *Bool) SetInt1D(val int, off int) { + tsr.Values.Set(IntToBool(val), off) +} + +func (tsr *Bool) IntRow(row, cell int) int { + _, sz := tsr.shape.RowCellSize() + return BoolToInt(tsr.Values.Index(row*sz + cell)) +} + +func (tsr *Bool) SetIntRow(val int, row, cell int) { + _, sz := tsr.shape.RowCellSize() + tsr.Values.Set(IntToBool(val), row*sz+cell) +} + +// AppendRowInt not possible with Bool. +func (tsr *Bool) AppendRowInt(val ...int) {} + +/////// Bools + +func (tsr *Bool) Bool(i ...int) bool { + return tsr.Values.Index(tsr.shape.IndexTo1D(i...)) +} + +func (tsr *Bool) SetBool(val bool, i ...int) { + tsr.Values.Set(val, tsr.shape.IndexTo1D(i...)) +} + +func (tsr *Bool) Bool1D(off int) bool { + return tsr.Values.Index(off) +} + +func (tsr *Bool) SetBool1D(val bool, off int) { + tsr.Values.Set(val, off) +} + +// SetZeros is a convenience function initialize all values to 0 (false). +func (tsr *Bool) SetZeros() { + ln := tsr.Len() + for j := 0; j < ln; j++ { + tsr.Values.Set(false, j) + } +} + +// SetTrue is simple convenience function initialize all values to 0 +func (tsr *Bool) SetTrue() { + ln := tsr.Len() + for j := 0; j < ln; j++ { + tsr.Values.Set(true, j) + } +} + +// Clone clones this tensor, creating a duplicate copy of itself with its +// own separate memory representation of all the values, and returns +// that as a Tensor (which can be converted into the known type as needed). +func (tsr *Bool) Clone() Values { + csr := NewBoolShape(&tsr.shape) + csr.Values = tsr.Values.Clone() + return csr +} + +// CopyFrom copies all avail values from other tensor into this tensor, with an +// optimized implementation if the other tensor is of the same type, and +// otherwise it goes through appropriate standard type. +func (tsr *Bool) CopyFrom(frm Values) { + if fsm, ok := frm.(*Bool); ok { + copy(tsr.Values, fsm.Values) + return + } + sz := min(len(tsr.Values), frm.Len()) + for i := 0; i < sz; i++ { + tsr.Values.Set(Float64ToBool(frm.Float1D(i)), i) + } +} + +// AppendFrom appends values from other tensor into this tensor, +// which must have the same cell size as this tensor. +// It uses and optimized implementation if the other tensor +// is of the same type, and otherwise it goes through +// appropriate standard type. +func (tsr *Bool) AppendFrom(frm Values) error { + rows, cell := tsr.shape.RowCellSize() + frows, fcell := frm.Shape().RowCellSize() + if cell != fcell { + return fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell) + } + tsr.SetNumRows(rows + frows) + st := rows * cell + fsz := frows * fcell + if fsm, ok := frm.(*Bool); ok { + copy(tsr.Values[st:st+fsz], fsm.Values) + return nil + } + for i := 0; i < fsz; i++ { + tsr.Values.Set(Float64ToBool(frm.Float1D(i)), st+i) + } + return nil +} + +// CopyCellsFrom copies given range of values from other tensor into this tensor, +// using flat 1D indexes: to = starting index in this Tensor to start copying into, +// start = starting index on from Tensor to start copying from, and n = number of +// values to copy. Uses an optimized implementation if the other tensor is +// of the same type, and otherwise it goes through appropriate standard type. +func (tsr *Bool) CopyCellsFrom(frm Values, to, start, n int) { + if fsm, ok := frm.(*Bool); ok { + for i := 0; i < n; i++ { + tsr.Values.Set(fsm.Values.Index(start+i), to+i) + } + return + } + for i := 0; i < n; i++ { + tsr.Values.Set(Float64ToBool(frm.Float1D(start+i)), to+i) + } +} diff --git a/tensor/cmd/tablecat/tablecat.go b/tensor/cmd/tablecat/tablecat.go index 5c3273257f..962be203f4 100644 --- a/tensor/cmd/tablecat/tablecat.go +++ b/tensor/cmd/tablecat/tablecat.go @@ -10,12 +10,12 @@ import ( "fmt" "os" "sort" - "strconv" "cogentcore.org/core/core" - "cogentcore.org/core/tensor/stats/split" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/stats/stats" "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" ) var ( @@ -110,13 +110,13 @@ func RawCat(files []string) { func AvgCat(files []string) { dts := make([]*table.Table, 0, len(files)) for _, fn := range files { - dt := &table.Table{} - err := dt.OpenCSV(core.Filename(fn), table.Tab) + dt := table.New() + err := dt.OpenCSV(core.Filename(fn), tensor.Tab) if err != nil { fmt.Println("Error opening file: ", err) continue } - if dt.Rows == 0 { + if dt.NumRows() == 0 { fmt.Printf("File %v empty\n", fn) continue } @@ -126,40 +126,43 @@ func AvgCat(files []string) { fmt.Println("No files or files are empty, exiting") return } - avgdt := stats.MeanTables(dts) - avgdt.SetMetaData("precision", strconv.Itoa(LogPrec)) - avgdt.SaveCSV(core.Filename(Output), table.Tab, table.Headers) + // todo: need meantables + // avgdt := stats.MeanTables(dts) + // tensor.SetPrecision(avgdt, LogPrec) + // avgdt.SaveCSV(core.Filename(Output), tensor.Tab, table.Headers) } // AvgByColumn computes average by given column for given files // If column is empty, averages across all rows. func AvgByColumn(files []string, column string) { for _, fn := range files { - dt := table.NewTable() - err := dt.OpenCSV(core.Filename(fn), table.Tab) + dt := table.New() + err := dt.OpenCSV(core.Filename(fn), tensor.Tab) if err != nil { fmt.Println("Error opening file: ", err) continue } - if dt.Rows == 0 { + if dt.NumRows() == 0 { fmt.Printf("File %v empty\n", fn) continue } - ix := table.NewIndexView(dt) - var spl *table.Splits + dir, _ := tensorfs.NewDir("Groups") if column == "" { - spl = split.All(ix) + stats.GroupAll(dir, dt.ColumnByIndex(0)) } else { - spl = split.GroupBy(ix, column) + stats.TableGroups(dir, dt, column) } - for ci, cl := range dt.Columns { - if cl.IsString() || dt.ColumnNames[ci] == column { + var cols []string + for ci, cl := range dt.Columns.Values { + if cl.IsString() || dt.Columns.Keys[ci] == column { continue } - split.AggIndex(spl, ci, stats.Mean) + cols = append(cols, dt.Columns.Keys[ci]) } - avgdt := spl.AggsToTable(table.ColumnNameOnly) - avgdt.SetMetaData("precision", strconv.Itoa(LogPrec)) - avgdt.SaveCSV(core.Filename(Output), table.Tab, table.Headers) + stats.TableGroupStats(dir, stats.StatMean, dt, cols...) + std := dir.Node("Stats") + avgdt := tensorfs.DirTable(std, nil) // todo: has stat name slash + tensor.SetPrecision(avgdt, LogPrec) + avgdt.SaveCSV(core.Filename(Output), tensor.Tab, table.Headers) } } diff --git a/tensor/convert.go b/tensor/convert.go new file mode 100644 index 0000000000..d94a471130 --- /dev/null +++ b/tensor/convert.go @@ -0,0 +1,251 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "math" + + "cogentcore.org/core/base/errors" +) + +// Clone returns a copy of the given tensor. +// If it is raw [Values] then a [Values.Clone] is returned. +// Otherwise if it is a view, then [Tensor.AsValues] is returned. +// This is equivalent to the NumPy copy function. +func Clone(tsr Tensor) Values { + if vl, ok := tsr.(Values); ok { + return vl.Clone() + } + return tsr.AsValues() +} + +// Flatten returns a copy of the given tensor as a 1D flat list +// of values, by calling Clone(As1D(tsr)). +// It is equivalent to the NumPy flatten function. +func Flatten(tsr Tensor) Values { + if msk, ok := tsr.(*Masked); ok { + return msk.AsValues() + } + return Clone(As1D(tsr)) +} + +// Squeeze a [Reshaped] view of given tensor with all singleton +// (size = 1) dimensions removed (if none, just returns the tensor). +func Squeeze(tsr Tensor) Tensor { + nd := tsr.NumDims() + sh := tsr.ShapeSizes() + reshape := make([]int, 0, nd) + for _, sz := range sh { + if sz > 1 { + reshape = append(reshape, sz) + } + } + if len(reshape) == nd { + return tsr + } + return NewReshaped(tsr, reshape...) +} + +// As1D returns a 1D tensor, which is either the input tensor if it is +// already 1D, or a new [Reshaped] 1D view of it. +// This can be useful e.g., for stats and metric functions that operate +// on a 1D list of values. See also [Flatten]. +func As1D(tsr Tensor) Tensor { + if tsr.NumDims() == 1 { + return tsr + } + return NewReshaped(tsr, tsr.Len()) +} + +// Cells1D returns a flat 1D view of the innermost cells for given row index. +// For a [RowMajor] tensor, it uses the [RowTensor] subspace directly, +// otherwise it uses [Sliced] to extract the cells. In either case, +// [As1D] is used to ensure the result is a 1D tensor. +func Cells1D(tsr Tensor, row int) Tensor { + if rm, ok := tsr.(RowMajor); ok { + return As1D(rm.RowTensor(row)) + } + return As1D(NewSliced(tsr, []int{row})) +} + +// MustBeValues returns the given tensor as a [Values] subtype, or nil and +// an error if it is not one. Typically outputs of compute operations must +// be values, and are reshaped to hold the results as needed. +func MustBeValues(tsr Tensor) (Values, error) { + vl, ok := tsr.(Values) + if !ok { + return nil, errors.New("tensor.MustBeValues: tensor must be a Values type") + } + return vl, nil +} + +// MustBeSameShape returns an error if the two tensors do not have the same shape. +func MustBeSameShape(a, b Tensor) error { + if !a.Shape().IsEqual(b.Shape()) { + return errors.New("tensor.MustBeSameShape: tensors must have the same shape") + } + return nil +} + +// SetShape sets the dimension sizes from given Shape +func SetShape(vals Values, sh *Shape) { + vals.SetShapeSizes(sh.Sizes...) +} + +// SetShapeSizesFromTensor sets the dimension sizes as 1D int values from given tensor. +// The backing storage is resized appropriately, retaining all existing data that fits. +func SetShapeSizesFromTensor(vals Values, sizes Tensor) { + vals.SetShapeSizes(AsIntSlice(sizes)...) +} + +// SetShapeFrom sets shape of given tensor from a source tensor. +func SetShapeFrom(vals Values, from Tensor) { + vals.SetShapeSizes(from.ShapeSizes()...) +} + +// AsFloat64Scalar returns the first value of tensor as a float64 scalar. +// Returns 0 if no values. +func AsFloat64Scalar(tsr Tensor) float64 { + if tsr.Len() == 0 { + return 0 + } + return tsr.Float1D(0) +} + +// AsIntScalar returns the first value of tensor as an int scalar. +// Returns 0 if no values. +func AsIntScalar(tsr Tensor) int { + if tsr.Len() == 0 { + return 0 + } + return tsr.Int1D(0) +} + +// AsStringScalar returns the first value of tensor as a string scalar. +// Returns "" if no values. +func AsStringScalar(tsr Tensor) string { + if tsr.Len() == 0 { + return "" + } + return tsr.String1D(0) +} + +// AsFloat64Slice returns all the tensor values as a slice of float64's. +// This allocates a new slice for the return values, and is not +// a good option for performance-critical code. +func AsFloat64Slice(tsr Tensor) []float64 { + if tsr.Len() == 0 { + return nil + } + sz := tsr.Len() + slc := make([]float64, sz) + for i := range sz { + slc[i] = tsr.Float1D(i) + } + return slc +} + +// AsIntSlice returns all the tensor values as a slice of ints. +// This allocates a new slice for the return values, and is not +// a good option for performance-critical code. +func AsIntSlice(tsr Tensor) []int { + if tsr.Len() == 0 { + return nil + } + sz := tsr.Len() + slc := make([]int, sz) + for i := range sz { + slc[i] = tsr.Int1D(i) + } + return slc +} + +// AsStringSlice returns all the tensor values as a slice of strings. +// This allocates a new slice for the return values, and is not +// a good option for performance-critical code. +func AsStringSlice(tsr Tensor) []string { + if tsr.Len() == 0 { + return nil + } + sz := tsr.Len() + slc := make([]string, sz) + for i := range sz { + slc[i] = tsr.String1D(i) + } + return slc +} + +// AsFloat64 returns the tensor as a [Float64] tensor. +// If already is a Float64, it is returned as such. +// Otherwise, a new Float64 tensor is created and values are copied. +// Use this function for interfacing with gonum or other apis that +// only operate on float64 types. +func AsFloat64(tsr Tensor) *Float64 { + if f, ok := tsr.(*Float64); ok { + return f + } + f := NewFloat64(tsr.ShapeSizes()...) + f.CopyFrom(tsr.AsValues()) + return f +} + +// AsFloat32 returns the tensor as a [Float32] tensor. +// If already is a Float32, it is returned as such. +// Otherwise, a new Float32 tensor is created and values are copied. +func AsFloat32(tsr Tensor) *Float32 { + if f, ok := tsr.(*Float32); ok { + return f + } + f := NewFloat32(tsr.ShapeSizes()...) + f.CopyFrom(tsr.AsValues()) + return f +} + +// AsString returns the tensor as a [String] tensor. +// If already is a String, it is returned as such. +// Otherwise, a new String tensor is created and values are copied. +func AsString(tsr Tensor) *String { + if f, ok := tsr.(*String); ok { + return f + } + f := NewString(tsr.ShapeSizes()...) + f.CopyFrom(tsr.AsValues()) + return f +} + +// AsInt returns the tensor as a [Int] tensor. +// If already is a Int, it is returned as such. +// Otherwise, a new Int tensor is created and values are copied. +func AsInt(tsr Tensor) *Int { + if f, ok := tsr.(*Int); ok { + return f + } + f := NewInt(tsr.ShapeSizes()...) + f.CopyFrom(tsr.AsValues()) + return f +} + +// Range returns the min, max (and associated indexes, -1 = no values) for the tensor. +// This is needed for display and is thus in the tensor api on Values. +func Range(vals Values) (min, max float64, minIndex, maxIndex int) { + minIndex = -1 + maxIndex = -1 + n := vals.Len() + for j := range n { + fv := vals.Float1D(n) + if math.IsNaN(fv) { + continue + } + if fv < min || minIndex < 0 { + min = fv + minIndex = j + } + if fv > max || maxIndex < 0 { + max = fv + maxIndex = j + } + } + return +} diff --git a/tensor/create.go b/tensor/create.go new file mode 100644 index 0000000000..78bee349f7 --- /dev/null +++ b/tensor/create.go @@ -0,0 +1,191 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "math/rand" + "slices" +) + +// NewFloat64Scalar is a convenience method for a Tensor +// representation of a single float64 scalar value. +func NewFloat64Scalar(val float64) *Float64 { + return NewNumberFromValues(val) +} + +// NewFloat32Scalar is a convenience method for a Tensor +// representation of a single float32 scalar value. +func NewFloat32Scalar(val float32) *Float32 { + return NewNumberFromValues(val) +} + +// NewIntScalar is a convenience method for a Tensor +// representation of a single int scalar value. +func NewIntScalar(val int) *Int { + return NewNumberFromValues(val) +} + +// NewStringScalar is a convenience method for a Tensor +// representation of a single string scalar value. +func NewStringScalar(val string) *String { + return NewStringFromValues(val) +} + +// NewFloat64FromValues returns a new 1-dimensional tensor of given value type +// initialized directly from the given slice values, which are not copied. +// The resulting Tensor thus "wraps" the given values. +func NewFloat64FromValues(vals ...float64) *Float64 { + return NewNumberFromValues(vals...) +} + +// NewFloat32FromValues returns a new 1-dimensional tensor of given value type +// initialized directly from the given slice values, which are not copied. +// The resulting Tensor thus "wraps" the given values. +func NewFloat32FromValues(vals ...float32) *Float32 { + return NewNumberFromValues(vals...) +} + +// NewIntFromValues returns a new 1-dimensional tensor of given value type +// initialized directly from the given slice values, which are not copied. +// The resulting Tensor thus "wraps" the given values. +func NewIntFromValues(vals ...int) *Int { + return NewNumberFromValues(vals...) +} + +// NewStringFromValues returns a new 1-dimensional tensor of given value type +// initialized directly from the given slice values, which are not copied. +// The resulting Tensor thus "wraps" the given values. +func NewStringFromValues(vals ...string) *String { + n := len(vals) + tsr := &String{} + tsr.Values = vals + tsr.SetShapeSizes(n) + return tsr +} + +// SetAllFloat64 sets all values of given tensor to given value. +func SetAllFloat64(tsr Tensor, val float64) { + VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() }, + func(idx int, tsr ...Tensor) { + tsr[0].SetFloat1D(val, idx) + }, tsr) +} + +// SetAllInt sets all values of given tensor to given value. +func SetAllInt(tsr Tensor, val int) { + VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() }, + func(idx int, tsr ...Tensor) { + tsr[0].SetInt1D(val, idx) + }, tsr) +} + +// SetAllString sets all values of given tensor to given value. +func SetAllString(tsr Tensor, val string) { + VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() }, + func(idx int, tsr ...Tensor) { + tsr[0].SetString1D(val, idx) + }, tsr) +} + +// NewFloat64Full returns a new tensor full of given scalar value, +// of given shape sizes. +func NewFloat64Full(val float64, sizes ...int) *Float64 { + tsr := NewFloat64(sizes...) + SetAllFloat64(tsr, val) + return tsr +} + +// NewFloat64Ones returns a new tensor full of 1s, +// of given shape sizes. +func NewFloat64Ones(sizes ...int) *Float64 { + tsr := NewFloat64(sizes...) + SetAllFloat64(tsr, 1.0) + return tsr +} + +// NewIntFull returns a new tensor full of given scalar value, +// of given shape sizes. +func NewIntFull(val int, sizes ...int) *Int { + tsr := NewInt(sizes...) + SetAllInt(tsr, val) + return tsr +} + +// NewStringFull returns a new tensor full of given scalar value, +// of given shape sizes. +func NewStringFull(val string, sizes ...int) *String { + tsr := NewString(sizes...) + SetAllString(tsr, val) + return tsr +} + +// NewFloat64Rand returns a new tensor full of random numbers from +// global random source, of given shape sizes. +func NewFloat64Rand(sizes ...int) *Float64 { + tsr := NewFloat64(sizes...) + FloatSetFunc(1, func(idx int) float64 { return rand.Float64() }, tsr) + return tsr +} + +// NewIntRange returns a new [Int] [Tensor] with given [Slice] +// range parameters, with the same semantics as NumPy arange based on +// the number of arguments passed: +// - 1 = stop +// - 2 = start, stop +// - 3 = start, stop, step +func NewIntRange(svals ...int) *Int { + if len(svals) == 0 { + return NewInt() + } + sl := Slice{} + switch len(svals) { + case 1: + sl.Stop = svals[0] + case 2: + sl.Start = svals[0] + sl.Stop = svals[1] + case 3: + sl.Start = svals[0] + sl.Stop = svals[1] + sl.Step = svals[2] + } + return sl.IntTensor(sl.Stop) +} + +// NewFloat64SpacedLinear returns a new [Float64] tensor with num linearly +// spaced numbers between start and stop values, as tensors, which +// must be the same length and determine the cell shape of the output. +// If num is 0, then a default of 50 is used. +// If endpoint = true, then the stop value is _inclusive_, i.e., it will +// be the final value, otherwise it is exclusive. +// This corresponds to the NumPy linspace function. +func NewFloat64SpacedLinear(start, stop Tensor, num int, endpoint bool) *Float64 { + if num <= 0 { + num = 50 + } + fnum := float64(num) + if endpoint { + fnum -= 1 + } + step := Clone(start) + n := step.Len() + for i := range n { + step.SetFloat1D((stop.Float1D(i)-start.Float1D(i))/fnum, i) + } + var tsr *Float64 + if start.Len() == 1 { + tsr = NewFloat64(num) + } else { + tsz := slices.Clone(start.Shape().Sizes) + tsz = append([]int{num}, tsz...) + tsr = NewFloat64(tsz...) + } + for r := range num { + for i := range n { + tsr.SetFloatRow(start.Float1D(i)+float64(r)*step.Float1D(i), r, i) + } + } + return tsr +} diff --git a/tensor/create_test.go b/tensor/create_test.go new file mode 100644 index 0000000000..be7927a543 --- /dev/null +++ b/tensor/create_test.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreate(t *testing.T) { + assert.Equal(t, 5.5, NewFloat64Scalar(5.5).Float1D(0)) + assert.Equal(t, 5, NewIntScalar(5).Int1D(0)) + assert.Equal(t, "test", NewStringScalar("test").String1D(0)) + + assert.Equal(t, []float64{5.5, 1.5}, NewFloat64FromValues(5.5, 1.5).Values) + assert.Equal(t, []int{5, 1}, NewIntFromValues(5, 1).Values) + assert.Equal(t, []string{"test1", "test2"}, NewStringFromValues("test1", "test2").Values) + + assert.Equal(t, []float64{5.5, 5.5, 5.5, 5.5}, NewFloat64Full(5.5, 2, 2).Values) + assert.Equal(t, []float64{1, 1, 1, 1}, NewFloat64Ones(2, 2).Values) + + ar := NewIntRange(5) + assert.Equal(t, []int{0, 1, 2, 3, 4}, AsIntSlice(ar)) + + ar = NewIntRange(2, 5) + assert.Equal(t, []int{2, 3, 4}, AsIntSlice(ar)) + + ar = NewIntRange(0, 5, 2) + assert.Equal(t, []int{0, 2, 4}, AsIntSlice(ar)) + + lr := NewFloat64SpacedLinear(NewFloat64Scalar(0), NewFloat64Scalar(5), 6, true) + assert.Equal(t, []float64{0, 1, 2, 3, 4, 5}, AsFloat64Slice(lr)) + + lr = NewFloat64SpacedLinear(NewFloat64Scalar(0), NewFloat64Scalar(5), 5, false) + assert.Equal(t, []float64{0, 1, 2, 3, 4}, AsFloat64Slice(lr)) + + lr2 := NewFloat64SpacedLinear(NewFloat64FromValues(0, 2), NewFloat64FromValues(5, 7), 5, false) + // fmt.Println(lr2) + assert.Equal(t, []float64{0, 2, 1, 3, 2, 4, 3, 5, 4, 6}, AsFloat64Slice(lr2)) +} diff --git a/tensor/databrowser/README.md b/tensor/databrowser/README.md new file mode 100644 index 0000000000..ce84a3a4c4 --- /dev/null +++ b/tensor/databrowser/README.md @@ -0,0 +1,16 @@ +# databrowser + +The databrowser package provides GUI elements for data exploration and visualization, and a simple `Browser` implementation that combines these elements. + +* `FileTree` (with `FileNode` elements), implementing a [filetree](https://github.com/cogentcore/tree/main/filetree) that has support for a [tensorfs](../tensorfs) filesystem, and data files in an actual filesystem. It has a `Tabber` pointer that handles the viewing actions on `tensorfs` elements (showing a Plot, etc). + +* `Tabber` interface and `Tabs` base implementation provides methods for showing data plots and editors in tabs. + +* `Terminal` running a `goal` shell that supports interactive commands operating on the `tensorfs` data etc. TODO! + +* `Browser` provides a hub structure connecting the above elements, which can be included in an actual GUI widget, that also provides additional functionality / GUI elements. + +The basic `Browser` puts the `FileTree` in a left `Splits` and the `Tabs` in the right, and supports interactive exploration and visualization of data. + +In the [emergent](https://github.com/emer) framework, these elements are combined with other GUI elements to provide a full neural network simulation environment on top of the databrowser foundation. + diff --git a/tensor/databrowser/basic.go b/tensor/databrowser/basic.go new file mode 100644 index 0000000000..d8dddb8a64 --- /dev/null +++ b/tensor/databrowser/basic.go @@ -0,0 +1,89 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package databrowser + +import ( + "io/fs" + "os" + "path/filepath" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/fsx" + "cogentcore.org/core/core" + "cogentcore.org/core/events" + "cogentcore.org/core/styles" + "cogentcore.org/core/tree" +) + +// Basic is a basic data browser with the files as the left panel, +// and the Tabber as the right panel. +type Basic struct { + core.Frame + Browser +} + +// Init initializes with the data and script directories +func (br *Basic) Init() { + br.Frame.Init() + br.Styler(func(s *styles.Style) { + s.Grow.Set(1, 1) + }) + br.InitInterp() + + br.OnShow(func(e events.Event) { + br.UpdateFiles() + }) + + tree.AddChildAt(br, "splits", func(w *core.Splits) { + br.Splits = w + w.SetSplits(.15, .85) + tree.AddChildAt(w, "fileframe", func(w *core.Frame) { + w.Styler(func(s *styles.Style) { + s.Direction = styles.Column + s.Overflow.Set(styles.OverflowAuto) + s.Grow.Set(1, 1) + }) + tree.AddChildAt(w, "filetree", func(w *DataTree) { + br.Files = w + }) + }) + tree.AddChildAt(w, "tabs", func(w *Tabs) { + br.Tabs = w + }) + }) + br.Updater(func() { + if br.Files != nil { + br.Files.Tabber = br.Tabs + } + }) + +} + +// NewBasicWindow returns a new data Browser window for given +// file system (nil for os files) and data directory. +// do RunWindow on resulting [core.Body] to open the window. +func NewBasicWindow(fsys fs.FS, dataDir string) (*core.Body, *Basic) { + startDir, _ := os.Getwd() + startDir = errors.Log1(filepath.Abs(startDir)) + b := core.NewBody("Cogent Data Browser: " + fsx.DirAndFile(startDir)) + br := NewBasic(b) + br.FS = fsys + ddr := dataDir + if fsys == nil { + ddr = errors.Log1(filepath.Abs(dataDir)) + } + b.AddTopBar(func(bar *core.Frame) { + tb := core.NewToolbar(bar) + br.Toolbar = tb + tb.Maker(br.MakeToolbar) + }) + br.SetDataRoot(ddr) + br.SetScriptsDir(filepath.Join(ddr, "dbscripts")) + TheBrowser = &br.Browser + CurTabber = br.Browser.Tabs + br.Interpreter.Eval("br := databrowser.TheBrowser") // grab it + br.UpdateScripts() + return b, br +} diff --git a/tensor/databrowser/browser.go b/tensor/databrowser/browser.go index 21d7cb268a..832013af4e 100644 --- a/tensor/databrowser/browser.go +++ b/tensor/databrowser/browser.go @@ -8,122 +8,99 @@ package databrowser import ( "io/fs" - "path/filepath" + "slices" - "cogentcore.org/core/base/errors" - "cogentcore.org/core/base/fsx" "cogentcore.org/core/core" "cogentcore.org/core/events" - "cogentcore.org/core/filetree" + "cogentcore.org/core/goal/interpreter" "cogentcore.org/core/icons" - "cogentcore.org/core/styles" "cogentcore.org/core/tree" - "cogentcore.org/core/types" + "golang.org/x/exp/maps" ) -// Browser is a data browser, for browsing data either on an os filesystem -// or as a datafs virtual data filesystem. -type Browser struct { - core.Frame +// TheBrowser is the current browser, +// which is valid immediately after NewBrowserWindow +// where it is used to get a local variable for subsequent use. +var TheBrowser *Browser - // FS is the filesystem, if browsing an FS +// Browser holds all the elements of a data browser, for browsing data +// either on an OS filesystem or as a tensorfs virtual data filesystem. +// It supports the automatic loading of [goal] scripts as toolbar actions to +// perform pre-programmed tasks on the data, to create app-like functionality. +// Scripts are ordered alphabetically and any leading #- prefix is automatically +// removed from the label, so you can use numbers to specify a custom order. +// It is not a [core.Widget] itself, and is intended to be incorporated into +// a [core.Frame] widget, potentially along with other custom elements. +// See [Basic] for a basic implementation. +type Browser struct { //types:add -setters + // FS is the filesystem, if browsing an FS. FS fs.FS - // DataRoot is the path to the root of the data to browse + // DataRoot is the path to the root of the data to browse. DataRoot string - toolbar *core.Toolbar - splits *core.Splits - files *filetree.Tree - tabs *core.Tabs -} + // StartDir is the starting directory, where the app was originally started. + StartDir string -// Init initializes with the data and script directories -func (br *Browser) Init() { - br.Frame.Init() - br.Styler(func(s *styles.Style) { - s.Grow.Set(1, 1) - }) + // ScriptsDir is the directory containing scripts for toolbar actions. + // It defaults to DataRoot/dbscripts + ScriptsDir string - br.OnShow(func(e events.Event) { - br.UpdateFiles() - }) + // Scripts + Scripts map[string]string `set:"-"` - tree.AddChildAt(br, "splits", func(w *core.Splits) { - br.splits = w - w.SetSplits(.15, .85) - tree.AddChildAt(w, "fileframe", func(w *core.Frame) { - w.Styler(func(s *styles.Style) { - s.Direction = styles.Column - s.Overflow.Set(styles.OverflowAuto) - s.Grow.Set(1, 1) - }) - tree.AddChildAt(w, "filetree", func(w *filetree.Tree) { - br.files = w - w.FileNodeType = types.For[FileNode]() - // w.OnSelect(func(e events.Event) { - // e.SetHandled() - // sels := w.SelectedViews() - // if sels != nil { - // br.FileNodeSelected(sn) - // } - // }) - }) - }) - tree.AddChildAt(w, "tabs", func(w *core.Tabs) { - br.tabs = w - w.Type = core.FunctionalTabs - }) - }) -} + // Interpreter is the interpreter to use for running Browser scripts + Interpreter *interpreter.Interpreter `set:"-"` -// NewBrowserWindow opens a new data Browser for given -// file system (nil for os files) and data directory. -func NewBrowserWindow(fsys fs.FS, dataDir string) *Browser { - b := core.NewBody("Cogent Data Browser: " + fsx.DirAndFile(dataDir)) - br := NewBrowser(b) - br.FS = fsys - ddr := dataDir - if fsys == nil { - ddr = errors.Log1(filepath.Abs(dataDir)) - } - b.AddTopBar(func(bar *core.Frame) { - tb := core.NewToolbar(bar) - br.toolbar = tb - tb.Maker(br.MakeToolbar) - }) - br.SetDataRoot(ddr) - b.RunWindow() - return br -} + // Files is the [DataTree] tree browser of the tensorfs or files. + Files *DataTree -// ParentBrowser returns the Browser parent of given node -func ParentBrowser(tn tree.Node) *Browser { - var res *Browser - tn.AsTree().WalkUp(func(n tree.Node) bool { - if c, ok := n.(*Browser); ok { - res = c - return false - } - return true - }) - return res + // Tabs is the [Tabber] element managing tabs of data views. + Tabs Tabber + + // Toolbar is the top-level toolbar for the browser, if used. + Toolbar *core.Toolbar + + // Splits is the overall [core.Splits] for the browser. + Splits *core.Splits } // UpdateFiles Updates the files list. func (br *Browser) UpdateFiles() { //types:add - files := br.files + if br.Files == nil { + return + } + files := br.Files if br.FS != nil { files.SortByModTime = true files.OpenPathFS(br.FS, br.DataRoot) } else { files.OpenPath(br.DataRoot) } - br.Update() } func (br *Browser) MakeToolbar(p *tree.Plan) { tree.Add(p, func(w *core.FuncButton) { w.SetFunc(br.UpdateFiles).SetText("").SetIcon(icons.Refresh).SetShortcut("Command+U") }) + tree.Add(p, func(w *core.FuncButton) { + w.SetFunc(br.UpdateScripts).SetText("").SetIcon(icons.Code) + }) + scr := maps.Keys(br.Scripts) + slices.Sort(scr) + for _, s := range scr { + lbl := TrimOrderPrefix(s) + tree.AddAt(p, lbl, func(w *core.Button) { + w.SetText(lbl).SetIcon(icons.RunCircle). + OnClick(func(e events.Event) { + br.RunScript(s) + }) + sc := br.Scripts[s] + tt := FirstComment(sc) + if tt == "" { + tt = "Run Script (add a comment to top of script to provide more useful info here)" + } + w.SetTooltip(tt) + }) + } } diff --git a/tensor/databrowser/datatab.go b/tensor/databrowser/datatab.go deleted file mode 100644 index 285a5c6a81..0000000000 --- a/tensor/databrowser/datatab.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package databrowser - -import ( - "cogentcore.org/core/core" - "cogentcore.org/core/plot/plotcore" - "cogentcore.org/core/styles" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/table" - "cogentcore.org/core/tensor/tensorcore" - "cogentcore.org/core/texteditor" -) - -// NewTab creates a tab with given label, or returns the existing one -// with given type of widget within it. mkfun function is called to create -// and configure a new widget if not already existing. -func NewTab[T any](br *Browser, label string, mkfun func(tab *core.Frame) T) T { - tab := br.tabs.RecycleTab(label) - if tab.HasChildren() { - return tab.Child(1).(T) - } - w := mkfun(tab) - return w -} - -// NewTabTensorTable creates a tab with a tensorcore.Table widget -// to view given table.Table, using its own table.IndexView as tv.Table. -// Use tv.Table.Table to get the underlying *table.Table -// Use tv.Table.Sequential to update the IndexView to view -// all of the rows when done updating the Table, and then call br.Update() -func (br *Browser) NewTabTensorTable(label string, dt *table.Table) *tensorcore.Table { - tv := NewTab[*tensorcore.Table](br, label, func(tab *core.Frame) *tensorcore.Table { - tb := core.NewToolbar(tab) - tv := tensorcore.NewTable(tab) - tb.Maker(tv.MakeToolbar) - return tv - }) - tv.SetTable(dt) - br.Update() - return tv -} - -// NewTabTensorEditor creates a tab with a tensorcore.TensorEditor widget -// to view given Tensor. -func (br *Browser) NewTabTensorEditor(label string, tsr tensor.Tensor) *tensorcore.TensorEditor { - tv := NewTab[*tensorcore.TensorEditor](br, label, func(tab *core.Frame) *tensorcore.TensorEditor { - tb := core.NewToolbar(tab) - tv := tensorcore.NewTensorEditor(tab) - tb.Maker(tv.MakeToolbar) - return tv - }) - tv.SetTensor(tsr) - br.Update() - return tv -} - -// NewTabTensorGrid creates a tab with a tensorcore.TensorGrid widget -// to view given Tensor. -func (br *Browser) NewTabTensorGrid(label string, tsr tensor.Tensor) *tensorcore.TensorGrid { - tv := NewTab[*tensorcore.TensorGrid](br, label, func(tab *core.Frame) *tensorcore.TensorGrid { - // tb := core.NewToolbar(tab) - tv := tensorcore.NewTensorGrid(tab) - // tb.Maker(tv.MakeToolbar) - return tv - }) - tv.SetTensor(tsr) - br.Update() - return tv -} - -// NewTabPlot creates a tab with a Plot of given table.Table. -func (br *Browser) NewTabPlot(label string, dt *table.Table) *plotcore.PlotEditor { - pl := NewTab[*plotcore.PlotEditor](br, label, func(tab *core.Frame) *plotcore.PlotEditor { - return plotcore.NewSubPlot(tab) - }) - pl.SetTable(dt) - br.Update() - return pl -} - -// NewTabSliceTable creates a tab with a core.Table widget -// to view the given slice of structs. -func (br *Browser) NewTabSliceTable(label string, slc any) *core.Table { - tv := NewTab[*core.Table](br, label, func(tab *core.Frame) *core.Table { - return core.NewTable(tab) - }) - tv.SetSlice(slc) - br.Update() - return tv -} - -// NewTabEditor opens a texteditor.Editor tab, displaying given string. -func (br *Browser) NewTabEditor(label, content string) *texteditor.Editor { - ed := NewTab[*texteditor.Editor](br, label, func(tab *core.Frame) *texteditor.Editor { - ed := texteditor.NewEditor(tab) - ed.Styler(func(s *styles.Style) { - s.Grow.Set(1, 1) - }) - return ed - }) - if content != "" { - ed.Buffer.SetText([]byte(content)) - } - br.Update() - return ed -} - -// NewTabEditorFile opens an editor tab for given file -func (br *Browser) NewTabEditorFile(label, filename string) *texteditor.Editor { - ed := NewTab[*texteditor.Editor](br, label, func(tab *core.Frame) *texteditor.Editor { - ed := texteditor.NewEditor(tab) - ed.Styler(func(s *styles.Style) { - s.Grow.Set(1, 1) - }) - return ed - }) - ed.Buffer.Open(core.Filename(filename)) - br.Update() - return ed -} diff --git a/tensor/databrowser/filetree.go b/tensor/databrowser/filetree.go index ad289ac312..1bdb3d5d79 100644 --- a/tensor/databrowser/filetree.go +++ b/tensor/databrowser/filetree.go @@ -6,8 +6,6 @@ package databrowser import ( "image" - "log" - "reflect" "strings" "cogentcore.org/core/base/errors" @@ -19,11 +17,47 @@ import ( "cogentcore.org/core/icons" "cogentcore.org/core/styles" "cogentcore.org/core/styles/states" - "cogentcore.org/core/tensor/datafs" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" "cogentcore.org/core/texteditor/diffbrowser" + "cogentcore.org/core/tree" + "cogentcore.org/core/types" ) +// Treer is an interface for getting the Root node as a DataTree struct. +type Treer interface { + AsDataTree() *DataTree +} + +// AsDataTree returns the given value as a [DataTree] if it has +// an AsDataTree() method, or nil otherwise. +func AsDataTree(n tree.Node) *DataTree { + if t, ok := n.(Treer); ok { + return t.AsDataTree() + } + return nil +} + +// DataTree is the databrowser version of [filetree.Tree], +// which provides the Tabber to show data editors. +type DataTree struct { + filetree.Tree + + // Tabber is the [Tabber] for this tree. + Tabber Tabber +} + +func (ft *DataTree) AsDataTree() *DataTree { + return ft +} + +func (ft *DataTree) Init() { + ft.Tree.Init() + ft.Root = ft + ft.FileNodeType = types.For[FileNode]() +} + // FileNode is databrowser version of FileNode for FileTree type FileNode struct { filetree.Node @@ -34,14 +68,23 @@ func (fn *FileNode) Init() { fn.AddContextMenu(fn.ContextMenu) } +// Tabber returns the [Tabber] for this filenode, from root tree. +func (fn *FileNode) Tabber() Tabber { + fr := AsDataTree(fn.Root) + if fr != nil { + return fr.Tabber + } + return nil +} + func (fn *FileNode) WidgetTooltip(pos image.Point) (string, image.Point) { res := fn.Tooltip if fn.Info.Cat == fileinfo.Data { ofn := fn.AsNode() switch fn.Info.Known { case fileinfo.Number, fileinfo.String: - dv := DataFS(ofn) - v, _ := dv.AsString() + dv := TensorFS(ofn) + v := dv.String() if res != "" { res += " " } @@ -51,10 +94,10 @@ func (fn *FileNode) WidgetTooltip(pos image.Point) (string, image.Point) { return res, fn.DefaultTooltipPos() } -// DataFS returns the datafs representation of this item. +// TensorFS returns the tensorfs representation of this item. // returns nil if not a dataFS item. -func DataFS(fn *filetree.Node) *datafs.Data { - dfs, ok := fn.FileRoot.FS.(*datafs.Data) +func TensorFS(fn *filetree.Node) *tensorfs.Node { + dfs, ok := fn.FileRoot().FS.(*tensorfs.Node) if !ok { return nil } @@ -62,15 +105,15 @@ func DataFS(fn *filetree.Node) *datafs.Data { if errors.Log(err) != nil { return nil } - return dfi.(*datafs.Data) + return dfi.(*tensorfs.Node) } func (fn *FileNode) GetFileInfo() error { err := fn.InitFileInfo() - if fn.FileRoot.FS == nil { + if fn.FileRoot().FS == nil { return err } - d := DataFS(fn.AsNode()) + d := TensorFS(fn.AsNode()) if d != nil { fn.Info.Known = d.KnownFileInfo() fn.Info.Cat = fileinfo.Data @@ -92,61 +135,55 @@ func (fn *FileNode) GetFileInfo() error { func (fn *FileNode) OpenFile() error { ofn := fn.AsNode() - br := ParentBrowser(fn.This) - if br == nil { + ts := fn.Tabber() + if ts == nil { return nil } df := fsx.DirAndFile(string(fn.Filepath)) switch { + case fn.IsDir(): + d := TensorFS(ofn) + dt := tensorfs.DirTable(d, nil) + ts.TensorTable(df, dt) case fn.Info.Cat == fileinfo.Data: switch fn.Info.Known { case fileinfo.Tensor: - d := DataFS(ofn) - tsr := d.AsTensor() - if tsr.IsString() || tsr.DataType() < reflect.Float32 { - br.NewTabTensorEditor(df, tsr) - } else { - br.NewTabTensorGrid(df, tsr) - } - case fileinfo.Table: - d := DataFS(ofn) - dt := d.AsTable() - br.NewTabTensorTable(df, dt) - br.Update() + d := TensorFS(ofn) + ts.TensorEditor(df, d.Tensor) case fileinfo.Number: - dv := DataFS(ofn) - v, _ := dv.AsFloat32() + dv := TensorFS(ofn) + v := dv.Tensor.Float1D(0) d := core.NewBody(df) core.NewText(d).SetType(core.TextSupporting).SetText(df) - sp := core.NewSpinner(d).SetValue(v) + sp := core.NewSpinner(d).SetValue(float32(v)) d.AddBottomBar(func(bar *core.Frame) { d.AddCancel(bar) d.AddOK(bar).OnClick(func(e events.Event) { - dv.SetFloat32(sp.Value) + dv.Tensor.SetFloat1D(float64(sp.Value), 0) }) }) - d.RunDialog(br) + d.RunDialog(fn) case fileinfo.String: - dv := DataFS(ofn) - v, _ := dv.AsString() + dv := TensorFS(ofn) + v := dv.Tensor.String1D(0) d := core.NewBody(df) core.NewText(d).SetType(core.TextSupporting).SetText(df) tf := core.NewTextField(d).SetText(v) d.AddBottomBar(func(bar *core.Frame) { d.AddCancel(bar) d.AddOK(bar).OnClick(func(e events.Event) { - dv.SetString(tf.Text()) + dv.Tensor.SetString1D(tf.Text(), 0) }) }) - d.RunDialog(br) + d.RunDialog(fn) default: - dt := table.NewTable() - err := dt.OpenCSV(fn.Filepath, table.Tab) // todo: need more flexible data handling mode + dt := table.New() + err := dt.OpenCSV(fsx.Filename(fn.Filepath), tensor.Tab) // todo: need more flexible data handling mode if err != nil { - core.ErrorSnackbar(br, err) + core.ErrorSnackbar(fn, err) } else { - br.NewTabTensorTable(df, dt) + ts.TensorTable(df, dt) } } case fn.IsExec(): // todo: use exec? @@ -166,7 +203,7 @@ func (fn *FileNode) OpenFile() error { case fn.Info.Cat == fileinfo.Archive || fn.Info.Cat == fileinfo.Backup: // don't edit fn.OpenFilesDefault() default: - br.NewTabEditor(df, string(fn.Filepath)) + ts.EditorString(df, string(fn.Filepath)) } return nil } @@ -181,11 +218,11 @@ func (fn *FileNode) EditFiles() { //types:add // EditFile pulls up this file in a texteditor func (fn *FileNode) EditFile() { if fn.IsDir() { - log.Printf("FileNode Edit -- cannot view (edit) directories!\n") + fn.OpenFile() return } - br := ParentBrowser(fn.This) - if br == nil { + ts := fn.Tabber() + if ts == nil { return } if fn.Info.Cat == fileinfo.Data { @@ -193,7 +230,7 @@ func (fn *FileNode) EditFile() { return } df := fsx.DirAndFile(string(fn.Filepath)) - br.NewTabEditor(df, string(fn.Filepath)) + ts.EditorString(df, string(fn.Filepath)) } // PlotFiles calls PlotFile on selected files @@ -207,46 +244,27 @@ func (fn *FileNode) PlotFiles() { //types:add // PlotFile pulls up this file in a texteditor. func (fn *FileNode) PlotFile() { - br := ParentBrowser(fn.This) - if br == nil { + ts := fn.Tabber() + if ts == nil { + return + } + d := TensorFS(fn.AsNode()) + if d != nil { + ts.PlotTensorFS(d) + return + } + if fn.Info.Cat != fileinfo.Data { return } - d := DataFS(fn.AsNode()) df := fsx.DirAndFile(string(fn.Filepath)) ptab := df + " Plot" - var dt *table.Table - switch { - case fn.IsDir(): - dt = d.DirTable(nil) - case fn.Info.Cat == fileinfo.Data: - switch fn.Info.Known { - case fileinfo.Tensor: - tsr := d.AsTensor() - dt = table.NewTable(df) - dt.Rows = tsr.DimSize(0) - rc := dt.AddIntColumn("Row") - for r := range dt.Rows { - rc.Values[r] = r - } - dt.AddColumn(tsr, fn.Name) - case fileinfo.Table: - dt = d.AsTable() - default: - dt = table.NewTable(df) - err := dt.OpenCSV(fn.Filepath, table.Tab) // todo: need more flexible data handling mode - if err != nil { - core.ErrorSnackbar(br, err) - dt = nil - } - } - } - if dt == nil { + dt := table.New(df) + err := dt.OpenCSV(fsx.Filename(fn.Filepath), tensor.Tab) // todo: need more flexible data handling mode + if err != nil { + core.ErrorSnackbar(fn, err) return } - pl := br.NewTabPlot(ptab, dt) - pl.Options.Title = df - // TODO: apply column and plot level options. - br.Update() + ts.PlotTable(ptab, dt) } // DiffDirs displays a browser with differences between two selected directories diff --git a/tensor/databrowser/scripts.go b/tensor/databrowser/scripts.go new file mode 100644 index 0000000000..2825be2657 --- /dev/null +++ b/tensor/databrowser/scripts.go @@ -0,0 +1,149 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package databrowser + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "strconv" + "strings" + "unicode" + + "cogentcore.org/core/base/fsx" + "cogentcore.org/core/base/logx" + "cogentcore.org/core/core" + "cogentcore.org/core/events" + "cogentcore.org/core/goal/goalib" + "cogentcore.org/core/goal/interpreter" + "cogentcore.org/core/styles" + "github.com/cogentcore/yaegi/interp" +) + +func (br *Browser) InitInterp() { + br.Interpreter = interpreter.NewInterpreter(interp.Options{}) + br.Interpreter.Config() + // logx.UserLevel = slog.LevelDebug // for debugging of init loading +} + +func (br *Browser) RunScript(snm string) { + sc, ok := br.Scripts[snm] + if !ok { + slog.Error("script not found:", "Script:", snm) + return + } + logx.PrintlnDebug("\n################\nrunning script:\n", sc, "\n") + _, _, err := br.Interpreter.Eval(sc) + if err == nil { + err = br.Interpreter.Goal.TrState.DepthError() + } + br.Interpreter.Goal.TrState.ResetDepth() +} + +// UpdateScripts updates the Scripts and updates the toolbar. +func (br *Browser) UpdateScripts() { //types:add + redo := (br.Scripts != nil) + scr := fsx.Filenames(br.ScriptsDir, ".goal") + br.Scripts = make(map[string]string) + for _, s := range scr { + snm := strings.TrimSuffix(s, ".goal") + sc, err := os.ReadFile(filepath.Join(br.ScriptsDir, s)) + if err == nil { + if unicode.IsLower(rune(snm[0])) { + if !redo { + fmt.Println("run init script:", snm) + br.Interpreter.Eval(string(sc)) + } + } else { + ssc := string(sc) + br.Scripts[snm] = ssc + } + } else { + slog.Error(err.Error()) + } + } + if br.Toolbar != nil { + br.Toolbar.Update() + } +} + +// TrimOrderPrefix trims any optional #- prefix from given string, +// used for ordering items by name. +func TrimOrderPrefix(s string) string { + i := strings.Index(s, "-") + if i < 0 { + return s + } + ds := s[:i] + if _, err := strconv.Atoi(ds); err != nil { + return s + } + return s[i+1:] +} + +// PromptOKCancel prompts the user for whether to do something, +// calling the given function if the user clicks OK. +func PromptOKCancel(ctx core.Widget, prompt string, fun func()) { + d := core.NewBody(prompt) + d.AddBottomBar(func(bar *core.Frame) { + d.AddCancel(bar) + d.AddOK(bar).OnClick(func(e events.Event) { + if fun != nil { + fun() + } + }) + }) + d.RunDialog(ctx) +} + +// PromptString prompts the user for a string value (initial value given), +// calling the given function if the user clicks OK. +func PromptString(ctx core.Widget, str string, prompt string, fun func(s string)) { + d := core.NewBody(prompt) + tf := core.NewTextField(d).SetText(str) + tf.Styler(func(s *styles.Style) { + s.Min.X.Ch(60) + }) + d.AddBottomBar(func(bar *core.Frame) { + d.AddCancel(bar) + d.AddOK(bar).OnClick(func(e events.Event) { + if fun != nil { + fun(tf.Text()) + } + }) + }) + d.RunDialog(ctx) +} + +// PromptStruct prompts the user for the values in given struct (pass a pointer), +// calling the given function if the user clicks OK. +func PromptStruct(ctx core.Widget, str any, prompt string, fun func()) { + d := core.NewBody(prompt) + core.NewForm(d).SetStruct(str) + d.AddBottomBar(func(bar *core.Frame) { + d.AddCancel(bar) + d.AddOK(bar).OnClick(func(e events.Event) { + if fun != nil { + fun() + } + }) + }) + d.RunDialog(ctx) +} + +// FirstComment returns the first comment lines from given .goal file, +// which is used to set the tooltip for scripts. +func FirstComment(sc string) string { + sl := goalib.SplitLines(sc) + cmt := "" + for _, l := range sl { + if !strings.HasPrefix(l, "// ") { + return cmt + } + cmt += strings.TrimSpace(l[3:]) + " " + } + return cmt +} diff --git a/tensor/databrowser/tabs.go b/tensor/databrowser/tabs.go new file mode 100644 index 0000000000..0e56dd2fde --- /dev/null +++ b/tensor/databrowser/tabs.go @@ -0,0 +1,273 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package databrowser + +import ( + "fmt" + + "cogentcore.org/core/base/fsx" + "cogentcore.org/core/core" + "cogentcore.org/core/plot/plotcore" + "cogentcore.org/core/styles" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorcore" + "cogentcore.org/core/tensor/tensorfs" + "cogentcore.org/core/texteditor" +) + +// CurTabber is the current Tabber. Set when one is created. +var CurTabber Tabber + +// Tabber is a [core.Tabs] based widget that has support for opening +// tabs for [plotcore.PlotEditor] and [tensorcore.Table] editors, +// among others. +type Tabber interface { + core.Tabber + + // AsDataTabs returns the underlying [databrowser.Tabs] widget. + AsDataTabs() *Tabs + + // TensorTable recycles a tab with a [tensorcore.Table] widget + // to view given [table.Table], using its own table.Table. + TensorTable(label string, dt *table.Table) *tensorcore.Table + + // TensorEditor recycles a tab with a [tensorcore.TensorEditor] widget + // to view given Tensor. + TensorEditor(label string, tsr tensor.Tensor) *tensorcore.TensorEditor + + // TensorGrid recycles a tab with a [tensorcore.TensorGrid] widget + // to view given Tensor. + TensorGrid(label string, tsr tensor.Tensor) *tensorcore.TensorGrid + + // PlotTable recycles a tab with a Plot of given [table.Table]. + PlotTable(label string, dt *table.Table) *plotcore.PlotEditor + + // PlotTensorFS recycles a tab with a Plot of given [tensorfs.Node], + // automatically using the Dir/File name of the data node for the label. + PlotTensorFS(dfs *tensorfs.Node) *plotcore.PlotEditor + + // GoUpdatePlot calls GoUpdatePlot on plot at tab with given name. + // Does nothing if tab name doesn't exist (returns nil). + GoUpdatePlot(label string) *plotcore.PlotEditor + + // UpdatePlot calls UpdatePlot on plot at tab with given name. + // Does nothing if tab name doesn't exist (returns nil). + UpdatePlot(label string) *plotcore.PlotEditor + + // todo: PlotData of plot.Node + + // SliceTable recycles a tab with a [core.Table] widget + // to view the given slice of structs. + SliceTable(label string, slc any) *core.Table + + // EditorString recycles a [texteditor.Editor] tab, displaying given string. + EditorString(label, content string) *texteditor.Editor + + // EditorFile opens an editor tab for given file. + EditorFile(label, filename string) *texteditor.Editor +} + +// NewTab recycles a tab with given label, or returns the existing one +// with given type of widget within it. The existing that is returned +// is the last one in the frame, allowing for there to be a toolbar at the top. +// mkfun function is called to create and configure a new widget +// if not already existing. +func NewTab[T any](tb Tabber, label string, mkfun func(tab *core.Frame) T) T { + tab := tb.RecycleTab(label) + var zv T + if tab.HasChildren() { + nc := tab.NumChildren() + lc := tab.Child(nc - 1) + if tt, ok := lc.(T); ok { + return tt + } + err := fmt.Errorf("Name / Type conflict: tab %q does not have the expected type of content: is %T", label, lc) + core.ErrorSnackbar(tb.AsDataTabs(), err) + return zv + } + w := mkfun(tab) + return w +} + +// TabAt returns widget of given type at tab of given name, nil if tab not found. +func TabAt[T any](tb Tabber, label string) T { + var zv T + tab := tb.TabByName(label) + if tab == nil { + return zv + } + if !tab.HasChildren() { // shouldn't happen + return zv + } + nc := tab.NumChildren() + lc := tab.Child(nc - 1) + if tt, ok := lc.(T); ok { + return tt + } + + err := fmt.Errorf("Name / Type conflict: tab %q does not have the expected type of content: %T", label, lc) + core.ErrorSnackbar(tb.AsDataTabs(), err) + return zv +} + +// Tabs implements the [Tabber] interface. +type Tabs struct { + core.Tabs +} + +func (ts *Tabs) Init() { + ts.Tabs.Init() + ts.Type = core.FunctionalTabs +} + +func (ts *Tabs) AsDataTabs() *Tabs { + return ts +} + +// TensorTable recycles a tab with a tensorcore.Table widget +// to view given table.Table, using its own table.Table as tv.Table. +// Use tv.Table.Table to get the underlying *table.Table +// Use tv.Table.Sequential to update the Indexed to view +// all of the rows when done updating the Table, and then call br.Update() +func (ts *Tabs) TensorTable(label string, dt *table.Table) *tensorcore.Table { + tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.Table { + tb := core.NewToolbar(tab) + tv := tensorcore.NewTable(tab) + tb.Maker(tv.MakeToolbar) + return tv + }) + tv.SetTable(dt) + ts.Update() + return tv +} + +// TensorEditor recycles a tab with a tensorcore.TensorEditor widget +// to view given Tensor. +func (ts *Tabs) TensorEditor(label string, tsr tensor.Tensor) *tensorcore.TensorEditor { + tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.TensorEditor { + tb := core.NewToolbar(tab) + tv := tensorcore.NewTensorEditor(tab) + tb.Maker(tv.MakeToolbar) + return tv + }) + tv.SetTensor(tsr) + ts.Update() + return tv +} + +// TensorGrid recycles a tab with a tensorcore.TensorGrid widget +// to view given Tensor. +func (ts *Tabs) TensorGrid(label string, tsr tensor.Tensor) *tensorcore.TensorGrid { + tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.TensorGrid { + // tb := core.NewToolbar(tab) + tv := tensorcore.NewTensorGrid(tab) + // tb.Maker(tv.MakeToolbar) + return tv + }) + tv.SetTensor(tsr) + ts.Update() + return tv +} + +// PlotTable recycles a tab with a Plot of given table.Table. +func (ts *Tabs) PlotTable(label string, dt *table.Table) *plotcore.PlotEditor { + pl := NewTab(ts, label, func(tab *core.Frame) *plotcore.PlotEditor { + tb := core.NewToolbar(tab) + pl := plotcore.NewPlotEditor(tab) + tab.Styler(func(s *styles.Style) { + s.Direction = styles.Column + s.Grow.Set(1, 1) + }) + tb.Maker(pl.MakeToolbar) + return pl + }) + if pl != nil { + pl.SetTable(dt) + ts.Update() + } + return pl +} + +// PlotTensorFS recycles a tab with a Plot of given [tensorfs.Node]. +func (ts *Tabs) PlotTensorFS(dfs *tensorfs.Node) *plotcore.PlotEditor { + label := fsx.DirAndFile(dfs.Path()) + " Plot" + if dfs.IsDir() { + return ts.PlotTable(label, tensorfs.DirTable(dfs, nil)) + } + tsr := dfs.Tensor + dt := table.New(label) + dt.Columns.Rows = tsr.DimSize(0) + if ix, ok := tsr.(*tensor.Rows); ok { + dt.Indexes = ix.Indexes + } + rc := dt.AddIntColumn("Row") + for r := range dt.Columns.Rows { + rc.Values[r] = r + } + dt.AddColumn(dfs.Name(), tsr.AsValues()) + return ts.PlotTable(label, dt) +} + +// GoUpdatePlot calls GoUpdatePlot on plot at tab with given name. +// Does nothing if tab name doesn't exist (returns nil). +func (ts *Tabs) GoUpdatePlot(label string) *plotcore.PlotEditor { + pl := TabAt[*plotcore.PlotEditor](ts, label) + if pl != nil { + pl.GoUpdatePlot() + } + return pl +} + +// UpdatePlot calls UpdatePlot on plot at tab with given name. +// Does nothing if tab name doesn't exist (returns nil). +func (ts *Tabs) UpdatePlot(label string) *plotcore.PlotEditor { + pl := TabAt[*plotcore.PlotEditor](ts, label) + if pl != nil { + pl.UpdatePlot() + } + return pl +} + +// SliceTable recycles a tab with a core.Table widget +// to view the given slice of structs. +func (ts *Tabs) SliceTable(label string, slc any) *core.Table { + tv := NewTab(ts, label, func(tab *core.Frame) *core.Table { + return core.NewTable(tab) + }) + tv.SetSlice(slc) + ts.Update() + return tv +} + +// EditorString recycles a [texteditor.Editor] tab, displaying given string. +func (ts *Tabs) EditorString(label, content string) *texteditor.Editor { + ed := NewTab(ts, label, func(tab *core.Frame) *texteditor.Editor { + ed := texteditor.NewEditor(tab) + ed.Styler(func(s *styles.Style) { + s.Grow.Set(1, 1) + }) + return ed + }) + if content != "" { + ed.Buffer.SetText([]byte(content)) + } + ts.Update() + return ed +} + +// EditorFile opens an editor tab for given file. +func (ts *Tabs) EditorFile(label, filename string) *texteditor.Editor { + ed := NewTab(ts, label, func(tab *core.Frame) *texteditor.Editor { + ed := texteditor.NewEditor(tab) + ed.Styler(func(s *styles.Style) { + s.Grow.Set(1, 1) + }) + return ed + }) + ed.Buffer.Open(core.Filename(filename)) + ts.Update() + return ed +} diff --git a/tensor/databrowser/typegen.go b/tensor/databrowser/typegen.go index 7ffdb629a5..700a1748c0 100644 --- a/tensor/databrowser/typegen.go +++ b/tensor/databrowser/typegen.go @@ -5,27 +5,72 @@ package databrowser import ( "io/fs" + "cogentcore.org/core/core" "cogentcore.org/core/tree" "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/databrowser.Browser", IDName: "browser", Doc: "Browser is a data browser, for browsing data either on an os filesystem\nor as a datafs virtual data filesystem.", Methods: []types.Method{{Name: "UpdateFiles", Doc: "UpdateFiles Updates the file picker with current files in DataRoot,", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "FS", Doc: "Filesystem, if browsing an FS"}, {Name: "DataRoot", Doc: "DataRoot is the path to the root of the data to browse"}, {Name: "toolbar"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/databrowser.Basic", IDName: "basic", Doc: "Basic is a basic data browser with the files as the left panel,\nand the Tabber as the right panel.", Embeds: []types.Field{{Name: "Frame"}, {Name: "Browser"}}}) -// NewBrowser returns a new [Browser] with the given optional parent: -// Browser is a data browser, for browsing data either on an os filesystem -// or as a datafs virtual data filesystem. -func NewBrowser(parent ...tree.Node) *Browser { return tree.New[Browser](parent...) } +// NewBasic returns a new [Basic] with the given optional parent: +// Basic is a basic data browser with the files as the left panel, +// and the Tabber as the right panel. +func NewBasic(parent ...tree.Node) *Basic { return tree.New[Basic](parent...) } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/databrowser.Browser", IDName: "browser", Doc: "Browser holds all the elements of a data browser, for browsing data\neither on an OS filesystem or as a tensorfs virtual data filesystem.\nIt supports the automatic loading of [goal] scripts as toolbar actions to\nperform pre-programmed tasks on the data, to create app-like functionality.\nScripts are ordered alphabetically and any leading #- prefix is automatically\nremoved from the label, so you can use numbers to specify a custom order.\nIt is not a [core.Widget] itself, and is intended to be incorporated into\na [core.Frame] widget, potentially along with other custom elements.\nSee [Basic] for a basic implementation.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Methods: []types.Method{{Name: "UpdateFiles", Doc: "UpdateFiles Updates the files list.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "UpdateScripts", Doc: "UpdateScripts updates the Scripts and updates the toolbar.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Fields: []types.Field{{Name: "FS", Doc: "FS is the filesystem, if browsing an FS."}, {Name: "DataRoot", Doc: "DataRoot is the path to the root of the data to browse."}, {Name: "StartDir", Doc: "StartDir is the starting directory, where the app was originally started."}, {Name: "ScriptsDir", Doc: "ScriptsDir is the directory containing scripts for toolbar actions.\nIt defaults to DataRoot/dbscripts"}, {Name: "Scripts", Doc: "Scripts"}, {Name: "Interpreter", Doc: "Interpreter is the interpreter to use for running Browser scripts"}, {Name: "Files", Doc: "Files is the [DataTree] tree browser of the tensorfs or files."}, {Name: "Tabs", Doc: "Tabs is the [Tabber] element managing tabs of data views."}, {Name: "Toolbar", Doc: "Toolbar is the top-level toolbar for the browser, if used."}, {Name: "Splits", Doc: "Splits is the overall [core.Splits] for the browser."}}}) // SetFS sets the [Browser.FS]: -// Filesystem, if browsing an FS +// FS is the filesystem, if browsing an FS. func (t *Browser) SetFS(v fs.FS) *Browser { t.FS = v; return t } // SetDataRoot sets the [Browser.DataRoot]: -// DataRoot is the path to the root of the data to browse +// DataRoot is the path to the root of the data to browse. func (t *Browser) SetDataRoot(v string) *Browser { t.DataRoot = v; return t } +// SetStartDir sets the [Browser.StartDir]: +// StartDir is the starting directory, where the app was originally started. +func (t *Browser) SetStartDir(v string) *Browser { t.StartDir = v; return t } + +// SetScriptsDir sets the [Browser.ScriptsDir]: +// ScriptsDir is the directory containing scripts for toolbar actions. +// It defaults to DataRoot/dbscripts +func (t *Browser) SetScriptsDir(v string) *Browser { t.ScriptsDir = v; return t } + +// SetFiles sets the [Browser.Files]: +// Files is the [DataTree] tree browser of the tensorfs or files. +func (t *Browser) SetFiles(v *DataTree) *Browser { t.Files = v; return t } + +// SetTabs sets the [Browser.Tabs]: +// Tabs is the [Tabber] element managing tabs of data views. +func (t *Browser) SetTabs(v Tabber) *Browser { t.Tabs = v; return t } + +// SetToolbar sets the [Browser.Toolbar]: +// Toolbar is the top-level toolbar for the browser, if used. +func (t *Browser) SetToolbar(v *core.Toolbar) *Browser { t.Toolbar = v; return t } + +// SetSplits sets the [Browser.Splits]: +// Splits is the overall [core.Splits] for the browser. +func (t *Browser) SetSplits(v *core.Splits) *Browser { t.Splits = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/databrowser.DataTree", IDName: "data-tree", Doc: "DataTree is the databrowser version of [filetree.Tree],\nwhich provides the Tabber to show data editors.", Embeds: []types.Field{{Name: "Tree"}}, Fields: []types.Field{{Name: "Tabber", Doc: "Tabber is the [Tabber] for this tree."}}}) + +// NewDataTree returns a new [DataTree] with the given optional parent: +// DataTree is the databrowser version of [filetree.Tree], +// which provides the Tabber to show data editors. +func NewDataTree(parent ...tree.Node) *DataTree { return tree.New[DataTree](parent...) } + +// SetTabber sets the [DataTree.Tabber]: +// Tabber is the [Tabber] for this tree. +func (t *DataTree) SetTabber(v Tabber) *DataTree { t.Tabber = v; return t } + var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/databrowser.FileNode", IDName: "file-node", Doc: "FileNode is databrowser version of FileNode for FileTree", Methods: []types.Method{{Name: "EditFiles", Doc: "EditFiles calls EditFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "PlotFiles", Doc: "PlotFiles calls PlotFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "DiffDirs", Doc: "DiffDirs displays a browser with differences between two selected directories", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Node"}}}) // NewFileNode returns a new [FileNode] with the given optional parent: // FileNode is databrowser version of FileNode for FileTree func NewFileNode(parent ...tree.Node) *FileNode { return tree.New[FileNode](parent...) } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/databrowser.Tabs", IDName: "tabs", Doc: "Tabs implements the [Tabber] interface.", Embeds: []types.Field{{Name: "Tabs"}}}) + +// NewTabs returns a new [Tabs] with the given optional parent: +// Tabs implements the [Tabber] interface. +func NewTabs(parent ...tree.Node) *Tabs { return tree.New[Tabs](parent...) } diff --git a/tensor/datafs/README.md b/tensor/datafs/README.md deleted file mode 100644 index c2054b8ab6..0000000000 --- a/tensor/datafs/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# datafs: a virtual filesystem for data - -TODO: write docs - -# Constraints - -* no pointers -- GPU does not like pointers -- use Set / As accessors -* names within directory must be unique - diff --git a/tensor/datafs/data.go b/tensor/datafs/data.go deleted file mode 100644 index 8c83a4d94d..0000000000 --- a/tensor/datafs/data.go +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package datafs - -import ( - "errors" - "reflect" - "time" - "unsafe" - - "cogentcore.org/core/base/fileinfo" - "cogentcore.org/core/base/metadata" - "cogentcore.org/core/base/reflectx" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/table" -) - -// Data is a single item of data, the "file" or "directory" in the data filesystem. -type Data struct { - // Parent is the parent data directory - Parent *Data - - // name is the name of this item. it is not a path. - name string - - // modTime tracks time added to directory, used for ordering. - modTime time.Time - - // Meta has metadata, including standardized support for - // plotting options, compute functions. - Meta metadata.Data - - // Value is the underlying value of data; - // is a map[string]*Data for directories. - Value any -} - -// NewData returns a new Data item in given directory Data item, -// which can be nil. If not a directory, an error will be generated. -// The modTime is automatically set to now, and can be used for sorting -// by order created. The name must be unique within parent. -func NewData(dir *Data, name string) (*Data, error) { - d := &Data{Parent: dir, name: name, modTime: time.Now()} - var err error - if dir != nil { - err = dir.Add(d) - } - return d, err -} - -// New adds new data item(s) of given basic type to given directory, -// with given name(s) (at least one is required). -// Values are initialized to zero value for type. -// All names must be unique in the directory. -// Returns the first item created, for immediate use of one value. -func New[T any](dir *Data, names ...string) (*Data, error) { - if len(names) == 0 { - err := errors.New("datafs.New requires at least 1 name") - return nil, err - } - var r *Data - var errs []error - for _, nm := range names { - var v T - d, err := NewData(dir, nm) - if err != nil { - errs = append(errs, err) - continue - } - d.Value = v - if r == nil { - r = d - } - } - return r, errors.Join(errs...) -} - -// NewTensor returns a new Tensor of given data type, shape sizes, -// and optional dimension names, in given directory Data item. -// The name must be unique in the directory. -func NewTensor[T string | bool | float32 | float64 | int | int32 | byte](dir *Data, name string, sizes []int, names ...string) (tensor.Tensor, error) { - tsr := tensor.New[T](sizes, names...) - d, err := NewData(dir, name) - d.Value = tsr - return tsr, err -} - -// NewTable makes new table.Table(s) in given directory, -// for given name(s) (at least one is required). -// All names must be unique in the directory. -// Returns the first table created, for immediate use of one item. -func NewTable(dir *Data, names ...string) (*table.Table, error) { - if len(names) == 0 { - err := errors.New("datafs.New requires at least 1 name") - return nil, err - } - var r *table.Table - var errs []error - for _, nm := range names { - t := table.NewTable(nm) - d, err := NewData(dir, nm) - if err != nil { - errs = append(errs, err) - continue - } - d.Value = t - if r == nil { - r = t - } - } - return r, errors.Join(errs...) -} - -/////////////////////////////// -// Data Access - -// IsNumeric returns true if the [DataType] is a basic scalar -// numerical value, e.g., float32, int, etc. -func (d *Data) IsNumeric() bool { - return reflectx.KindIsNumber(d.DataType()) -} - -// DataType returns the type of the data elements in the tensor. -// Bool is returned for the Bits tensor type. -func (d *Data) DataType() reflect.Kind { - if d.Value == nil { - return reflect.Invalid - } - return reflect.TypeOf(d.Value).Kind() -} - -func (d *Data) KnownFileInfo() fileinfo.Known { - if tsr := d.AsTensor(); tsr != nil { - return fileinfo.Tensor - } - kind := d.DataType() - if reflectx.KindIsNumber(kind) { - return fileinfo.Number - } - if kind == reflect.String { - return fileinfo.String - } - return fileinfo.Unknown -} - -// AsTensor returns the data as a tensor if it is one, else nil. -func (d *Data) AsTensor() tensor.Tensor { - tsr, _ := d.Value.(tensor.Tensor) - return tsr -} - -// AsTable returns the data as a table if it is one, else nil. -func (d *Data) AsTable() *table.Table { - dt, _ := d.Value.(*table.Table) - return dt -} - -// AsFloat64 returns data as a float64 if it is a scalar value -// that can be so converted. Returns false if not. -func (d *Data) AsFloat64() (float64, bool) { - // fast path for actual floats - if f, ok := d.Value.(float64); ok { - return f, true - } - if f, ok := d.Value.(float32); ok { - return float64(f), true - } - if tsr := d.AsTensor(); tsr != nil { - return 0, false - } - if dt := d.AsTable(); dt != nil { - return 0, false - } - v, err := reflectx.ToFloat(d.Value) - if err != nil { - return 0, false - } - return v, true -} - -// SetFloat64 sets data from given float64 if it is a scalar value -// that can be so set. Returns false if not. -func (d *Data) SetFloat64(v float64) bool { - // fast path for actual floats - if _, ok := d.Value.(float64); ok { - d.Value = v - return true - } - if _, ok := d.Value.(float32); ok { - d.Value = float32(v) - return true - } - if tsr := d.AsTensor(); tsr != nil { - return false - } - if dt := d.AsTable(); dt != nil { - return false - } - err := reflectx.SetRobust(&d.Value, v) - if err != nil { - return false - } - return true -} - -// AsFloat32 returns data as a float32 if it is a scalar value -// that can be so converted. Returns false if not. -func (d *Data) AsFloat32() (float32, bool) { - v, ok := d.AsFloat64() - return float32(v), ok -} - -// SetFloat32 sets data from given float32 if it is a scalar value -// that can be so set. Returns false if not. -func (d *Data) SetFloat32(v float32) bool { - return d.SetFloat64(float64(v)) -} - -// AsString returns data as a string if it is a scalar value -// that can be so converted. Returns false if not. -func (d *Data) AsString() (string, bool) { - // fast path for actual strings - if s, ok := d.Value.(string); ok { - return s, true - } - if tsr := d.AsTensor(); tsr != nil { - return "", false - } - if dt := d.AsTable(); dt != nil { - return "", false - } - s := reflectx.ToString(d.Value) - return s, true -} - -// SetString sets data from given string if it is a scalar value -// that can be so set. Returns false if not. -func (d *Data) SetString(v string) bool { - // fast path for actual strings - if _, ok := d.Value.(string); ok { - d.Value = v - return true - } - if tsr := d.AsTensor(); tsr != nil { - return false - } - if dt := d.AsTable(); dt != nil { - return false - } - err := reflectx.SetRobust(&d.Value, v) - if err != nil { - return false - } - return true -} - -// AsInt returns data as a int if it is a scalar value -// that can be so converted. Returns false if not. -func (d *Data) AsInt() (int, bool) { - // fast path for actual ints - if f, ok := d.Value.(int); ok { - return f, true - } - if tsr := d.AsTensor(); tsr != nil { - return 0, false - } - if dt := d.AsTable(); dt != nil { - return 0, false - } - v, err := reflectx.ToInt(d.Value) - if err != nil { - return 0, false - } - return int(v), true -} - -// SetInt sets data from given int if it is a scalar value -// that can be so set. Returns false if not. -func (d *Data) SetInt(v int) bool { - // fast path for actual ints - if _, ok := d.Value.(int); ok { - d.Value = v - return true - } - if tsr := d.AsTensor(); tsr != nil { - return false - } - if dt := d.AsTable(); dt != nil { - return false - } - err := reflectx.SetRobust(&d.Value, v) - if err != nil { - return false - } - return true -} - -// Bytes returns the byte-wise representation of the data Value. -// This is the actual underlying data, so make a copy if it can be -// unintentionally modified or retained more than for immediate use. -func (d *Data) Bytes() []byte { - if tsr := d.AsTensor(); tsr != nil { - return tsr.Bytes() - } - size := d.Size() - switch x := d.Value.(type) { - // todo: other things here? - default: - return unsafe.Slice((*byte)(unsafe.Pointer(&x)), size) - } -} diff --git a/tensor/datafs/dir.go b/tensor/datafs/dir.go deleted file mode 100644 index 9824f66509..0000000000 --- a/tensor/datafs/dir.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package datafs - -import ( - "errors" - "fmt" - "io/fs" - "path" - "slices" - "sort" - - "golang.org/x/exp/maps" -) - -// NewDir returns a new datafs directory with given name. -// if parent != nil and a directory, this dir is added to it. -// if name is empty, then it is set to "root", the root directory. -// Note that "/" is not allowed for the root directory in Go [fs]. -// Names must be unique within a directory. -func NewDir(name string, parent ...*Data) (*Data, error) { - if name == "" { - name = "root" - } - var par *Data - if len(parent) == 1 { - par = parent[0] - } - d, err := NewData(par, name) - d.Value = make(map[string]*Data) - return d, err -} - -// Item returns data item in given directory by name. -// This is for fast access and direct usage of known -// items, and it will crash if item is not found or -// this data is not a directory. -func (d *Data) Item(name string) *Data { - fm := d.filemap() - return fm[name] -} - -// Items returns data items in given directory by name. -// error reports any items not found, or if not a directory. -func (d *Data) Items(names ...string) ([]*Data, error) { - if err := d.mustDir("Items", ""); err != nil { - return nil, err - } - fm := d.filemap() - var errs []error - var its []*Data - for _, nm := range names { - dt := fm[nm] - if dt != nil { - its = append(its, dt) - } else { - err := fmt.Errorf("datafs Dir %q item not found: %q", d.Path(), nm) - errs = append(errs, err) - } - } - return its, errors.Join(errs...) -} - -// ItemsFunc returns data items in given directory -// filtered by given function, in alpha order. -// If func is nil, all items are returned. -// Any directories within this directory are returned, -// unless specifically filtered. -func (d *Data) ItemsFunc(fun func(item *Data) bool) []*Data { - if err := d.mustDir("ItemsFunc", ""); err != nil { - return nil - } - fm := d.filemap() - names := d.DirNamesAlpha() - var its []*Data - for _, nm := range names { - dt := fm[nm] - if fun != nil && !fun(dt) { - continue - } - its = append(its, dt) - } - return its -} - -// ItemsByTimeFunc returns data items in given directory -// filtered by given function, in time order (i.e., order added). -// If func is nil, all items are returned. -// Any directories within this directory are returned, -// unless specifically filtered. -func (d *Data) ItemsByTimeFunc(fun func(item *Data) bool) []*Data { - if err := d.mustDir("ItemsByTimeFunc", ""); err != nil { - return nil - } - fm := d.filemap() - names := d.DirNamesByTime() - var its []*Data - for _, nm := range names { - dt := fm[nm] - if fun != nil && !fun(dt) { - continue - } - its = append(its, dt) - } - return its -} - -// FlatItemsFunc returns all "leaf" (non directory) data items -// in given directory, recursively descending into directories -// to return a flat list of the entire subtree, -// filtered by given function, in alpha order. The function can -// filter out directories to prune the tree. -// If func is nil, all items are returned. -func (d *Data) FlatItemsFunc(fun func(item *Data) bool) []*Data { - if err := d.mustDir("FlatItemsFunc", ""); err != nil { - return nil - } - fm := d.filemap() - names := d.DirNamesAlpha() - var its []*Data - for _, nm := range names { - dt := fm[nm] - if fun != nil && !fun(dt) { - continue - } - if dt.IsDir() { - subs := dt.FlatItemsFunc(fun) - its = append(its, subs...) - } else { - its = append(its, dt) - } - } - return its -} - -// FlatItemsByTimeFunc returns all "leaf" (non directory) data items -// in given directory, recursively descending into directories -// to return a flat list of the entire subtree, -// filtered by given function, in time order (i.e., order added). -// The function can filter out directories to prune the tree. -// If func is nil, all items are returned. -func (d *Data) FlatItemsByTimeFunc(fun func(item *Data) bool) []*Data { - if err := d.mustDir("FlatItemsByTimeFunc", ""); err != nil { - return nil - } - fm := d.filemap() - names := d.DirNamesByTime() - var its []*Data - for _, nm := range names { - dt := fm[nm] - if fun != nil && !fun(dt) { - continue - } - if dt.IsDir() { - subs := dt.FlatItemsByTimeFunc(fun) - its = append(its, subs...) - } else { - its = append(its, dt) - } - } - return its -} - -// DirAtPath returns directory at given relative path -// from this starting dir. -func (d *Data) DirAtPath(dir string) (*Data, error) { - var err error - dir = path.Clean(dir) - sdf, err := d.Sub(dir) // this ensures that d is a dir - if err != nil { - return nil, err - } - return sdf.(*Data), nil -} - -// Path returns the full path to this data item -func (d *Data) Path() string { - pt := d.name - cur := d.Parent - loops := make(map[*Data]struct{}) - for { - if cur == nil { - return pt - } - if _, ok := loops[cur]; ok { - return pt - } - pt = path.Join(cur.name, pt) - loops[cur] = struct{}{} - cur = cur.Parent - } -} - -// filemap returns the Value as map[string]*Data, or nil if not a dir -func (d *Data) filemap() map[string]*Data { - fm, ok := d.Value.(map[string]*Data) - if !ok { - return nil - } - return fm -} - -// DirNamesAlpha returns the names of items in the directory -// sorted alphabetically. Data must be dir by this point. -func (d *Data) DirNamesAlpha() []string { - fm := d.filemap() - names := maps.Keys(fm) - sort.Strings(names) - return names -} - -// DirNamesByTime returns the names of items in the directory -// sorted by modTime (order added). Data must be dir by this point. -func (d *Data) DirNamesByTime() []string { - fm := d.filemap() - names := maps.Keys(fm) - slices.SortFunc(names, func(a, b string) int { - return fm[a].ModTime().Compare(fm[b].ModTime()) - }) - return names -} - -// mustDir returns an error for given operation and path -// if this data item is not a directory. -func (d *Data) mustDir(op, path string) error { - if !d.IsDir() { - return &fs.PathError{Op: op, Path: path, Err: errors.New("datafs item is not a directory")} - } - return nil -} - -// Add adds an item to this directory data item. -// The only errors are if this item is not a directory, -// or the name already exists. -// Names must be unique within a directory. -func (d *Data) Add(it *Data) error { - if err := d.mustDir("Add", it.name); err != nil { - return err - } - fm := d.filemap() - _, ok := fm[it.name] - if ok { - return &fs.PathError{Op: "add", Path: it.name, Err: errors.New("data item already exists; names must be unique")} - } - fm[it.name] = it - return nil -} - -// Mkdir creates a new directory with the specified name. -// The only error is if this item is not a directory. -func (d *Data) Mkdir(name string) (*Data, error) { - if err := d.mustDir("Mkdir", name); err != nil { - return nil, err - } - return NewDir(name, d) -} diff --git a/tensor/datafs/fs.go b/tensor/datafs/fs.go deleted file mode 100644 index 034994cb99..0000000000 --- a/tensor/datafs/fs.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package datafs - -import ( - "bytes" - "errors" - "io/fs" - "path" - "slices" - "sort" - "time" - "unsafe" - - "cogentcore.org/core/base/fsx" - "golang.org/x/exp/maps" -) - -// fs.go contains all the io/fs interface implementations - -// Open opens the given data Value within this datafs filesystem. -func (d *Data) Open(name string) (fs.File, error) { - if !fs.ValidPath(name) { - return nil, &fs.PathError{Op: "open", Path: name, Err: errors.New("invalid name")} - } - dir, file := path.Split(name) - sd, err := d.DirAtPath(dir) - if err != nil { - return nil, err - } - fm := sd.filemap() - itm, ok := fm[file] - if !ok { - if dir == "" && (file == d.name || file == ".") { - return &DirFile{File: File{Reader: *bytes.NewReader(d.Bytes()), Data: d}}, nil - } - return nil, &fs.PathError{Op: "open", Path: name, Err: errors.New("file not found")} - } - if itm.IsDir() { - return &DirFile{File: File{Reader: *bytes.NewReader(itm.Bytes()), Data: itm}}, nil - } - return &File{Reader: *bytes.NewReader(itm.Bytes()), Data: itm}, nil -} - -// Stat returns a FileInfo describing the file. -// If there is an error, it should be of type *PathError. -func (d *Data) Stat(name string) (fs.FileInfo, error) { - if !fs.ValidPath(name) { - return nil, &fs.PathError{Op: "open", Path: name, Err: errors.New("invalid name")} - } - dir, file := path.Split(name) - sd, err := d.DirAtPath(dir) - if err != nil { - return nil, err - } - fm := sd.filemap() - itm, ok := fm[file] - if !ok { - if dir == "" && (file == d.name || file == ".") { - return d, nil - } - return nil, &fs.PathError{Op: "stat", Path: name, Err: errors.New("file not found")} - } - return itm, nil -} - -// Sub returns a data FS corresponding to the subtree rooted at dir. -func (d *Data) Sub(dir string) (fs.FS, error) { - if err := d.mustDir("sub", dir); err != nil { - return nil, err - } - if !fs.ValidPath(dir) { - return nil, &fs.PathError{Op: "sub", Path: dir, Err: errors.New("invalid name")} - } - if dir == "." || dir == "" || dir == d.name { - return d, nil - } - cd := dir - cur := d - root, rest := fsx.SplitRootPathFS(dir) - if root == "." || root == d.name { - cd = rest - } - for { - if cd == "." || cd == "" { - return cur, nil - } - root, rest := fsx.SplitRootPathFS(cd) - if root == "." && rest == "" { - return cur, nil - } - cd = rest - fm := cur.filemap() - sd, ok := fm[root] - if !ok { - return nil, &fs.PathError{Op: "sub", Path: dir, Err: errors.New("directory not found")} - } - if !sd.IsDir() { - return nil, &fs.PathError{Op: "sub", Path: dir, Err: errors.New("is not a directory")} - } - cur = sd - } -} - -// ReadDir returns the contents of the given directory within this filesystem. -// Use "." (or "") to refer to the current directory. -func (d *Data) ReadDir(dir string) ([]fs.DirEntry, error) { - sd, err := d.DirAtPath(dir) - if err != nil { - return nil, err - } - fm := sd.filemap() - names := maps.Keys(fm) - sort.Strings(names) - ents := make([]fs.DirEntry, len(names)) - for i, nm := range names { - ents[i] = fm[nm] - } - return ents, nil -} - -// ReadFile reads the named file and returns its contents. -// A successful call returns a nil error, not io.EOF. -// (Because ReadFile reads the whole file, the expected EOF -// from the final Read is not treated as an error to be reported.) -// -// The caller is permitted to modify the returned byte slice. -// This method should return a copy of the underlying data. -func (d *Data) ReadFile(name string) ([]byte, error) { - if err := d.mustDir("readFile", name); err != nil { - return nil, err - } - if !fs.ValidPath(name) { - return nil, &fs.PathError{Op: "readFile", Path: name, Err: errors.New("invalid name")} - } - dir, file := path.Split(name) - sd, err := d.DirAtPath(dir) - if err != nil { - return nil, err - } - fm := sd.filemap() - itm, ok := fm[file] - if !ok { - return nil, &fs.PathError{Op: "readFile", Path: name, Err: errors.New("file not found")} - } - if itm.IsDir() { - return nil, &fs.PathError{Op: "readFile", Path: name, Err: errors.New("Value is a directory")} - } - return slices.Clone(itm.Bytes()), nil -} - -/////////////////////////////// -// FileInfo interface: - -// Sizer is an interface to allow an arbitrary data Value -// to report its size in bytes. Size is automatically computed for -// known basic data Values supported by datafs directly. -type Sizer interface { - Sizeof() int64 -} - -func (d *Data) Name() string { return d.name } - -// Size returns the size of known data Values, or it uses -// the Sizer interface, otherwise returns 0. -func (d *Data) Size() int64 { - if szr, ok := d.Value.(Sizer); ok { // tensor implements Sizer - return szr.Sizeof() - } - switch x := d.Value.(type) { - case float32, int32, uint32: - return 4 - case float64, int64: - return 8 - case int: - return int64(unsafe.Sizeof(x)) - case complex64: - return 16 - case complex128: - return 32 - } - return 0 -} - -func (d *Data) IsDir() bool { - _, ok := d.Value.(map[string]*Data) - return ok -} - -func (d *Data) ModTime() time.Time { - return d.modTime -} - -func (d *Data) Mode() fs.FileMode { - if d.IsDir() { - return 0755 | fs.ModeDir - } - return 0444 -} - -// Sys returns the metadata for Value -func (d *Data) Sys() any { return d.Meta } - -/////////////////////////////// -// DirEntry interface - -func (d *Data) Type() fs.FileMode { - return d.Mode().Type() -} - -func (d *Data) Info() (fs.FileInfo, error) { - return d, nil -} diff --git a/tensor/datafs/metadata.go b/tensor/datafs/metadata.go deleted file mode 100644 index 3ff8d86c8d..0000000000 --- a/tensor/datafs/metadata.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package datafs - -import ( - "cogentcore.org/core/base/errors" - "cogentcore.org/core/base/fsx" - "cogentcore.org/core/base/metadata" - "cogentcore.org/core/plot/plotcore" - "cogentcore.org/core/tensor/table" -) - -// This file provides standardized metadata options for frequent -// use cases, using codified key names to eliminate typos. - -// SetMetaItems sets given metadata for items in given directory -// with given names. Returns error for any items not found. -func (d *Data) SetMetaItems(key string, value any, names ...string) error { - its, err := d.Items(names...) - for _, it := range its { - it.Meta.Set(key, value) - } - return err -} - -// PlotColumnZeroOne returns plot options with a fixed 0-1 range -func PlotColumnZeroOne() *plotcore.ColumnOptions { - opts := &plotcore.ColumnOptions{} - opts.Range.SetMin(0) - opts.Range.SetMax(1) - return opts -} - -// SetPlotColumnOptions sets given plotting options for named items -// within this directory (stored in Metadata). -func (d *Data) SetPlotColumnOptions(opts *plotcore.ColumnOptions, names ...string) error { - return d.SetMetaItems("PlotColumnOptions", opts, names...) -} - -// PlotColumnOptions returns plotting options if they have been set, else nil. -func (d *Data) PlotColumnOptions() *plotcore.ColumnOptions { - return errors.Ignore1(metadata.Get[*plotcore.ColumnOptions](d.Meta, "PlotColumnOptions")) -} - -// SetCalcFunc sets a function to compute an updated Value for this data item. -// Function is stored as CalcFunc in Metadata. Can be called by [Data.Calc] method. -func (d *Data) SetCalcFunc(fun func() error) { - d.Meta.Set("CalcFunc", fun) -} - -// Calc calls function set by [Data.SetCalcFunc] to compute an updated Value -// for this data item. Returns an error if func not set, or any error from func itself. -// Function is stored as CalcFunc in Metadata. -func (d *Data) Calc() error { - fun, err := metadata.Get[func() error](d.Meta, "CalcFunc") - if err != nil { - return err - } - return fun() -} - -// CalcAll calls function set by [Data.SetCalcFunc] for all items -// in this directory and all of its subdirectories. -// Calls Calc on items from FlatItemsByTimeFunc(nil) -func (d *Data) CalcAll() error { - var errs []error - items := d.FlatItemsByTimeFunc(nil) - for _, it := range items { - err := it.Calc() - if err != nil { - errs = append(errs, err) - } - } - return errors.Join(errs...) -} - -// DirTable returns a table.Table for this directory item, with columns -// as the Tensor elements in the directory and any subdirectories, -// from FlatItemsByTimeFunc using given filter function. -// This is a convenient mechanism for creating a plot of all the data -// in a given directory. -// If such was previously constructed, it is returned from "DirTable" -// Metadata key where the table is stored. -// Row count is updated to current max row. -// Delete that key to reconstruct if items have changed. -func (d *Data) DirTable(fun func(item *Data) bool) *table.Table { - dt, err := metadata.Get[*table.Table](d.Meta, "DirTable") - if err == nil { - var maxRow int - for _, tsr := range dt.Columns { - maxRow = max(maxRow, tsr.DimSize(0)) - } - dt.Rows = maxRow - return dt - } - items := d.FlatItemsByTimeFunc(fun) - dt = table.NewTable(fsx.DirAndFile(string(d.Path()))) - for _, it := range items { - tsr := it.AsTensor() - if tsr == nil { - continue - } - if dt.Rows == 0 { - dt.Rows = tsr.DimSize(0) - } - nm := it.Name() - if it.Parent != d { - nm = fsx.DirAndFile(string(it.Path())) - } - dt.AddColumn(tsr, nm) - } - d.Meta.Set("DirTable", dt) - return dt -} diff --git a/tensor/enumgen.go b/tensor/enumgen.go new file mode 100644 index 0000000000..86353996ec --- /dev/null +++ b/tensor/enumgen.go @@ -0,0 +1,89 @@ +// Code generated by "core generate"; DO NOT EDIT. + +package tensor + +import ( + "cogentcore.org/core/enums" +) + +var _DelimsValues = []Delims{0, 1, 2, 3} + +// DelimsN is the highest valid value for type Delims, plus one. +const DelimsN Delims = 4 + +var _DelimsValueMap = map[string]Delims{`Tab`: 0, `Comma`: 1, `Space`: 2, `Detect`: 3} + +var _DelimsDescMap = map[Delims]string{0: `Tab is the tab rune delimiter, for TSV tab separated values`, 1: `Comma is the comma rune delimiter, for CSV comma separated values`, 2: `Space is the space rune delimiter, for SSV space separated value`, 3: `Detect is used during reading a file -- reads the first line and detects tabs or commas`} + +var _DelimsMap = map[Delims]string{0: `Tab`, 1: `Comma`, 2: `Space`, 3: `Detect`} + +// String returns the string representation of this Delims value. +func (i Delims) String() string { return enums.String(i, _DelimsMap) } + +// SetString sets the Delims value from its string representation, +// and returns an error if the string is invalid. +func (i *Delims) SetString(s string) error { return enums.SetString(i, s, _DelimsValueMap, "Delims") } + +// Int64 returns the Delims value as an int64. +func (i Delims) Int64() int64 { return int64(i) } + +// SetInt64 sets the Delims value from an int64. +func (i *Delims) SetInt64(in int64) { *i = Delims(in) } + +// Desc returns the description of the Delims value. +func (i Delims) Desc() string { return enums.Desc(i, _DelimsDescMap) } + +// DelimsValues returns all possible values for the type Delims. +func DelimsValues() []Delims { return _DelimsValues } + +// Values returns all possible values for the type Delims. +func (i Delims) Values() []enums.Enum { return enums.Values(_DelimsValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i Delims) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *Delims) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Delims") } + +var _SlicesMagicValues = []SlicesMagic{0, 1, 2} + +// SlicesMagicN is the highest valid value for type SlicesMagic, plus one. +const SlicesMagicN SlicesMagic = 3 + +var _SlicesMagicValueMap = map[string]SlicesMagic{`FullAxis`: 0, `NewAxis`: 1, `Ellipsis`: 2} + +var _SlicesMagicDescMap = map[SlicesMagic]string{0: `FullAxis indicates that the full existing axis length should be used. This is equivalent to Slice{}, but is more semantic. In NumPy it is equivalent to a single : colon.`, 1: `NewAxis creates a new singleton (length=1) axis, used to to reshape without changing the size. Can also be used in [Reshaped].`, 2: `Ellipsis (...) is used in [NewSliced] expressions to produce a flexibly-sized stretch of FullAxis dimensions, which automatically aligns the remaining slice elements based on the source dimensionality.`} + +var _SlicesMagicMap = map[SlicesMagic]string{0: `FullAxis`, 1: `NewAxis`, 2: `Ellipsis`} + +// String returns the string representation of this SlicesMagic value. +func (i SlicesMagic) String() string { return enums.String(i, _SlicesMagicMap) } + +// SetString sets the SlicesMagic value from its string representation, +// and returns an error if the string is invalid. +func (i *SlicesMagic) SetString(s string) error { + return enums.SetString(i, s, _SlicesMagicValueMap, "SlicesMagic") +} + +// Int64 returns the SlicesMagic value as an int64. +func (i SlicesMagic) Int64() int64 { return int64(i) } + +// SetInt64 sets the SlicesMagic value from an int64. +func (i *SlicesMagic) SetInt64(in int64) { *i = SlicesMagic(in) } + +// Desc returns the description of the SlicesMagic value. +func (i SlicesMagic) Desc() string { return enums.Desc(i, _SlicesMagicDescMap) } + +// SlicesMagicValues returns all possible values for the type SlicesMagic. +func SlicesMagicValues() []SlicesMagic { return _SlicesMagicValues } + +// Values returns all possible values for the type SlicesMagic. +func (i SlicesMagic) Values() []enums.Enum { return enums.Values(_SlicesMagicValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i SlicesMagic) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *SlicesMagic) UnmarshalText(text []byte) error { + return enums.UnmarshalText(i, text, "SlicesMagic") +} diff --git a/tensor/examples/datafs-sim/sim.go b/tensor/examples/datafs-sim/sim.go deleted file mode 100644 index 2062c560b5..0000000000 --- a/tensor/examples/datafs-sim/sim.go +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "math/rand/v2" - "reflect" - "strconv" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/core" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/databrowser" - "cogentcore.org/core/tensor/datafs" - "cogentcore.org/core/tensor/stats/stats" -) - -type Sim struct { - Root *datafs.Data - Config *datafs.Data - Stats *datafs.Data - Logs *datafs.Data -} - -// ConfigAll configures the sim -func (ss *Sim) ConfigAll() { - ss.Root = errors.Log1(datafs.NewDir("Root")) - ss.Config = errors.Log1(ss.Root.Mkdir("Config")) - errors.Log1(datafs.New[int](ss.Config, "NRun", "NEpoch", "NTrial")) - ss.Config.Item("NRun").SetInt(5) - ss.Config.Item("NEpoch").SetInt(20) - ss.Config.Item("NTrial").SetInt(25) - - ss.Stats = ss.ConfigStats(ss.Root) - ss.Logs = ss.ConfigLogs(ss.Root) -} - -// ConfigStats adds basic stats that we record for our simulation. -func (ss *Sim) ConfigStats(dir *datafs.Data) *datafs.Data { - stats := errors.Log1(dir.Mkdir("Stats")) - errors.Log1(datafs.New[int](stats, "Run", "Epoch", "Trial")) // counters - errors.Log1(datafs.New[string](stats, "TrialName")) - errors.Log1(datafs.New[float32](stats, "SSE", "AvgSSE", "TrlErr")) - z1 := datafs.PlotColumnZeroOne() - stats.SetPlotColumnOptions(z1, "AvgErr", "TrlErr") - zmax := datafs.PlotColumnZeroOne() - zmax.Range.FixMax = false - stats.SetPlotColumnOptions(z1, "SSE") - return stats -} - -// ConfigLogs adds first-level logging of stats into tensors -func (ss *Sim) ConfigLogs(dir *datafs.Data) *datafs.Data { - logd := errors.Log1(dir.Mkdir("Log")) - trial := ss.ConfigTrialLog(logd) - ss.ConfigAggLog(logd, "Epoch", trial, stats.Mean, stats.Sem, stats.Min) - return logd -} - -// ConfigTrialLog adds first-level logging of stats into tensors -func (ss *Sim) ConfigTrialLog(dir *datafs.Data) *datafs.Data { - logd := errors.Log1(dir.Mkdir("Trial")) - ntrial, _ := ss.Config.Item("NTrial").AsInt() - sitems := ss.Stats.ItemsByTimeFunc(nil) - for _, st := range sitems { - dt := errors.Log1(datafs.NewData(logd, st.Name())) - tsr := tensor.NewOfType(st.DataType(), []int{ntrial}, "row") - dt.Value = tsr - dt.Meta.Copy(st.Meta) // key affordance: we get meta data from source - dt.SetCalcFunc(func() error { - trl, _ := ss.Stats.Item("Trial").AsInt() - if st.IsNumeric() { - v, _ := st.AsFloat64() - tsr.SetFloat1D(trl, v) - } else { - v, _ := st.AsString() - tsr.SetString1D(trl, v) - } - return nil - }) - } - return logd -} - -// ConfigAggLog adds a higher-level logging of lower-level into higher-level tensors -func (ss *Sim) ConfigAggLog(dir *datafs.Data, level string, from *datafs.Data, aggs ...stats.Stats) *datafs.Data { - logd := errors.Log1(dir.Mkdir(level)) - sitems := ss.Stats.ItemsByTimeFunc(nil) - nctr, _ := ss.Config.Item("N" + level).AsInt() - for _, st := range sitems { - if !st.IsNumeric() { - continue - } - src := from.Item(st.Name()).AsTensor() - if st.DataType() >= reflect.Float32 { - dd := errors.Log1(logd.Mkdir(st.Name())) - for _, ag := range aggs { // key advantage of dir structure: multiple stats per item - dt := errors.Log1(datafs.NewData(dd, ag.String())) - tsr := tensor.NewOfType(st.DataType(), []int{nctr}, "row") - dt.Value = tsr - dt.Meta.Copy(st.Meta) - dt.SetCalcFunc(func() error { - ctr, _ := ss.Stats.Item(level).AsInt() - v := stats.StatTensor(src, ag) - tsr.SetFloat1D(ctr, v) - return nil - }) - } - } else { - dt := errors.Log1(datafs.NewData(logd, st.Name())) - tsr := tensor.NewOfType(st.DataType(), []int{nctr}, "row") - // todo: set level counter as default x axis in plot config - dt.Value = tsr - dt.Meta.Copy(st.Meta) - dt.SetCalcFunc(func() error { - ctr, _ := ss.Stats.Item(level).AsInt() - v, _ := st.AsFloat64() - tsr.SetFloat1D(ctr, v) - return nil - }) - } - } - return logd -} - -func (ss *Sim) Run() { - nepc, _ := ss.Config.Item("NEpoch").AsInt() - ntrl, _ := ss.Config.Item("NTrial").AsInt() - for epc := range nepc { - ss.Stats.Item("Epoch").SetInt(epc) - for trl := range ntrl { - ss.Stats.Item("Trial").SetInt(trl) - ss.RunTrial(trl) - } - ss.EpochDone() - } -} - -func (ss *Sim) RunTrial(trl int) { - ss.Stats.Item("TrialName").SetString("Trial_" + strconv.Itoa(trl)) - sse := rand.Float32() - avgSSE := rand.Float32() - ss.Stats.Item("SSE").SetFloat32(sse) - ss.Stats.Item("AvgSSE").SetFloat32(avgSSE) - trlErr := float32(1) - if sse < 0.5 { - trlErr = 0 - } - ss.Stats.Item("TrlErr").SetFloat32(trlErr) - ss.Logs.Item("Trial").CalcAll() -} - -func (ss *Sim) EpochDone() { - ss.Logs.Item("Epoch").CalcAll() -} - -func main() { - ss := &Sim{} - ss.ConfigAll() - ss.Run() - - databrowser.NewBrowserWindow(ss.Root, "Root") - core.Wait() -} diff --git a/tensor/examples/dataproc/README.md b/tensor/examples/dataproc/README.md deleted file mode 100644 index a12d4d1266..0000000000 --- a/tensor/examples/dataproc/README.md +++ /dev/null @@ -1,6 +0,0 @@ -Build and run this `main` package to see a full demo of how to use this system for data analysis, paralleling the example in [Python Data Science](https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html) using pandas, to see directly how that translates into this framework. - -Most of the code is in the `AnalyzePlanets` function, which opens a .csv file, and then uses a number of `IndexView` views of the data to perform various analyses as shown in the GUI tables. Click on the tabs at the very top of the window to see the various analyzed versions of the data shown in the first tab. - -You can also click on headers of the columns to sort by those columns (toggles between ascending and descending), - diff --git a/tensor/examples/dataproc/dataproc.go b/tensor/examples/dataproc/dataproc.go deleted file mode 100644 index d5455b695d..0000000000 --- a/tensor/examples/dataproc/dataproc.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "embed" - "fmt" - "math" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/core" - "cogentcore.org/core/events" - "cogentcore.org/core/icons" - "cogentcore.org/core/tensor/stats/split" - "cogentcore.org/core/tensor/stats/stats" - "cogentcore.org/core/tensor/table" - "cogentcore.org/core/tensor/tensorcore" - "cogentcore.org/core/tree" -) - -// Planets is raw data -var Planets *table.Table - -// PlanetsDesc are descriptive stats of all (non-Null) data -var PlanetsDesc *table.Table - -// PlanetsNNDesc are descriptive stats of planets where entire row is non-null -var PlanetsNNDesc *table.Table - -// GpMethodOrbit shows the median of orbital period as a function of method -var GpMethodOrbit *table.Table - -// GpMethodYear shows all stats of year described by orbit -var GpMethodYear *table.Table - -// GpMethodDecade shows number of planets found in each decade by given method -var GpMethodDecade *table.Table - -// GpDecade shows number of planets found in each decade -var GpDecade *table.Table - -//go:embed *.csv -var csv embed.FS - -// AnalyzePlanets analyzes planets.csv data following some of the examples -// given here, using pandas: -// -// https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html -func AnalyzePlanets() { - Planets = table.NewTable("planets") - Planets.OpenFS(csv, "planets.csv", table.Comma) - - PlanetsAll := table.NewIndexView(Planets) // full original data - - PlanetsDesc = stats.DescAll(PlanetsAll) // individually excludes Null values in each col, but not row-wise - PlanetsNNDesc = stats.DescAll(PlanetsAll) // standard descriptive stats for row-wise non-nulls - - byMethod := split.GroupBy(PlanetsAll, "method") - split.AggColumn(byMethod, "orbital_period", stats.Median) - GpMethodOrbit = byMethod.AggsToTable(table.AddAggName) - - byMethod.DeleteAggs() - split.DescColumn(byMethod, "year") // full desc stats of year - - byMethod.Filter(func(idx int) bool { - ag := errors.Log1(byMethod.AggByColumnName("year:Std")) - return ag.Aggs[idx][0] > 0 // exclude results with 0 std - }) - - GpMethodYear = byMethod.AggsToTable(table.AddAggName) - - byMethodDecade := split.GroupByFunc(PlanetsAll, func(row int) []string { - meth := Planets.StringValue("method", row) - yr := Planets.Float("year", row) - decade := math.Floor(yr/10) * 10 - return []string{meth, fmt.Sprintf("%gs", decade)} - }) - byMethodDecade.SetLevels("method", "decade") - - split.AggColumn(byMethodDecade, "number", stats.Sum) - - // uncomment this to switch to decade first, then method - // byMethodDecade.ReorderLevels([]int{1, 0}) - // byMethodDecade.SortLevels() - - decadeOnly := errors.Log1(byMethodDecade.ExtractLevels([]int{1})) - split.AggColumn(decadeOnly, "number", stats.Sum) - GpDecade = decadeOnly.AggsToTable(table.AddAggName) - - GpMethodDecade = byMethodDecade.AggsToTable(table.AddAggName) // here to ensure that decadeOnly didn't mess up.. - - // todo: need unstack -- should be specific to the splits data because we already have the cols and - // groups etc -- the ExtractLevels method provides key starting point. - - // todo: pivot table -- neeeds unstack function. - - // todo: could have a generic unstack-like method that takes a column for the data to turn into columns - // and another that has the data to put in the cells. -} - -func main() { - AnalyzePlanets() - - b := core.NewBody("dataproc") - tv := core.NewTabs(b) - - nt, _ := tv.NewTab("Planets Data") - tbv := tensorcore.NewTable(nt).SetTable(Planets) - b.AddTopBar(func(bar *core.Frame) { - tb := core.NewToolbar(bar) - tb.Maker(tbv.MakeToolbar) - tb.Maker(func(p *tree.Plan) { - tree.Add(p, func(w *core.Button) { - w.SetText("README").SetIcon(icons.FileMarkdown). - SetTooltip("open README help file").OnClick(func(e events.Event) { - core.TheApp.OpenURL("https://github.com/cogentcore/core/blob/main/tensor/examples/dataproc/README.md") - }) - }) - }) - }) - - nt, _ = tv.NewTab("Non-Null Rows Desc") - tensorcore.NewTable(nt).SetTable(PlanetsNNDesc) - nt, _ = tv.NewTab("All Desc") - tensorcore.NewTable(nt).SetTable(PlanetsDesc) - nt, _ = tv.NewTab("By Method Orbit") - tensorcore.NewTable(nt).SetTable(GpMethodOrbit) - nt, _ = tv.NewTab("By Method Year") - tensorcore.NewTable(nt).SetTable(GpMethodYear) - nt, _ = tv.NewTab("By Method Decade") - tensorcore.NewTable(nt).SetTable(GpMethodDecade) - nt, _ = tv.NewTab("By Decade") - tensorcore.NewTable(nt).SetTable(GpDecade) - - tv.SelectTabIndex(0) - - b.RunMainWindow() -} diff --git a/tensor/examples/grids/grids.go b/tensor/examples/grids/grids.go index f56658e1d6..cb2a40a792 100644 --- a/tensor/examples/grids/grids.go +++ b/tensor/examples/grids/grids.go @@ -8,7 +8,9 @@ import ( "embed" "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/metadata" "cogentcore.org/core/core" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/table" "cogentcore.org/core/tensor/tensorcore" ) @@ -17,19 +19,19 @@ import ( var tsv embed.FS func main() { - pats := table.NewTable("pats") - pats.SetMetaData("name", "TrainPats") - pats.SetMetaData("desc", "Training patterns") + pats := table.New("TrainPats") + metadata.SetDoc(pats, "Training patterns") // todo: meta data for grid size - errors.Log(pats.OpenFS(tsv, "random_5x5_25.tsv", table.Tab)) + errors.Log(pats.OpenFS(tsv, "random_5x5_25.tsv", tensor.Tab)) b := core.NewBody("grids") - tv := core.NewTabs(b) - - // nt, _ := tv.NewTab("First") nt, _ := tv.NewTab("Patterns") - etv := tensorcore.NewTable(nt).SetTable(pats) + etv := tensorcore.NewTable(nt) + tensorcore.AddGridStylerTo(pats, func(s *tensorcore.GridStyle) { + s.TotalSize = 200 + }) + etv.SetTable(pats) b.AddTopBar(func(bar *core.Frame) { core.NewToolbar(bar).Maker(etv.MakeToolbar) }) diff --git a/tensor/examples/planets/README.md b/tensor/examples/planets/README.md new file mode 100644 index 0000000000..ac2c25e2eb --- /dev/null +++ b/tensor/examples/planets/README.md @@ -0,0 +1,10 @@ +# Planets + +Build and run this `main` package to see a full demo of how to use this system for data analysis, paralleling the example in [Python Data Science](https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html) using pandas, to see directly how that translates into this framework. + +Important: you must run from an interactive terminal shell: it will quit immediately if not. + +Most of the code is in the `AnalyzePlanets` function, which opens a .csv file, and then performs various analyses, using various `tensor/stats` functions. + +The GUI is from the `tensor/databrowser`, showing various views of the data in the left browser. The `Describe` directory has summary stats of various columns of the overall data.by Click into `Stats` and explore from there, doing context menu, `Plot` or double-clicking to get a tabular representation of the data. + diff --git a/tensor/examples/dataproc/planets.csv b/tensor/examples/planets/planets.csv similarity index 100% rename from tensor/examples/dataproc/planets.csv rename to tensor/examples/planets/planets.csv diff --git a/tensor/examples/planets/planets.go b/tensor/examples/planets/planets.go new file mode 100644 index 0000000000..f6672183cc --- /dev/null +++ b/tensor/examples/planets/planets.go @@ -0,0 +1,123 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "embed" + "math" + + "cogentcore.org/core/cli" + "cogentcore.org/core/core" + "cogentcore.org/core/events" + "cogentcore.org/core/goal/interpreter" + "cogentcore.org/core/icons" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/databrowser" + "cogentcore.org/core/tensor/stats/stats" + "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" + "cogentcore.org/core/tree" + "cogentcore.org/core/yaegicore/symbols" +) + +//go:embed *.csv +var csv embed.FS + +// AnalyzePlanets analyzes planets.csv data following some of the examples +// in pandas from: +// https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html +func AnalyzePlanets(dir *tensorfs.Node) { + Planets := table.New("planets") + Planets.OpenFS(csv, "planets.csv", tensor.Comma) + + vals := []string{"number", "orbital_period", "mass", "distance", "year"} + stats.DescribeTable(dir, Planets, vals...) + + decade := Planets.AddFloat64Column("decade") + year := Planets.Column("year") + for row := range Planets.NumRows() { + yr := year.FloatRow(row, 0) + dec := math.Floor(yr/10) * 10 + decade.SetFloatRow(dec, row, 0) + } + + stats.TableGroups(dir, Planets, "method", "decade") + stats.TableGroupDescribe(dir, Planets, vals...) + + // byMethod := split.GroupBy(PlanetsAll, "method") + // split.AggColumn(byMethod, "orbital_period", stats.Median) + // GpMethodOrbit = byMethod.AggsToTable(table.AddAggName) + + // byMethod.DeleteAggs() + // split.DescColumn(byMethod, "year") // full desc stats of year + + // byMethod.Filter(func(idx int) bool { + // ag := errors.Log1(byMethod.AggByColumnName("year:Std")) + // return ag.Aggs[idx][0] > 0 // exclude results with 0 std + // }) + + // GpMethodYear = byMethod.AggsToTable(table.AddAggName) + + // split.AggColumn(byMethodDecade, "number", stats.Sum) + + // uncomment this to switch to decade first, then method + // byMethodDecade.ReorderLevels([]int{1, 0}) + // byMethodDecade.SortLevels() + + // decadeOnly := errors.Log1(byMethodDecade.ExtractLevels([]int{1})) + // split.AggColumn(decadeOnly, "number", stats.Sum) + // GpDecade = decadeOnly.AggsToTable(table.AddAggName) + // + // GpMethodDecade = byMethodDecade.AggsToTable(table.AddAggName) // here to ensure that decadeOnly didn't mess up.. + + // todo: need unstack -- should be specific to the splits data because we already have the cols and + // groups etc -- the ExtractLevels method provides key starting point. + + // todo: pivot table -- neeeds unstack function. + + // todo: could have a generic unstack-like method that takes a column for the data to turn into columns + // and another that has the data to put in the cells. +} + +// important: must be run from an interactive terminal. +// Will quit immediately if not! +func main() { + dir := tensorfs.Mkdir("Planets") + AnalyzePlanets(dir) + + opts := cli.DefaultOptions("planets", "interactive data analysis.") + cfg := &interpreter.Config{} + cfg.InteractiveFunc = Interactive + cli.Run(opts, cfg, interpreter.Run, interpreter.Build) +} + +func Interactive(c *interpreter.Config, in *interpreter.Interpreter) error { + in.Interp.Use(symbols.Symbols) // gui imports + in.Config() + b, _ := databrowser.NewBasicWindow(tensorfs.CurRoot, "Planets") + b.AddTopBar(func(bar *core.Frame) { + tb := core.NewToolbar(bar) + // tb.Maker(tbv.MakeToolbar) + tb.Maker(func(p *tree.Plan) { + tree.Add(p, func(w *core.Button) { + w.SetText("README").SetIcon(icons.FileMarkdown). + SetTooltip("open README help file").OnClick(func(e events.Event) { + core.TheApp.OpenURL("https://github.com/cogentcore/core/blob/main/tensor/examples/planets/README.md") + }) + }) + }) + }) + b.OnShow(func(e events.Event) { + go func() { + if c.Expr != "" { + in.Eval(c.Expr) + } + in.Interactive() + }() + }) + b.RunWindow() + core.Wait() + return nil +} diff --git a/tensor/examples/simstats/README.md b/tensor/examples/simstats/README.md new file mode 100644 index 0000000000..5738af2006 --- /dev/null +++ b/tensor/examples/simstats/README.md @@ -0,0 +1,5 @@ +# tensorfs sim + +This is a prototype for neural network simulation statistics computation using the [tensorfs](../tensorfs) framework, now implemented in the [emergent](https://github.com/emer) framework. + + diff --git a/tensor/examples/simstats/enumgen.go b/tensor/examples/simstats/enumgen.go new file mode 100644 index 0000000000..dffa071159 --- /dev/null +++ b/tensor/examples/simstats/enumgen.go @@ -0,0 +1,89 @@ +// Code generated by "core generate"; DO NOT EDIT. + +package main + +import ( + "cogentcore.org/core/enums" +) + +var _TimesValues = []Times{0, 1, 2} + +// TimesN is the highest valid value for type Times, plus one. +const TimesN Times = 3 + +var _TimesValueMap = map[string]Times{`Trial`: 0, `Epoch`: 1, `Run`: 2} + +var _TimesDescMap = map[Times]string{0: ``, 1: ``, 2: ``} + +var _TimesMap = map[Times]string{0: `Trial`, 1: `Epoch`, 2: `Run`} + +// String returns the string representation of this Times value. +func (i Times) String() string { return enums.String(i, _TimesMap) } + +// SetString sets the Times value from its string representation, +// and returns an error if the string is invalid. +func (i *Times) SetString(s string) error { return enums.SetString(i, s, _TimesValueMap, "Times") } + +// Int64 returns the Times value as an int64. +func (i Times) Int64() int64 { return int64(i) } + +// SetInt64 sets the Times value from an int64. +func (i *Times) SetInt64(in int64) { *i = Times(in) } + +// Desc returns the description of the Times value. +func (i Times) Desc() string { return enums.Desc(i, _TimesDescMap) } + +// TimesValues returns all possible values for the type Times. +func TimesValues() []Times { return _TimesValues } + +// Values returns all possible values for the type Times. +func (i Times) Values() []enums.Enum { return enums.Values(_TimesValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i Times) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *Times) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Times") } + +var _LoopPhaseValues = []LoopPhase{0, 1} + +// LoopPhaseN is the highest valid value for type LoopPhase, plus one. +const LoopPhaseN LoopPhase = 2 + +var _LoopPhaseValueMap = map[string]LoopPhase{`Start`: 0, `Step`: 1} + +var _LoopPhaseDescMap = map[LoopPhase]string{0: `Start is the start of the loop: resets accumulated stats, initializes.`, 1: `Step is each iteration of the loop.`} + +var _LoopPhaseMap = map[LoopPhase]string{0: `Start`, 1: `Step`} + +// String returns the string representation of this LoopPhase value. +func (i LoopPhase) String() string { return enums.String(i, _LoopPhaseMap) } + +// SetString sets the LoopPhase value from its string representation, +// and returns an error if the string is invalid. +func (i *LoopPhase) SetString(s string) error { + return enums.SetString(i, s, _LoopPhaseValueMap, "LoopPhase") +} + +// Int64 returns the LoopPhase value as an int64. +func (i LoopPhase) Int64() int64 { return int64(i) } + +// SetInt64 sets the LoopPhase value from an int64. +func (i *LoopPhase) SetInt64(in int64) { *i = LoopPhase(in) } + +// Desc returns the description of the LoopPhase value. +func (i LoopPhase) Desc() string { return enums.Desc(i, _LoopPhaseDescMap) } + +// LoopPhaseValues returns all possible values for the type LoopPhase. +func LoopPhaseValues() []LoopPhase { return _LoopPhaseValues } + +// Values returns all possible values for the type LoopPhase. +func (i LoopPhase) Values() []enums.Enum { return enums.Values(_LoopPhaseValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i LoopPhase) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *LoopPhase) UnmarshalText(text []byte) error { + return enums.UnmarshalText(i, text, "LoopPhase") +} diff --git a/tensor/examples/simstats/sim.go b/tensor/examples/simstats/sim.go new file mode 100644 index 0000000000..f7bc68c4ac --- /dev/null +++ b/tensor/examples/simstats/sim.go @@ -0,0 +1,219 @@ +// Code generated by "goal build"; DO NOT EDIT. +//line sim.goal:1 +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +//go:generate core generate + +import ( + "math/rand/v2" + + "cogentcore.org/core/core" + "cogentcore.org/core/plot" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/databrowser" + "cogentcore.org/core/tensor/stats/stats" + "cogentcore.org/core/tensor/tensorfs" +) + +// Times are the looping time levels for running and statistics. +type Times int32 //enums:enum + +const ( + Trial Times = iota + Epoch + Run +) + +// LoopPhase is the phase of loop processing for given time. +type LoopPhase int32 //enums:enum + +const ( + // Start is the start of the loop: resets accumulated stats, initializes. + Start LoopPhase = iota + + // Step is each iteration of the loop. + Step +) + +type Sim struct { + // Root is the root data dir. + Root *tensorfs.Node + + // Config has config data. + Config *tensorfs.Node + + // Stats has all stats data. + Stats *tensorfs.Node + + // Current has current value of all stats + Current *tensorfs.Node + + // StatFuncs are statistics functions, per stat, handles everything. + StatFuncs []func(ltime Times, lphase LoopPhase) + + // Counters are current values of counters: normally in looper. + Counters [TimesN]int +} + +// ConfigAll configures the sim +func (ss *Sim) ConfigAll() { + ss.Root, _ = tensorfs.NewDir("Root") + ss.Config = ss.Root.Dir("Config") + mx := tensorfs.Value[int](ss.Config, "Max", int(TimesN)).(*tensor.Int) + mx.Set1D(5, int(Trial)) + mx.Set1D(4, int(Epoch)) + mx.Set1D(3, int(Run)) + // todo: failing - assigns 3 to all + // # mx[Trial] = 5 + // # mx[Epoch] = 4 + // # mx[Run] = 3 + ss.ConfigStats() +} + +func (ss *Sim) AddStat(f func(ltime Times, lphase LoopPhase)) { + ss.StatFuncs = append(ss.StatFuncs, f) +} + +func (ss *Sim) RunStats(ltime Times, lphase LoopPhase) { + for _, sf := range ss.StatFuncs { + sf(ltime, lphase) + } +} + +func (ss *Sim) ConfigStats() { + ss.Stats = ss.Root.Dir("Stats") + ss.Current = ss.Stats.Dir("Current") + ctrs := []Times{Run, Epoch, Trial} + for _, ctr := range ctrs { + ss.AddStat(func(ltime Times, lphase LoopPhase) { + if ltime > ctr { // don't record counter for time above it + return + } + name := ctr.String() // name of stat = counter + timeDir := ss.Stats.Dir(ltime.String()) + tsr := tensorfs.Value[int](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0) + }) + plot.SetStylersTo(tsr, ps) + } + return + } + ctv := ss.Counters[ctr] + tensorfs.Scalar[int](ss.Current, name).SetInt1D(ctv, 0) + tsr.AppendRowInt(ctv) + }) + } + // note: it is essential to only have 1 per func + // so generic names can be used for everything. + ss.AddStat(func(ltime Times, lphase LoopPhase) { + name := "SSE" + timeDir := ss.Stats.Dir(ltime.String()) + tsr := tensorfs.Value[float64](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0).SetMax(1) + s.On = true + }) + plot.SetStylersTo(tsr, ps) + } + return + } + switch ltime { + case Trial: + stat := rand.Float64() + tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0) + tsr.AppendRowFloat(stat) + case Epoch: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + case Run: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + } + }) + ss.AddStat(func(ltime Times, lphase LoopPhase) { + name := "Err" + timeDir := ss.Stats.Dir(ltime.String()) + tsr := tensorfs.Value[float64](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0).SetMax(1) + s.On = true + }) + plot.SetStylersTo(tsr, ps) + } + return + } + switch ltime { + case Trial: + sse := tensorfs.Scalar[float64](ss.Current, "SSE").Float1D(0) + stat := 1.0 + if sse < 0.5 { + stat = 0 + } + tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0) + tsr.AppendRowFloat(stat) + case Epoch: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + case Run: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + } + }) +} + +func (ss *Sim) Run() { + mx := ss.Config.Value("Max").(*tensor.Int) + nrun := mx.Value1D(int(Run)) + nepc := mx.Value1D(int(Epoch)) + ntrl := mx.Value1D(int(Trial)) + ss.RunStats(Run, Start) + for run := range nrun { + ss.Counters[Run] = run + ss.RunStats(Epoch, Start) + for epc := range nepc { + ss.Counters[Epoch] = epc + ss.RunStats(Trial, Start) + for trl := range ntrl { + ss.Counters[Trial] = trl + ss.RunStats(Trial, Step) + } + ss.RunStats(Epoch, Step) + } + ss.RunStats(Run, Step) + } + // todo: could do final analysis here + // alldt := ss.Logs.Item("AllTrials").GetDirTable(nil) + // dir := ss.Logs.Dir("Stats") + // stats.TableGroups(dir, alldt, "Run", "Epoch", "Trial") + // sts := []string{"SSE", "AvgSSE", "TrlErr"} + // stats.TableGroupStats(dir, stats.StatMean, alldt, sts...) + // stats.TableGroupStats(dir, stats.StatSem, alldt, sts...) +} + +func main() { + ss := &Sim{} + ss.ConfigAll() + ss.Run() + + b, _ := databrowser.NewBasicWindow(ss.Root, "Root") + b.RunWindow() + core.Wait() +} diff --git a/tensor/examples/simstats/sim.goal b/tensor/examples/simstats/sim.goal new file mode 100644 index 0000000000..23ee5ebf5e --- /dev/null +++ b/tensor/examples/simstats/sim.goal @@ -0,0 +1,217 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +//go:generate core generate + +import ( + "math/rand/v2" + + "cogentcore.org/core/core" + "cogentcore.org/core/plot" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/databrowser" + "cogentcore.org/core/tensor/stats/stats" + "cogentcore.org/core/tensor/tensorfs" +) + +// Times are the looping time levels for running and statistics. +type Times int32 //enums:enum + +const ( + Trial Times = iota + Epoch + Run +) + +// LoopPhase is the phase of loop processing for given time. +type LoopPhase int32 //enums:enum + +const ( + // Start is the start of the loop: resets accumulated stats, initializes. + Start LoopPhase = iota + + // Step is each iteration of the loop. + Step +) + +type Sim struct { + // Root is the root data dir. + Root *tensorfs.Node + + // Config has config data. + Config *tensorfs.Node + + // Stats has all stats data. + Stats *tensorfs.Node + + // Current has current value of all stats + Current *tensorfs.Node + + // StatFuncs are statistics functions, per stat, handles everything. + StatFuncs []func(ltime Times, lphase LoopPhase) + + // Counters are current values of counters: normally in looper. + Counters [TimesN]int +} + +// ConfigAll configures the sim +func (ss *Sim) ConfigAll() { + ss.Root, _ = tensorfs.NewDir("Root") + ss.Config = ss.Root.Dir("Config") + mx := tensorfs.Value[int](ss.Config, "Max", int(TimesN)).(*tensor.Int) + mx.Set1D(5, int(Trial)) + mx.Set1D(4, int(Epoch)) + mx.Set1D(3, int(Run)) + // todo: failing - assigns 3 to all + // # mx[Trial] = 5 + // # mx[Epoch] = 4 + // # mx[Run] = 3 + ss.ConfigStats() +} + +func (ss *Sim) AddStat(f func(ltime Times, lphase LoopPhase)) { + ss.StatFuncs = append(ss.StatFuncs, f) +} + +func (ss *Sim) RunStats(ltime Times, lphase LoopPhase) { + for _, sf := range ss.StatFuncs { + sf(ltime, lphase) + } +} + +func (ss *Sim) ConfigStats() { + ss.Stats = ss.Root.Dir("Stats") + ss.Current = ss.Stats.Dir("Current") + ctrs := []Times{Run, Epoch, Trial} + for _, ctr := range ctrs { + ss.AddStat(func(ltime Times, lphase LoopPhase) { + if ltime > ctr { // don't record counter for time above it + return + } + name := ctr.String() // name of stat = counter + timeDir := ss.Stats.Dir(ltime.String()) + tsr := tensorfs.Value[int](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0) + }) + plot.SetStylersTo(tsr, ps) + } + return + } + ctv := ss.Counters[ctr] + tensorfs.Scalar[int](ss.Current, name).SetInt1D(ctv, 0) + tsr.AppendRowInt(ctv) + }) + } + // note: it is essential to only have 1 per func + // so generic names can be used for everything. + ss.AddStat(func(ltime Times, lphase LoopPhase) { + name := "SSE" + timeDir := ss.Stats.Dir(ltime.String()) + tsr := tensorfs.Value[float64](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0).SetMax(1) + s.On = true + }) + plot.SetStylersTo(tsr, ps) + } + return + } + switch ltime { + case Trial: + stat := rand.Float64() + tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0) + tsr.AppendRowFloat(stat) + case Epoch: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + case Run: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + } + }) + ss.AddStat(func(ltime Times, lphase LoopPhase) { + name := "Err" + timeDir := ss.Stats.Dir(ltime.String()) + tsr := tensorfs.Value[float64](timeDir, name) + if lphase == Start { + tsr.SetNumRows(0) + if ps := plot.GetStylersFrom(tsr); ps == nil { + ps.Add(func(s *plot.Style) { + s.Range.SetMin(0).SetMax(1) + s.On = true + }) + plot.SetStylersTo(tsr, ps) + } + return + } + switch ltime { + case Trial: + sse := tensorfs.Scalar[float64](ss.Current, "SSE").Float1D(0) + stat := 1.0 + if sse < 0.5 { + stat = 0 + } + tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0) + tsr.AppendRowFloat(stat) + case Epoch: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + case Run: + subd := ss.Stats.Dir((ltime - 1).String()) + stat := stats.StatMean.Call(subd.Value(name)) + tsr.AppendRow(stat) + } + }) +} + +func (ss *Sim) Run() { + mx := ss.Config.Value("Max").(*tensor.Int) + nrun := mx.Value1D(int(Run)) + nepc := mx.Value1D(int(Epoch)) + ntrl := mx.Value1D(int(Trial)) + ss.RunStats(Run, Start) + for run := range nrun { + ss.Counters[Run] = run + ss.RunStats(Epoch, Start) + for epc := range nepc { + ss.Counters[Epoch] = epc + ss.RunStats(Trial, Start) + for trl := range ntrl { + ss.Counters[Trial] = trl + ss.RunStats(Trial, Step) + } + ss.RunStats(Epoch, Step) + } + ss.RunStats(Run, Step) + } + // todo: could do final analysis here + // alldt := ss.Logs.Item("AllTrials").GetDirTable(nil) + // dir := ss.Logs.Dir("Stats") + // stats.TableGroups(dir, alldt, "Run", "Epoch", "Trial") + // sts := []string{"SSE", "AvgSSE", "TrlErr"} + // stats.TableGroupStats(dir, stats.StatMean, alldt, sts...) + // stats.TableGroupStats(dir, stats.StatSem, alldt, sts...) +} + +func main() { + ss := &Sim{} + ss.ConfigAll() + ss.Run() + + b, _ := databrowser.NewBasicWindow(ss.Root, "Root") + b.RunWindow() + core.Wait() +} diff --git a/tensor/funcs.go b/tensor/funcs.go new file mode 100644 index 0000000000..1e7dc54684 --- /dev/null +++ b/tensor/funcs.go @@ -0,0 +1,253 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "fmt" + "reflect" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/metadata" +) + +// Func represents a registered tensor function, which has +// In number of input Tensor arguments, and Out number of output +// arguments (typically 1). There can also be an 'any' first +// argument to support other kinds of parameters. +// This is used to make tensor functions available to the Goal language. +type Func struct { + // Name is the original CamelCase Go name for function + Name string + + // Fun is the function, which must _only_ take some number of Tensor + // args, with an optional any first arg. + Fun any + + // Args has parsed information about the function args, for Goal. + Args []*Arg +} + +// Arg has key information that Goal needs about each arg, for converting +// expressions into the appropriate type. +type Arg struct { + // Type has full reflection type info. + Type reflect.Type + + // IsTensor is true if it satisfies the Tensor interface. + IsTensor bool + + // IsInt is true if Kind = Int, for shape, slice etc params. + IsInt bool + + // IsVariadic is true if this is the last arg and has ...; type will be an array. + IsVariadic bool +} + +// NewFunc creates a new Func desciption of the given +// function, which must have a signature like this: +// func([opt any,] a, b, out tensor.Tensor) error +// i.e., taking some specific number of Tensor arguments (up to 5). +// Functions can also take an 'any' first argument to handle other +// non-tensor inputs (e.g., function pointer, dirfs directory, etc). +// The name should be a standard 'package.FuncName' qualified, exported +// CamelCase name, with 'out' indicating the number of output arguments, +// and an optional arg indicating an 'any' first argument. +// The remaining arguments in the function (automatically +// determined) are classified as input arguments. +func NewFunc(name string, fun any) (*Func, error) { + fn := &Func{Name: name, Fun: fun} + fn.GetArgs() + return fn, nil +} + +// GetArgs gets key info about each arg, for use by Goal transpiler. +func (fn *Func) GetArgs() { + ft := reflect.TypeOf(fn.Fun) + n := ft.NumIn() + if n == 0 { + return + } + fn.Args = make([]*Arg, n) + tsrt := reflect.TypeFor[Tensor]() + for i := range n { + at := ft.In(i) + ag := &Arg{Type: at} + if ft.IsVariadic() && i == n-1 { + ag.IsVariadic = true + } + if at.Kind() == reflect.Int || (at.Kind() == reflect.Slice && at.Elem().Kind() == reflect.Int) { + ag.IsInt = true + } else if at.Implements(tsrt) { + ag.IsTensor = true + } + fn.Args[i] = ag + } +} + +func (fn *Func) String() string { + s := fn.Name + "(" + na := len(fn.Args) + for i, a := range fn.Args { + if a.IsVariadic { + s += "..." + } + ts := a.Type.String() + if ts == "interface {}" { + ts = "any" + } + s += ts + if i < na-1 { + s += ", " + } + } + s += ")" + return s +} + +// Funcs is the global tensor named function registry. +// All functions must have a signature like this: +// func([opt any,] a, b, out tensor.Tensor) error +// i.e., taking some specific number of Tensor arguments (up to 5), +// with the number of output vs. input arguments registered. +// Functions can also take an 'any' first argument to handle other +// non-tensor inputs (e.g., function pointer, dirfs directory, etc). +// This is used to make tensor functions available to the Goal +// language. +var Funcs map[string]*Func + +// AddFunc adds given named function to the global tensor named function +// registry, which is used by Goal to call functions by name. +// See [NewFunc] for more informa.tion. +func AddFunc(name string, fun any) error { + if Funcs == nil { + Funcs = make(map[string]*Func) + } + _, ok := Funcs[name] + if ok { + return fmt.Errorf("tensor.AddFunc: function of name %q already exists, not added", name) + } + fn, err := NewFunc(name, fun) + if errors.Log(err) != nil { + return err + } + Funcs[name] = fn + // note: can record orig camel name if needed for docs etc later. + return nil +} + +// FuncByName finds function of given name in the registry, +// returning an error if the function name has not been registered. +func FuncByName(name string) (*Func, error) { + fn, ok := Funcs[name] + if !ok { + return nil, fmt.Errorf("tensor.FuncByName: function of name %q not registered", name) + } + return fn, nil +} + +// These generic functions provide a one liner for wrapping functions +// that take an output Tensor as the last argument, which is important +// for memory re-use of the output in performance-critical cases. +// The names indicate the number of input tensor arguments. +// Additional generic non-Tensor inputs are supported up to 2, +// with Gen1 and Gen2 versions. + +// FloatPromoteType returns the DataType for Tensor(s) that promotes +// the Float type if any of the elements are of that type. +// Otherwise it returns the type of the first tensor. +func FloatPromoteType(tsr ...Tensor) reflect.Kind { + ft := tsr[0].DataType() + for i := 1; i < len(tsr); i++ { + t := tsr[i].DataType() + if t == reflect.Float64 { + ft = t + } else if t == reflect.Float32 && ft != reflect.Float64 { + ft = t + } + } + return ft +} + +// CallOut1 adds output [Values] tensor for function. +func CallOut1(fun func(a Tensor, out Values) error, a Tensor) Values { + out := NewOfType(a.DataType()) + errors.Log(fun(a, out)) + return out +} + +// CallOut1Float64 adds Float64 output [Values] tensor for function. +func CallOut1Float64(fun func(a Tensor, out Values) error, a Tensor) Values { + out := NewFloat64() + errors.Log(fun(a, out)) + return out +} + +func CallOut2Float64(fun func(a, b Tensor, out Values) error, a, b Tensor) Values { + out := NewFloat64() + errors.Log(fun(a, b, out)) + return out +} + +func CallOut2(fun func(a, b Tensor, out Values) error, a, b Tensor) Values { + out := NewOfType(FloatPromoteType(a, b)) + errors.Log(fun(a, b, out)) + return out +} + +func CallOut3(fun func(a, b, c Tensor, out Values) error, a, b, c Tensor) Values { + out := NewOfType(FloatPromoteType(a, b, c)) + errors.Log(fun(a, b, c, out)) + return out +} + +func CallOut2Bool(fun func(a, b Tensor, out *Bool) error, a, b Tensor) *Bool { + out := NewBool() + errors.Log(fun(a, b, out)) + return out +} + +func CallOut1Gen1[T any](fun func(g T, a Tensor, out Values) error, g T, a Tensor) Values { + out := NewOfType(a.DataType()) + errors.Log(fun(g, a, out)) + return out +} + +func CallOut1Gen2[T any, S any](fun func(g T, h S, a Tensor, out Values) error, g T, h S, a Tensor) Values { + out := NewOfType(a.DataType()) + errors.Log(fun(g, h, a, out)) + return out +} + +func CallOut2Gen1[T any](fun func(g T, a, b Tensor, out Values) error, g T, a, b Tensor) Values { + out := NewOfType(FloatPromoteType(a, b)) + errors.Log(fun(g, a, b, out)) + return out +} + +func CallOut2Gen2[T any, S any](fun func(g T, h S, a, b Tensor, out Values) error, g T, h S, a, b Tensor) Values { + out := NewOfType(FloatPromoteType(a, b)) + errors.Log(fun(g, h, a, b, out)) + return out +} + +//////// Metadata + +// SetCalcFunc sets a function to calculate updated value for given tensor, +// storing the function pointer in the Metadata "CalcFunc" key for the tensor. +// Can be called by [Calc] function. +func SetCalcFunc(tsr Tensor, fun func() error) { + tsr.Metadata().Set("CalcFunc", fun) +} + +// Calc calls function set by [SetCalcFunc] to compute an updated value for +// given tensor. Returns an error if func not set, or any error from func itself. +// Function is stored as CalcFunc in Metadata. +func Calc(tsr Tensor) error { + fun, err := metadata.Get[func() error](*tsr.Metadata(), "CalcFunc") + if err != nil { + return err + } + return fun() +} diff --git a/tensor/funcs_test.go b/tensor/funcs_test.go new file mode 100644 index 0000000000..e2465ab132 --- /dev/null +++ b/tensor/funcs_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" +) + +// prototype of a simple compute function: +func absout(in Tensor, out Values) error { + SetShapeFrom(out, in) + VectorizeThreaded(1, NFirstLen, func(idx int, tsr ...Tensor) { + tsr[1].SetFloat1D(math.Abs(tsr[0].Float1D(idx)), idx) + }, in, out) + return nil +} + +func TestFuncs(t *testing.T) { + err := AddFunc("Abs", absout) + assert.NoError(t, err) + + err = AddFunc("Abs", absout) + assert.Error(t, err) // already + + vals := []float64{-1.507556722888818, -1.2060453783110545, -0.9045340337332908, -0.6030226891555273, -0.3015113445777635, 0, 0.3015113445777635, 0.603022689155527, 0.904534033733291, 1.2060453783110545, 1.507556722888818, .3} + + oned := NewNumberFromValues(vals...) + oneout := oned.Clone() + + fn, err := FuncByName("Abs") + assert.NoError(t, err) + + // fmt.Println(fn.Args[0], fn.Args[1]) + assert.Equal(t, true, fn.Args[0].IsTensor) + assert.Equal(t, true, fn.Args[1].IsTensor) + assert.Equal(t, false, fn.Args[0].IsInt) + assert.Equal(t, false, fn.Args[1].IsInt) + + absout(oned, oneout) + assert.Equal(t, 1.507556722888818, oneout.Float1D(0)) +} + +func TestAlign(t *testing.T) { + a := NewFloat64(3, 4) + b := NewFloat64(1, 3, 4) + as, bs, os, err := AlignShapes(a, b) + assert.NoError(t, err) + assert.Equal(t, []int{1, 3, 4}, os.Sizes) + assert.Equal(t, []int{1, 3, 4}, as.Sizes) + assert.Equal(t, []int{1, 3, 4}, bs.Sizes) + + ars := NewReshaped(a, 12) + as, bs, os, err = AlignShapes(ars, b) + assert.Error(t, err) + + brs := NewReshaped(b, 12) + as, bs, os, err = AlignShapes(ars, brs) + assert.NoError(t, err) + + ars = NewReshaped(a, 3, 1, 4) + as, bs, os, err = AlignShapes(ars, b) + assert.NoError(t, err) + assert.Equal(t, []int{3, 3, 4}, os.Sizes) + assert.Equal(t, []int{3, 1, 4}, as.Sizes) + assert.Equal(t, []int{1, 3, 4}, bs.Sizes) +} diff --git a/tensor/indexed.go b/tensor/indexed.go new file mode 100644 index 0000000000..ad778ca808 --- /dev/null +++ b/tensor/indexed.go @@ -0,0 +1,202 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "reflect" + "slices" + + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/base/reflectx" +) + +// Indexed provides an arbitrarily indexed view onto another "source" [Tensor] +// with each index value providing a full n-dimensional index into the source. +// The shape of this view is determined by the shape of the [Indexed.Indexes] +// tensor up to the final innermost dimension, which holds the index values. +// Thus the innermost dimension size of the indexes is equal to the number +// of dimensions in the source tensor. Given the essential role of the +// indexes in this view, it is not usable without the indexes. +// This view is not memory-contiguous and does not support the [RowMajor] +// interface or efficient access to inner-dimensional subspaces. +// To produce a new concrete [Values] that has raw data actually +// organized according to the indexed order (i.e., the copy function +// of numpy), call [Indexed.AsValues]. +type Indexed struct { //types:add + + // Tensor source that we are an indexed view onto. + Tensor Tensor + + // Indexes is the list of indexes into the source tensor, + // with the innermost dimension providing the index values + // (size = number of dimensions in the source tensor), and + // the remaining outer dimensions determine the shape + // of this [Indexed] tensor view. + Indexes *Int +} + +// NewIndexed returns a new [Indexed] view of given tensor, +// with tensor of indexes into the source tensor. +func NewIndexed(tsr Tensor, idx *Int) *Indexed { + ix := &Indexed{Tensor: tsr} + ix.Indexes = idx + return ix +} + +// AsIndexed returns the tensor as a [Indexed] view, if it is one. +// Otherwise, it returns nil; there is no usable "null" Indexed view. +func AsIndexed(tsr Tensor) *Indexed { + if ix, ok := tsr.(*Indexed); ok { + return ix + } + return nil +} + +// SetTensor sets as indexes into given tensor with sequential initial indexes. +func (ix *Indexed) SetTensor(tsr Tensor) { + ix.Tensor = tsr +} + +// SourceIndexes returns the actual indexes into underlying source tensor +// based on given list of indexes into the [Indexed.Indexes] tensor, +// _excluding_ the final innermost dimension. +func (ix *Indexed) SourceIndexes(i ...int) []int { + idx := slices.Clone(i) + idx = append(idx, 0) // first index + oned := ix.Indexes.Shape().IndexTo1D(idx...) + nd := ix.Tensor.NumDims() + return ix.Indexes.Values[oned : oned+nd] +} + +// SourceIndexesFrom1D returns the full indexes into source tensor based on the +// given 1d index, which is based on the outer dimensions, excluding the +// final innermost dimension. +func (ix *Indexed) SourceIndexesFrom1D(oned int) []int { + nd := ix.Tensor.NumDims() + oned *= nd + return ix.Indexes.Values[oned : oned+nd] +} + +func (ix *Indexed) Label() string { return label(metadata.Name(ix), ix.Shape()) } +func (ix *Indexed) String() string { return Sprintf("", ix, 0) } +func (ix *Indexed) Metadata() *metadata.Data { return ix.Tensor.Metadata() } +func (ix *Indexed) IsString() bool { return ix.Tensor.IsString() } +func (ix *Indexed) DataType() reflect.Kind { return ix.Tensor.DataType() } +func (ix *Indexed) Shape() *Shape { return NewShape(ix.ShapeSizes()...) } +func (ix *Indexed) Len() int { return ix.Shape().Len() } +func (ix *Indexed) NumDims() int { return ix.Indexes.NumDims() - 1 } +func (ix *Indexed) DimSize(dim int) int { return ix.Indexes.DimSize(dim) } + +func (ix *Indexed) ShapeSizes() []int { + si := slices.Clone(ix.Indexes.ShapeSizes()) + return si[:len(si)-1] // exclude last dim +} + +// AsValues returns a copy of this tensor as raw [Values]. +// This "renders" the Indexed view into a fully contiguous +// and optimized memory representation of that view, which will be faster +// to access for further processing, and enables all the additional +// functionality provided by the [Values] interface. +func (ix *Indexed) AsValues() Values { + dt := ix.Tensor.DataType() + vt := NewOfType(dt, ix.ShapeSizes()...) + n := ix.Len() + switch { + case ix.Tensor.IsString(): + for i := range n { + vt.SetString1D(ix.String1D(i), i) + } + case reflectx.KindIsFloat(dt): + for i := range n { + vt.SetFloat1D(ix.Float1D(i), i) + } + default: + for i := range n { + vt.SetInt1D(ix.Int1D(i), i) + } + } + return vt +} + +//////// Floats + +// Float returns the value of given index as a float64. +// The indexes are indirected through the [Indexed.Indexes]. +func (ix *Indexed) Float(i ...int) float64 { + return ix.Tensor.Float(ix.SourceIndexes(i...)...) +} + +// SetFloat sets the value of given index as a float64 +// The indexes are indirected through the [Indexed.Indexes]. +func (ix *Indexed) SetFloat(val float64, i ...int) { + ix.Tensor.SetFloat(val, ix.SourceIndexes(i...)...) +} + +// Float1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ix *Indexed) Float1D(i int) float64 { + return ix.Tensor.Float(ix.SourceIndexesFrom1D(i)...) +} + +// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ix *Indexed) SetFloat1D(val float64, i int) { + ix.Tensor.SetFloat(val, ix.SourceIndexesFrom1D(i)...) +} + +//////// Strings + +// StringValue returns the value of given index as a string. +// The indexes are indirected through the [Indexed.Indexes]. +func (ix *Indexed) StringValue(i ...int) string { + return ix.Tensor.StringValue(ix.SourceIndexes(i...)...) +} + +// SetString sets the value of given index as a string +// The indexes are indirected through the [Indexed.Indexes]. +func (ix *Indexed) SetString(val string, i ...int) { + ix.Tensor.SetString(val, ix.SourceIndexes(i...)...) +} + +// String1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ix *Indexed) String1D(i int) string { + return ix.Tensor.StringValue(ix.SourceIndexesFrom1D(i)...) +} + +// SetString1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ix *Indexed) SetString1D(val string, i int) { + ix.Tensor.SetString(val, ix.SourceIndexesFrom1D(i)...) +} + +//////// Ints + +// Int returns the value of given index as an int. +// The indexes are indirected through the [Indexed.Indexes]. +func (ix *Indexed) Int(i ...int) int { + return ix.Tensor.Int(ix.SourceIndexes(i...)...) +} + +// SetInt sets the value of given index as an int +// The indexes are indirected through the [Indexed.Indexes]. +func (ix *Indexed) SetInt(val int, i ...int) { + ix.Tensor.SetInt(val, ix.SourceIndexes(i...)...) +} + +// Int1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ix *Indexed) Int1D(i int) int { + return ix.Tensor.Int(ix.SourceIndexesFrom1D(i)...) +} + +// SetInt1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ix *Indexed) SetInt1D(val int, i int) { + ix.Tensor.SetInt(val, ix.SourceIndexesFrom1D(i)...) +} + +// check for interface impl +var _ Tensor = (*Indexed)(nil) diff --git a/tensor/io.go b/tensor/io.go index 00b5548fd2..cbc77c5308 100644 --- a/tensor/io.go +++ b/tensor/io.go @@ -5,20 +5,67 @@ package tensor import ( + "bytes" "encoding/csv" + "fmt" "io" "log" "os" "strconv" + "strings" - "cogentcore.org/core/core" + "cogentcore.org/core/base/fsx" + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/base/reflectx" ) +// Delim are standard CSV delimiter options (Tab, Comma, Space) +type Delims int32 //enums:enum + +const ( + // Tab is the tab rune delimiter, for TSV tab separated values + Tab Delims = iota + + // Comma is the comma rune delimiter, for CSV comma separated values + Comma + + // Space is the space rune delimiter, for SSV space separated value + Space + + // Detect is used during reading a file -- reads the first line and detects tabs or commas + Detect +) + +func (dl Delims) Rune() rune { + switch dl { + case Tab: + return '\t' + case Comma: + return ',' + case Space: + return ' ' + } + return '\t' +} + +// SetPrecision sets the "precision" metadata value that determines +// the precision to use in writing floating point numbers to files. +func SetPrecision(obj any, prec int) { + metadata.SetTo(obj, "Precision", prec) +} + +// Precision gets the "precision" metadata value that determines +// the precision to use in writing floating point numbers to files. +// returns an error if not set. +func Precision(obj any) (int, error) { + return metadata.GetFrom[int](obj, "Precision") +} + // SaveCSV writes a tensor to a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg). // Outer-most dims are rows in the file, and inner-most is column -- // Reading just grabs all values and doesn't care about shape. -func SaveCSV(tsr Tensor, filename core.Filename, delim rune) error { +func SaveCSV(tsr Tensor, filename fsx.Filename, delim Delims) error { fp, err := os.Create(string(filename)) defer fp.Close() if err != nil { @@ -34,7 +81,7 @@ func SaveCSV(tsr Tensor, filename core.Filename, delim rune) error { // using the Go standard encoding/csv reader conforming // to the official CSV standard. // Reads all values and assigns as many as fit. -func OpenCSV(tsr Tensor, filename core.Filename, delim rune) error { +func OpenCSV(tsr Tensor, filename fsx.Filename, delim Delims) error { fp, err := os.Open(string(filename)) defer fp.Close() if err != nil { @@ -44,22 +91,19 @@ func OpenCSV(tsr Tensor, filename core.Filename, delim rune) error { return ReadCSV(tsr, fp, delim) } -////////////////////////////////////////////////////////////////////////// -// WriteCSV +//////// WriteCSV // WriteCSV writes a tensor to a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg). // Outer-most dims are rows in the file, and inner-most is column -- // Reading just grabs all values and doesn't care about shape. -func WriteCSV(tsr Tensor, w io.Writer, delim rune) error { +func WriteCSV(tsr Tensor, w io.Writer, delim Delims) error { prec := -1 - if ps, ok := tsr.MetaData("precision"); ok { - prec, _ = strconv.Atoi(ps) + if ps, err := Precision(tsr); err == nil { + prec = ps } cw := csv.NewWriter(w) - if delim != 0 { - cw.Comma = delim - } + cw.Comma = delim.Rune() nrow := tsr.DimSize(0) nin := tsr.Len() / nrow rec := make([]string, nin) @@ -88,11 +132,9 @@ func WriteCSV(tsr Tensor, w io.Writer, delim rune) error { // using the Go standard encoding/csv reader conforming // to the official CSV standard. // Reads all values and assigns as many as fit. -func ReadCSV(tsr Tensor, r io.Reader, delim rune) error { +func ReadCSV(tsr Tensor, r io.Reader, delim Delims) error { cr := csv.NewReader(r) - if delim != 0 { - cr.Comma = delim - } + cr.Comma = delim.Rune() rec, err := cr.ReadAll() // todo: lazy, avoid resizing if err != nil || len(rec) == 0 { return err @@ -104,7 +146,7 @@ func ReadCSV(tsr Tensor, r io.Reader, delim rune) error { for ri := 0; ri < rows; ri++ { for ci := 0; ci < cols; ci++ { str := rec[ri][ci] - tsr.SetString1D(idx, str) + tsr.SetString1D(str, idx) idx++ if idx >= sz { goto done @@ -114,3 +156,154 @@ func ReadCSV(tsr Tensor, r io.Reader, delim rune) error { done: return nil } + +func label(nm string, sh *Shape) string { + if nm != "" { + nm += " " + sh.String() + } else { + nm = sh.String() + } + return nm +} + +// padToLength returns the given string with added spaces +// to pad out to target length. at least 1 space will be added +func padToLength(str string, tlen int) string { + slen := len(str) + if slen < tlen-1 { + return str + strings.Repeat(" ", tlen-slen) + } + return str + " " +} + +// prepadToLength returns the given string with added spaces +// to pad out to target length at start (for numbers). +// at least 1 space will be added +func prepadToLength(str string, tlen int) string { + slen := len(str) + if slen < tlen-1 { + return strings.Repeat(" ", tlen-slen-1) + str + " " + } + return str + " " +} + +// MaxPrintLineWidth is the maximum line width in characters +// to generate for tensor Sprintf function. +var MaxPrintLineWidth = 80 + +// Sprintf returns a string representation of the given tensor, +// with a maximum length of as given: output is terminated +// when it exceeds that length. If maxLen = 0, [MaxSprintLength] is used. +// The format is the per-element format string. +// If empty it uses general %g for number or %s for string. +func Sprintf(format string, tsr Tensor, maxLen int) string { + if maxLen == 0 { + maxLen = MaxSprintLength + } + defFmt := format == "" + if defFmt { + switch { + case tsr.IsString(): + format = "%s" + case reflectx.KindIsInt(tsr.DataType()): + format = "%.10g" + default: + format = "%.10g" + } + } + nd := tsr.NumDims() + if nd == 1 && tsr.DimSize(0) == 1 { // scalar special case + if tsr.IsString() { + return fmt.Sprintf(format, tsr.String1D(0)) + } else { + return fmt.Sprintf(format, tsr.Float1D(0)) + } + } + mxwd := 0 + n := min(tsr.Len(), maxLen) + for i := range n { + s := "" + if tsr.IsString() { + s = fmt.Sprintf(format, tsr.String1D(i)) + } else { + s = fmt.Sprintf(format, tsr.Float1D(i)) + } + if len(s) > mxwd { + mxwd = len(s) + } + } + onedRow := false + shp := tsr.Shape() + rowShape, colShape, _, colIdxs := Projection2DDimShapes(shp, onedRow) + rows, cols, _, _ := Projection2DShape(shp, onedRow) + + rowWd := len(rowShape.String()) + 1 + legend := "" + if nd > 2 { + leg := bytes.Repeat([]byte("r "), nd) + for _, i := range colIdxs { + leg[2*i] = 'c' + } + legend = "[" + string(leg[:len(leg)-1]) + "]" + } + rowWd = max(rowWd, len(legend)+1) + hdrWd := len(colShape.String()) + 1 + colWd := mxwd + 1 + + var b strings.Builder + b.WriteString(tsr.Label()) + noidx := false + if tsr.NumDims() == 1 { + b.WriteString(" ") + rowWd = len(tsr.Label()) + 1 + noidx = true + } else { + b.WriteString("\n") + } + if !noidx && nd > 1 && cols > 1 { + colWd = max(colWd, hdrWd) + b.WriteString(padToLength(legend, rowWd)) + totWd := rowWd + for c := 0; c < cols; c++ { + _, cc := Projection2DCoords(shp, onedRow, 0, c) + s := prepadToLength(fmt.Sprintf("%v", cc), colWd) + if totWd+len(s) > MaxPrintLineWidth { + b.WriteString("\n" + strings.Repeat(" ", rowWd)) + totWd = rowWd + } + b.WriteString(s) + totWd += len(s) + } + b.WriteString("\n") + } + ctr := 0 + for r := range rows { + rc, _ := Projection2DCoords(shp, onedRow, r, 0) + if !noidx { + b.WriteString(padToLength(fmt.Sprintf("%v", rc), rowWd)) + } + ri := r + totWd := rowWd + for c := 0; c < cols; c++ { + s := "" + if tsr.IsString() { + s = padToLength(fmt.Sprintf(format, Projection2DString(tsr, onedRow, ri, c)), colWd) + } else { + s = prepadToLength(fmt.Sprintf(format, Projection2DValue(tsr, onedRow, ri, c)), colWd) + } + if totWd+len(s) > MaxPrintLineWidth { + b.WriteString("\n" + strings.Repeat(" ", rowWd)) + totWd = rowWd + } + b.WriteString(s) + totWd += len(s) + } + b.WriteString("\n") + ctr += cols + if ctr > maxLen { + b.WriteString("...\n") + break + } + } + return b.String() +} diff --git a/tensor/masked.go b/tensor/masked.go new file mode 100644 index 0000000000..53d5caba99 --- /dev/null +++ b/tensor/masked.go @@ -0,0 +1,264 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "math" + "reflect" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/base/reflectx" +) + +// Masked is a filtering wrapper around another "source" [Tensor], +// that provides a bit-masked view onto the Tensor defined by a [Bool] [Values] +// tensor with a matching shape. If the bool mask has a 'false' +// then the corresponding value cannot be Set, and Float access returns +// NaN indicating missing data (other type access returns the zero value). +// A new Masked view defaults to a full transparent view of the source tensor. +// To produce a new [Values] tensor with only the 'true' cases, +// (i.e., the copy function of numpy), call [Masked.AsValues]. +type Masked struct { //types:add + + // Tensor source that we are a masked view onto. + Tensor Tensor + + // Bool tensor with same shape as source tensor, providing mask. + Mask *Bool +} + +// NewMasked returns a new [Masked] view of given tensor, +// with given [Bool] mask values. If no mask is provided, +// a default full transparent (all bool values = true) mask is used. +func NewMasked(tsr Tensor, mask ...*Bool) *Masked { + ms := &Masked{Tensor: tsr} + if len(mask) == 1 { + ms.Mask = mask[0] + ms.SyncShape() + } else { + ms.Mask = NewBoolShape(tsr.Shape()) + ms.Mask.SetTrue() + } + return ms +} + +// Mask is the general purpose masking function, which checks +// if the mask arg is a Bool and uses if so. +// Otherwise, it logs an error. +func Mask(tsr, mask Tensor) Tensor { + if mb, ok := mask.(*Bool); ok { + return NewMasked(tsr, mb) + } + errors.Log(errors.New("tensor.Mask: provided tensor is not a Bool tensor")) + return tsr +} + +// AsMasked returns the tensor as a [Masked] view. +// If it already is one, then it is returned, otherwise it is wrapped +// with an initially fully transparent mask. +func AsMasked(tsr Tensor) *Masked { + if ms, ok := tsr.(*Masked); ok { + return ms + } + return NewMasked(tsr) +} + +// SetTensor sets the given source tensor. If the shape does not match +// the current Mask, then a new transparent mask is established. +func (ms *Masked) SetTensor(tsr Tensor) { + ms.Tensor = tsr + ms.SyncShape() +} + +// SyncShape ensures that [Masked.Mask] shape is the same as source tensor. +// If the Mask does not exist or is a different shape from the source, +// then it is created or reshaped, and all values set to true ("transparent"). +func (ms *Masked) SyncShape() { + if ms.Mask == nil { + ms.Mask = NewBoolShape(ms.Tensor.Shape()) + ms.Mask.SetTrue() + return + } + if !ms.Mask.Shape().IsEqual(ms.Tensor.Shape()) { + SetShapeFrom(ms.Mask, ms.Tensor) + ms.Mask.SetTrue() + } +} + +func (ms *Masked) Label() string { return label(metadata.Name(ms), ms.Shape()) } +func (ms *Masked) String() string { return Sprintf("", ms, 0) } +func (ms *Masked) Metadata() *metadata.Data { return ms.Tensor.Metadata() } +func (ms *Masked) IsString() bool { return ms.Tensor.IsString() } +func (ms *Masked) DataType() reflect.Kind { return ms.Tensor.DataType() } +func (ms *Masked) ShapeSizes() []int { return ms.Tensor.ShapeSizes() } +func (ms *Masked) Shape() *Shape { return ms.Tensor.Shape() } +func (ms *Masked) Len() int { return ms.Tensor.Len() } +func (ms *Masked) NumDims() int { return ms.Tensor.NumDims() } +func (ms *Masked) DimSize(dim int) int { return ms.Tensor.DimSize(dim) } + +// AsValues returns a copy of this tensor as raw [Values]. +// This "renders" the Masked view into a fully contiguous +// and optimized memory representation of that view. +// Because the masking pattern is unpredictable, only a 1D shape is possible. +func (ms *Masked) AsValues() Values { + dt := ms.Tensor.DataType() + n := ms.Len() + switch { + case ms.Tensor.IsString(): + vals := make([]string, 0, n) + for i := range n { + if !ms.Mask.Bool1D(i) { + continue + } + vals = append(vals, ms.Tensor.String1D(i)) + } + return NewStringFromValues(vals...) + case reflectx.KindIsFloat(dt): + vals := make([]float64, 0, n) + for i := range n { + if !ms.Mask.Bool1D(i) { + continue + } + vals = append(vals, ms.Tensor.Float1D(i)) + } + return NewFloat64FromValues(vals...) + default: + vals := make([]int, 0, n) + for i := range n { + if !ms.Mask.Bool1D(i) { + continue + } + vals = append(vals, ms.Tensor.Int1D(i)) + } + return NewIntFromValues(vals...) + } +} + +// SourceIndexes returns a flat [Int] tensor of the mask values +// that match the given getTrue argument state. +// These can be used as indexes in the [Indexed] view, for example. +// The resulting tensor is 2D with inner dimension = number of source +// tensor dimensions, to hold the indexes, and outer dimension = number +// of indexes. +func (ms *Masked) SourceIndexes(getTrue bool) *Int { + n := ms.Len() + nd := ms.Tensor.NumDims() + idxs := make([]int, 0, n*nd) + for i := range n { + if ms.Mask.Bool1D(i) != getTrue { + continue + } + ix := ms.Tensor.Shape().IndexFrom1D(i) + idxs = append(idxs, ix...) + } + it := NewIntFromValues(idxs...) + it.SetShapeSizes(len(idxs)/nd, nd) + return it +} + +//////// Floats + +func (ms *Masked) Float(i ...int) float64 { + if !ms.Mask.Bool(i...) { + return math.NaN() + } + return ms.Tensor.Float(i...) +} + +func (ms *Masked) SetFloat(val float64, i ...int) { + if !ms.Mask.Bool(i...) { + return + } + ms.Tensor.SetFloat(val, i...) +} + +func (ms *Masked) Float1D(i int) float64 { + if !ms.Mask.Bool1D(i) { + return math.NaN() + } + return ms.Tensor.Float1D(i) +} + +func (ms *Masked) SetFloat1D(val float64, i int) { + if !ms.Mask.Bool1D(i) { + return + } + ms.Tensor.SetFloat1D(val, i) +} + +//////// Strings + +func (ms *Masked) StringValue(i ...int) string { + if !ms.Mask.Bool(i...) { + return "" + } + return ms.Tensor.StringValue(i...) +} + +func (ms *Masked) SetString(val string, i ...int) { + if !ms.Mask.Bool(i...) { + return + } + ms.Tensor.SetString(val, i...) +} + +func (ms *Masked) String1D(i int) string { + if !ms.Mask.Bool1D(i) { + return "" + } + return ms.Tensor.String1D(i) +} + +func (ms *Masked) SetString1D(val string, i int) { + if !ms.Mask.Bool1D(i) { + return + } + ms.Tensor.SetString1D(val, i) +} + +//////// Ints + +func (ms *Masked) Int(i ...int) int { + if !ms.Mask.Bool(i...) { + return 0 + } + return ms.Tensor.Int(i...) +} + +func (ms *Masked) SetInt(val int, i ...int) { + if !ms.Mask.Bool(i...) { + return + } + ms.Tensor.SetInt(val, i...) +} + +func (ms *Masked) Int1D(i int) int { + if !ms.Mask.Bool1D(i) { + return 0 + } + return ms.Tensor.Int1D(i) +} + +// SetInt1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (ms *Masked) SetInt1D(val int, i int) { + if !ms.Mask.Bool1D(i) { + return + } + ms.Tensor.SetInt1D(val, i) +} + +// Filter sets the mask values using given Filter function. +// The filter function gets the 1D index into the source tensor. +func (ms *Masked) Filter(filterer func(tsr Tensor, idx int) bool) { + n := ms.Tensor.Len() + for i := range n { + ms.Mask.SetBool1D(filterer(ms.Tensor, i), i) + } +} + +// check for interface impl +var _ Tensor = (*Masked)(nil) diff --git a/tensor/matrix/README.md b/tensor/matrix/README.md new file mode 100644 index 0000000000..c27c88a623 --- /dev/null +++ b/tensor/matrix/README.md @@ -0,0 +1,13 @@ +# matrix: linear algebra with tensors + +This package provides interfaces for `Tensor` types to the [gonum](https://github.com/gonum/gonum) functions for linear algebra, defined on the 2D `mat.Matrix` interface. + +# TODO + +Add following functions here: + +* `eye` +* `identity` +* `diag` +* `diagonal` + diff --git a/tensor/matrix/eigen.go b/tensor/matrix/eigen.go new file mode 100644 index 0000000000..fef7af50af --- /dev/null +++ b/tensor/matrix/eigen.go @@ -0,0 +1,397 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/stats/stats" + "cogentcore.org/core/tensor/tmath" + "gonum.org/v1/gonum/mat" +) + +// Eig performs the eigen decomposition of the given square matrix, +// which is not symmetric. See EigSym for a symmetric square matrix. +// In this non-symmetric case, the results are typically complex valued, +// so the outputs are complex tensors. TODO: need complex support! +// The vectors are same size as the input. Each vector is a column +// in this 2D square matrix, ordered *lowest* to *highest* across the columns, +// i.e., maximum vector is the last column. +// The values are the size of one row, ordered *lowest* to *highest*. +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func Eig(a tensor.Tensor) (vecs, vals *tensor.Float64) { + vecs = tensor.NewFloat64() + vals = tensor.NewFloat64() + errors.Log(EigOut(a, vecs, vals)) + return +} + +// EigOut performs the eigen decomposition of the given square matrix, +// which is not symmetric. See EigSym for a symmetric square matrix. +// In this non-symmetric case, the results are typically complex valued, +// so the outputs are complex tensors. TODO: need complex support! +// The vectors are same size as the input. Each vector is a column +// in this 2D square matrix, ordered *lowest* to *highest* across the columns, +// i.e., maximum vector is the last column. +// The values are the size of one row, ordered *lowest* to *highest*. +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func EigOut(a tensor.Tensor, vecs, vals *tensor.Float64) error { + if err := StringCheck(a); err != nil { + return err + } + na := a.NumDims() + if na == 1 { + return mat.ErrShape + } + var asz []int + ea := a + if na > 2 { + asz = tensor.SplitAtInnerDims(a, 2) + if asz[0] == 1 { + ea = tensor.Reshape(a, asz[1:]...) + na = 2 + } + } + if na == 2 { + if a.DimSize(0) != a.DimSize(1) { + return mat.ErrShape + } + ma, _ := NewMatrix(a) + vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1)) + vals.SetShapeSizes(a.DimSize(0)) + do, _ := NewDense(vecs) + var eig mat.Eigen + ok := eig.Factorize(ma, mat.EigenRight) + if !ok { + return errors.New("gonum mat.Eigen Factorize failed") + } + _ = do + // eig.VectorsTo(do) // todo: requires complex! + // eig.Values(vals.Values) + return nil + } + ea = tensor.Reshape(a, asz...) + if ea.DimSize(1) != ea.DimSize(2) { + return mat.ErrShape + } + nr := ea.DimSize(0) + sz := ea.DimSize(1) + vecs.SetShapeSizes(nr, sz, sz) + vals.SetShapeSizes(nr, sz) + var errs []error + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000, + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewMatrix(sa) + do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64)) + var eig mat.Eigen + ok := eig.Factorize(ma, mat.EigenRight) + if !ok { + errs = append(errs, errors.New("gonum mat.Eigen Factorize failed")) + } + _ = do + // eig.VectorsTo(do) // todo: requires complex! + // eig.Values(vals.Values[r*sz : (r+1)*sz]) + }) + return errors.Join(errs...) +} + +// EigSym performs the eigen decomposition of the given symmetric square matrix, +// which produces real-valued results. When input is the [metric.CovarianceMatrix], +// this is known as Principal Components Analysis (PCA). +// The vectors are same size as the input. Each vector is a column +// in this 2D square matrix, ordered *lowest* to *highest* across the columns, +// i.e., maximum vector is the last column. +// The values are the size of one row, ordered *lowest* to *highest*. +// Note that Eig produces results in the *opposite* order of [SVD] (which is much faster). +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func EigSym(a tensor.Tensor) (vecs, vals *tensor.Float64) { + vecs = tensor.NewFloat64() + vals = tensor.NewFloat64() + errors.Log(EigSymOut(a, vecs, vals)) + return +} + +// EigSymOut performs the eigen decomposition of the given symmetric square matrix, +// which produces real-valued results. When input is the [metric.CovarianceMatrix], +// this is known as Principal Components Analysis (PCA). +// The vectors are same size as the input. Each vector is a column +// in this 2D square matrix, ordered *lowest* to *highest* across the columns, +// i.e., maximum vector is the last column. +// The values are the size of one row, ordered *lowest* to *highest*. +// Note that Eig produces results in the *opposite* order of [SVD] (which is much faster). +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func EigSymOut(a tensor.Tensor, vecs, vals *tensor.Float64) error { + if err := StringCheck(a); err != nil { + return err + } + na := a.NumDims() + if na == 1 { + return mat.ErrShape + } + var asz []int + ea := a + if na > 2 { + asz = tensor.SplitAtInnerDims(a, 2) + if asz[0] == 1 { + ea = tensor.Reshape(a, asz[1:]...) + na = 2 + } + } + if na == 2 { + if a.DimSize(0) != a.DimSize(1) { + return mat.ErrShape + } + ma, _ := NewSymmetric(a) + vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1)) + vals.SetShapeSizes(a.DimSize(0)) + do, _ := NewDense(vecs) + var eig mat.EigenSym + ok := eig.Factorize(ma, true) + if !ok { + return errors.New("gonum mat.EigenSym Factorize failed") + } + eig.VectorsTo(do) + eig.Values(vals.Values) + return nil + } + ea = tensor.Reshape(a, asz...) + if ea.DimSize(1) != ea.DimSize(2) { + return mat.ErrShape + } + nr := ea.DimSize(0) + sz := ea.DimSize(1) + vecs.SetShapeSizes(nr, sz, sz) + vals.SetShapeSizes(nr, sz) + var errs []error + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000, + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewSymmetric(sa) + do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64)) + var eig mat.EigenSym + ok := eig.Factorize(ma, true) + if !ok { + errs = append(errs, errors.New("gonum mat.Eigen Factorize failed")) + } + eig.VectorsTo(do) + eig.Values(vals.Values[r*sz : (r+1)*sz]) + }) + return errors.Join(errs...) +} + +// SVD performs the singular value decomposition of the given symmetric square matrix, +// which produces real-valued results, and is generally much faster than [EigSym], +// while producing the same results. +// The vectors are same size as the input. Each vector is a column +// in this 2D square matrix, ordered *highest* to *lowest* across the columns, +// i.e., maximum vector is the first column. +// The values are the size of one row ordered in alignment with the vectors. +// Note that SVD produces results in the *opposite* order of [EigSym]. +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func SVD(a tensor.Tensor) (vecs, vals *tensor.Float64) { + vecs = tensor.NewFloat64() + vals = tensor.NewFloat64() + errors.Log(SVDOut(a, vecs, vals)) + return +} + +// SVDOut performs the singular value decomposition of the given symmetric square matrix, +// which produces real-valued results, and is generally much faster than [EigSym], +// while producing the same results. +// The vectors are same size as the input. Each vector is a column +// in this 2D square matrix, ordered *highest* to *lowest* across the columns, +// i.e., maximum vector is the first column. +// The values are the size of one row ordered in alignment with the vectors. +// Note that SVD produces results in the *opposite* order of [EigSym]. +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func SVDOut(a tensor.Tensor, vecs, vals *tensor.Float64) error { + if err := StringCheck(a); err != nil { + return err + } + na := a.NumDims() + if na == 1 { + return mat.ErrShape + } + var asz []int + ea := a + if na > 2 { + asz = tensor.SplitAtInnerDims(a, 2) + if asz[0] == 1 { + ea = tensor.Reshape(a, asz[1:]...) + na = 2 + } + } + if na == 2 { + if a.DimSize(0) != a.DimSize(1) { + return mat.ErrShape + } + ma, _ := NewSymmetric(a) + vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1)) + vals.SetShapeSizes(a.DimSize(0)) + do, _ := NewDense(vecs) + var eig mat.SVD + ok := eig.Factorize(ma, mat.SVDFull) + if !ok { + return errors.New("gonum mat.SVD Factorize failed") + } + eig.UTo(do) + eig.Values(vals.Values) + return nil + } + ea = tensor.Reshape(a, asz...) + if ea.DimSize(1) != ea.DimSize(2) { + return mat.ErrShape + } + nr := ea.DimSize(0) + sz := ea.DimSize(1) + vecs.SetShapeSizes(nr, sz, sz) + vals.SetShapeSizes(nr, sz) + var errs []error + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000, + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewSymmetric(sa) + do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64)) + var eig mat.SVD + ok := eig.Factorize(ma, mat.SVDFull) + if !ok { + errs = append(errs, errors.New("gonum mat.SVD Factorize failed")) + } + eig.UTo(do) + eig.Values(vals.Values[r*sz : (r+1)*sz]) + }) + return errors.Join(errs...) +} + +// SVDValues performs the singular value decomposition of the given +// symmetric square matrix, which produces real-valued results, +// and is generally much faster than [EigSym], while producing the same results. +// This version only generates eigenvalues, not vectors: see [SVD]. +// The values are the size of one row ordered highest to lowest, +// which is the opposite of [EigSym]. +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func SVDValues(a tensor.Tensor) *tensor.Float64 { + vals := tensor.NewFloat64() + errors.Log(SVDValuesOut(a, vals)) + return vals +} + +// SVDValuesOut performs the singular value decomposition of the given +// symmetric square matrix, which produces real-valued results, +// and is generally much faster than [EigSym], while producing the same results. +// This version only generates eigenvalues, not vectors: see [SVDOut]. +// The values are the size of one row ordered highest to lowest, +// which is the opposite of [EigSym]. +// If the input tensor is > 2D, it is treated as a list of 2D matricies, +// and parallel threading is used where beneficial. +func SVDValuesOut(a tensor.Tensor, vals *tensor.Float64) error { + if err := StringCheck(a); err != nil { + return err + } + na := a.NumDims() + if na == 1 { + return mat.ErrShape + } + var asz []int + ea := a + if na > 2 { + asz = tensor.SplitAtInnerDims(a, 2) + if asz[0] == 1 { + ea = tensor.Reshape(a, asz[1:]...) + na = 2 + } + } + if na == 2 { + if a.DimSize(0) != a.DimSize(1) { + return mat.ErrShape + } + ma, _ := NewSymmetric(a) + vals.SetShapeSizes(a.DimSize(0)) + var eig mat.SVD + ok := eig.Factorize(ma, mat.SVDNone) + if !ok { + return errors.New("gonum mat.SVD Factorize failed") + } + eig.Values(vals.Values) + return nil + } + ea = tensor.Reshape(a, asz...) + if ea.DimSize(1) != ea.DimSize(2) { + return mat.ErrShape + } + nr := ea.DimSize(0) + sz := ea.DimSize(1) + vals.SetShapeSizes(nr, sz) + var errs []error + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000, + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewSymmetric(sa) + var eig mat.SVD + ok := eig.Factorize(ma, mat.SVDNone) + if !ok { + errs = append(errs, errors.New("gonum mat.SVD Factorize failed")) + } + eig.Values(vals.Values[r*sz : (r+1)*sz]) + }) + return errors.Join(errs...) +} + +// ProjectOnMatrixColumn is a convenience function for projecting given vector +// of values along a specific column (2nd dimension) of the given 2D matrix, +// specified by the scalar colindex, putting results into out. +// If the vec is more than 1 dimensional, then it is treated as rows x cells, +// and each row of cells is projected through the matrix column, producing a +// 1D output with the number of rows. Otherwise a single number is produced. +// This is typically done with results from SVD or EigSym (PCA). +func ProjectOnMatrixColumn(mtx, vec, colindex tensor.Tensor) tensor.Values { + out := tensor.NewOfType(vec.DataType()) + errors.Log(ProjectOnMatrixColumnOut(mtx, vec, colindex, out)) + return out +} + +// ProjectOnMatrixColumnOut is a convenience function for projecting given vector +// of values along a specific column (2nd dimension) of the given 2D matrix, +// specified by the scalar colindex, putting results into out. +// If the vec is more than 1 dimensional, then it is treated as rows x cells, +// and each row of cells is projected through the matrix column, producing a +// 1D output with the number of rows. Otherwise a single number is produced. +// This is typically done with results from SVD or EigSym (PCA). +func ProjectOnMatrixColumnOut(mtx, vec, colindex tensor.Tensor, out tensor.Values) error { + ci := int(colindex.Float1D(0)) + col := tensor.As1D(tensor.Reslice(mtx, tensor.Slice{}, ci)) + // fmt.Println(mtx.String(), col.String()) + rows, cells := vec.Shape().RowCellSize() + if rows > 0 && cells > 0 { + msum := tensor.NewFloat64Scalar(0) + out.SetShapeSizes(rows) + mout := tensor.NewFloat64(cells) + for i := range rows { + err := tmath.MulOut(tensor.Cells1D(vec, i), col, mout) + if err != nil { + return err + } + stats.SumOut(mout, msum) + out.SetFloat1D(msum.Float1D(0), i) + } + } else { + mout := tensor.NewFloat64(1) + tmath.MulOut(vec, col, mout) + stats.SumOut(mout, out) + } + return nil +} diff --git a/tensor/matrix/eigen_test.go b/tensor/matrix/eigen_test.go new file mode 100644 index 0000000000..6165a5bfc5 --- /dev/null +++ b/tensor/matrix/eigen_test.go @@ -0,0 +1,88 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "testing" + + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "github.com/stretchr/testify/assert" +) + +func TestPCAIris(t *testing.T) { + dt := table.New() + dt.AddFloat64Column("data", 4) + dt.AddStringColumn("class") + err := dt.OpenCSV("testdata/iris.data", tensor.Comma) + if err != nil { + t.Error(err) + } + data := dt.Column("data") + + covar := tensor.NewFloat64(4, 4) + tensor.OpenCSV(covar, "testdata/iris-covar.tsv", tensor.Tab) + + vecs, vals := EigSym(covar) + + // fmt.Printf("correl vec: %v\n", vecs) + // fmt.Printf("correl val: %v\n", vals) + errtol := 1.0e-9 + corvals := []float64{0.020607707235624825, 0.14735327830509573, 0.9212209307072254, 2.910818083752054} + for i, v := range vals.Values { + assert.InDelta(t, corvals[i], v, errtol) + } + + colidx := tensor.NewFloat64Scalar(3) // strongest at end + prjns := tensor.NewFloat64() + err = ProjectOnMatrixColumnOut(vecs, data, colidx, prjns) + assert.NoError(t, err) + // tensor.SaveCSV(prjns, "testdata/pca_projection.csv", tensor.Comma) + trgprjns := []float64{ + 2.6692308782935146, + 2.696434011868953, + 2.4811633041648684, + 2.5715124347750256, + 2.5906582247213543, + 3.0080988099460613, + 2.490941664609344, + 2.7014546083439073, + 2.4615836931965167, + 2.6716628159090594, + } + for i, v := range prjns.Values[:10] { + assert.InDelta(t, trgprjns[i], v, errtol) + } + + //////// SVD + + err = SVDOut(covar, vecs, vals) + assert.NoError(t, err) + // fmt.Printf("correl vec: %v\n", vecs) + // fmt.Printf("correl val: %v\n", vals) + for i, v := range vals.Values { + assert.InDelta(t, corvals[3-i], v, errtol) // opposite order + } + + colidx.SetFloat1D(0, 0) // strongest at start + err = ProjectOnMatrixColumnOut(vecs, data, colidx, prjns) + assert.NoError(t, err) + // tensor.SaveCSV(prjns, "testdata/svd_projection.csv", tensor.Comma) + trgprjns = []float64{ + -2.6692308782935172, + -2.696434011868955, + -2.48116330416487, + -2.5715124347750273, + -2.590658224721357, + -3.008098809946064, + -2.4909416646093456, + -2.70145460834391, + -2.4615836931965185, + -2.671662815909061, + } + for i, v := range prjns.Values[:10] { + assert.InDelta(t, trgprjns[i], v, errtol) + } +} diff --git a/tensor/matrix/indices.go b/tensor/matrix/indices.go new file mode 100644 index 0000000000..40f0e01c62 --- /dev/null +++ b/tensor/matrix/indices.go @@ -0,0 +1,322 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/num" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/vector" +) + +// offCols is a helper function to process the optional offset_cols args +func offCols(size int, offset_cols ...int) (off, cols int) { + off = 0 + cols = size + if len(offset_cols) >= 1 { + off = offset_cols[0] + } + if len(offset_cols) == 2 { + cols = offset_cols[1] + } + return +} + +// Identity returns a new 2D Float64 tensor with 1s along the diagonal and +// 0s elsewhere, with the given row and column size. +// - If one additional parameter is passed, it is the offset, +// to set values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func Identity(size int, offset_cols ...int) *tensor.Float64 { + off, cols := offCols(size, offset_cols...) + tsr := tensor.NewFloat64(size, cols) + for r := range size { + c := r + off + if c < 0 || c >= cols { + continue + } + tsr.SetFloat(1, r, c) + } + return tsr +} + +// DiagonalN returns the number of elements in the along the diagonal +// of a 2D matrix of given row and column size. +// - If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func DiagonalN(size int, offset_cols ...int) int { + off, cols := offCols(size, offset_cols...) + rows := size + if num.Abs(off) > 0 { + oa := num.Abs(off) + if off > 0 { + if cols > rows { + return DiagonalN(rows, 0, cols-oa) + } else { + return DiagonalN(rows-oa, 0, cols-oa) + } + } else { + if rows > cols { + return DiagonalN(rows-oa, 0, cols) + } else { + return DiagonalN(rows-oa, 0, cols-oa) + } + } + } + n := min(rows, cols) + return n +} + +// DiagonalIndices returns a list of indices for the diagonal elements of +// a 2D matrix of given row and column size. +// The result is a 2D list of indices, where the outer (row) dimension +// is the number of indices, and the inner dimension is 2 for the r, c coords. +// - If one additional parameter is passed, it is the offset, +// to set values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func DiagonalIndices(size int, offset_cols ...int) *tensor.Int { + off, cols := offCols(size, offset_cols...) + dn := DiagonalN(size, off, cols) + tsr := tensor.NewInt(dn, 2) + idx := 0 + for r := range size { + c := r + off + if c < 0 || c >= cols { + continue + } + tsr.SetInt(r, idx, 0) + tsr.SetInt(c, idx, 1) + idx++ + } + return tsr +} + +// Diagonal returns an [Indexed] view of the given tensor for the diagonal +// values, as a 1D list. An error is logged if the tensor is not 2D. +// Use the optional offset parameter to get values above (positive) or +// below (negative) the diagonal. +func Diagonal(tsr tensor.Tensor, offset ...int) *tensor.Indexed { + if tsr.NumDims() != 2 { + errors.Log(errors.New("matrix.TriLView requires a 2D tensor")) + return nil + } + off := 0 + if len(offset) == 1 { + off = offset[0] + } + return tensor.NewIndexed(tsr, DiagonalIndices(tsr.DimSize(0), off, tsr.DimSize(1))) +} + +// Trace returns the sum of the [Diagonal] elements of the given +// tensor, as a tensor scalar. +// An error is logged if the tensor is not 2D. +// Use the optional offset parameter to get values above (positive) or +// below (negative) the diagonal. +func Trace(tsr tensor.Tensor, offset ...int) tensor.Values { + return vector.Sum(Diagonal(tsr, offset...)) +} + +// Tri returns a new 2D Float64 tensor with 1s along the diagonal and +// below it, and 0s elsewhere (i.e., a filled lower triangle). +// - If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func Tri(size int, offset_cols ...int) *tensor.Float64 { + off, cols := offCols(size, offset_cols...) + tsr := tensor.NewFloat64(size, cols) + for r := range size { + for c := range cols { + if c <= r+off { + tsr.SetFloat(1, r, c) + } + } + } + return tsr +} + +// TriUpper returns a new 2D Float64 tensor with 1s along the diagonal and +// above it, and 0s elsewhere (i.e., a filled upper triangle). +// - If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func TriUpper(size int, offset_cols ...int) *tensor.Float64 { + off, cols := offCols(size, offset_cols...) + tsr := tensor.NewFloat64(size, cols) + for r := range size { + for c := range cols { + if c >= r+off { + tsr.SetFloat(1, r, c) + } + } + } + return tsr +} + +// TriUNum returns the number of elements in the upper triangular region +// of a 2D matrix of given row and column size, where the triangle includes the +// elements along the diagonal. +// - If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func TriUNum(size int, offset_cols ...int) int { + off, cols := offCols(size, offset_cols...) + rows := size + if off > 0 { + if cols > rows { + return TriUNum(rows, 0, cols-off) + } else { + return TriUNum(rows-off, 0, cols-off) + } + } else if off < 0 { // invert + return cols*rows - TriUNum(cols, -(off-1), rows) + } + if cols <= size { + return cols + (cols*(cols-1))/2 + } + return rows + (rows*(2*cols-rows-1))/2 +} + +// TriLNum returns the number of elements in the lower triangular region +// of a 2D matrix of given row and column size, where the triangle includes the +// elements along the diagonal. +// - If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix (first size parameter = number of rows). +func TriLNum(size int, offset_cols ...int) int { + off, cols := offCols(size, offset_cols...) + return TriUNum(cols, -off, size) +} + +// TriLIndicies returns the list of r, c indexes for the lower triangular +// portion of a square matrix of size n, including the diagonal. +// The result is a 2D list of indices, where the outer (row) dimension +// is the number of indices, and the inner dimension is 2 for the r, c coords. +// - If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// - If a second additional parameter is passed, it is the number of columns +// for a non-square matrix. +func TriLIndicies(size int, offset_cols ...int) *tensor.Int { + off, cols := offCols(size, offset_cols...) + trin := TriLNum(size, off, cols) + coords := tensor.NewInt(trin, 2) + i := 0 + for r := range size { + for c := range cols { + if c <= r+off { + coords.SetInt(r, i, 0) + coords.SetInt(c, i, 1) + i++ + } + } + } + return coords +} + +// TriUIndicies returns the list of r, c indexes for the upper triangular +// portion of a square matrix of size n, including the diagonal. +// If one additional parameter is passed, it is the offset, +// to include values above (positive) or below (negative) the diagonal. +// If a second additional parameter is passed, it is the number of columns +// for a non-square matrix. +// The result is a 2D list of indices, where the outer (row) dimension +// is the number of indices, and the inner dimension is 2 for the r, c coords. +func TriUIndicies(size int, offset_cols ...int) *tensor.Int { + off, cols := offCols(size, offset_cols...) + trin := TriUNum(size, off, cols) + coords := tensor.NewInt(trin, 2) + i := 0 + for r := range size { + for c := range cols { + if c >= r+off { + coords.SetInt(r, i, 0) + coords.SetInt(c, i, 1) + i++ + } + } + } + return coords +} + +// TriLView returns an [Indexed] view of the given tensor for the lower triangular +// region of values, as a 1D list. An error is logged if the tensor is not 2D. +// Use the optional offset parameter to get values above (positive) or +// below (negative) the diagonal. +func TriLView(tsr tensor.Tensor, offset ...int) *tensor.Indexed { + if tsr.NumDims() != 2 { + errors.Log(errors.New("matrix.TriLView requires a 2D tensor")) + return nil + } + off := 0 + if len(offset) == 1 { + off = offset[0] + } + return tensor.NewIndexed(tsr, TriLIndicies(tsr.DimSize(0), off, tsr.DimSize(1))) +} + +// TriUView returns an [Indexed] view of the given tensor for the upper triangular +// region of values, as a 1D list. An error is logged if the tensor is not 2D. +// Use the optional offset parameter to get values above (positive) or +// below (negative) the diagonal. +func TriUView(tsr tensor.Tensor, offset ...int) *tensor.Indexed { + if tsr.NumDims() != 2 { + errors.Log(errors.New("matrix.TriUView requires a 2D tensor")) + return nil + } + off := 0 + if len(offset) == 1 { + off = offset[0] + } + return tensor.NewIndexed(tsr, TriUIndicies(tsr.DimSize(0), off, tsr.DimSize(1))) +} + +// TriL returns a copy of the given tensor containing the lower triangular +// region of values (including the diagonal), with the lower triangular region +// zeroed. An error is logged if the tensor is not 2D. +// Use the optional offset parameter to include values above (positive) or +// below (negative) the diagonal. +func TriL(tsr tensor.Tensor, offset ...int) tensor.Tensor { + if tsr.NumDims() != 2 { + errors.Log(errors.New("matrix.TriL requires a 2D tensor")) + return nil + } + off := 0 + if len(offset) == 1 { + off = offset[0] + } + off += 1 + tc := tensor.Clone(tsr) + tv := TriUView(tc, off) // opposite + tensor.SetAllFloat64(tv, 0) + return tc +} + +// TriU returns a copy of the given tensor containing the upper triangular +// region of values (including the diagonal), with the lower triangular region +// zeroed. An error is logged if the tensor is not 2D. +// Use the optional offset parameter to include values above (positive) or +// below (negative) the diagonal. +func TriU(tsr tensor.Tensor, offset ...int) tensor.Tensor { + if tsr.NumDims() != 2 { + errors.Log(errors.New("matrix.TriU requires a 2D tensor")) + return nil + } + off := 0 + if len(offset) == 1 { + off = offset[0] + } + off -= 1 + tc := tensor.Clone(tsr) + tv := TriLView(tc, off) // opposite + tensor.SetAllFloat64(tv, 0) + return tc +} diff --git a/tensor/matrix/indices_test.go b/tensor/matrix/indices_test.go new file mode 100644 index 0000000000..df40fd0771 --- /dev/null +++ b/tensor/matrix/indices_test.go @@ -0,0 +1,105 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "testing" + + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/vector" + "github.com/stretchr/testify/assert" +) + +func TestIndices(t *testing.T) { + assert.Equal(t, []float64{1, 0, 0, 1}, Identity(2).Values) + assert.Equal(t, []float64{1, 0, 0, 0, 1, 0, 0, 0, 1}, Identity(3).Values) + assert.Equal(t, []float64{0, 0, 0, 1, 0, 0, 0, 1, 0}, Identity(3, -1).Values) + + assert.Equal(t, int(vector.Sum(Identity(3)).Float1D(0)), DiagonalN(3)) + assert.Equal(t, int(vector.Sum(Identity(3, 0, 4)).Float1D(0)), DiagonalN(3, 0, 4)) + assert.Equal(t, int(vector.Sum(Identity(3, 0, 2)).Float1D(0)), DiagonalN(3, 0, 2)) + assert.Equal(t, int(vector.Sum(Identity(3, 1)).Float1D(0)), DiagonalN(3, 1)) + assert.Equal(t, int(vector.Sum(Identity(10, 4, 7)).Float1D(0)), DiagonalN(10, 4, 7)) + assert.Equal(t, int(vector.Sum(Identity(10, 4, 12)).Float1D(0)), DiagonalN(10, 4, 12)) + assert.Equal(t, int(vector.Sum(Identity(3, -1)).Float1D(0)), DiagonalN(3, -1)) + assert.Equal(t, int(vector.Sum(Identity(10, -4, 7)).Float1D(0)), DiagonalN(10, -4, 7)) + assert.Equal(t, int(vector.Sum(Identity(10, -4, 12)).Float1D(0)), DiagonalN(10, -4, 12)) + + assert.Equal(t, []int{0, 0, 1, 1, 2, 2}, DiagonalIndices(3).Values) + assert.Equal(t, []int{0, 1, 1, 2}, DiagonalIndices(3, 1).Values) + assert.Equal(t, []int{1, 0, 2, 1}, DiagonalIndices(3, -1).Values) + assert.Equal(t, []int{1, 0, 2, 1}, DiagonalIndices(3, -1, 4).Values) + assert.Equal(t, []int{0, 1, 1, 2, 2, 3}, DiagonalIndices(3, 1, 4).Values) + + a := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, 10), 3, 3)) + assert.Equal(t, []float64{1, 5, 9}, tensor.Flatten(Diagonal(a)).(*tensor.Float64).Values) + assert.Equal(t, []float64{4, 8}, tensor.Flatten(Diagonal(a, -1)).(*tensor.Float64).Values) + assert.Equal(t, []float64{2, 6}, tensor.Flatten(Diagonal(a, 1)).(*tensor.Float64).Values) + + assert.Equal(t, 15.0, Trace(a).Float1D(0)) + assert.Equal(t, 12.0, Trace(a, -1).Float1D(0)) + assert.Equal(t, 8.0, Trace(a, 1).Float1D(0)) + + assert.Equal(t, []float64{1, 0, 0, 1, 1, 0, 1, 1, 1}, Tri(3).Values) + assert.Equal(t, []float64{1, 1, 0, 1, 1, 1, 1, 1, 1}, Tri(3, 1).Values) + assert.Equal(t, []float64{0, 0, 0, 1, 0, 0, 1, 1, 0}, Tri(3, -1).Values) + + assert.Equal(t, []float64{0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0}, Tri(3, -1, 4).Values) + assert.Equal(t, []float64{0, 0, 1, 0, 1, 1}, Tri(3, -1, 2).Values) + + assert.Equal(t, []float64{1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1}, Tri(3, 1, 4).Values) + assert.Equal(t, []float64{1, 1, 1, 1, 1, 1}, Tri(3, 1, 2).Values) + + assert.Equal(t, int(vector.Sum(TriUpper(3)).Float1D(0)), TriUNum(3)) + assert.Equal(t, int(vector.Sum(TriUpper(3, 0, 4)).Float1D(0)), TriUNum(3, 0, 4)) + assert.Equal(t, int(vector.Sum(TriUpper(3, 0, 2)).Float1D(0)), TriUNum(3, 0, 2)) + assert.Equal(t, int(vector.Sum(TriUpper(3, 1)).Float1D(0)), TriUNum(3, 1)) + assert.Equal(t, int(vector.Sum(TriUpper(3, 1, 4)).Float1D(0)), TriUNum(3, 1, 4)) + assert.Equal(t, int(vector.Sum(TriUpper(10, 4, 7)).Float1D(0)), TriUNum(10, 4, 7)) + assert.Equal(t, int(vector.Sum(TriUpper(10, 4, 12)).Float1D(0)), TriUNum(10, 4, 12)) + assert.Equal(t, int(vector.Sum(TriUpper(3, -1)).Float1D(0)), TriUNum(3, -1)) + assert.Equal(t, int(vector.Sum(TriUpper(3, -1, 4)).Float1D(0)), TriUNum(3, -1, 4)) + assert.Equal(t, int(vector.Sum(TriUpper(3, -1, 2)).Float1D(0)), TriUNum(3, -1, 2)) + assert.Equal(t, int(vector.Sum(TriUpper(10, -4, 7)).Float1D(0)), TriUNum(10, -4, 7)) + assert.Equal(t, int(vector.Sum(TriUpper(10, -4, 12)).Float1D(0)), TriUNum(10, -4, 12)) + + assert.Equal(t, int(vector.Sum(Tri(3)).Float1D(0)), TriLNum(3)) + assert.Equal(t, int(vector.Sum(Tri(3, 0, 4)).Float1D(0)), TriLNum(3, 0, 4)) + assert.Equal(t, int(vector.Sum(Tri(3, 0, 2)).Float1D(0)), TriLNum(3, 0, 2)) + assert.Equal(t, int(vector.Sum(Tri(3, 1)).Float1D(0)), TriLNum(3, 1)) + assert.Equal(t, int(vector.Sum(Tri(3, 1, 4)).Float1D(0)), TriLNum(3, 1, 4)) + assert.Equal(t, int(vector.Sum(Tri(10, 4, 7)).Float1D(0)), TriLNum(10, 4, 7)) + assert.Equal(t, int(vector.Sum(Tri(10, 4, 12)).Float1D(0)), TriLNum(10, 4, 12)) + assert.Equal(t, int(vector.Sum(Tri(3, -1)).Float1D(0)), TriLNum(3, -1)) + assert.Equal(t, int(vector.Sum(Tri(3, -1, 4)).Float1D(0)), TriLNum(3, -1, 4)) + assert.Equal(t, int(vector.Sum(Tri(3, -1, 2)).Float1D(0)), TriLNum(3, -1, 2)) + assert.Equal(t, int(vector.Sum(Tri(10, -4, 7)).Float1D(0)), TriLNum(10, -4, 7)) + assert.Equal(t, int(vector.Sum(Tri(10, -4, 12)).Float1D(0)), TriLNum(10, -4, 12)) + + tli := TriLIndicies(3) + assert.Equal(t, []int{0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2}, tli.Values) + + tli = TriLIndicies(3, -1) + assert.Equal(t, []int{1, 0, 2, 0, 2, 1}, tli.Values) + + tli = TriLIndicies(3, 1) + assert.Equal(t, []int{0, 0, 0, 1, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1, 2, 2}, tli.Values) + + tli = TriUIndicies(3, 1) + assert.Equal(t, []int{0, 1, 0, 2, 1, 2}, tli.Values) + + tli = TriUIndicies(3, -1) + assert.Equal(t, []int{0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 1, 2, 2}, tli.Values) + + tf := tensor.NewFloat64Ones(3, 4) + + assert.Equal(t, Tri(3, -1, 4).Values, TriL(tf, -1).(*tensor.Float64).Values) + assert.Equal(t, Tri(3, 0, 4).Values, TriL(tf).(*tensor.Float64).Values) + assert.Equal(t, Tri(3, 1, 4).Values, TriL(tf, 1).(*tensor.Float64).Values) + + assert.Equal(t, TriUpper(3, -1, 4).Values, TriU(tf, -1).(*tensor.Float64).Values) + assert.Equal(t, TriUpper(3, 0, 4).Values, TriU(tf).(*tensor.Float64).Values) + assert.Equal(t, TriUpper(3, 1, 4).Values, TriU(tf, 1).(*tensor.Float64).Values) +} diff --git a/tensor/matrix/matrix.go b/tensor/matrix/matrix.go new file mode 100644 index 0000000000..bb4ada60ef --- /dev/null +++ b/tensor/matrix/matrix.go @@ -0,0 +1,121 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "errors" + + "cogentcore.org/core/tensor" + "gonum.org/v1/gonum/mat" +) + +// Matrix provides a view of the given [tensor.Tensor] as a [gonum] +// [mat.Matrix] interface type. +type Matrix struct { + Tensor tensor.Tensor +} + +func StringCheck(tsr tensor.Tensor) error { + if tsr.IsString() { + return errors.New("matrix: tensor has string values; must be numeric") + } + return nil +} + +// NewMatrix returns given [tensor.Tensor] as a [gonum] [mat.Matrix]. +// It returns an error if the tensor is not 2D. +func NewMatrix(tsr tensor.Tensor) (*Matrix, error) { + if err := StringCheck(tsr); err != nil { + return nil, err + } + nd := tsr.NumDims() + if nd != 2 { + err := errors.New("matrix.NewMatrix: tensor is not 2D") + return nil, err + } + return &Matrix{Tensor: tsr}, nil +} + +// Dims is the gonum/mat.Matrix interface method for returning the +// dimension sizes of the 2D Matrix. Assumes Row-major ordering. +func (mx *Matrix) Dims() (r, c int) { + return mx.Tensor.DimSize(0), mx.Tensor.DimSize(1) +} + +// At is the gonum/mat.Matrix interface method for returning 2D +// matrix element at given row, column index. Assumes Row-major ordering. +func (mx *Matrix) At(i, j int) float64 { + return mx.Tensor.Float(i, j) +} + +// T is the gonum/mat.Matrix transpose method. +// It performs an implicit transpose by returning the receiver inside a Transpose. +func (mx *Matrix) T() mat.Matrix { + return mat.Transpose{mx} +} + +//////// Symmetric + +// Symmetric provides a view of the given [tensor.Tensor] as a [gonum] +// [mat.Symmetric] matrix interface type. +type Symmetric struct { + Matrix +} + +// NewSymmetric returns given [tensor.Tensor] as a [gonum] [mat.Symmetric] matrix. +// It returns an error if the tensor is not 2D or not symmetric. +func NewSymmetric(tsr tensor.Tensor) (*Symmetric, error) { + if tsr.IsString() { + err := errors.New("matrix.NewSymmetric: tensor has string values; must be numeric") + return nil, err + } + nd := tsr.NumDims() + if nd != 2 { + err := errors.New("matrix.NewSymmetric: tensor is not 2D") + return nil, err + } + if tsr.DimSize(0) != tsr.DimSize(1) { + err := errors.New("matrix.NewSymmetric: tensor is not symmetric") + return nil, err + } + sy := &Symmetric{} + sy.Tensor = tsr + return sy, nil +} + +// SymmetricDim is the gonum/mat.Matrix interface method for returning the +// dimensionality of a symmetric 2D Matrix. +func (sy *Symmetric) SymmetricDim() (r int) { + return sy.Tensor.DimSize(0) +} + +// NewDense returns given [tensor.Float64] as a [gonum] [mat.Dense] +// Matrix, on which many of the matrix operations are defined. +// It functions similar to the [tensor.Values] type, as the output +// of matrix operations. The Dense type serves as a view onto +// the tensor's data, so operations directly modify it. +func NewDense(tsr *tensor.Float64) (*mat.Dense, error) { + nd := tsr.NumDims() + if nd != 2 { + err := errors.New("matrix.NewDense: tensor is not 2D") + return nil, err + } + return mat.NewDense(tsr.DimSize(0), tsr.DimSize(1), tsr.Values), nil +} + +// CopyFromDense copies a gonum mat.Dense matrix into given Tensor +// using standard Float64 interface +func CopyFromDense(to tensor.Values, dm *mat.Dense) { + nr, nc := dm.Dims() + to.SetShapeSizes(nr, nc) + idx := 0 + for ri := 0; ri < nr; ri++ { + for ci := 0; ci < nc; ci++ { + v := dm.At(ri, ci) + to.SetFloat1D(v, idx) + idx++ + } + } +} diff --git a/tensor/matrix/matrix_test.go b/tensor/matrix/matrix_test.go new file mode 100644 index 0000000000..566345cee2 --- /dev/null +++ b/tensor/matrix/matrix_test.go @@ -0,0 +1,111 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "fmt" + "testing" + + "cogentcore.org/core/base/tolassert" + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" +) + +func TestMatrix(t *testing.T) { + a := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, 5), 2, 2)) + // fmt.Println(a) + + v := tensor.NewFloat64FromValues(2, 3) + _ = v + + o := Mul(a, a) + assert.Equal(t, []float64{7, 10, 15, 22}, o.Values) + + o = Mul(a, v) + assert.Equal(t, []float64{8, 18}, o.Values) + assert.Equal(t, []int{2}, o.Shape().Sizes) + + o = Mul(v, a) + assert.Equal(t, []float64{11, 16}, o.Values) + assert.Equal(t, []int{2}, o.Shape().Sizes) + + nr := 3 + b := tensor.NewFloat64(nr, 1, 2, 2) + for r := range nr { + b.SetRowTensor(a, r) + } + // fmt.Println(b) + + o = Mul(b, a) + assert.Equal(t, []float64{7, 10, 15, 22, 7, 10, 15, 22, 7, 10, 15, 22}, o.Values) + assert.Equal(t, []int{3, 2, 2}, o.Shape().Sizes) + + o = Mul(a, b) + assert.Equal(t, []float64{7, 10, 15, 22, 7, 10, 15, 22, 7, 10, 15, 22}, o.Values) + assert.Equal(t, []int{3, 2, 2}, o.Shape().Sizes) + + o = Mul(b, b) + assert.Equal(t, []float64{7, 10, 15, 22, 7, 10, 15, 22, 7, 10, 15, 22}, o.Values) + assert.Equal(t, []int{3, 2, 2}, o.Shape().Sizes) + + o = Mul(v, b) + assert.Equal(t, []float64{11, 16, 11, 16, 11, 16}, o.Values) + assert.Equal(t, []int{3, 2}, o.Shape().Sizes) + + o = Mul(b, v) + assert.Equal(t, []float64{8, 18, 8, 18, 8, 18}, o.Values) + assert.Equal(t, []int{3, 2}, o.Shape().Sizes) + + o = Mul(a, tensor.Transpose(a)) + assert.Equal(t, []float64{5, 11, 11, 25}, o.Values) + + d := Det(a) + assert.Equal(t, -2.0, d.Float1D(0)) + + inv := Inverse(a) + tolassert.EqualTolSlice(t, []float64{-2, 1, 1.5, -0.5}, inv.Values, 1.0e-8) + + inv = Inverse(b) + tolassert.EqualTolSlice(t, []float64{-2, 1, 1.5, -0.5, -2, 1, 1.5, -0.5, -2, 1, 1.5, -0.5}, inv.Values, 1.0e-8) +} + +func runBenchMult(b *testing.B, n int, thread bool) { + if thread { + tensor.ThreadingThreshold = 1 + } else { + tensor.ThreadingThreshold = 100_000_000 + } + nrows := 10 // benefits even at 10 x 2 x 2 = 40 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, nrows*n*n+1), nrows, n, n)) + bv := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, nrows*n*n+1), nrows, n, n)) + ov := tensor.NewFloat64(nrows, n, n) + b.ResetTimer() + for range b.N { + MulOut(av, bv, ov) + } +} + +// to run this benchmark, do: +// go test -bench BenchmarkMult -count 10 >bench.txt +// go install golang.org/x/perf/cmd/benchstat@latest +// benchstat -row /n -col .name bench.txt + +var ns = []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40} + +func BenchmarkMultThreaded(b *testing.B) { + for _, n := range ns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchMult(b, n, true) + }) + } +} + +func BenchmarkMultSingle(b *testing.B) { + for _, n := range ns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchMult(b, n, false) + }) + } +} diff --git a/tensor/matrix/ops.go b/tensor/matrix/ops.go new file mode 100644 index 0000000000..6ab669e4a5 --- /dev/null +++ b/tensor/matrix/ops.go @@ -0,0 +1,260 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package matrix + +import ( + "slices" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" + "gonum.org/v1/gonum/mat" +) + +// CallOut1 calls an Out function with 1 input arg. All matrix functions +// require *tensor.Float64 outputs. +func CallOut1(fun func(a tensor.Tensor, out *tensor.Float64) error, a tensor.Tensor) *tensor.Float64 { + out := tensor.NewFloat64() + errors.Log(fun(a, out)) + return out +} + +// CallOut2 calls an Out function with 2 input args. All matrix functions +// require *tensor.Float64 outputs. +func CallOut2(fun func(a, b tensor.Tensor, out *tensor.Float64) error, a, b tensor.Tensor) *tensor.Float64 { + out := tensor.NewFloat64() + errors.Log(fun(a, b, out)) + return out +} + +// Mul performs matrix multiplication, using the following rules based +// on the shapes of the relevant tensors. If the tensor shapes are not +// suitable, an error is logged (see [MulOut] for a version returning the error). +// N > 2 dimensional cases use parallel threading where beneficial. +// - If both arguments are 2-D they are multiplied like conventional matrices. +// - If either argument is N-D, N > 2, it is treated as a stack of matrices +// residing in the last two indexes and broadcast accordingly. +// - If the first argument is 1-D, it is promoted to a matrix by prepending +// a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. +// - If the second argument is 1-D, it is promoted to a matrix by appending +// a 1 to its dimensions. After matrix multiplication the appended 1 is removed. +func Mul(a, b tensor.Tensor) *tensor.Float64 { + return CallOut2(MulOut, a, b) +} + +// MulOut performs matrix multiplication, into the given output tensor, +// using the following rules based on the shapes of the relevant tensors. +// If the tensor shapes are not suitable, a [gonum] [mat.ErrShape] error is returned. +// N > 2 dimensional cases use parallel threading where beneficial. +// - If both arguments are 2-D they are multiplied like conventional matrices. +// The result has shape a.Rows, b.Columns. +// - If either argument is N-D, N > 2, it is treated as a stack of matrices +// residing in the last two indexes and broadcast accordingly. Both cannot +// be > 2 dimensional, unless their outer dimension size is 1 or the same. +// - If the first argument is 1-D, it is promoted to a matrix by prepending +// a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. +// - If the second argument is 1-D, it is promoted to a matrix by appending +// a 1 to its dimensions. After matrix multiplication the appended 1 is removed. +func MulOut(a, b tensor.Tensor, out *tensor.Float64) error { + if err := StringCheck(a); err != nil { + return err + } + if err := StringCheck(b); err != nil { + return err + } + na := a.NumDims() + nb := b.NumDims() + ea := a + eb := b + collapse := false + colDim := 0 + if na == 1 { + ea = tensor.Reshape(a, 1, a.DimSize(0)) + collapse = true + colDim = -2 + na = 2 + } + if nb == 1 { + eb = tensor.Reshape(b, b.DimSize(0), 1) + collapse = true + colDim = -1 + nb = 2 + } + if na > 2 { + asz := tensor.SplitAtInnerDims(a, 2) + if asz[0] == 1 { + ea = tensor.Reshape(a, asz[1:]...) + na = 2 + } else { + ea = tensor.Reshape(a, asz...) + } + } + if nb > 2 { + bsz := tensor.SplitAtInnerDims(b, 2) + if bsz[0] == 1 { + eb = tensor.Reshape(b, bsz[1:]...) + nb = 2 + } else { + eb = tensor.Reshape(b, bsz...) + } + } + switch { + case na == nb && na == 2: + if ea.DimSize(1) != eb.DimSize(0) { + return mat.ErrShape + } + ma, _ := NewMatrix(ea) + mb, _ := NewMatrix(eb) + out.SetShapeSizes(ea.DimSize(0), eb.DimSize(1)) + do, _ := NewDense(out) + do.Mul(ma, mb) + case na > 2 && nb == 2: + if ea.DimSize(2) != eb.DimSize(0) { + return mat.ErrShape + } + mb, _ := NewMatrix(eb) + nr := ea.DimSize(0) + out.SetShapeSizes(nr, ea.DimSize(1), eb.DimSize(1)) + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*eb.Len()*100, // always beneficial + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewMatrix(sa) + do, _ := NewDense(out.RowTensor(r).(*tensor.Float64)) + do.Mul(ma, mb) + }) + case nb > 2 && na == 2: + if ea.DimSize(1) != eb.DimSize(1) { + return mat.ErrShape + } + ma, _ := NewMatrix(ea) + nr := eb.DimSize(0) + out.SetShapeSizes(nr, ea.DimSize(0), eb.DimSize(2)) + tensor.VectorizeThreaded(ea.Len()*eb.DimSize(1)*eb.DimSize(2)*100, + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sb := tensor.Reslice(eb, r, tensor.FullAxis, tensor.FullAxis) + mb, _ := NewMatrix(sb) + do, _ := NewDense(out.RowTensor(r).(*tensor.Float64)) + do.Mul(ma, mb) + }) + case na > 2 && nb > 2: + if ea.DimSize(0) != eb.DimSize(0) { + return errors.New("matrix.Mul: a and b input matricies are > 2 dimensional; must have same outer dimension sizes") + } + if ea.DimSize(2) != eb.DimSize(1) { + return mat.ErrShape + } + nr := ea.DimSize(0) + out.SetShapeSizes(nr, ea.DimSize(1), eb.DimSize(2)) + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*eb.DimSize(1)*eb.DimSize(2), + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewMatrix(sa) + sb := tensor.Reslice(eb, r, tensor.FullAxis, tensor.FullAxis) + mb, _ := NewMatrix(sb) + do, _ := NewDense(out.RowTensor(r).(*tensor.Float64)) + do.Mul(ma, mb) + }) + default: + return mat.ErrShape + } + if collapse { + nd := out.NumDims() + sz := slices.Clone(out.Shape().Sizes) + if colDim == -1 { + out.SetShapeSizes(sz[:nd-1]...) + } else { + out.SetShapeSizes(append(sz[:nd-2], sz[nd-1])...) + } + } + return nil +} + +// todo: following should handle N>2 dim case. + +// Det returns the determinant of the given tensor. +// For a 2D matrix [[a, b], [c, d]] it this is ad - bc. +// See also [LogDet] for a version that is more numerically +// stable for large matricies. +func Det(a tensor.Tensor) *tensor.Float64 { + m, err := NewMatrix(a) + if errors.Log(err) != nil { + return tensor.NewFloat64Scalar(0) + } + return tensor.NewFloat64Scalar(mat.Det(m)) +} + +// LogDet returns the determinant of the given tensor, +// as the log and sign of the value, which is more +// numerically stable. The return is a 1D vector of length 2, +// with the first value being the log, and the second the sign. +func LogDet(a tensor.Tensor) *tensor.Float64 { + m, err := NewMatrix(a) + if errors.Log(err) != nil { + return tensor.NewFloat64Scalar(0) + } + l, s := mat.LogDet(m) + return tensor.NewFloat64FromValues(l, s) +} + +// Inverse performs matrix inversion of a square matrix, +// logging an error for non-invertable cases. +// See [InverseOut] for a version that returns an error. +// If the input tensor is > 2D, it is treated as a list of 2D matricies +// which are each inverted. +func Inverse(a tensor.Tensor) *tensor.Float64 { + return CallOut1(InverseOut, a) +} + +// InverseOut performs matrix inversion of a square matrix, +// returning an error for non-invertable cases. If the input tensor +// is > 2D, it is treated as a list of 2D matricies which are each inverted. +func InverseOut(a tensor.Tensor, out *tensor.Float64) error { + if err := StringCheck(a); err != nil { + return err + } + na := a.NumDims() + if na == 1 { + return mat.ErrShape + } + var asz []int + ea := a + if na > 2 { + asz = tensor.SplitAtInnerDims(a, 2) + if asz[0] == 1 { + ea = tensor.Reshape(a, asz[1:]...) + na = 2 + } + } + if na == 2 { + if a.DimSize(0) != a.DimSize(1) { + return mat.ErrShape + } + ma, _ := NewMatrix(a) + out.SetShapeSizes(a.DimSize(0), a.DimSize(1)) + do, _ := NewDense(out) + return do.Inverse(ma) + } + ea = tensor.Reshape(a, asz...) + if ea.DimSize(1) != ea.DimSize(2) { + return mat.ErrShape + } + nr := ea.DimSize(0) + out.SetShapeSizes(nr, ea.DimSize(1), ea.DimSize(2)) + var errs []error + tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*100, + func(tsr ...tensor.Tensor) int { return nr }, + func(r int, tsr ...tensor.Tensor) { + sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis) + ma, _ := NewMatrix(sa) + do, _ := NewDense(out.RowTensor(r).(*tensor.Float64)) + err := do.Inverse(ma) + if err != nil { + errs = append(errs, err) + } + }) + return errors.Join(errs...) +} diff --git a/tensor/matrix/testdata/iris-covar.tsv b/tensor/matrix/testdata/iris-covar.tsv new file mode 100644 index 0000000000..d1493e2ec7 --- /dev/null +++ b/tensor/matrix/testdata/iris-covar.tsv @@ -0,0 +1,4 @@ +1 -0.10936924995064935 0.8717541573048719 0.8179536333691635 +-0.10936924995064935 1 -0.4205160964011548 -0.3565440896138057 +0.8717541573048719 -0.4205160964011548 1 0.9627570970509667 +0.8179536333691635 -0.3565440896138057 0.9627570970509667 1 diff --git a/tensor/stats/pca/testdata/iris.data b/tensor/matrix/testdata/iris.data similarity index 100% rename from tensor/stats/pca/testdata/iris.data rename to tensor/matrix/testdata/iris.data diff --git a/tensor/number.go b/tensor/number.go index 91a6adb142..d45aa83dde 100644 --- a/tensor/number.go +++ b/tensor/number.go @@ -6,14 +6,10 @@ package tensor import ( "fmt" - "log" - "math" "strconv" - "strings" "cogentcore.org/core/base/num" - "cogentcore.org/core/base/slicesx" - "gonum.org/v1/gonum/mat" + "cogentcore.org/core/base/reflectx" ) // Number is a tensor of numerical values @@ -33,44 +29,53 @@ type Int = Number[int] // Int32 is an alias for Number[int32]. type Int32 = Number[int32] +// Uint32 is an alias for Number[uint32]. +type Uint32 = Number[uint32] + // Byte is an alias for Number[byte]. type Byte = Number[byte] -// NewFloat32 returns a new Float32 tensor -// with the given sizes per dimension (shape), and optional dimension names. -func NewFloat32(sizes []int, names ...string) *Float32 { - return New[float32](sizes, names...).(*Float32) +// NewFloat32 returns a new [Float32] tensor +// with the given sizes per dimension (shape). +func NewFloat32(sizes ...int) *Float32 { + return New[float32](sizes...).(*Float32) } -// NewFloat64 returns a new Float64 tensor -// with the given sizes per dimension (shape), and optional dimension names. -func NewFloat64(sizes []int, names ...string) *Float64 { - return New[float64](sizes, names...).(*Float64) +// NewFloat64 returns a new [Float64] tensor +// with the given sizes per dimension (shape). +func NewFloat64(sizes ...int) *Float64 { + return New[float64](sizes...).(*Float64) } // NewInt returns a new Int tensor -// with the given sizes per dimension (shape), and optional dimension names. -func NewInt(sizes []int, names ...string) *Int { - return New[float64](sizes, names...).(*Int) +// with the given sizes per dimension (shape). +func NewInt(sizes ...int) *Int { + return New[int](sizes...).(*Int) } // NewInt32 returns a new Int32 tensor -// with the given sizes per dimension (shape), and optional dimension names. -func NewInt32(sizes []int, names ...string) *Int32 { - return New[float64](sizes, names...).(*Int32) +// with the given sizes per dimension (shape). +func NewInt32(sizes ...int) *Int32 { + return New[int32](sizes...).(*Int32) +} + +// NewUint32 returns a new Uint32 tensor +// with the given sizes per dimension (shape). +func NewUint32(sizes ...int) *Uint32 { + return New[uint32](sizes...).(*Uint32) } // NewByte returns a new Byte tensor -// with the given sizes per dimension (shape), and optional dimension names. -func NewByte(sizes []int, names ...string) *Byte { - return New[float64](sizes, names...).(*Byte) +// with the given sizes per dimension (shape). +func NewByte(sizes ...int) *Byte { + return New[uint8](sizes...).(*Byte) } // NewNumber returns a new n-dimensional tensor of numerical values -// with the given sizes per dimension (shape), and optional dimension names. -func NewNumber[T num.Number](sizes []int, names ...string) *Number[T] { +// with the given sizes per dimension (shape). +func NewNumber[T num.Number](sizes ...int) *Number[T] { tsr := &Number[T]{} - tsr.SetShape(sizes, names...) + tsr.SetShapeSizes(sizes...) tsr.Values = make([]T, tsr.Len()) return tsr } @@ -79,171 +84,161 @@ func NewNumber[T num.Number](sizes []int, names ...string) *Number[T] { // using given shape. func NewNumberShape[T num.Number](shape *Shape) *Number[T] { tsr := &Number[T]{} - tsr.Shp.CopyShape(shape) + tsr.shape.CopyFrom(shape) tsr.Values = make([]T, tsr.Len()) return tsr } -func (tsr *Number[T]) IsString() bool { - return false -} +// todo: this should in principle work with yaegi:add but it is crashing +// will come back to it later. -func (tsr *Number[T]) AddScalar(i []int, val float64) float64 { - j := tsr.Shp.Offset(i) - tsr.Values[j] += T(val) - return float64(tsr.Values[j]) +// NewNumberFromValues returns a new 1-dimensional tensor of given value type +// initialized directly from the given slice values, which are not copied. +// The resulting Tensor thus "wraps" the given values. +func NewNumberFromValues[T num.Number](vals ...T) *Number[T] { + n := len(vals) + tsr := &Number[T]{} + tsr.Values = vals + tsr.SetShapeSizes(n) + return tsr } -func (tsr *Number[T]) MulScalar(i []int, val float64) float64 { - j := tsr.Shp.Offset(i) - tsr.Values[j] *= T(val) - return float64(tsr.Values[j]) +// String satisfies the fmt.Stringer interface for string of tensor data. +func (tsr *Number[T]) String() string { return Sprintf("", tsr, 0) } + +func (tsr *Number[T]) IsString() bool { return false } + +func (tsr *Number[T]) AsValues() Values { return tsr } + +func (tsr *Number[T]) SetAdd(val T, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] += val +} +func (tsr *Number[T]) SetSub(val T, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] -= val } +func (tsr *Number[T]) SetMul(val T, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] *= val +} +func (tsr *Number[T]) SetDiv(val T, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] /= val +} + +/////// Strings -func (tsr *Number[T]) SetString(i []int, val string) { +func (tsr *Number[T]) SetString(val string, i ...int) { if fv, err := strconv.ParseFloat(val, 64); err == nil { - j := tsr.Shp.Offset(i) - tsr.Values[j] = T(fv) + tsr.Values[tsr.shape.IndexTo1D(i...)] = T(fv) } } -func (tsr Number[T]) SetString1D(off int, val string) { +func (tsr Number[T]) SetString1D(val string, i int) { if fv, err := strconv.ParseFloat(val, 64); err == nil { - tsr.Values[off] = T(fv) + tsr.Values[i] = T(fv) } } -func (tsr *Number[T]) SetStringRowCell(row, cell int, val string) { + +func (tsr *Number[T]) SetStringRow(val string, row, cell int) { if fv, err := strconv.ParseFloat(val, 64); err == nil { - _, sz := tsr.Shp.RowCellSize() + _, sz := tsr.shape.RowCellSize() tsr.Values[row*sz+cell] = T(fv) } } -// String satisfies the fmt.Stringer interface for string of tensor data -func (tsr *Number[T]) String() string { - str := tsr.Label() - sz := len(tsr.Values) - if sz > 1000 { - return str +// AppendRowString adds a row and sets string value(s), up to number of cells. +func (tsr *Number[T]) AppendRowString(val ...string) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) } - var b strings.Builder - b.WriteString(str) - b.WriteString("\n") - oddRow := true - rows, cols, _, _ := Projection2DShape(&tsr.Shp, oddRow) - for r := 0; r < rows; r++ { - rc, _ := Projection2DCoords(&tsr.Shp, oddRow, r, 0) - b.WriteString(fmt.Sprintf("%v: ", rc)) - for c := 0; c < cols; c++ { - vl := Projection2DValue(tsr, oddRow, r, c) - b.WriteString(fmt.Sprintf("%7g ", vl)) - } - b.WriteString("\n") + nrow, sz := tsr.shape.RowCellSize() + tsr.SetNumRows(nrow + 1) + mx := min(sz, len(val)) + for i := range mx { + tsr.SetStringRow(val[i], nrow, i) } - return b.String() } -func (tsr *Number[T]) Float(i []int) float64 { - j := tsr.Shp.Offset(i) - return float64(tsr.Values[j]) +/////// Floats + +func (tsr *Number[T]) Float(i ...int) float64 { + return float64(tsr.Values[tsr.shape.IndexTo1D(i...)]) } -func (tsr *Number[T]) SetFloat(i []int, val float64) { - j := tsr.Shp.Offset(i) - tsr.Values[j] = T(val) +func (tsr *Number[T]) SetFloat(val float64, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] = T(val) } func (tsr *Number[T]) Float1D(i int) float64 { - return float64(tsr.Values[i]) + return float64(tsr.Values[NegIndex(i, len(tsr.Values))]) } -func (tsr *Number[T]) SetFloat1D(i int, val float64) { - tsr.Values[i] = T(val) +func (tsr *Number[T]) SetFloat1D(val float64, i int) { + tsr.Values[NegIndex(i, len(tsr.Values))] = T(val) } -func (tsr *Number[T]) FloatRowCell(row, cell int) float64 { - _, sz := tsr.Shp.RowCellSize() +func (tsr *Number[T]) FloatRow(row, cell int) float64 { + _, sz := tsr.shape.RowCellSize() i := row*sz + cell - return float64(tsr.Values[i]) + return float64(tsr.Values[NegIndex(i, len(tsr.Values))]) } -func (tsr *Number[T]) SetFloatRowCell(row, cell int, val float64) { - _, sz := tsr.Shp.RowCellSize() +func (tsr *Number[T]) SetFloatRow(val float64, row, cell int) { + _, sz := tsr.shape.RowCellSize() tsr.Values[row*sz+cell] = T(val) } -// Floats sets []float64 slice of all elements in the tensor -// (length is ensured to be sufficient). -// This can be used for all of the gonum/floats methods -// for basic math, gonum/stats, etc. -func (tsr *Number[T]) Floats(flt *[]float64) { - *flt = slicesx.SetLength(*flt, len(tsr.Values)) - switch vals := any(tsr.Values).(type) { - case []float64: - copy(*flt, vals) - default: - for i, v := range tsr.Values { - (*flt)[i] = float64(v) - } +// AppendRowFloat adds a row and sets float value(s), up to number of cells. +func (tsr *Number[T]) AppendRowFloat(val ...float64) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + nrow, sz := tsr.shape.RowCellSize() + tsr.SetNumRows(nrow + 1) + mx := min(sz, len(val)) + for i := range mx { + tsr.SetFloatRow(val[i], nrow, i) } } -// SetFloats sets tensor values from a []float64 slice (copies values). -func (tsr *Number[T]) SetFloats(flt []float64) { - switch vals := any(tsr.Values).(type) { - case []float64: - copy(vals, flt) - default: - for i, v := range flt { - tsr.Values[i] = T(v) - } - } +/////// Ints + +func (tsr *Number[T]) Int(i ...int) int { + return int(tsr.Values[tsr.shape.IndexTo1D(i...)]) } -// At is the gonum/mat.Matrix interface method for returning 2D matrix element at given -// row, column index. Assumes Row-major ordering and logs an error if NumDims < 2. -func (tsr *Number[T]) At(i, j int) float64 { - nd := tsr.NumDims() - if nd < 2 { - log.Println("tensor Dims gonum Matrix call made on Tensor with dims < 2") - return 0 - } else if nd == 2 { - return tsr.Float([]int{i, j}) - } else { - ix := make([]int, nd) - ix[nd-2] = i - ix[nd-1] = j - return tsr.Float(ix) - } +func (tsr *Number[T]) SetInt(val int, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] = T(val) } -// T is the gonum/mat.Matrix transpose method. -// It performs an implicit transpose by returning the receiver inside a Transpose. -func (tsr *Number[T]) T() mat.Matrix { - return mat.Transpose{tsr} +func (tsr *Number[T]) Int1D(i int) int { + return int(tsr.Values[NegIndex(i, len(tsr.Values))]) } -// Range returns the min, max (and associated indexes, -1 = no values) for the tensor. -// This is needed for display and is thus in the core api in optimized form -// Other math operations can be done using gonum/floats package. -func (tsr *Number[T]) Range() (min, max float64, minIndex, maxIndex int) { - minIndex = -1 - maxIndex = -1 - for j, vl := range tsr.Values { - fv := float64(vl) - if math.IsNaN(fv) { - continue - } - if fv < min || minIndex < 0 { - min = fv - minIndex = j - } - if fv > max || maxIndex < 0 { - max = fv - maxIndex = j - } +func (tsr *Number[T]) SetInt1D(val int, i int) { + tsr.Values[NegIndex(i, len(tsr.Values))] = T(val) +} + +func (tsr *Number[T]) IntRow(row, cell int) int { + _, sz := tsr.shape.RowCellSize() + i := row*sz + cell + return int(tsr.Values[i]) +} + +func (tsr *Number[T]) SetIntRow(val int, row, cell int) { + _, sz := tsr.shape.RowCellSize() + tsr.Values[row*sz+cell] = T(val) +} + +// AppendRowInt adds a row and sets int value(s), up to number of cells. +func (tsr *Number[T]) AppendRowInt(val ...int) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + nrow, sz := tsr.shape.RowCellSize() + tsr.SetNumRows(nrow + 1) + mx := min(sz, len(val)) + for i := range mx { + tsr.SetIntRow(val[i], nrow, i) } - return } // SetZeros is simple convenience function initialize all values to 0 @@ -254,10 +249,9 @@ func (tsr *Number[T]) SetZeros() { } // Clone clones this tensor, creating a duplicate copy of itself with its -// own separate memory representation of all the values, and returns -// that as a Tensor (which can be converted into the known type as needed). -func (tsr *Number[T]) Clone() Tensor { - csr := NewNumberShape[T](&tsr.Shp) +// own separate memory representation of all the values. +func (tsr *Number[T]) Clone() Values { + csr := NewNumberShape[T](&tsr.shape) copy(csr.Values, tsr.Values) return csr } @@ -265,21 +259,45 @@ func (tsr *Number[T]) Clone() Tensor { // CopyFrom copies all avail values from other tensor into this tensor, with an // optimized implementation if the other tensor is of the same type, and // otherwise it goes through appropriate standard type. -func (tsr *Number[T]) CopyFrom(frm Tensor) { +func (tsr *Number[T]) CopyFrom(frm Values) { if fsm, ok := frm.(*Number[T]); ok { copy(tsr.Values, fsm.Values) return } - sz := min(len(tsr.Values), frm.Len()) - for i := 0; i < sz; i++ { - tsr.Values[i] = T(frm.Float1D(i)) + sz := min(tsr.Len(), frm.Len()) + if reflectx.KindIsInt(tsr.DataType()) { + for i := range sz { + tsr.Values[i] = T(frm.Int1D(i)) + } + } else { + for i := range sz { + tsr.Values[i] = T(frm.Float1D(i)) + } } } -// CopyShapeFrom copies just the shape from given source tensor -// calling SetShape with the shape params from source (see for more docs). -func (tsr *Number[T]) CopyShapeFrom(frm Tensor) { - tsr.SetShape(frm.Shape().Sizes, frm.Shape().Names...) +// AppendFrom appends values from other tensor into this tensor, +// which must have the same cell size as this tensor. +// It uses and optimized implementation if the other tensor +// is of the same type, and otherwise it goes through +// appropriate standard type. +func (tsr *Number[T]) AppendFrom(frm Values) error { + rows, cell := tsr.shape.RowCellSize() + frows, fcell := frm.Shape().RowCellSize() + if cell != fcell { + return fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell) + } + tsr.SetNumRows(rows + frows) + st := rows * cell + fsz := frows * fcell + if fsm, ok := frm.(*Number[T]); ok { + copy(tsr.Values[st:st+fsz], fsm.Values) + return nil + } + for i := 0; i < fsz; i++ { + tsr.Values[st+i] = T(frm.Float1D(i)) + } + return nil } // CopyCellsFrom copies given range of values from other tensor into this tensor, @@ -287,14 +305,12 @@ func (tsr *Number[T]) CopyShapeFrom(frm Tensor) { // start = starting index on from Tensor to start copying from, and n = number of // values to copy. Uses an optimized implementation if the other tensor is // of the same type, and otherwise it goes through appropriate standard type. -func (tsr *Number[T]) CopyCellsFrom(frm Tensor, to, start, n int) { +func (tsr *Number[T]) CopyCellsFrom(frm Values, to, start, n int) { if fsm, ok := frm.(*Number[T]); ok { - for i := 0; i < n; i++ { - tsr.Values[to+i] = fsm.Values[start+i] - } + copy(tsr.Values[to:to+n], fsm.Values[start:start+n]) return } - for i := 0; i < n; i++ { + for i := range n { tsr.Values[to+i] = T(frm.Float1D(start + i)) } } @@ -304,9 +320,34 @@ func (tsr *Number[T]) CopyCellsFrom(frm Tensor, to, start, n int) { // The new tensor points to the values of the this tensor (i.e., modifications // will affect both), as its Values slice is a view onto the original (which // is why only inner-most contiguous supsaces are supported). -// Use Clone() method to separate the two. -func (tsr *Number[T]) SubSpace(offs []int) Tensor { - b := tsr.subSpaceImpl(offs) +// Use AsValues() method to separate the two. +func (tsr *Number[T]) SubSpace(offs ...int) Values { + b := tsr.subSpaceImpl(offs...) rt := &Number[T]{Base: *b} return rt } + +// RowTensor is a convenience version of [RowMajor.SubSpace] to return the +// SubSpace for the outermost row dimension. [Rows] defines a version +// of this that indirects through the row indexes. +func (tsr *Number[T]) RowTensor(row int) Values { + return tsr.SubSpace(row) +} + +// SetRowTensor sets the values of the SubSpace at given row to given values. +func (tsr *Number[T]) SetRowTensor(val Values, row int) { + _, cells := tsr.shape.RowCellSize() + st := row * cells + mx := min(val.Len(), cells) + tsr.CopyCellsFrom(val, st, 0, mx) +} + +// AppendRow adds a row and sets values to given values. +func (tsr *Number[T]) AppendRow(val Values) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + nrow := tsr.DimSize(0) + tsr.SetNumRows(nrow + 1) + tsr.SetRowTensor(val, nrow) +} diff --git a/tensor/projection2d.go b/tensor/projection2d.go index 592cecc4e1..531be1daad 100644 --- a/tensor/projection2d.go +++ b/tensor/projection2d.go @@ -5,176 +5,190 @@ package tensor const ( - // OddRow is for oddRow arguments to Projection2D functions, - // specifies that the odd dimension goes along the row. - OddRow = true + // OnedRow is for onedRow arguments to Projection2D functions, + // specifies that the 1D case goes along the row. + OnedRow = true - // OddColumn is for oddRow arguments to Projection2D functions, - // specifies that the odd dimension goes along the column. - OddColumn = false + // OnedColumn is for onedRow arguments to Projection2D functions, + // specifies that the 1D case goes along the column. + OnedColumn = false ) // Projection2DShape returns the size of a 2D projection of the given tensor Shape, // collapsing higher dimensions down to 2D (and 1D up to 2D). -// For any odd number of dimensions, the remaining outer-most dimension -// can either be multipliexed across the row or column, given the oddRow arg. -// Even multiples of inner-most dimensions are assumed to be row, then column. -// rowEx returns the number of "extra" (higher dimensional) rows -// and colEx returns the number of extra cols -func Projection2DShape(shp *Shape, oddRow bool) (rows, cols, rowEx, colEx int) { +// For the 1D case, onedRow determines if the values are row-wise or not. +// Even multiples of inner-most dimensions are placed along the row, odd in the column. +// If there are an odd number of dimensions, the first dimension is row-wise, and +// the remaining inner dimensions use the above logic from there, as if it was even. +// rowEx returns the number of "extra" (outer-dimensional) rows +// and colEx returns the number of extra cols, to add extra spacing between these dimensions. +func Projection2DShape(shp *Shape, onedRow bool) (rows, cols, rowEx, colEx int) { if shp.Len() == 0 { return 1, 1, 0, 0 } nd := shp.NumDims() - switch nd { - case 1: - if oddRow { + if nd == 1 { + if onedRow { return shp.DimSize(0), 1, 0, 0 - } else { - return 1, shp.DimSize(0), 0, 0 } - case 2: + return 1, shp.DimSize(0), 0, 0 + } + if nd == 2 { return shp.DimSize(0), shp.DimSize(1), 0, 0 - case 3: - if oddRow { - return shp.DimSize(0) * shp.DimSize(1), shp.DimSize(2), shp.DimSize(0), 0 - } else { - return shp.DimSize(1), shp.DimSize(0) * shp.DimSize(2), 0, shp.DimSize(0) + } + rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow) + rows = rowShape.Len() + cols = colShape.Len() + nri := len(rowIdxs) + if nri > 1 { + rowEx = 1 + for i := range nri - 1 { + rowEx *= shp.DimSize(rowIdxs[i]) } - case 4: - return shp.DimSize(0) * shp.DimSize(2), shp.DimSize(1) * shp.DimSize(3), shp.DimSize(0), shp.DimSize(1) - case 5: - if oddRow { - return shp.DimSize(0) * shp.DimSize(1) * shp.DimSize(3), shp.DimSize(2) * shp.DimSize(4), shp.DimSize(0) * shp.DimSize(1), 0 + } + nci := len(colIdxs) + if nci > 1 { + colEx = 1 + for i := range nci - 1 { + colEx *= shp.DimSize(colIdxs[i]) + } + } + return +} + +// Projection2DDimShapes returns the shapes and dimension indexes for a 2D projection +// of given tensor Shape, collapsing higher dimensions down to 2D (and 1D up to 2D). +// For the 1D case, onedRow determines if the values are row-wise or not. +// Even multiples of inner-most dimensions are placed along the row, odd in the column. +// If there are an odd number of dimensions, the first dimension is row-wise, and +// the remaining inner dimensions use the above logic from there, as if it was even. +// This is the main organizing function for all Projection2D calls. +func Projection2DDimShapes(shp *Shape, onedRow bool) (rowShape, colShape *Shape, rowIdxs, colIdxs []int) { + nd := shp.NumDims() + if nd == 1 { + if onedRow { + return NewShape(shp.DimSize(0)), NewShape(1), []int{0}, nil + } + return NewShape(1), NewShape(shp.DimSize(0)), nil, []int{0} + } + if nd == 2 { + return NewShape(shp.DimSize(0)), NewShape(shp.DimSize(1)), []int{0}, []int{1} + } + var rs, cs []int + odd := nd%2 == 1 + sd := 0 + end := nd + if odd { + end = nd - 1 + sd = 1 + rs = []int{shp.DimSize(0)} + rowIdxs = []int{0} + } + for d := range end { + ad := d + sd + if d%2 == 0 { // even goes to row + rs = append(rs, shp.DimSize(ad)) + rowIdxs = append(rowIdxs, ad) } else { - return shp.DimSize(1) * shp.DimSize(3), shp.DimSize(0) * shp.DimSize(2) * shp.DimSize(4), 0, shp.DimSize(0) * shp.DimSize(1) + cs = append(cs, shp.DimSize(ad)) + colIdxs = append(colIdxs, ad) } } - return 1, 1, 0, 0 + rowShape = NewShape(rs...) + colShape = NewShape(cs...) + return } // Projection2DIndex returns the flat 1D index for given row, col coords for a 2D projection // of the given tensor shape, collapsing higher dimensions down to 2D (and 1D up to 2D). -// For any odd number of dimensions, the remaining outer-most dimension -// can either be multipliexed across the row or column, given the oddRow arg. -// Even multiples of inner-most dimensions are assumed to be row, then column. -func Projection2DIndex(shp *Shape, oddRow bool, row, col int) int { +// See [Projection2DShape] for full info. +func Projection2DIndex(shp *Shape, onedRow bool, row, col int) int { + if shp.Len() == 0 { + return 0 + } nd := shp.NumDims() - switch nd { - case 1: - if oddRow { + if nd == 1 { + if onedRow { return row - } else { - return col } - case 2: - return shp.Offset([]int{row, col}) - case 3: - if oddRow { - ny := shp.DimSize(1) - yy := row / ny - y := row % ny - return shp.Offset([]int{yy, y, col}) - } else { - nx := shp.DimSize(2) - xx := col / nx - x := col % nx - return shp.Offset([]int{xx, row, x}) - } - case 4: - ny := shp.DimSize(2) - yy := row / ny - y := row % ny - nx := shp.DimSize(3) - xx := col / nx - x := col % nx - return shp.Offset([]int{yy, xx, y, x}) - case 5: - // todo: oddRows version! - nyy := shp.DimSize(1) - ny := shp.DimSize(3) - yyy := row / (nyy * ny) - yy := row % (nyy * ny) - y := yy % ny - yy = yy / ny - nx := shp.DimSize(4) - xx := col / nx - x := col % nx - return shp.Offset([]int{yyy, yy, xx, y, x}) - } - return 0 + return col + } + if nd == 2 { + return shp.IndexTo1D(row, col) + } + rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow) + ris := rowShape.IndexFrom1D(row) + cis := colShape.IndexFrom1D(col) + ixs := make([]int, nd) + for i, ri := range rowIdxs { + ixs[ri] = ris[i] + } + for i, ci := range colIdxs { + ixs[ci] = cis[i] + } + return shp.IndexTo1D(ixs...) } // Projection2DCoords returns the corresponding full-dimensional coordinates // that go into the given row, col coords for a 2D projection of the given tensor, // collapsing higher dimensions down to 2D (and 1D up to 2D). -func Projection2DCoords(shp *Shape, oddRow bool, row, col int) (rowCoords, colCoords []int) { - idx := Projection2DIndex(shp, oddRow, row, col) - dims := shp.Index(idx) +// See [Projection2DShape] for full info. +func Projection2DCoords(shp *Shape, onedRow bool, row, col int) (rowCoords, colCoords []int) { + if shp.Len() == 0 { + return []int{0}, []int{0} + } + idx := Projection2DIndex(shp, onedRow, row, col) + dims := shp.IndexFrom1D(idx) nd := shp.NumDims() - switch nd { - case 1: - if oddRow { + if nd == 1 { + if onedRow { return dims, []int{0} - } else { - return []int{0}, dims } - case 2: + return []int{0}, dims + } + if nd == 2 { return dims[:1], dims[1:] - case 3: - if oddRow { - return dims[:2], dims[2:] - } else { - return dims[:1], dims[1:] - } - case 4: - return []int{dims[0], dims[2]}, []int{dims[1], dims[3]} - case 5: - if oddRow { - return []int{dims[0], dims[1], dims[3]}, []int{dims[2], dims[4]} - } else { - return []int{dims[1], dims[3]}, []int{dims[0], dims[2], dims[4]} - } } - return nil, nil + _, _, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow) + rowCoords = make([]int, len(rowIdxs)) + colCoords = make([]int, len(colIdxs)) + for i, ri := range rowIdxs { + rowCoords[i] = dims[ri] + } + for i, ci := range colIdxs { + colCoords[i] = dims[ci] + } + return } // Projection2DValue returns the float64 value at given row, col coords for a 2D projection // of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D). -// For any odd number of dimensions, the remaining outer-most dimension -// can either be multipliexed across the row or column, given the oddRow arg. -// Even multiples of inner-most dimensions are assumed to be row, then column. -func Projection2DValue(tsr Tensor, oddRow bool, row, col int) float64 { - idx := Projection2DIndex(tsr.Shape(), oddRow, row, col) +// See [Projection2DShape] for full info. +func Projection2DValue(tsr Tensor, onedRow bool, row, col int) float64 { + idx := Projection2DIndex(tsr.Shape(), onedRow, row, col) return tsr.Float1D(idx) } // Projection2DString returns the string value at given row, col coords for a 2D projection // of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D). -// For any odd number of dimensions, the remaining outer-most dimension -// can either be multipliexed across the row or column, given the oddRow arg. -// Even multiples of inner-most dimensions are assumed to be row, then column. -func Projection2DString(tsr Tensor, oddRow bool, row, col int) string { - idx := Projection2DIndex(tsr.Shape(), oddRow, row, col) +// See [Projection2DShape] for full info. +func Projection2DString(tsr Tensor, onedRow bool, row, col int) string { + idx := Projection2DIndex(tsr.Shape(), onedRow, row, col) return tsr.String1D(idx) } // Projection2DSet sets a float64 value at given row, col coords for a 2D projection // of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D). -// For any odd number of dimensions, the remaining outer-most dimension -// can either be multipliexed across the row or column, given the oddRow arg. -// Even multiples of inner-most dimensions are assumed to be row, then column. -func Projection2DSet(tsr Tensor, oddRow bool, row, col int, val float64) { - idx := Projection2DIndex(tsr.Shape(), oddRow, row, col) - tsr.SetFloat1D(idx, val) +// See [Projection2DShape] for full info. +func Projection2DSet(tsr Tensor, onedRow bool, row, col int, val float64) { + idx := Projection2DIndex(tsr.Shape(), onedRow, row, col) + tsr.SetFloat1D(val, idx) } // Projection2DSetString sets a string value at given row, col coords for a 2D projection // of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D). -// For any odd number of dimensions, the remaining outer-most dimension -// can either be multipliexed across the row or column, given the oddRow arg. -// Even multiples of inner-most dimensions are assumed to be row, then column. -func Projection2DSetString(tsr Tensor, oddRow bool, row, col int, val string) { - idx := Projection2DIndex(tsr.Shape(), oddRow, row, col) - tsr.SetString1D(idx, val) +// See [Projection2DShape] for full info. +func Projection2DSetString(tsr Tensor, onedRow bool, row, col int, val string) { + idx := Projection2DIndex(tsr.Shape(), onedRow, row, col) + tsr.SetString1D(val, idx) } diff --git a/tensor/reshaped.go b/tensor/reshaped.go new file mode 100644 index 0000000000..6c59e7de49 --- /dev/null +++ b/tensor/reshaped.go @@ -0,0 +1,215 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "reflect" + "slices" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/metadata" +) + +// Reshaped is a reshaping wrapper around another "source" [Tensor], +// that provides a length-preserving reshaped view onto the source Tensor. +// Reshaping by adding new size=1 dimensions (via [NewAxis] value) is +// often important for properly aligning two tensors in a computationally +// compatible manner; see the [AlignShapes] function. +// [Reshaped.AsValues] on this view returns a new [Values] with the view +// shape, calling [Clone] on the source tensor to get the values. +type Reshaped struct { //types:add + + // Tensor source that we are a masked view onto. + Tensor Tensor + + // Reshape is the effective shape we use for access. + // This must have the same Len() as the source Tensor. + Reshape Shape +} + +// NewReshaped returns a new [Reshaped] view of given tensor, with given shape +// sizes. If no such sizes are provided, the source shape is used. +// A single -1 value can be used to automatically specify the remaining tensor +// length, as long as the other sizes are an even multiple of the total length. +// A single -1 returns a 1D view of the entire tensor. +func NewReshaped(tsr Tensor, sizes ...int) *Reshaped { + rs := &Reshaped{Tensor: tsr} + if len(sizes) == 0 { + rs.Reshape.CopyFrom(tsr.Shape()) + } else { + errors.Log(rs.SetShapeSizes(sizes...)) + } + return rs +} + +// Reshape returns a view of the given tensor with given shape sizes. +// A single -1 value can be used to automatically specify the remaining tensor +// length, as long as the other sizes are an even multiple of the total length. +// A single -1 returns a 1D view of the entire tensor. +func Reshape(tsr Tensor, sizes ...int) Tensor { + if len(sizes) == 0 { + err := errors.New("tensor.Reshape: must pass shape sizes") + errors.Log(err) + return tsr + } + if len(sizes) == 1 { + sz := sizes[0] + if sz == -1 { + return As1D(tsr) + } + } + rs := &Reshaped{Tensor: tsr} + errors.Log(rs.SetShapeSizes(sizes...)) + return rs +} + +// Transpose returns a new [Reshaped] tensor with the strides +// switched so that rows and column dimensions are effectively +// reversed. +func Transpose(tsr Tensor) Tensor { + rs := &Reshaped{Tensor: tsr} + rs.Reshape.CopyFrom(tsr.Shape()) + rs.Reshape.Strides = ColumnMajorStrides(rs.Reshape.Sizes...) + return rs +} + +// NewRowCellsView returns a 2D [Reshaped] view onto the given tensor, +// with a single outer "row" dimension and a single inner "cells" dimension, +// with the given 'split' dimension specifying where the cells start. +// All dimensions prior to split are collapsed to form the new outer row dimension, +// and the remainder are collapsed to form the 1D cells dimension. +// This is useful for stats, metrics and other packages that operate +// on data in this shape. +func NewRowCellsView(tsr Tensor, split int) *Reshaped { + sizes := tsr.ShapeSizes() + rows := sizes[:split] + cells := sizes[split:] + nr := 1 + for _, r := range rows { + nr *= r + } + nc := 1 + for _, c := range cells { + nc *= c + } + return NewReshaped(tsr, nr, nc) +} + +// AsReshaped returns the tensor as a [Reshaped] view. +// If it already is one, then it is returned, otherwise it is wrapped +// with an initial shape equal to the source tensor. +func AsReshaped(tsr Tensor) *Reshaped { + if rs, ok := tsr.(*Reshaped); ok { + return rs + } + return NewReshaped(tsr) +} + +// SetShapeSizes sets our shape sizes to the given values, which must result in +// the same length as the source tensor. An error is returned if not. +// If a different subset of content is desired, use another view such as [Sliced]. +// Note that any number of size = 1 dimensions can be added without affecting +// the length, and the [NewAxis] value can be used to semantically +// indicate when such a new dimension is being inserted. This is often useful +// for aligning two tensors to achieve a desired computation; see [AlignShapes] +// function. A single -1 can be used to specify a dimension size that takes the +// remaining length, as long as the other sizes are an even multiple of the length. +// A single -1 indicates to use the full length. +func (rs *Reshaped) SetShapeSizes(sizes ...int) error { + sln := rs.Tensor.Len() + if sln == 0 { + return nil + } + if sln == 1 { + sz := sizes[0] + if sz < 0 { + rs.Reshape.SetShapeSizes(sln) + return nil + } + } + sz := slices.Clone(sizes) + ln := 1 + negIdx := -1 + for i, s := range sz { + if s < 0 { + negIdx = i + } else { + ln *= s + } + } + if negIdx >= 0 { + if sln%ln != 0 { + return errors.New("tensor.Reshaped SetShapeSizes: -1 cannot be used because the remaining dimensions are not an even multiple of the source tensor length") + } + sz[negIdx] = sln / ln + } + rs.Reshape.SetShapeSizes(sz...) + if rs.Reshape.Len() != sln { + return errors.New("tensor.Reshaped SetShapeSizes: new length is different from source tensor; use Sliced or other views to change view content") + } + return nil +} + +func (rs *Reshaped) Label() string { return label(metadata.Name(rs), rs.Shape()) } +func (rs *Reshaped) String() string { return Sprintf("", rs, 0) } +func (rs *Reshaped) Metadata() *metadata.Data { return rs.Tensor.Metadata() } +func (rs *Reshaped) IsString() bool { return rs.Tensor.IsString() } +func (rs *Reshaped) DataType() reflect.Kind { return rs.Tensor.DataType() } +func (rs *Reshaped) ShapeSizes() []int { return rs.Reshape.Sizes } +func (rs *Reshaped) Shape() *Shape { return &rs.Reshape } +func (rs *Reshaped) Len() int { return rs.Reshape.Len() } +func (rs *Reshaped) NumDims() int { return rs.Reshape.NumDims() } +func (rs *Reshaped) DimSize(dim int) int { return rs.Reshape.DimSize(dim) } + +// AsValues returns a copy of this tensor as raw [Values], with +// the same shape as our view. This calls [Clone] on the source +// tensor to get the Values and then sets our shape sizes to it. +func (rs *Reshaped) AsValues() Values { + vals := Clone(rs.Tensor) + vals.SetShapeSizes(rs.Reshape.Sizes...) + return vals +} + +//////// Floats + +func (rs *Reshaped) Float(i ...int) float64 { + return rs.Tensor.Float1D(rs.Reshape.IndexTo1D(i...)) +} + +func (rs *Reshaped) SetFloat(val float64, i ...int) { + rs.Tensor.SetFloat1D(val, rs.Reshape.IndexTo1D(i...)) +} + +func (rs *Reshaped) Float1D(i int) float64 { return rs.Tensor.Float1D(i) } +func (rs *Reshaped) SetFloat1D(val float64, i int) { rs.Tensor.SetFloat1D(val, i) } + +//////// Strings + +func (rs *Reshaped) StringValue(i ...int) string { + return rs.Tensor.String1D(rs.Reshape.IndexTo1D(i...)) +} + +func (rs *Reshaped) SetString(val string, i ...int) { + rs.Tensor.SetString1D(val, rs.Reshape.IndexTo1D(i...)) +} + +func (rs *Reshaped) String1D(i int) string { return rs.Tensor.String1D(i) } +func (rs *Reshaped) SetString1D(val string, i int) { rs.Tensor.SetString1D(val, i) } + +//////// Ints + +func (rs *Reshaped) Int(i ...int) int { + return rs.Tensor.Int1D(rs.Reshape.IndexTo1D(i...)) +} + +func (rs *Reshaped) SetInt(val int, i ...int) { + rs.Tensor.SetInt1D(val, rs.Reshape.IndexTo1D(i...)) +} + +func (rs *Reshaped) Int1D(i int) int { return rs.Tensor.Int1D(i) } +func (rs *Reshaped) SetInt1D(val int, i int) { rs.Tensor.SetInt1D(val, i) } + +// check for interface impl +var _ Tensor = (*Reshaped)(nil) diff --git a/tensor/rowmajor.go b/tensor/rowmajor.go new file mode 100644 index 0000000000..3bce95bc8a --- /dev/null +++ b/tensor/rowmajor.go @@ -0,0 +1,79 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +// RowMajor is subtype of [Tensor] that maintains a row major memory organization +// that thereby supports efficient access via the outermost 'row' dimension, +// with all remaining inner dimensions comprising the 'cells' of data per row +// (1 scalar value in the case of a 1D tensor). +// It is implemented by raw [Values] tensors, and the [Rows] indexed view of +// raw Values tensors. Other views however do not retain the underlying +// outer to inner row major memory structure and thus do not implement this interface. +type RowMajor interface { + Tensor + + // SubSpace returns a new tensor with innermost subspace at given + // offset(s) in outermost dimension(s) (len(offs) < [NumDims]). + // The new tensor points to the values of the this tensor (i.e., modifications + // will affect both), as its Values slice is a view onto the original (which + // is why only inner-most contiguous supsaces are supported). + // Use AsValues() method to separate the two. See [Slice] function to + // extract arbitrary subspaces along ranges of each dimension. + SubSpace(offs ...int) Values + + // RowTensor is a convenience version of [RowMajor.SubSpace] to return the + // SubSpace for the outermost row dimension. [Rows] defines a version + // of this that indirects through the row indexes. + RowTensor(row int) Values + + // SetRowTensor sets the values of the [RowMajor.SubSpace] at given row to given values. + SetRowTensor(val Values, row int) + + // AppendRow adds a row and sets values to given values. + AppendRow(val Values) + + //////// Floats + + // FloatRow returns the value at given row and cell, where row is the outermost + // dimension, and cell is a 1D index into remaining inner dimensions (0 for scalar). + FloatRow(row, cell int) float64 + + // SetFloatRow sets the value at given row and cell, where row is the outermost + // dimension, and cell is a 1D index into remaining inner dimensions. + SetFloatRow(val float64, row, cell int) + + // AppendRowFloat adds a row and sets float value(s), up to number of cells. + AppendRowFloat(val ...float64) + + //////// Ints + + // IntRow returns the value at given row and cell, where row is the outermost + // dimension, and cell is a 1D index into remaining inner dimensions. + IntRow(row, cell int) int + + // SetIntRow sets the value at given row and cell, where row is the outermost + // dimension, and cell is a 1D index into remaining inner dimensions. + SetIntRow(val int, row, cell int) + + // AppendRowInt adds a row and sets int value(s), up to number of cells. + AppendRowInt(val ...int) + + //////// Strings + + // StringRow returns the value at given row and cell, where row is the outermost + // dimension, and cell is a 1D index into remaining inner dimensions. + // [Rows] tensors index along the row, and use this interface extensively. + // This is useful for lists of patterns, and the [table.Table] container. + StringRow(row, cell int) string + + // SetStringRow sets the value at given row and cell, where row is the outermost + // dimension, and cell is a 1D index into remaining inner dimensions. + // [Rows] tensors index along the row, and use this interface extensively. + // This is useful for lists of patterns, and the [table.Table] container. + SetStringRow(val string, row, cell int) + + // AppendRowString adds a row and sets string value(s), up to number of cells. + AppendRowString(val ...string) +} diff --git a/tensor/rows.go b/tensor/rows.go new file mode 100644 index 0000000000..7d1925128c --- /dev/null +++ b/tensor/rows.go @@ -0,0 +1,679 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "cmp" + "math" + "math/rand" + "reflect" + "slices" + "sort" + "strings" + + "cogentcore.org/core/base/metadata" +) + +// Rows is a row-indexed wrapper view around a [Values] [Tensor] that allows +// arbitrary row-wise ordering and filtering according to the [Rows.Indexes]. +// Sorting and filtering a tensor along this outermost row dimension only +// requires updating the indexes while leaving the underlying Tensor alone. +// Unlike the more general [Sliced] view, Rows maintains memory contiguity +// for the inner dimensions ("cells") within each row, and supports the [RowMajor] +// interface, with the [Set]FloatRow[Cell] methods providing efficient access. +// Use [Rows.AsValues] to obtain a concrete [Values] representation with the +// current row sorting. +type Rows struct { //types:add + + // Tensor source that we are an indexed view onto. + // Note that this must be a concrete [Values] tensor, to enable efficient + // [RowMajor] access and subspace functions. + Tensor Values + + // Indexes are the indexes into Tensor rows, with nil = sequential. + // Only set if order is different from default sequential order. + // Use the [Rows.RowIndex] method for nil-aware logic. + Indexes []int +} + +// NewRows returns a new [Rows] view of given tensor, +// with optional list of indexes (none / nil = sequential). +func NewRows(tsr Values, idxs ...int) *Rows { + rw := &Rows{Tensor: tsr, Indexes: slices.Clone(idxs)} + return rw +} + +// AsRows returns the tensor as a [Rows] view. +// If it already is one, then it is returned, otherwise +// a new Rows is created to wrap around the given tensor, which is +// enforced to be a [Values] tensor either because it already is one, +// or by calling [Tensor.AsValues] on it. +func AsRows(tsr Tensor) *Rows { + if rw, ok := tsr.(*Rows); ok { + return rw + } + return NewRows(tsr.AsValues()) +} + +// SetTensor sets as indexes into given [Values] tensor with sequential initial indexes. +func (rw *Rows) SetTensor(tsr Values) { + rw.Tensor = tsr + rw.Sequential() +} + +func (rw *Rows) IsString() bool { return rw.Tensor.IsString() } + +func (rw *Rows) DataType() reflect.Kind { return rw.Tensor.DataType() } + +// RowIndex returns the actual index into underlying tensor row based on given +// index value. If Indexes == nil, index is passed through. +func (rw *Rows) RowIndex(idx int) int { + if rw.Indexes == nil { + return idx + } + return rw.Indexes[idx] +} + +// NumRows returns the effective number of rows in this Rows view, +// which is the length of the index list or number of outer +// rows dimension of tensor if no indexes (full sequential view). +func (rw *Rows) NumRows() int { + if rw.Indexes == nil { + return rw.Tensor.DimSize(0) + } + return len(rw.Indexes) +} + +func (rw *Rows) String() string { return Sprintf("", rw.Tensor, 0) } +func (rw *Rows) Label() string { return rw.Tensor.Label() } +func (rw *Rows) Metadata() *metadata.Data { return rw.Tensor.Metadata() } +func (rw *Rows) NumDims() int { return rw.Tensor.NumDims() } + +// If we have Indexes, this is the effective shape sizes using +// the current number of indexes as the outermost row dimension size. +func (rw *Rows) ShapeSizes() []int { + if rw.Indexes == nil || rw.Tensor.NumDims() == 0 { + return rw.Tensor.ShapeSizes() + } + sh := slices.Clone(rw.Tensor.ShapeSizes()) + sh[0] = len(rw.Indexes) + return sh +} + +// Shape() returns a [Shape] representation of the tensor shape +// (dimension sizes). If we have Indexes, this is the effective +// shape using the current number of indexes as the outermost row dimension size. +func (rw *Rows) Shape() *Shape { + if rw.Indexes == nil { + return rw.Tensor.Shape() + } + return NewShape(rw.ShapeSizes()...) +} + +// Len returns the total number of elements in the tensor, +// taking into account the Indexes via [Rows], +// as NumRows() * cell size. +func (rw *Rows) Len() int { + rows := rw.NumRows() + _, cells := rw.Tensor.Shape().RowCellSize() + return cells * rows +} + +// DimSize returns size of given dimension, returning NumRows() +// for first dimension. +func (rw *Rows) DimSize(dim int) int { + if dim == 0 { + return rw.NumRows() + } + return rw.Tensor.DimSize(dim) +} + +// RowCellSize returns the size of the outermost Row shape dimension +// (via [Rows.NumRows] method), and the size of all the remaining +// inner dimensions (the "cell" size). +func (rw *Rows) RowCellSize() (rows, cells int) { + _, cells = rw.Tensor.Shape().RowCellSize() + rows = rw.NumRows() + return +} + +// ValidIndexes deletes all invalid indexes from the list. +// Call this if rows (could) have been deleted from tensor. +func (rw *Rows) ValidIndexes() { + if rw.Tensor.DimSize(0) <= 0 || rw.Indexes == nil { + rw.Indexes = nil + return + } + ni := rw.NumRows() + for i := ni - 1; i >= 0; i-- { + if rw.Indexes[i] >= rw.Tensor.DimSize(0) { + rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...) + } + } +} + +// Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor. +func (rw *Rows) Sequential() { //types:add + rw.Indexes = nil +} + +// IndexesNeeded is called prior to an operation that needs actual indexes, +// e.g., Sort, Filter. If Indexes == nil, they are set to all rows, otherwise +// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure +// all rows are represented. +func (rw *Rows) IndexesNeeded() { + if rw.Tensor.DimSize(0) <= 0 { + rw.Indexes = nil + return + } + if rw.Indexes != nil { + return + } + rw.Indexes = make([]int, rw.Tensor.DimSize(0)) + for i := range rw.Indexes { + rw.Indexes[i] = i + } +} + +// ExcludeMissing deletes indexes where the values are missing, as indicated by NaN. +// Uses first cell of higher dimensional data. +func (rw *Rows) ExcludeMissing() { //types:add + if rw.Tensor.DimSize(0) <= 0 { + rw.Indexes = nil + return + } + rw.IndexesNeeded() + ni := rw.NumRows() + for i := ni - 1; i >= 0; i-- { + if math.IsNaN(rw.Tensor.FloatRow(rw.Indexes[i], 0)) { + rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...) + } + } +} + +// Permuted sets indexes to a permuted order. If indexes already exist +// then existing list of indexes is permuted, otherwise a new set of +// permuted indexes are generated +func (rw *Rows) Permuted() { + if rw.Tensor.DimSize(0) <= 0 { + rw.Indexes = nil + return + } + if rw.Indexes == nil { + rw.Indexes = rand.Perm(rw.Tensor.DimSize(0)) + } else { + rand.Shuffle(len(rw.Indexes), func(i, j int) { + rw.Indexes[i], rw.Indexes[j] = rw.Indexes[j], rw.Indexes[i] + }) + } +} + +const ( + // Ascending specifies an ascending sort direction for tensor Sort routines + Ascending = true + + // Descending specifies a descending sort direction for tensor Sort routines + Descending = false + + // StableSort specifies using stable, original order-preserving sort, which is slower. + StableSort = true + + // Unstable specifies using faster but unstable sorting. + UnstableSort = false +) + +// SortFunc sorts the row-wise indexes using given compare function. +// The compare function operates directly on row numbers into the Tensor +// as these row numbers have already been projected through the indexes. +// cmp(a, b) should return a negative number when a < b, a positive +// number when a > b and zero when a == b. +func (rw *Rows) SortFunc(cmp func(tsr Values, i, j int) int) { + rw.IndexesNeeded() + slices.SortFunc(rw.Indexes, func(a, b int) int { + return cmp(rw.Tensor, a, b) // key point: these are already indirected through indexes!! + }) +} + +// SortIndexes sorts the indexes into our Tensor directly in +// numerical order, producing the native ordering, while preserving +// any filtering that might have occurred. +func (rw *Rows) SortIndexes() { + if rw.Indexes == nil { + return + } + sort.Ints(rw.Indexes) +} + +// CompareAscending is a sort compare function that reverses direction +// based on the ascending bool. +func CompareAscending[T cmp.Ordered](a, b T, ascending bool) int { + if ascending { + return cmp.Compare(a, b) + } + return cmp.Compare(b, a) +} + +// Sort does default alpha or numeric sort of row-wise data. +// Uses first cell of higher dimensional data. +func (rw *Rows) Sort(ascending bool) { + if rw.Tensor.IsString() { + rw.SortFunc(func(tsr Values, i, j int) int { + return CompareAscending(tsr.StringRow(i, 0), tsr.StringRow(j, 0), ascending) + }) + } else { + rw.SortFunc(func(tsr Values, i, j int) int { + return CompareAscending(tsr.FloatRow(i, 0), tsr.FloatRow(j, 0), ascending) + }) + } +} + +// SortStableFunc stably sorts the row-wise indexes using given compare function. +// The compare function operates directly on row numbers into the Tensor +// as these row numbers have already been projected through the indexes. +// cmp(a, b) should return a negative number when a < b, a positive +// number when a > b and zero when a == b. +// It is *essential* that it always returns 0 when the two are equal +// for the stable function to actually work. +func (rw *Rows) SortStableFunc(cmp func(tsr Values, i, j int) int) { + rw.IndexesNeeded() + slices.SortStableFunc(rw.Indexes, func(a, b int) int { + return cmp(rw.Tensor, a, b) // key point: these are already indirected through indexes!! + }) +} + +// SortStable does stable default alpha or numeric sort. +// Uses first cell of higher dimensional data. +func (rw *Rows) SortStable(ascending bool) { + if rw.Tensor.IsString() { + rw.SortStableFunc(func(tsr Values, i, j int) int { + return CompareAscending(tsr.StringRow(i, 0), tsr.StringRow(j, 0), ascending) + }) + } else { + rw.SortStableFunc(func(tsr Values, i, j int) int { + return CompareAscending(tsr.FloatRow(i, 0), tsr.FloatRow(j, 0), ascending) + }) + } +} + +// FilterFunc is a function used for filtering that returns +// true if Tensor row should be included in the current filtered +// view of the tensor, and false if it should be removed. +type FilterFunc func(tsr Values, row int) bool + +// Filter filters the indexes using given Filter function. +// The Filter function operates directly on row numbers into the Tensor +// as these row numbers have already been projected through the indexes. +func (rw *Rows) Filter(filterer func(tsr Values, row int) bool) { + rw.IndexesNeeded() + sz := len(rw.Indexes) + for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering + if !filterer(rw.Tensor, rw.Indexes[i]) { // delete + rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...) + } + } +} + +// FilterOptions are options to a Filter function +// determining how the string filter value is used for matching. +type FilterOptions struct { //types:add + + // Exclude means to exclude matches, + // with the default (false) being to include + Exclude bool + + // Contains means the string only needs to contain the target string, + // with the default (false) requiring a complete match to entire string. + Contains bool + + // IgnoreCase means that differences in case are ignored in comparing strings, + // with the default (false) using case. + IgnoreCase bool +} + +// FilterString filters the indexes using string values compared to given +// string. Includes rows with matching values unless the Exclude option is set. +// If Contains option is set, it only checks if row contains string; +// if IgnoreCase, ignores case, otherwise filtering is case sensitive. +// Uses first cell of higher dimensional data. +func (rw *Rows) FilterString(str string, opts FilterOptions) { //types:add + lowstr := strings.ToLower(str) + rw.Filter(func(tsr Values, row int) bool { + val := tsr.StringRow(row, 0) + has := false + switch { + case opts.Contains && opts.IgnoreCase: + has = strings.Contains(strings.ToLower(val), lowstr) + case opts.Contains: + has = strings.Contains(val, str) + case opts.IgnoreCase: + has = strings.EqualFold(val, str) + default: + has = (val == str) + } + if opts.Exclude { + return !has + } + return has + }) +} + +// AsValues returns this tensor as raw [Values]. +// If the row [Rows.Indexes] are nil, then the wrapped Values tensor +// is returned. Otherwise, it "renders" the Rows view into a fully contiguous +// and optimized memory representation of that view, which will be faster +// to access for further processing, and enables all the additional +// functionality provided by the [Values] interface. +func (rw *Rows) AsValues() Values { + if rw.Indexes == nil { + return rw.Tensor + } + vt := NewOfType(rw.Tensor.DataType(), rw.ShapeSizes()...) + rows := rw.NumRows() + for r := range rows { + vt.SetRowTensor(rw.RowTensor(r), r) + } + return vt +} + +// CloneIndexes returns a copy of the current Rows view with new indexes, +// with a pointer to the same underlying Tensor as the source. +func (rw *Rows) CloneIndexes() *Rows { + nix := &Rows{} + nix.Tensor = rw.Tensor + nix.CopyIndexes(rw) + return nix +} + +// CopyIndexes copies indexes from other Rows view. +func (rw *Rows) CopyIndexes(oix *Rows) { + if oix.Indexes == nil { + rw.Indexes = nil + } else { + rw.Indexes = slices.Clone(oix.Indexes) + } +} + +// addRowsIndexes adds n rows to indexes starting at end of current tensor size +func (rw *Rows) addRowsIndexes(n int) { //types:add + if rw.Indexes == nil { + return + } + stidx := rw.Tensor.DimSize(0) + for i := stidx; i < stidx+n; i++ { + rw.Indexes = append(rw.Indexes, i) + } +} + +// AddRows adds n rows to end of underlying Tensor, and to the indexes in this view +func (rw *Rows) AddRows(n int) { //types:add + stidx := rw.Tensor.DimSize(0) + rw.addRowsIndexes(n) + rw.Tensor.SetNumRows(stidx + n) +} + +// InsertRows adds n rows to end of underlying Tensor, and to the indexes starting at +// given index in this view +func (rw *Rows) InsertRows(at, n int) { + stidx := rw.Tensor.DimSize(0) + rw.IndexesNeeded() + rw.Tensor.SetNumRows(stidx + n) + nw := make([]int, n, n+len(rw.Indexes)-at) + for i := 0; i < n; i++ { + nw[i] = stidx + i + } + rw.Indexes = append(rw.Indexes[:at], append(nw, rw.Indexes[at:]...)...) +} + +// DeleteRows deletes n rows of indexes starting at given index in the list of indexes +func (rw *Rows) DeleteRows(at, n int) { + rw.IndexesNeeded() + rw.Indexes = append(rw.Indexes[:at], rw.Indexes[at+n:]...) +} + +// Swap switches the indexes for i and j +func (rw *Rows) Swap(i, j int) { + if rw.Indexes == nil { + return + } + rw.Indexes[i], rw.Indexes[j] = rw.Indexes[j], rw.Indexes[i] +} + +/////// Floats + +// Float returns the value of given index as a float64. +// The first index value is indirected through the indexes. +func (rw *Rows) Float(i ...int) float64 { + if rw.Indexes == nil { + return rw.Tensor.Float(i...) + } + ic := slices.Clone(i) + ic[0] = rw.Indexes[ic[0]] + return rw.Tensor.Float(ic...) +} + +// SetFloat sets the value of given index as a float64 +// The first index value is indirected through the [Rows.Indexes]. +func (rw *Rows) SetFloat(val float64, i ...int) { + if rw.Indexes == nil { + rw.Tensor.SetFloat(val, i...) + return + } + ic := slices.Clone(i) + ic[0] = rw.Indexes[ic[0]] + rw.Tensor.SetFloat(val, ic...) +} + +// FloatRow returns the value at given row and cell, +// where row is outermost dim, and cell is 1D index into remaining inner dims. +// Row is indirected through the [Rows.Indexes]. +// This is the preferred interface for all Rows operations. +func (rw *Rows) FloatRow(row, cell int) float64 { + return rw.Tensor.FloatRow(rw.RowIndex(row), cell) +} + +// SetFloatRow sets the value at given row and cell, +// where row is outermost dim, and cell is 1D index into remaining inner dims. +// Row is indirected through the [Rows.Indexes]. +// This is the preferred interface for all Rows operations. +func (rw *Rows) SetFloatRow(val float64, row, cell int) { + rw.Tensor.SetFloatRow(val, rw.RowIndex(row), cell) +} + +// Float1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (rw *Rows) Float1D(i int) float64 { + if rw.Indexes == nil { + return rw.Tensor.Float1D(i) + } + return rw.Float(rw.Tensor.Shape().IndexFrom1D(i)...) +} + +// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (rw *Rows) SetFloat1D(val float64, i int) { + if rw.Indexes == nil { + rw.Tensor.SetFloat1D(val, i) + return + } + rw.SetFloat(val, rw.Tensor.Shape().IndexFrom1D(i)...) +} + +/////// Strings + +// StringValue returns the value of given index as a string. +// The first index value is indirected through the indexes. +func (rw *Rows) StringValue(i ...int) string { + if rw.Indexes == nil { + return rw.Tensor.StringValue(i...) + } + ic := slices.Clone(i) + ic[0] = rw.Indexes[ic[0]] + return rw.Tensor.StringValue(ic...) +} + +// SetString sets the value of given index as a string +// The first index value is indirected through the [Rows.Indexes]. +func (rw *Rows) SetString(val string, i ...int) { + if rw.Indexes == nil { + rw.Tensor.SetString(val, i...) + return + } + ic := slices.Clone(i) + ic[0] = rw.Indexes[ic[0]] + rw.Tensor.SetString(val, ic...) +} + +// StringRow returns the value at given row and cell, +// where row is outermost dim, and cell is 1D index into remaining inner dims. +// Row is indirected through the [Rows.Indexes]. +// This is the preferred interface for all Rows operations. +func (rw *Rows) StringRow(row, cell int) string { + return rw.Tensor.StringRow(rw.RowIndex(row), cell) +} + +// SetStringRow sets the value at given row and cell, +// where row is outermost dim, and cell is 1D index into remaining inner dims. +// Row is indirected through the [Rows.Indexes]. +// This is the preferred interface for all Rows operations. +func (rw *Rows) SetStringRow(val string, row, cell int) { + rw.Tensor.SetStringRow(val, rw.RowIndex(row), cell) +} + +// AppendRowFloat adds a row and sets float value(s), up to number of cells. +func (rw *Rows) AppendRowFloat(val ...float64) { + rw.addRowsIndexes(1) + rw.Tensor.AppendRowFloat(val...) +} + +// String1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (rw *Rows) String1D(i int) string { + if rw.Indexes == nil { + return rw.Tensor.String1D(i) + } + return rw.StringValue(rw.Tensor.Shape().IndexFrom1D(i)...) +} + +// SetString1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (rw *Rows) SetString1D(val string, i int) { + if rw.Indexes == nil { + rw.Tensor.SetString1D(val, i) + return + } + rw.SetString(val, rw.Tensor.Shape().IndexFrom1D(i)...) +} + +// AppendRowString adds a row and sets string value(s), up to number of cells. +func (rw *Rows) AppendRowString(val ...string) { + rw.addRowsIndexes(1) + rw.Tensor.AppendRowString(val...) +} + +/////// Ints + +// Int returns the value of given index as an int. +// The first index value is indirected through the indexes. +func (rw *Rows) Int(i ...int) int { + if rw.Indexes == nil { + return rw.Tensor.Int(i...) + } + ic := slices.Clone(i) + ic[0] = rw.Indexes[ic[0]] + return rw.Tensor.Int(ic...) +} + +// SetInt sets the value of given index as an int +// The first index value is indirected through the [Rows.Indexes]. +func (rw *Rows) SetInt(val int, i ...int) { + if rw.Indexes == nil { + rw.Tensor.SetInt(val, i...) + return + } + ic := slices.Clone(i) + ic[0] = rw.Indexes[ic[0]] + rw.Tensor.SetInt(val, ic...) +} + +// IntRow returns the value at given row and cell, +// where row is outermost dim, and cell is 1D index into remaining inner dims. +// Row is indirected through the [Rows.Indexes]. +// This is the preferred interface for all Rows operations. +func (rw *Rows) IntRow(row, cell int) int { + return rw.Tensor.IntRow(rw.RowIndex(row), cell) +} + +// SetIntRow sets the value at given row and cell, +// where row is outermost dim, and cell is 1D index into remaining inner dims. +// Row is indirected through the [Rows.Indexes]. +// This is the preferred interface for all Rows operations. +func (rw *Rows) SetIntRow(val int, row, cell int) { + rw.Tensor.SetIntRow(val, rw.RowIndex(row), cell) +} + +// AppendRowInt adds a row and sets int value(s), up to number of cells. +func (rw *Rows) AppendRowInt(val ...int) { + rw.addRowsIndexes(1) + rw.Tensor.AppendRowInt(val...) +} + +// Int1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (rw *Rows) Int1D(i int) int { + if rw.Indexes == nil { + return rw.Tensor.Int1D(i) + } + return rw.Int(rw.Tensor.Shape().IndexFrom1D(i)...) +} + +// SetInt1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (rw *Rows) SetInt1D(val int, i int) { + if rw.Indexes == nil { + rw.Tensor.SetInt1D(val, i) + return + } + rw.SetInt(val, rw.Tensor.Shape().IndexFrom1D(i)...) +} + +/////// SubSpaces + +// SubSpace returns a new tensor with innermost subspace at given +// offset(s) in outermost dimension(s) (len(offs) < NumDims). +// The new tensor points to the values of the this tensor (i.e., modifications +// will affect both), as its Values slice is a view onto the original (which +// is why only inner-most contiguous supsaces are supported). +// Use Clone() method to separate the two. +// Rows version does indexed indirection of the outermost row dimension +// of the offsets. +func (rw *Rows) SubSpace(offs ...int) Values { + if len(offs) == 0 { + return nil + } + offs[0] = rw.RowIndex(offs[0]) + return rw.Tensor.SubSpace(offs...) +} + +// RowTensor is a convenience version of [Rows.SubSpace] to return the +// SubSpace for the outermost row dimension, indirected through the indexes. +func (rw *Rows) RowTensor(row int) Values { + return rw.Tensor.RowTensor(rw.RowIndex(row)) +} + +// SetRowTensor sets the values of the SubSpace at given row to given values, +// with row indirected through the indexes. +func (rw *Rows) SetRowTensor(val Values, row int) { + rw.Tensor.SetRowTensor(val, rw.RowIndex(row)) +} + +// AppendRow adds a row and sets values to given values. +func (rw *Rows) AppendRow(val Values) { + nrow := rw.Tensor.DimSize(0) + rw.AddRows(1) + rw.Tensor.SetRowTensor(val, nrow) +} + +// check for interface impl +var _ RowMajor = (*Rows)(nil) diff --git a/tensor/shape.go b/tensor/shape.go index e31f7fe553..2f010ae793 100644 --- a/tensor/shape.go +++ b/tensor/shape.go @@ -9,91 +9,76 @@ import ( "slices" ) -// Shape manages a tensor's shape information, including strides and dimension names +// Shape manages a tensor's shape information, including sizes and strides, // and can compute the flat index into an underlying 1D data storage array based on an // n-dimensional index (and vice-versa). -// Per C / Go / Python conventions, indexes are Row-Major, ordered from +// Per Go / C / Python conventions, indexes are Row-Major, ordered from // outer to inner left-to-right, so the inner-most is right-most. type Shape struct { - // size per dimension + // size per dimension. Sizes []int - // offsets for each dimension + // offsets for each dimension. Strides []int `display:"-"` - - // names of each dimension - Names []string `display:"-"` } -// NewShape returns a new shape with given sizes and optional dimension names. +// NewShape returns a new shape with given sizes. // RowMajor ordering is used by default. -func NewShape(sizes []int, names ...string) *Shape { +func NewShape(sizes ...int) *Shape { sh := &Shape{} - sh.SetShape(sizes, names...) + sh.SetShapeSizes(sizes...) return sh } -// SetShape sets the shape size and optional names +// SetShapeSizes sets the shape sizes from list of ints. // RowMajor ordering is used by default. -func (sh *Shape) SetShape(sizes []int, names ...string) { +func (sh *Shape) SetShapeSizes(sizes ...int) { sh.Sizes = slices.Clone(sizes) - sh.Strides = RowMajorStrides(sizes) - sh.Names = make([]string, len(sh.Sizes)) - if len(names) == len(sizes) { - copy(sh.Names, names) - } + sh.Strides = RowMajorStrides(sizes...) +} + +// SetShapeSizesFromTensor sets the shape sizes from given tensor. +// RowMajor ordering is used by default. +func (sh *Shape) SetShapeSizesFromTensor(sizes Tensor) { + sh.SetShapeSizes(AsIntSlice(sizes)...) } -// CopyShape copies the shape parameters from another Shape struct. +// SizesAsTensor returns shape sizes as an Int Tensor. +func (sh *Shape) SizesAsTensor() *Int { + return NewIntFromValues(sh.Sizes...) +} + +// CopyFrom copies the shape parameters from another Shape struct. // copies the data so it is not accidentally subject to updates. -func (sh *Shape) CopyShape(cp *Shape) { +func (sh *Shape) CopyFrom(cp *Shape) { sh.Sizes = slices.Clone(cp.Sizes) sh.Strides = slices.Clone(cp.Strides) - sh.Names = slices.Clone(cp.Names) } // Len returns the total length of elements in the tensor -// (i.e., the product of the shape sizes) +// (i.e., the product of the shape sizes). func (sh *Shape) Len() int { if len(sh.Sizes) == 0 { return 0 } - o := int(1) + ln := 1 for _, v := range sh.Sizes { - o *= v + ln *= v } - return int(o) + return ln } // NumDims returns the total number of dimensions. func (sh *Shape) NumDims() int { return len(sh.Sizes) } // DimSize returns the size of given dimension. -func (sh *Shape) DimSize(i int) int { return sh.Sizes[i] } - -// DimName returns the name of given dimension. -func (sh *Shape) DimName(i int) string { return sh.Names[i] } - -// DimByName returns the index of the given dimension name. -// returns -1 if not found. -func (sh *Shape) DimByName(name string) int { - for i, nm := range sh.Names { - if nm == name { - return i - } - } - return -1 -} - -// DimSizeByName returns the size of given dimension, specified by name. -// will crash if name not found. -func (sh *Shape) DimSizeByName(name string) int { - return sh.DimSize(sh.DimByName(name)) +func (sh *Shape) DimSize(i int) int { + return sh.Sizes[i] } // IndexIsValid() returns true if given index is valid (within ranges for all dimensions) -func (sh *Shape) IndexIsValid(idx []int) bool { +func (sh *Shape) IndexIsValid(idx ...int) bool { if len(idx) != sh.NumDims() { return false } @@ -107,45 +92,58 @@ func (sh *Shape) IndexIsValid(idx []int) bool { // IsEqual returns true if this shape is same as other (does not compare names) func (sh *Shape) IsEqual(oth *Shape) bool { - if !EqualInts(sh.Sizes, oth.Sizes) { + if slices.Compare(sh.Sizes, oth.Sizes) != 0 { return false } - if !EqualInts(sh.Strides, oth.Strides) { + if slices.Compare(sh.Strides, oth.Strides) != 0 { return false } return true } -// RowCellSize returns the size of the outer-most Row shape dimension, +// RowCellSize returns the size of the outermost Row shape dimension, // and the size of all the remaining inner dimensions (the "cell" size). // Used for Tensors that are columns in a data table. func (sh *Shape) RowCellSize() (rows, cells int) { + if len(sh.Sizes) == 0 { + return 0, 1 + } rows = sh.Sizes[0] if len(sh.Sizes) == 1 { cells = 1 - } else { + } else if rows > 0 { cells = sh.Len() / rows + } else { + ln := 1 + for _, v := range sh.Sizes[1:] { + ln *= v + } + cells = ln } return } -// Offset returns the "flat" 1D array index into an element at the given n-dimensional index. -// No checking is done on the length or size of the index values relative to the shape of the tensor. -func (sh *Shape) Offset(index []int) int { - var offset int +// IndexTo1D returns the flat 1D index from given n-dimensional indicies. +// No checking is done on the length or size of the index values relative +// to the shape of the tensor. +func (sh *Shape) IndexTo1D(index ...int) int { + oned := 0 for i, v := range index { - offset += v * sh.Strides[i] + oned += v * sh.Strides[i] } - return offset + return oned } -// Index returns the n-dimensional index from a "flat" 1D array index. -func (sh *Shape) Index(offset int) []int { +// IndexFrom1D returns the n-dimensional index from a "flat" 1D array index. +func (sh *Shape) IndexFrom1D(oned int) []int { nd := len(sh.Sizes) index := make([]int, nd) - rem := offset + rem := oned for i := nd - 1; i >= 0; i-- { s := sh.Sizes[i] + if s == 0 { + return index + } iv := rem % s rem /= s index[i] = iv @@ -155,24 +153,16 @@ func (sh *Shape) Index(offset int) []int { // String satisfies the fmt.Stringer interface func (sh *Shape) String() string { - str := "[" - for i := range sh.Sizes { - nm := sh.Names[i] - if nm != "" { - str += nm + ": " - } - str += fmt.Sprintf("%d", sh.Sizes[i]) - if i < len(sh.Sizes)-1 { - str += ", " - } - } - str += "]" - return str + return fmt.Sprintf("%v", sh.Sizes) } -// RowMajorStrides returns strides for sizes where the first dimension is outer-most +// RowMajorStrides returns strides for sizes where the first dimension is outermost // and subsequent dimensions are progressively inner. -func RowMajorStrides(sizes []int) []int { +func RowMajorStrides(sizes ...int) []int { + if len(sizes) == 0 { + return nil + } + sizes[0] = max(1, sizes[0]) // critical for strides to not be nil due to rows = 0 rem := int(1) for _, v := range sizes { rem *= v @@ -180,7 +170,6 @@ func RowMajorStrides(sizes []int) []int { if rem == 0 { strides := make([]int, len(sizes)) - rem := int(1) for i := range strides { strides[i] = rem } @@ -195,9 +184,9 @@ func RowMajorStrides(sizes []int) []int { return strides } -// ColMajorStrides returns strides for sizes where the first dimension is inner-most +// ColumnMajorStrides returns strides for sizes where the first dimension is inner-most // and subsequent dimensions are progressively outer -func ColMajorStrides(sizes []int) []int { +func ColumnMajorStrides(sizes ...int) []int { total := int(1) for _, v := range sizes { if v == 0 { @@ -217,19 +206,6 @@ func ColMajorStrides(sizes []int) []int { return strides } -// EqualInts compares two int slices and returns true if they are equal -func EqualInts(a, b []int) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - // AddShapes returns a new shape by adding two shapes one after the other. func AddShapes(shape1, shape2 *Shape) *Shape { sh1 := shape1.Sizes @@ -237,8 +213,19 @@ func AddShapes(shape1, shape2 *Shape) *Shape { nsh := make([]int, len(sh1)+len(sh2)) copy(nsh, sh1) copy(nsh[len(sh1):], sh2) - nms := make([]string, len(sh1)+len(sh2)) - copy(nms, shape1.Names) - copy(nms[len(sh1):], shape2.Names) - return NewShape(nsh, nms...) + sh := NewShape(nsh...) + return sh +} + +// CellsSizes returns the sizes of inner cells dimensions given +// overall tensor sizes. It returns []int{1} for the 1D case. +// Used for ensuring cell-wise outputs are the right size. +func CellsSize(sizes []int) []int { + csz := slices.Clone(sizes) + if len(csz) == 1 { + csz[0] = 1 + } else { + csz = csz[1:] + } + return csz } diff --git a/tensor/sliced.go b/tensor/sliced.go new file mode 100644 index 0000000000..d1cdf0b3d5 --- /dev/null +++ b/tensor/sliced.go @@ -0,0 +1,435 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "math/rand" + "reflect" + "slices" + "sort" + + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/base/reflectx" + "cogentcore.org/core/base/slicesx" +) + +// Sliced provides a re-sliced view onto another "source" [Tensor], +// defined by a set of [Sliced.Indexes] for each dimension (must have +// at least 1 index per dimension to avoid a null view). +// Thus, each dimension can be transformed in arbitrary ways relative +// to the original tensor (filtered subsets, reversals, sorting, etc). +// This view is not memory-contiguous and does not support the [RowMajor] +// interface or efficient access to inner-dimensional subspaces. +// A new Sliced view defaults to a full transparent view of the source tensor. +// There is additional cost for every access operation associated with the +// indexed indirection, and access is always via the full n-dimensional indexes. +// See also [Rows] for a version that only indexes the outermost row dimension, +// which is much more efficient for this common use-case, and does support [RowMajor]. +// To produce a new concrete [Values] that has raw data actually organized according +// to the indexed order (i.e., the copy function of numpy), call [Sliced.AsValues]. +type Sliced struct { //types:add + + // Tensor source that we are an indexed view onto. + Tensor Tensor + + // Indexes are the indexes for each dimension, with dimensions as the outer + // slice (enforced to be the same length as the NumDims of the source Tensor), + // and a list of dimension index values (within range of DimSize(d)). + // A nil list of indexes for a dimension automatically provides a full, + // sequential view of that dimension. + Indexes [][]int +} + +// NewSliced returns a new [Sliced] view of given tensor, +// with optional list of indexes for each dimension (none / nil = sequential). +// Any dimensions without indexes default to nil = full sequential view. +func NewSliced(tsr Tensor, idxs ...[]int) *Sliced { + sl := &Sliced{Tensor: tsr, Indexes: idxs} + sl.ValidIndexes() + return sl +} + +// Reslice returns a new [Sliced] (and potentially [Reshaped]) view of given tensor, +// with given slice expressions for each dimension, which can be: +// - an integer, indicating a specific index value along that dimension. +// Can use negative numbers to index from the end. +// This axis will also be removed using a [Reshaped]. +// - a [Slice] object expressing a range of indexes. +// - [FullAxis] includes the full original axis (equivalent to `Slice{}`). +// - [Ellipsis] creates a flexibly-sized stretch of FullAxis dimensions, +// which automatically aligns the remaining slice elements based on the source +// dimensionality. +// - [NewAxis] creates a new singleton (length=1) axis, used to to reshape +// without changing the size. This triggers a [Reshaped]. +// - any remaining dimensions without indexes default to nil = full sequential view. +func Reslice(tsr Tensor, sls ...any) Tensor { + ns := len(sls) + if ns == 0 { + return NewSliced(tsr) + } + nd := tsr.NumDims() + ed := nd - ns // extra dimensions + ixs := make([][]int, nd) + doReshape := false // indicates if we need a Reshaped + reshape := make([]int, 0, nd+2) // if we need one, this is the target shape + ci := 0 + for d := range ns { + s := sls[d] + switch x := s.(type) { + case int: + doReshape = true // doesn't add to new shape. + if x < 0 { + ixs[ci] = []int{tsr.DimSize(ci) + x} + } else { + ixs[ci] = []int{x} + } + case Slice: + ixs[ci] = x.IntSlice(tsr.DimSize(ci)) + reshape = append(reshape, len(ixs[ci])) + case SlicesMagic: + switch x { + case FullAxis: + ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci)) + reshape = append(reshape, len(ixs[ci])) + case NewAxis: + ed++ // we are not real + doReshape = true + reshape = append(reshape, 1) + continue // skip the increment in ci + case Ellipsis: + ed++ // extra for us + for range ed { + ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci)) + reshape = append(reshape, len(ixs[ci])) + ci++ + } + if ed > 0 { + ci-- + } + ed = 0 // ate them up + } + } + ci++ + } + for range ed { // fill any extra dimensions + ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci)) + reshape = append(reshape, len(ixs[ci])) + ci++ + } + sl := NewSliced(tsr, ixs...) + if doReshape { + if len(reshape) == 0 { // all indexes + reshape = []int{1} + } + return NewReshaped(sl, reshape...) + } + return sl +} + +// AsSliced returns the tensor as a [Sliced] view. +// If it already is one, then it is returned, otherwise it is wrapped +// in a new Sliced, with default full sequential ("transparent") view. +func AsSliced(tsr Tensor) *Sliced { + if sl, ok := tsr.(*Sliced); ok { + return sl + } + return NewSliced(tsr) +} + +// SetTensor sets tensor as source for this view, and initializes a full +// transparent view onto source (calls [Sliced.Sequential]). +func (sl *Sliced) SetTensor(tsr Tensor) { + sl.Tensor = tsr + sl.Sequential() +} + +// SourceIndex returns the actual index into source tensor dimension +// based on given index value. +func (sl *Sliced) SourceIndex(dim, idx int) int { + ix := sl.Indexes[dim] + if ix == nil { + return idx + } + return ix[idx] +} + +// SourceIndexes returns the actual n-dimensional indexes into source tensor +// based on given list of indexes based on the Sliced view shape. +func (sl *Sliced) SourceIndexes(i ...int) []int { + ix := slices.Clone(i) + for d, idx := range i { + ix[d] = sl.SourceIndex(d, idx) + } + return ix +} + +// SourceIndexesFrom1D returns the n-dimensional indexes into source tensor +// based on the given 1D index based on the Sliced view shape. +func (sl *Sliced) SourceIndexesFrom1D(oned int) []int { + sh := sl.Shape() + oix := sh.IndexFrom1D(oned) // full indexes in our coords + return sl.SourceIndexes(oix...) +} + +// ValidIndexes ensures that [Sliced.Indexes] are valid, +// removing any out-of-range values and setting the view to nil (full sequential) +// for any dimension with no indexes (which is an invalid condition). +// Call this when any structural changes are made to underlying Tensor. +func (sl *Sliced) ValidIndexes() { + nd := sl.Tensor.NumDims() + sl.Indexes = slicesx.SetLength(sl.Indexes, nd) + for d := range nd { + ni := len(sl.Indexes[d]) + if ni == 0 { // invalid + sl.Indexes[d] = nil // full + continue + } + ds := sl.Tensor.DimSize(d) + ix := sl.Indexes[d] + for i := ni - 1; i >= 0; i-- { + if ix[i] >= ds { + ix = append(ix[:i], ix[i+1:]...) + } + } + sl.Indexes[d] = ix + } +} + +// Sequential sets all Indexes to nil, resulting in full sequential access into tensor. +func (sl *Sliced) Sequential() { //types:add + nd := sl.Tensor.NumDims() + sl.Indexes = slicesx.SetLength(sl.Indexes, nd) + for d := range nd { + sl.Indexes[d] = nil + } +} + +// IndexesNeeded is called prior to an operation that needs actual indexes, +// on given dimension. If Indexes == nil, they are set to all items, otherwise +// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure +// all dimension indexes are represented. +func (sl *Sliced) IndexesNeeded(d int) { + ix := sl.Indexes[d] + if ix != nil { + return + } + ix = make([]int, sl.Tensor.DimSize(d)) + for i := range ix { + ix[i] = i + } + sl.Indexes[d] = ix +} + +func (sl *Sliced) Label() string { return label(metadata.Name(sl), sl.Shape()) } +func (sl *Sliced) String() string { return Sprintf("", sl, 0) } +func (sl *Sliced) Metadata() *metadata.Data { return sl.Tensor.Metadata() } +func (sl *Sliced) IsString() bool { return sl.Tensor.IsString() } +func (sl *Sliced) DataType() reflect.Kind { return sl.Tensor.DataType() } +func (sl *Sliced) Shape() *Shape { return NewShape(sl.ShapeSizes()...) } +func (sl *Sliced) Len() int { return sl.Shape().Len() } +func (sl *Sliced) NumDims() int { return sl.Tensor.NumDims() } + +// For each dimension, we return the effective shape sizes using +// the current number of indexes per dimension. +func (sl *Sliced) ShapeSizes() []int { + nd := sl.Tensor.NumDims() + if nd == 0 { + return sl.Tensor.ShapeSizes() + } + sh := slices.Clone(sl.Tensor.ShapeSizes()) + for d := range nd { + if sl.Indexes[d] != nil { + sh[d] = len(sl.Indexes[d]) + } + } + return sh +} + +// DimSize returns the effective view size of given dimension. +func (sl *Sliced) DimSize(dim int) int { + if sl.Indexes[dim] != nil { + return len(sl.Indexes[dim]) + } + return sl.Tensor.DimSize(dim) +} + +// AsValues returns a copy of this tensor as raw [Values]. +// This "renders" the Sliced view into a fully contiguous +// and optimized memory representation of that view, which will be faster +// to access for further processing, and enables all the additional +// functionality provided by the [Values] interface. +func (sl *Sliced) AsValues() Values { + dt := sl.Tensor.DataType() + vt := NewOfType(dt, sl.ShapeSizes()...) + n := sl.Len() + switch { + case sl.Tensor.IsString(): + for i := range n { + vt.SetString1D(sl.String1D(i), i) + } + case reflectx.KindIsFloat(dt): + for i := range n { + vt.SetFloat1D(sl.Float1D(i), i) + } + default: + for i := range n { + vt.SetInt1D(sl.Int1D(i), i) + } + } + return vt +} + +//////// Floats + +// Float returns the value of given index as a float64. +// The indexes are indirected through the [Sliced.Indexes]. +func (sl *Sliced) Float(i ...int) float64 { + return sl.Tensor.Float(sl.SourceIndexes(i...)...) +} + +// SetFloat sets the value of given index as a float64 +// The indexes are indirected through the [Sliced.Indexes]. +func (sl *Sliced) SetFloat(val float64, i ...int) { + sl.Tensor.SetFloat(val, sl.SourceIndexes(i...)...) +} + +// Float1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (sl *Sliced) Float1D(i int) float64 { + return sl.Tensor.Float(sl.SourceIndexesFrom1D(i)...) +} + +// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (sl *Sliced) SetFloat1D(val float64, i int) { + sl.Tensor.SetFloat(val, sl.SourceIndexesFrom1D(i)...) +} + +//////// Strings + +// StringValue returns the value of given index as a string. +// The indexes are indirected through the [Sliced.Indexes]. +func (sl *Sliced) StringValue(i ...int) string { + return sl.Tensor.StringValue(sl.SourceIndexes(i...)...) +} + +// SetString sets the value of given index as a string +// The indexes are indirected through the [Sliced.Indexes]. +func (sl *Sliced) SetString(val string, i ...int) { + sl.Tensor.SetString(val, sl.SourceIndexes(i...)...) +} + +// String1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (sl *Sliced) String1D(i int) string { + return sl.Tensor.StringValue(sl.SourceIndexesFrom1D(i)...) +} + +// SetString1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (sl *Sliced) SetString1D(val string, i int) { + sl.Tensor.SetString(val, sl.SourceIndexesFrom1D(i)...) +} + +//////// Ints + +// Int returns the value of given index as an int. +// The indexes are indirected through the [Sliced.Indexes]. +func (sl *Sliced) Int(i ...int) int { + return sl.Tensor.Int(sl.SourceIndexes(i...)...) +} + +// SetInt sets the value of given index as an int +// The indexes are indirected through the [Sliced.Indexes]. +func (sl *Sliced) SetInt(val int, i ...int) { + sl.Tensor.SetInt(val, sl.SourceIndexes(i...)...) +} + +// Int1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (sl *Sliced) Int1D(i int) int { + return sl.Tensor.Int(sl.SourceIndexesFrom1D(i)...) +} + +// SetInt1D is somewhat expensive if indexes are set, because it needs to convert +// the flat index back into a full n-dimensional index and then use that api. +func (sl *Sliced) SetInt1D(val int, i int) { + sl.Tensor.SetInt(val, sl.SourceIndexesFrom1D(i)...) +} + +// Permuted sets indexes in given dimension to a permuted order. +// If indexes already exist then existing list of indexes is permuted, +// otherwise a new set of permuted indexes are generated +func (sl *Sliced) Permuted(dim int) { + ix := sl.Indexes[dim] + if ix == nil { + ix = rand.Perm(sl.Tensor.DimSize(dim)) + } else { + rand.Shuffle(len(ix), func(i, j int) { + ix[i], ix[j] = ix[j], ix[i] + }) + } + sl.Indexes[dim] = ix +} + +// SortFunc sorts the indexes along given dimension using given compare function. +// The compare function operates directly on indexes into the Tensor +// as these row numbers have already been projected through the indexes. +// cmp(a, b) should return a negative number when a < b, a positive +// number when a > b and zero when a == b. +func (sl *Sliced) SortFunc(dim int, cmp func(tsr Tensor, dim, i, j int) int) { + sl.IndexesNeeded(dim) + ix := sl.Indexes[dim] + slices.SortFunc(ix, func(a, b int) int { + return cmp(sl.Tensor, dim, a, b) // key point: these are already indirected through indexes!! + }) + sl.Indexes[dim] = ix +} + +// SortIndexes sorts the indexes along given dimension directly in +// numerical order, producing the native ordering, while preserving +// any filtering that might have occurred. +func (sl *Sliced) SortIndexes(dim int) { + ix := sl.Indexes[dim] + if ix == nil { + return + } + sort.Ints(ix) + sl.Indexes[dim] = ix +} + +// SortStableFunc stably sorts along given dimension using given compare function. +// The compare function operates directly on row numbers into the Tensor +// as these row numbers have already been projected through the indexes. +// cmp(a, b) should return a negative number when a < b, a positive +// number when a > b and zero when a == b. +// It is *essential* that it always returns 0 when the two are equal +// for the stable function to actually work. +func (sl *Sliced) SortStableFunc(dim int, cmp func(tsr Tensor, dim, i, j int) int) { + sl.IndexesNeeded(dim) + ix := sl.Indexes[dim] + slices.SortStableFunc(ix, func(a, b int) int { + return cmp(sl.Tensor, dim, a, b) // key point: these are already indirected through indexes!! + }) + sl.Indexes[dim] = ix +} + +// Filter filters the indexes using the given Filter function +// for setting the indexes for given dimension, and index into the +// source data. +func (sl *Sliced) Filter(dim int, filterer func(tsr Tensor, dim, idx int) bool) { + sl.IndexesNeeded(dim) + ix := sl.Indexes[dim] + sz := len(ix) + for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering + if !filterer(sl, dim, ix[i]) { // delete + ix = append(ix[:i], ix[i+1:]...) + } + } + sl.Indexes[dim] = ix +} + +// check for interface impl +var _ Tensor = (*Sliced)(nil) diff --git a/tensor/slices.go b/tensor/slices.go new file mode 100644 index 0000000000..0707f04eb7 --- /dev/null +++ b/tensor/slices.go @@ -0,0 +1,160 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +// SlicesMagic are special elements in slice expressions, including +// NewAxis, FullAxis, and Ellipsis in [NewSliced] expressions. +type SlicesMagic int //enums:enum + +const ( + // FullAxis indicates that the full existing axis length should be used. + // This is equivalent to Slice{}, but is more semantic. In NumPy it is + // equivalent to a single : colon. + FullAxis SlicesMagic = iota + + // NewAxis creates a new singleton (length=1) axis, used to to reshape + // without changing the size. Can also be used in [Reshaped]. + NewAxis + + // Ellipsis (...) is used in [NewSliced] expressions to produce + // a flexibly-sized stretch of FullAxis dimensions, which automatically + // aligns the remaining slice elements based on the source dimensionality. + Ellipsis +) + +// Slice represents a slice of index values, for extracting slices of data, +// along a dimension of a given size, which is provided separately as an argument. +// Uses standard 'for' loop logic with a Start and _exclusive_ Stop value, +// and a Step increment: for i := Start; i < Stop; i += Step. +// The values stored in this struct are the _inputs_ for computing the actual +// slice values based on the actual size parameter for the dimension. +// Negative numbers count back from the end (i.e., size + val), and +// the zero value results in a list of all values in the dimension, with Step = 1 if 0. +// The behavior is identical to the NumPy slice. +type Slice struct { + // Start is the starting value. If 0 and Step < 0, = size-1; + // If negative, = size+Start. + Start int + + // Stop value. If 0 and Step >= 0, = size; + // If 0 and Step < 0, = -1, to include whole range. + // If negative = size+Stop. + Stop int + + // Step increment. If 0, = 1; if negative then Start must be > Stop + // to produce anything. + Step int +} + +// NewSlice returns a new Slice with given srat, stop, step values. +func NewSlice(start, stop, step int) Slice { + return Slice{Start: start, Stop: stop, Step: step} +} + +// GetStart is the actual start value given the size of the dimension. +func (sl Slice) GetStart(size int) int { + if sl.Start == 0 && sl.Step < 0 { + return size - 1 + } + if sl.Start < 0 { + return size + sl.Start + } + return sl.Start +} + +// GetStop is the actual end value given the size of the dimension. +func (sl Slice) GetStop(size int) int { + if sl.Stop == 0 && sl.Step >= 0 { + return size + } + if sl.Stop == 0 && sl.Step < 0 { + return -1 + } + if sl.Stop < 0 { + return size + sl.Stop + } + return min(sl.Stop, size) +} + +// GetStep is the actual increment value. +func (sl Slice) GetStep() int { + if sl.Step == 0 { + return 1 + } + return sl.Step +} + +// Len is the number of elements in the actual slice given +// size of the dimension. +func (sl Slice) Len(size int) int { + s := sl.GetStart(size) + e := sl.GetStop(size) + i := sl.GetStep() + n := max((e-s)/i, 0) + pe := s + n*i + if i < 0 { + if pe > e { + n++ + } + } else { + if pe < e { + n++ + } + } + return n +} + +// ToIntSlice writes values to given []int slice, with given size parameter +// for the dimension being sliced. If slice is wrong size to hold values, +// not all are written: allocate ints using Len(size) to fit. +func (sl Slice) ToIntSlice(size int, ints []int) { + n := len(ints) + if n == 0 { + return + } + s := sl.GetStart(size) + e := sl.GetStop(size) + inc := sl.GetStep() + idx := 0 + if inc < 0 { + for i := s; i > e; i += inc { + ints[idx] = i + idx++ + if idx >= n { + break + } + } + } else { + for i := s; i < e; i += inc { + ints[idx] = i + idx++ + if idx >= n { + break + } + } + } +} + +// IntSlice returns []int slice with slice index values, up to given actual size. +func (sl Slice) IntSlice(size int) []int { + n := sl.Len(size) + if n == 0 { + return nil + } + ints := make([]int, n) + sl.ToIntSlice(size, ints) + return ints +} + +// IntTensor returns an [Int] [Tensor] for slice, using actual size. +func (sl Slice) IntTensor(size int) *Int { + n := sl.Len(size) + if n == 0 { + return nil + } + tsr := NewInt(n) + sl.ToIntSlice(size, tsr.Values) + return tsr +} diff --git a/tensor/slices_test.go b/tensor/slices_test.go new file mode 100644 index 0000000000..f4048aae82 --- /dev/null +++ b/tensor/slices_test.go @@ -0,0 +1,119 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSlice(t *testing.T) { + assert.Equal(t, 3, Slice{}.Len(3)) + assert.Equal(t, 3, Slice{0, 3, 0}.Len(3)) + assert.Equal(t, 3, Slice{0, 3, 1}.Len(3)) + + assert.Equal(t, 2, Slice{0, 0, 2}.Len(3)) + assert.Equal(t, 2, Slice{0, 0, 2}.Len(4)) + assert.Equal(t, 1, Slice{0, 0, 3}.Len(3)) + assert.Equal(t, 2, Slice{0, 0, 3}.Len(4)) + assert.Equal(t, 2, Slice{0, 0, 3}.Len(6)) + assert.Equal(t, 3, Slice{0, 0, 3}.Len(7)) + + assert.Equal(t, 1, Slice{-1, 0, 0}.Len(3)) + assert.Equal(t, 2, Slice{0, -1, 0}.Len(3)) + assert.Equal(t, 3, Slice{0, 0, -1}.Len(3)) + assert.Equal(t, 3, Slice{-1, 0, -1}.Len(3)) + assert.Equal(t, 1, Slice{-1, -2, -1}.Len(3)) + assert.Equal(t, 2, Slice{-1, -3, -1}.Len(3)) + + assert.Equal(t, 2, Slice{0, 0, -2}.Len(3)) + assert.Equal(t, 2, Slice{0, 0, -2}.Len(4)) + assert.Equal(t, 1, Slice{0, 0, -3}.Len(3)) + assert.Equal(t, 2, Slice{0, 0, -3}.Len(4)) + assert.Equal(t, 2, Slice{0, 0, -3}.Len(6)) + assert.Equal(t, 3, Slice{0, 0, -3}.Len(7)) + + assert.Equal(t, []int{0, 1, 2}, Slice{}.IntSlice(3)) + assert.Equal(t, []int{0, 1, 2}, Slice{0, 3, 0}.IntSlice(3)) + assert.Equal(t, []int{0, 1, 2}, Slice{0, 3, 1}.IntSlice(3)) + + assert.Equal(t, []int{0, 2}, Slice{0, 0, 2}.IntSlice(3)) + assert.Equal(t, []int{0, 2}, Slice{0, 0, 2}.IntSlice(4)) + assert.Equal(t, []int{0}, Slice{0, 0, 3}.IntSlice(3)) + assert.Equal(t, []int{0, 3}, Slice{0, 0, 3}.IntSlice(4)) + assert.Equal(t, []int{0, 3}, Slice{0, 0, 3}.IntSlice(6)) + assert.Equal(t, []int{0, 3, 6}, Slice{0, 0, 3}.IntSlice(7)) + + assert.Equal(t, []int{2}, Slice{-1, 0, 0}.IntSlice(3)) + assert.Equal(t, []int{0, 1}, Slice{0, -1, 0}.IntSlice(3)) + assert.Equal(t, []int{2, 1, 0}, Slice{0, 0, -1}.IntSlice(3)) + assert.Equal(t, []int{2, 1, 0}, Slice{-1, 0, -1}.IntSlice(3)) + assert.Equal(t, []int{2}, Slice{-1, -2, -1}.IntSlice(3)) + assert.Equal(t, []int{2, 1}, Slice{-1, -3, -1}.IntSlice(3)) + + assert.Equal(t, []int{2, 0}, Slice{0, 0, -2}.IntSlice(3)) + assert.Equal(t, []int{3, 1}, Slice{0, 0, -2}.IntSlice(4)) + assert.Equal(t, []int{2}, Slice{0, 0, -3}.IntSlice(3)) + assert.Equal(t, []int{3, 0}, Slice{0, 0, -3}.IntSlice(4)) + assert.Equal(t, []int{5, 2}, Slice{0, 0, -3}.IntSlice(6)) + assert.Equal(t, []int{6, 3, 0}, Slice{0, 0, -3}.IntSlice(7)) +} + +func TestSlicedExpr(t *testing.T) { + ft := NewFloat64(3, 4) + for y := range 3 { + for x := range 4 { + v := y*10 + x + ft.SetFloat(float64(v), y, x) + } + } + + rf := []float64{0, 1, 2, 3, 10, 11, 12, 13, 20, 21, 22, 23} + assert.Equal(t, rf, ft.Values) + // fmt.Println(ft) + + sl := Reslice(ft, 1, 2) + assert.Equal(t, "12", sl.String()) + + res := `[4] 10 11 12 13 +` + sl = Reslice(ft, 1) + assert.Equal(t, res, sl.String()) + + res = `[3] 2 12 22 +` + sl = Reslice(ft, Ellipsis, 2) + assert.Equal(t, res, sl.String()) + + res = `[3 4] + [0] [1] [2] [3] +[0] 3 2 1 0 +[1] 13 12 11 10 +[2] 23 22 21 20 +` + sl = Reslice(ft, Ellipsis, Slice{Step: -1}) + assert.Equal(t, res, sl.String()) + + res = `[1 4] + [0] [1] [2] [3] +[0] 10 11 12 13 +` + sl = Reslice(ft, NewAxis, 1) + assert.Equal(t, res, sl.String()) + + res = `[1 3] + [0] [1] [2] +[0] 1 11 21 +` + sl = Reslice(ft, NewAxis, FullAxis, 1) // keeps result as a column vector + assert.Equal(t, res, sl.String()) + + res = `[3] 1 11 21 +` + sl = Reslice(ft, FullAxis, 1) + // fmt.Println(sl.String()) + assert.Equal(t, res, sl.String()) +} diff --git a/tensor/stats/README.md b/tensor/stats/README.md index 9e40210158..3c84a69d38 100644 --- a/tensor/stats/README.md +++ b/tensor/stats/README.md @@ -1,15 +1,11 @@ # stats -There are several packages here for operating on vector, [tensor](../tensor), and [table](../table) data, for computing standard statistics and performing related computations, such as normalizing the data. +There are several packages here for operating on [tensor](../), and [table](../table) data, for computing standard statistics and performing related computations, such as normalizing the data. -* [clust](clust) implements agglomerative clustering of items based on [simat](simat) similarity matrix data. +* [cluster](cluster) implements agglomerative clustering of items based on [metric](metric) distance / similarity matrix data. * [convolve](convolve) convolves data (e.g., for smoothing). * [glm](glm) fits a general linear model for one or more dependent variables as a function of one or more independent variables. This encompasses all forms of regression. * [histogram](histogram) bins data into groups and reports the frequency of elements in the bins. -* [metric](metric) computes similarity / distance metrics for comparing two vectors -* [norm](norm) normalizes vector data -* [pca](pca) computes principal components analysis (PCA) or singular value decomposition (SVD) on correlation matricies, which is a widely-used way of reducing the dimensionality of high-dimensional data. -* [simat](simat) computes a similarity matrix for the [metric](metric) similarity of two vectors. -* [split](split) provides grouping and aggregation functions operating on `table.Table` data, e.g., like a "pivot table" in a spreadsheet. -* [stats](stats) provides a set of standard summary statistics on a range of different data types, including basic slices of floats, to tensor and table data. +* [metric](metric) computes similarity / distance metrics for comparing two tensors, and associated distance / similarity matrix functions, including PCA and SVD analysis functions that operate on a covariance matrix. +* [stats](stats) provides a set of standard summary statistics on a range of different data types, including basic slices of floats, to tensor and table data. It also includes the ability to extract Groups of values and generate statistics for each group, as in a "pivot table" in a spreadsheet. diff --git a/tensor/stats/clust/README.md b/tensor/stats/clust/README.md deleted file mode 100644 index a11a82a5d9..0000000000 --- a/tensor/stats/clust/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# clust - -`clust` implements agglomerative clustering of items based on [simat](../simat) similarity matrix data. - -`GlomClust` is the main function, taking different `DistFunc` options for comparing distance between items. - - diff --git a/tensor/stats/clust/clust.go b/tensor/stats/clust/clust.go deleted file mode 100644 index 666990b250..0000000000 --- a/tensor/stats/clust/clust.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package clust - -//go:generate core generate - -import ( - "fmt" - "math" - "math/rand" - - "cogentcore.org/core/base/indent" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/simat" - "cogentcore.org/core/tensor/stats/stats" -) - -// Node is one node in the cluster -type Node struct { - - // index into original distance matrix -- only valid for for terminal leaves - Index int - - // distance for this node -- how far apart were all the kids from each other when this node was created -- is 0 for leaf nodes - Dist float64 - - // total aggregate distance from parents -- the X axis offset at which our cluster starts - ParDist float64 - - // y-axis value for this node -- if a parent, it is the average of its kids Y's, otherwise it counts down - Y float64 - - // child nodes under this one - Kids []*Node -} - -// IsLeaf returns true if node is a leaf of the tree with no kids -func (nn *Node) IsLeaf() bool { - return len(nn.Kids) == 0 -} - -// Sprint prints to string -func (nn *Node) Sprint(smat *simat.SimMat, depth int) string { - if nn.IsLeaf() { - return smat.Rows[nn.Index] + " " - } - sv := fmt.Sprintf("\n%v%v: ", indent.Tabs(depth), nn.Dist) - for _, kn := range nn.Kids { - sv += kn.Sprint(smat, depth+1) - } - return sv -} - -// Indexes collects all the indexes in this node -func (nn *Node) Indexes(ix []int, ctr *int) { - if nn.IsLeaf() { - ix[*ctr] = nn.Index - (*ctr)++ - } else { - for _, kn := range nn.Kids { - kn.Indexes(ix, ctr) - } - } -} - -// NewNode merges two nodes into a new node -func NewNode(na, nb *Node, dst float64) *Node { - nn := &Node{Dist: dst} - nn.Kids = []*Node{na, nb} - return nn -} - -// Glom implements basic agglomerative clustering, based on a raw similarity matrix as given. -// This calls GlomInit to initialize the root node with all of the leaves, and the calls -// GlomClust to do the iterative clustering process. If you want to start with pre-defined -// initial clusters, then call GlomClust with a root node so-initialized. -// The smat.Mat matrix must be an tensor.Float64. -func Glom(smat *simat.SimMat, dfunc DistFunc) *Node { - ntot := smat.Mat.DimSize(0) // number of leaves - root := GlomInit(ntot) - return GlomClust(root, smat, dfunc) -} - -// GlomStd implements basic agglomerative clustering, based on a raw similarity matrix as given. -// This calls GlomInit to initialize the root node with all of the leaves, and the calls -// GlomClust to do the iterative clustering process. If you want to start with pre-defined -// initial clusters, then call GlomClust with a root node so-initialized. -// The smat.Mat matrix must be an tensor.Float64. -// Std version uses std distance functions -func GlomStd(smat *simat.SimMat, std StdDists) *Node { - return Glom(smat, StdFunc(std)) -} - -// GlomInit returns a standard root node initialized with all of the leaves -func GlomInit(ntot int) *Node { - root := &Node{} - root.Kids = make([]*Node, ntot) - for i := 0; i < ntot; i++ { - root.Kids[i] = &Node{Index: i} - } - return root -} - -// GlomClust does the iterative agglomerative clustering, based on a raw similarity matrix as given, -// using a root node that has already been initialized with the starting clusters (all of the -// leaves by default, but could be anything if you want to start with predefined clusters). -// The smat.Mat matrix must be an tensor.Float64. -func GlomClust(root *Node, smat *simat.SimMat, dfunc DistFunc) *Node { - ntot := smat.Mat.DimSize(0) // number of leaves - smatf := smat.Mat.(*tensor.Float64).Values - maxd := stats.Max64(smatf) - // indexes in each group - aidx := make([]int, ntot) - bidx := make([]int, ntot) - for { - var ma, mb []int - mval := math.MaxFloat64 - for ai, ka := range root.Kids { - actr := 0 - ka.Indexes(aidx, &actr) - aix := aidx[0:actr] - for bi := 0; bi < ai; bi++ { - kb := root.Kids[bi] - bctr := 0 - kb.Indexes(bidx, &bctr) - bix := bidx[0:bctr] - dv := dfunc(aix, bix, ntot, maxd, smatf) - if dv < mval { - mval = dv - ma = []int{ai} - mb = []int{bi} - } else if dv == mval { // do all ties at same time - ma = append(ma, ai) - mb = append(mb, bi) - } - } - } - ni := 0 - if len(ma) > 1 { - ni = rand.Intn(len(ma)) - } - na := ma[ni] - nb := mb[ni] - // fmt.Printf("merging nodes at dist: %v: %v and %v\nA: %v\nB: %v\n", mval, na, nb, root.Kids[na].Sprint(smat, 0), root.Kids[nb].Sprint(smat, 0)) - nn := NewNode(root.Kids[na], root.Kids[nb], mval) - for i := len(root.Kids) - 1; i >= 0; i-- { - if i == na || i == nb { - root.Kids = append(root.Kids[:i], root.Kids[i+1:]...) - } - } - root.Kids = append(root.Kids, nn) - if len(root.Kids) == 1 { - break - } - } - return root -} diff --git a/tensor/stats/clust/dist.go b/tensor/stats/clust/dist.go deleted file mode 100644 index a60acf62b9..0000000000 --- a/tensor/stats/clust/dist.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package clust - -import ( - "math" -) - -// DistFunc is a clustering distance function that evaluates aggregate distance -// between nodes, given the indexes of leaves in a and b clusters -// which are indexs into an ntot x ntot similarity (distance) matrix smat. -// maxd is the maximum distance value in the smat, which is needed by the -// ContrastDist function and perhaps others. -type DistFunc func(aix, bix []int, ntot int, maxd float64, smat []float64) float64 - -// MinDist is the minimum-distance or single-linkage weighting function for comparing -// two clusters a and b, given by their list of indexes. -// ntot is total number of nodes, and smat is the square similarity matrix [ntot x ntot]. -func MinDist(aix, bix []int, ntot int, maxd float64, smat []float64) float64 { - md := math.MaxFloat64 - for _, ai := range aix { - for _, bi := range bix { - d := smat[ai*ntot+bi] - if d < md { - md = d - } - } - } - return md -} - -// MaxDist is the maximum-distance or complete-linkage weighting function for comparing -// two clusters a and b, given by their list of indexes. -// ntot is total number of nodes, and smat is the square similarity matrix [ntot x ntot]. -func MaxDist(aix, bix []int, ntot int, maxd float64, smat []float64) float64 { - md := -math.MaxFloat64 - for _, ai := range aix { - for _, bi := range bix { - d := smat[ai*ntot+bi] - if d > md { - md = d - } - } - } - return md -} - -// AvgDist is the average-distance or average-linkage weighting function for comparing -// two clusters a and b, given by their list of indexes. -// ntot is total number of nodes, and smat is the square similarity matrix [ntot x ntot]. -func AvgDist(aix, bix []int, ntot int, maxd float64, smat []float64) float64 { - md := 0.0 - n := 0 - for _, ai := range aix { - for _, bi := range bix { - d := smat[ai*ntot+bi] - md += d - n++ - } - } - if n > 0 { - md /= float64(n) - } - return md -} - -// ContrastDist computes maxd + (average within distance - average between distance) -// for two clusters a and b, given by their list of indexes. -// avg between is average distance between all items in a & b versus all outside that. -// ntot is total number of nodes, and smat is the square similarity matrix [ntot x ntot]. -// maxd is the maximum distance and is needed to ensure distances are positive. -func ContrastDist(aix, bix []int, ntot int, maxd float64, smat []float64) float64 { - wd := AvgDist(aix, bix, ntot, maxd, smat) - nab := len(aix) + len(bix) - abix := append(aix, bix...) - abmap := make(map[int]struct{}, ntot-nab) - for _, ix := range abix { - abmap[ix] = struct{}{} - } - oix := make([]int, ntot-nab) - octr := 0 - for ix := 0; ix < ntot; ix++ { - if _, has := abmap[ix]; !has { - oix[octr] = ix - octr++ - } - } - bd := AvgDist(abix, oix, ntot, maxd, smat) - return maxd + (wd - bd) -} - -// StdDists are standard clustering distance functions -type StdDists int32 //enums:enum - -const ( - // Min is the minimum-distance or single-linkage weighting function - Min StdDists = iota - - // Max is the maximum-distance or complete-linkage weighting function - Max - - // Avg is the average-distance or average-linkage weighting function - Avg - - // Contrast computes maxd + (average within distance - average between distance) - Contrast -) - -// StdFunc returns a standard distance function as specified -func StdFunc(std StdDists) DistFunc { - switch std { - case Min: - return MinDist - case Max: - return MaxDist - case Avg: - return AvgDist - case Contrast: - return ContrastDist - } - return nil -} diff --git a/tensor/stats/clust/enumgen.go b/tensor/stats/clust/enumgen.go deleted file mode 100644 index 1cbbc383b5..0000000000 --- a/tensor/stats/clust/enumgen.go +++ /dev/null @@ -1,48 +0,0 @@ -// Code generated by "core generate"; DO NOT EDIT. - -package clust - -import ( - "cogentcore.org/core/enums" -) - -var _StdDistsValues = []StdDists{0, 1, 2, 3} - -// StdDistsN is the highest valid value for type StdDists, plus one. -const StdDistsN StdDists = 4 - -var _StdDistsValueMap = map[string]StdDists{`Min`: 0, `Max`: 1, `Avg`: 2, `Contrast`: 3} - -var _StdDistsDescMap = map[StdDists]string{0: `Min is the minimum-distance or single-linkage weighting function`, 1: `Max is the maximum-distance or complete-linkage weighting function`, 2: `Avg is the average-distance or average-linkage weighting function`, 3: `Contrast computes maxd + (average within distance - average between distance)`} - -var _StdDistsMap = map[StdDists]string{0: `Min`, 1: `Max`, 2: `Avg`, 3: `Contrast`} - -// String returns the string representation of this StdDists value. -func (i StdDists) String() string { return enums.String(i, _StdDistsMap) } - -// SetString sets the StdDists value from its string representation, -// and returns an error if the string is invalid. -func (i *StdDists) SetString(s string) error { - return enums.SetString(i, s, _StdDistsValueMap, "StdDists") -} - -// Int64 returns the StdDists value as an int64. -func (i StdDists) Int64() int64 { return int64(i) } - -// SetInt64 sets the StdDists value from an int64. -func (i *StdDists) SetInt64(in int64) { *i = StdDists(in) } - -// Desc returns the description of the StdDists value. -func (i StdDists) Desc() string { return enums.Desc(i, _StdDistsDescMap) } - -// StdDistsValues returns all possible values for the type StdDists. -func StdDistsValues() []StdDists { return _StdDistsValues } - -// Values returns all possible values for the type StdDists. -func (i StdDists) Values() []enums.Enum { return enums.Values(_StdDistsValues) } - -// MarshalText implements the [encoding.TextMarshaler] interface. -func (i StdDists) MarshalText() ([]byte, error) { return []byte(i.String()), nil } - -// UnmarshalText implements the [encoding.TextUnmarshaler] interface. -func (i *StdDists) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "StdDists") } diff --git a/tensor/stats/clust/testdata/faces.dat b/tensor/stats/clust/testdata/faces.dat deleted file mode 100644 index 53912dc13a..0000000000 --- a/tensor/stats/clust/testdata/faces.dat +++ /dev/null @@ -1,13 +0,0 @@ -_H: $Name %Input[2:0,0]<2:16,16> %Input[2:0,1] %Input[2:0,2] %Input[2:0,3] %Input[2:0,4] %Input[2:0,5] %Input[2:0,6] %Input[2:0,7] %Input[2:0,8] %Input[2:0,9] %Input[2:0,10] %Input[2:0,11] %Input[2:0,12] %Input[2:0,13] %Input[2:0,14] %Input[2:0,15] %Input[2:1,0] %Input[2:1,1] %Input[2:1,2] %Input[2:1,3] %Input[2:1,4] %Input[2:1,5] %Input[2:1,6] %Input[2:1,7] %Input[2:1,8] %Input[2:1,9] %Input[2:1,10] %Input[2:1,11] %Input[2:1,12] %Input[2:1,13] %Input[2:1,14] %Input[2:1,15] %Input[2:2,0] %Input[2:2,1] %Input[2:2,2] %Input[2:2,3] %Input[2:2,4] %Input[2:2,5] %Input[2:2,6] %Input[2:2,7] %Input[2:2,8] %Input[2:2,9] %Input[2:2,10] %Input[2:2,11] %Input[2:2,12] %Input[2:2,13] %Input[2:2,14] %Input[2:2,15] %Input[2:3,0] %Input[2:3,1] %Input[2:3,2] %Input[2:3,3] %Input[2:3,4] %Input[2:3,5] %Input[2:3,6] %Input[2:3,7] %Input[2:3,8] %Input[2:3,9] %Input[2:3,10] %Input[2:3,11] %Input[2:3,12] %Input[2:3,13] %Input[2:3,14] %Input[2:3,15] %Input[2:4,0] %Input[2:4,1] %Input[2:4,2] %Input[2:4,3] %Input[2:4,4] %Input[2:4,5] %Input[2:4,6] %Input[2:4,7] %Input[2:4,8] %Input[2:4,9] %Input[2:4,10] %Input[2:4,11] %Input[2:4,12] %Input[2:4,13] %Input[2:4,14] %Input[2:4,15] %Input[2:5,0] %Input[2:5,1] %Input[2:5,2] %Input[2:5,3] %Input[2:5,4] %Input[2:5,5] %Input[2:5,6] %Input[2:5,7] %Input[2:5,8] %Input[2:5,9] %Input[2:5,10] %Input[2:5,11] %Input[2:5,12] %Input[2:5,13] %Input[2:5,14] %Input[2:5,15] %Input[2:6,0] %Input[2:6,1] %Input[2:6,2] %Input[2:6,3] %Input[2:6,4] %Input[2:6,5] %Input[2:6,6] %Input[2:6,7] %Input[2:6,8] %Input[2:6,9] %Input[2:6,10] %Input[2:6,11] %Input[2:6,12] %Input[2:6,13] %Input[2:6,14] %Input[2:6,15] %Input[2:7,0] %Input[2:7,1] %Input[2:7,2] %Input[2:7,3] %Input[2:7,4] %Input[2:7,5] %Input[2:7,6] %Input[2:7,7] %Input[2:7,8] %Input[2:7,9] %Input[2:7,10] %Input[2:7,11] %Input[2:7,12] %Input[2:7,13] %Input[2:7,14] %Input[2:7,15] %Input[2:8,0] %Input[2:8,1] %Input[2:8,2] %Input[2:8,3] %Input[2:8,4] %Input[2:8,5] %Input[2:8,6] %Input[2:8,7] %Input[2:8,8] %Input[2:8,9] %Input[2:8,10] %Input[2:8,11] %Input[2:8,12] %Input[2:8,13] %Input[2:8,14] %Input[2:8,15] %Input[2:9,0] %Input[2:9,1] %Input[2:9,2] %Input[2:9,3] %Input[2:9,4] %Input[2:9,5] %Input[2:9,6] %Input[2:9,7] %Input[2:9,8] %Input[2:9,9] %Input[2:9,10] %Input[2:9,11] %Input[2:9,12] %Input[2:9,13] %Input[2:9,14] %Input[2:9,15] %Input[2:10,0] %Input[2:10,1] %Input[2:10,2] %Input[2:10,3] %Input[2:10,4] %Input[2:10,5] %Input[2:10,6] %Input[2:10,7] %Input[2:10,8] %Input[2:10,9] %Input[2:10,10] %Input[2:10,11] %Input[2:10,12] %Input[2:10,13] %Input[2:10,14] %Input[2:10,15] %Input[2:11,0] %Input[2:11,1] %Input[2:11,2] %Input[2:11,3] %Input[2:11,4] %Input[2:11,5] %Input[2:11,6] %Input[2:11,7] %Input[2:11,8] %Input[2:11,9] %Input[2:11,10] %Input[2:11,11] %Input[2:11,12] %Input[2:11,13] %Input[2:11,14] %Input[2:11,15] %Input[2:12,0] %Input[2:12,1] %Input[2:12,2] %Input[2:12,3] %Input[2:12,4] %Input[2:12,5] %Input[2:12,6] %Input[2:12,7] %Input[2:12,8] %Input[2:12,9] %Input[2:12,10] %Input[2:12,11] %Input[2:12,12] %Input[2:12,13] %Input[2:12,14] %Input[2:12,15] %Input[2:13,0] %Input[2:13,1] %Input[2:13,2] %Input[2:13,3] %Input[2:13,4] %Input[2:13,5] %Input[2:13,6] %Input[2:13,7] %Input[2:13,8] %Input[2:13,9] %Input[2:13,10] %Input[2:13,11] %Input[2:13,12] %Input[2:13,13] %Input[2:13,14] %Input[2:13,15] %Input[2:14,0] %Input[2:14,1] %Input[2:14,2] %Input[2:14,3] %Input[2:14,4] %Input[2:14,5] %Input[2:14,6] %Input[2:14,7] %Input[2:14,8] %Input[2:14,9] %Input[2:14,10] %Input[2:14,11] %Input[2:14,12] %Input[2:14,13] %Input[2:14,14] %Input[2:14,15] %Input[2:15,0] %Input[2:15,1] %Input[2:15,2] %Input[2:15,3] %Input[2:15,4] %Input[2:15,5] %Input[2:15,6] %Input[2:15,7] %Input[2:15,8] %Input[2:15,9] %Input[2:15,10] %Input[2:15,11] %Input[2:15,12] %Input[2:15,13] %Input[2:15,14] %Input[2:15,15] %Emotion[2:0,0]<2:1,2> %Emotion[2:0,1] %Gender[2:0,0]<2:1,2> %Gender[2:0,1] %Identity[2:0,0]<2:1,10> %Identity[2:0,1] %Identity[2:0,2] %Identity[2:0,3] %Identity[2:0,4] %Identity[2:0,5] %Identity[2:0,6] %Identity[2:0,7] %Identity[2:0,8] %Identity[2:0,9] -_D: Alberto_happy 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 -_D: Alberto_sad 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 -_D: Betty_happy 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 1 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 1 0 1 0 1 0 1 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 0 -_D: Betty_sad 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 1 1 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 -_D: Lisa_happy 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 1 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 0 1 0 1 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 -_D: Lisa_sad 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 1 1 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 -_D: Mark_happy 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 -_D: Mark_sad 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 1 0 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 -_D: Wendy_happy 0 1 1 1 0 0 1 1 1 1 0 0 1 1 1 0 0 0 0 1 0 1 0 0 0 0 1 0 1 0 0 0 0 0 1 0 1 0 0 1 1 0 0 1 0 1 0 0 0 0 0 1 0 0 1 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 0 1 0 1 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0 0 -_D: Wendy_sad 0 1 1 1 0 0 1 1 1 1 0 0 1 1 1 0 0 0 0 1 0 1 0 0 0 0 1 1 1 0 0 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 0 0 1 0 0 0 1 1 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 1 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 -_D: Zane_happy 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 -_D: Zane_sad 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0 diff --git a/tensor/stats/cluster/README.md b/tensor/stats/cluster/README.md new file mode 100644 index 0000000000..9057a68d40 --- /dev/null +++ b/tensor/stats/cluster/README.md @@ -0,0 +1,14 @@ +# cluster + +`cluster` implements agglomerative clustering of items based on [metric](../metric) distance `Matrix` data (which is provided as an input, and must have been generated with a distance-like metric (increasing with dissimiliarity). + +There are different standard ways of accumulating the aggregate distance of a node based on its leaves: + +* `Min`: the minimum-distance across leaves, i.e., the single-linkage weighting function. +* `Max`: the maximum-distance across leaves, i.e,. the complete-linkage weighting function. +* `Avg`: the average-distance across leaves, i.e., the average-linkage weighting function. +* `Contrast`: is Max + (average within distance - average between distance). + +`GlomCluster` is the main function, taking different `ClusterFunc` options for comparing distance between items. + + diff --git a/tensor/stats/clust/clust_test.go b/tensor/stats/cluster/clust_test.go similarity index 57% rename from tensor/stats/clust/clust_test.go rename to tensor/stats/cluster/clust_test.go index 858dc0dd6b..bdfc4df69b 100644 --- a/tensor/stats/clust/clust_test.go +++ b/tensor/stats/cluster/clust_test.go @@ -2,14 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package clust +package cluster import ( "testing" "cogentcore.org/core/base/tolassert" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/stats/simat" "cogentcore.org/core/tensor/table" ) @@ -28,20 +28,15 @@ var clustres = ` 3.605551275463989: Wendy_sad Wendy_happy ` func TestClust(t *testing.T) { - dt := &table.Table{} - err := dt.OpenCSV("testdata/faces.dat", table.Tab) + dt := table.New() + err := dt.OpenCSV("testdata/faces.dat", tensor.Tab) if err != nil { t.Error(err) } - ix := table.NewIndexView(dt) - smat := &simat.SimMat{} - smat.TableColumn(ix, "Input", "Name", false, metric.Euclidean64) + in := dt.Column("Input") + out := metric.Matrix(metric.L2Norm, in) - // fmt.Printf("%v\n", smat.Mat) - // cl := Glom(smat, MinDist) - cl := Glom(smat, AvgDist) - // s := cl.Sprint(smat, 0) - // fmt.Println(s) + cl := Cluster(Avg.String(), out, dt.Column("Name")) var dists []float64 @@ -54,7 +49,7 @@ func TestClust(t *testing.T) { } gather(cl) - exdists := []float64{0, 9.181170003996987, 5.534356399283667, 4.859933131085473, 3.4641016151377544, 0, 0, 3.4641016151377544, 0, 0, 3.4641016151377544, 0, 0, 5.111664626761644, 4.640135790634417, 4, 0, 0, 3.4641016151377544, 0, 0, 3.605551275463989, 0, 0} + exdists := []float64{0, 9.181170119179619, 5.534356355667114, 4.859933137893677, 3.464101552963257, 0, 0, 3.464101552963257, 0, 0, 3.464101552963257, 0, 0, 5.111664593219757, 4.640135824680328, 4, 0, 0, 3.464101552963257, 0, 0, 3.605551242828369, 0, 0} - tolassert.EqualTolSlice(t, exdists, dists, 1.0e-8) + tolassert.EqualTolSlice(t, exdists, dists, 1.0e-7) } diff --git a/tensor/stats/cluster/cluster.go b/tensor/stats/cluster/cluster.go new file mode 100644 index 0000000000..8570662689 --- /dev/null +++ b/tensor/stats/cluster/cluster.go @@ -0,0 +1,163 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cluster + +//go:generate core generate + +import ( + "fmt" + "math" + "math/rand" + + "cogentcore.org/core/base/indent" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/stats/stats" +) + +// todo: all of this data goes into the tensorfs +// Cluster makes a new dir, stuffs results in there! +// need a global "cwd" that it uses, so basically you cd +// to a dir, then call it. + +// Node is one node in the cluster +type Node struct { + // index into original distance matrix; only valid for for terminal leaves. + Index int + + // Distance value for this node, i.e., how far apart were all the kids from + // each other when this node was created. is 0 for leaf nodes + Dist float64 + + // ParDist is total aggregate distance from parents; The X axis offset at which our cluster starts. + ParDist float64 + + // Y is y-axis value for this node; if a parent, it is the average of its kids Y's, + // otherwise it counts down. + Y float64 + + // Kids are child nodes under this one. + Kids []*Node +} + +// IsLeaf returns true if node is a leaf of the tree with no kids +func (nn *Node) IsLeaf() bool { + return len(nn.Kids) == 0 +} + +// Sprint prints to string +func (nn *Node) Sprint(labels tensor.Tensor, depth int) string { + if nn.IsLeaf() && labels != nil { + return labels.String1D(nn.Index) + " " + } + sv := fmt.Sprintf("\n%v%v: ", indent.Tabs(depth), nn.Dist) + for _, kn := range nn.Kids { + sv += kn.Sprint(labels, depth+1) + } + return sv +} + +// Indexes collects all the indexes in this node +func (nn *Node) Indexes(ix []int, ctr *int) { + if nn.IsLeaf() { + ix[*ctr] = nn.Index + (*ctr)++ + } else { + for _, kn := range nn.Kids { + kn.Indexes(ix, ctr) + } + } +} + +// NewNode merges two nodes into a new node +func NewNode(na, nb *Node, dst float64) *Node { + nn := &Node{Dist: dst} + nn.Kids = []*Node{na, nb} + return nn +} + +// TODO: this call signature does not fit with standard +// not sure how one might pack Node into a tensor + +// Cluster implements agglomerative clustering, based on a +// distance matrix dmat, e.g., as computed by metric.Matrix method, +// using a metric that increases in value with greater dissimilarity. +// labels provides an optional String tensor list of labels for the elements +// of the distance matrix. +// This calls InitAllLeaves to initialize the root node with all of the leaves, +// and then Glom to do the iterative agglomerative clustering process. +// If you want to start with pre-defined initial clusters, +// then call Glom with a root node so-initialized. +func Cluster(funcName string, dmat, labels tensor.Tensor) *Node { + ntot := dmat.DimSize(0) // number of leaves + root := InitAllLeaves(ntot) + return Glom(root, funcName, dmat) +} + +// InitAllLeaves returns a standard root node initialized with all of the leaves. +func InitAllLeaves(ntot int) *Node { + root := &Node{} + root.Kids = make([]*Node, ntot) + for i := 0; i < ntot; i++ { + root.Kids[i] = &Node{Index: i} + } + return root +} + +// Glom does the iterative agglomerative clustering, +// based on a raw similarity matrix as given, +// using a root node that has already been initialized +// with the starting clusters, which is all of the +// leaves by default, but could be anything if you want +// to start with predefined clusters. +func Glom(root *Node, funcName string, dmat tensor.Tensor) *Node { + ntot := dmat.DimSize(0) // number of leaves + mout := tensor.NewFloat64Scalar(0) + stats.MaxOut(tensor.As1D(dmat), mout) + maxd := mout.Float1D(0) + // indexes in each group + aidx := make([]int, ntot) + bidx := make([]int, ntot) + for { + var ma, mb []int + mval := math.MaxFloat64 + for ai, ka := range root.Kids { + actr := 0 + ka.Indexes(aidx, &actr) + aix := aidx[0:actr] + for bi := 0; bi < ai; bi++ { + kb := root.Kids[bi] + bctr := 0 + kb.Indexes(bidx, &bctr) + bix := bidx[0:bctr] + dv := Call(funcName, aix, bix, ntot, maxd, dmat) + if dv < mval { + mval = dv + ma = []int{ai} + mb = []int{bi} + } else if dv == mval { // do all ties at same time + ma = append(ma, ai) + mb = append(mb, bi) + } + } + } + ni := 0 + if len(ma) > 1 { + ni = rand.Intn(len(ma)) + } + na := ma[ni] + nb := mb[ni] + nn := NewNode(root.Kids[na], root.Kids[nb], mval) + for i := len(root.Kids) - 1; i >= 0; i-- { + if i == na || i == nb { + root.Kids = append(root.Kids[:i], root.Kids[i+1:]...) + } + } + root.Kids = append(root.Kids, nn) + if len(root.Kids) == 1 { + break + } + } + return root +} diff --git a/tensor/stats/cluster/enumgen.go b/tensor/stats/cluster/enumgen.go new file mode 100644 index 0000000000..d13044ae4e --- /dev/null +++ b/tensor/stats/cluster/enumgen.go @@ -0,0 +1,48 @@ +// Code generated by "core generate"; DO NOT EDIT. + +package cluster + +import ( + "cogentcore.org/core/enums" +) + +var _MetricsValues = []Metrics{0, 1, 2, 3} + +// MetricsN is the highest valid value for type Metrics, plus one. +const MetricsN Metrics = 4 + +var _MetricsValueMap = map[string]Metrics{`Min`: 0, `Max`: 1, `Avg`: 2, `Contrast`: 3} + +var _MetricsDescMap = map[Metrics]string{0: `Min is the minimum-distance or single-linkage weighting function.`, 1: `Max is the maximum-distance or complete-linkage weighting function.`, 2: `Avg is the average-distance or average-linkage weighting function.`, 3: `Contrast computes maxd + (average within distance - average between distance).`} + +var _MetricsMap = map[Metrics]string{0: `Min`, 1: `Max`, 2: `Avg`, 3: `Contrast`} + +// String returns the string representation of this Metrics value. +func (i Metrics) String() string { return enums.String(i, _MetricsMap) } + +// SetString sets the Metrics value from its string representation, +// and returns an error if the string is invalid. +func (i *Metrics) SetString(s string) error { + return enums.SetString(i, s, _MetricsValueMap, "Metrics") +} + +// Int64 returns the Metrics value as an int64. +func (i Metrics) Int64() int64 { return int64(i) } + +// SetInt64 sets the Metrics value from an int64. +func (i *Metrics) SetInt64(in int64) { *i = Metrics(in) } + +// Desc returns the description of the Metrics value. +func (i Metrics) Desc() string { return enums.Desc(i, _MetricsDescMap) } + +// MetricsValues returns all possible values for the type Metrics. +func MetricsValues() []Metrics { return _MetricsValues } + +// Values returns all possible values for the type Metrics. +func (i Metrics) Values() []enums.Enum { return enums.Values(_MetricsValues) } + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (i Metrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil } + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (i *Metrics) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Metrics") } diff --git a/tensor/stats/cluster/funcs.go b/tensor/stats/cluster/funcs.go new file mode 100644 index 0000000000..a46ac8ccda --- /dev/null +++ b/tensor/stats/cluster/funcs.go @@ -0,0 +1,135 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cluster + +import ( + "log/slog" + "math" + + "cogentcore.org/core/tensor" +) + +// Metrics are standard clustering distance metric functions, +// specifying how a node computes its distance based on its leaves. +type Metrics int32 //enums:enum + +const ( + // Min is the minimum-distance or single-linkage weighting function. + Min Metrics = iota + + // Max is the maximum-distance or complete-linkage weighting function. + Max + + // Avg is the average-distance or average-linkage weighting function. + Avg + + // Contrast computes maxd + (average within distance - average between distance). + Contrast +) + +// MetricFunc is a clustering distance metric function that evaluates aggregate distance +// between nodes, given the indexes of leaves in a and b clusters +// which are indexs into an ntot x ntot distance matrix dmat. +// maxd is the maximum distance value in the dmat, which is needed by the +// ContrastDist function and perhaps others. +type MetricFunc func(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 + +// Funcs is a registry of clustering metric functions, +// initialized with the standard options. +var Funcs map[string]MetricFunc + +func init() { + Funcs = make(map[string]MetricFunc) + Funcs[Min.String()] = MinFunc + Funcs[Max.String()] = MaxFunc + Funcs[Avg.String()] = AvgFunc + Funcs[Contrast.String()] = ContrastFunc +} + +// Call calls a cluster metric function by name. +func Call(funcName string, aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 { + fun, ok := Funcs[funcName] + if !ok { + slog.Error("cluster.Call: function not found", "function:", funcName) + return 0 + } + return fun(aix, bix, ntot, maxd, dmat) +} + +// MinFunc is the minimum-distance or single-linkage weighting function for comparing +// two clusters a and b, given by their list of indexes. +// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot]. +func MinFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 { + md := math.MaxFloat64 + for _, ai := range aix { + for _, bi := range bix { + d := dmat.Float(ai, bi) + if d < md { + md = d + } + } + } + return md +} + +// MaxFunc is the maximum-distance or complete-linkage weighting function for comparing +// two clusters a and b, given by their list of indexes. +// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot]. +func MaxFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 { + md := -math.MaxFloat64 + for _, ai := range aix { + for _, bi := range bix { + d := dmat.Float(ai, bi) + if d > md { + md = d + } + } + } + return md +} + +// AvgFunc is the average-distance or average-linkage weighting function for comparing +// two clusters a and b, given by their list of indexes. +// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot]. +func AvgFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 { + md := 0.0 + n := 0 + for _, ai := range aix { + for _, bi := range bix { + d := dmat.Float(ai, bi) + md += d + n++ + } + } + if n > 0 { + md /= float64(n) + } + return md +} + +// ContrastFunc computes maxd + (average within distance - average between distance) +// for two clusters a and b, given by their list of indexes. +// avg between is average distance between all items in a & b versus all outside that. +// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot]. +// maxd is the maximum distance and is needed to ensure distances are positive. +func ContrastFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 { + wd := AvgFunc(aix, bix, ntot, maxd, dmat) + nab := len(aix) + len(bix) + abix := append(aix, bix...) + abmap := make(map[int]struct{}, ntot-nab) + for _, ix := range abix { + abmap[ix] = struct{}{} + } + oix := make([]int, ntot-nab) + octr := 0 + for ix := 0; ix < ntot; ix++ { + if _, has := abmap[ix]; !has { + oix[octr] = ix + octr++ + } + } + bd := AvgFunc(abix, oix, ntot, maxd, dmat) + return maxd + (wd - bd) +} diff --git a/tensor/stats/clust/plot.go b/tensor/stats/cluster/plot.go similarity index 66% rename from tensor/stats/clust/plot.go rename to tensor/stats/cluster/plot.go index 51f9144926..dba9039850 100644 --- a/tensor/stats/clust/plot.go +++ b/tensor/stats/cluster/plot.go @@ -2,17 +2,17 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package clust +package cluster import ( - "cogentcore.org/core/tensor/stats/simat" + "cogentcore.org/core/tensor" "cogentcore.org/core/tensor/table" ) // Plot sets the rows of given data table to trace out lines with labels that // will render cluster plot starting at root node when plotted with a standard plotting package. // The lines double-back on themselves to form a continuous line to be plotted. -func Plot(pt *table.Table, root *Node, smat *simat.SimMat) { +func Plot(pt *table.Table, root *Node, dmat, labels tensor.Tensor) { pt.DeleteAll() pt.AddFloat64Column("X") pt.AddFloat64Column("Y") @@ -20,39 +20,42 @@ func Plot(pt *table.Table, root *Node, smat *simat.SimMat) { nextY := 0.5 root.SetYs(&nextY) root.SetParDist(0.0) - root.Plot(pt, smat) + root.Plot(pt, dmat, labels) } // Plot sets the rows of given data table to trace out lines with labels that // will render this node in a cluster plot when plotted with a standard plotting package. // The lines double-back on themselves to form a continuous line to be plotted. -func (nn *Node) Plot(pt *table.Table, smat *simat.SimMat) { - row := pt.Rows +func (nn *Node) Plot(pt *table.Table, dmat, labels tensor.Tensor) { + row := pt.NumRows() + xc := pt.ColumnByIndex(0) + yc := pt.ColumnByIndex(1) + lbl := pt.ColumnByIndex(2) if nn.IsLeaf() { pt.SetNumRows(row + 1) - pt.SetFloatIndex(0, row, nn.ParDist) - pt.SetFloatIndex(1, row, nn.Y) - if len(smat.Rows) > nn.Index { - pt.SetStringIndex(2, row, smat.Rows[nn.Index]) + xc.SetFloatRow(nn.ParDist, row, 0) + yc.SetFloatRow(nn.Y, row, 0) + if labels.Len() > nn.Index { + lbl.SetStringRow(labels.StringValue(nn.Index), row, 0) } } else { for _, kn := range nn.Kids { pt.SetNumRows(row + 2) - pt.SetFloatIndex(0, row, nn.ParDist) - pt.SetFloatIndex(1, row, kn.Y) + xc.SetFloatRow(nn.ParDist, row, 0) + yc.SetFloatRow(kn.Y, row, 0) row++ - pt.SetFloatIndex(0, row, nn.ParDist+nn.Dist) - pt.SetFloatIndex(1, row, kn.Y) - kn.Plot(pt, smat) - row = pt.Rows + xc.SetFloatRow(nn.ParDist+nn.Dist, row, 0) + yc.SetFloatRow(kn.Y, row, 0) + kn.Plot(pt, dmat, labels) + row = pt.NumRows() pt.SetNumRows(row + 1) - pt.SetFloatIndex(0, row, nn.ParDist) - pt.SetFloatIndex(1, row, kn.Y) + xc.SetFloatRow(nn.ParDist, row, 0) + yc.SetFloatRow(kn.Y, row, 0) row++ } pt.SetNumRows(row + 1) - pt.SetFloatIndex(0, row, nn.ParDist) - pt.SetFloatIndex(1, row, nn.Y) + xc.SetFloatRow(nn.ParDist, row, 0) + yc.SetFloatRow(nn.Y, row, 0) } } diff --git a/tensor/stats/cluster/testdata/faces.dat b/tensor/stats/cluster/testdata/faces.dat new file mode 100644 index 0000000000..ad88d27b6f --- /dev/null +++ b/tensor/stats/cluster/testdata/faces.dat @@ -0,0 +1,13 @@ +$Name %Input[2:0,0]<2:16,16> %Input[2:0,1] %Input[2:0,2] %Input[2:0,3] %Input[2:0,4] %Input[2:0,5] %Input[2:0,6] %Input[2:0,7] %Input[2:0,8] %Input[2:0,9] %Input[2:0,10] %Input[2:0,11] %Input[2:0,12] %Input[2:0,13] %Input[2:0,14] %Input[2:0,15] %Input[2:1,0] %Input[2:1,1] %Input[2:1,2] %Input[2:1,3] %Input[2:1,4] %Input[2:1,5] %Input[2:1,6] %Input[2:1,7] %Input[2:1,8] %Input[2:1,9] %Input[2:1,10] %Input[2:1,11] %Input[2:1,12] %Input[2:1,13] %Input[2:1,14] %Input[2:1,15] %Input[2:2,0] %Input[2:2,1] %Input[2:2,2] %Input[2:2,3] %Input[2:2,4] %Input[2:2,5] %Input[2:2,6] %Input[2:2,7] %Input[2:2,8] %Input[2:2,9] %Input[2:2,10] %Input[2:2,11] %Input[2:2,12] %Input[2:2,13] %Input[2:2,14] %Input[2:2,15] %Input[2:3,0] %Input[2:3,1] %Input[2:3,2] %Input[2:3,3] %Input[2:3,4] %Input[2:3,5] %Input[2:3,6] %Input[2:3,7] %Input[2:3,8] %Input[2:3,9] %Input[2:3,10] %Input[2:3,11] %Input[2:3,12] %Input[2:3,13] %Input[2:3,14] %Input[2:3,15] %Input[2:4,0] %Input[2:4,1] %Input[2:4,2] %Input[2:4,3] %Input[2:4,4] %Input[2:4,5] %Input[2:4,6] %Input[2:4,7] %Input[2:4,8] %Input[2:4,9] %Input[2:4,10] %Input[2:4,11] %Input[2:4,12] %Input[2:4,13] %Input[2:4,14] %Input[2:4,15] %Input[2:5,0] %Input[2:5,1] %Input[2:5,2] %Input[2:5,3] %Input[2:5,4] %Input[2:5,5] %Input[2:5,6] %Input[2:5,7] %Input[2:5,8] %Input[2:5,9] %Input[2:5,10] %Input[2:5,11] %Input[2:5,12] %Input[2:5,13] %Input[2:5,14] %Input[2:5,15] %Input[2:6,0] %Input[2:6,1] %Input[2:6,2] %Input[2:6,3] %Input[2:6,4] %Input[2:6,5] %Input[2:6,6] %Input[2:6,7] %Input[2:6,8] %Input[2:6,9] %Input[2:6,10] %Input[2:6,11] %Input[2:6,12] %Input[2:6,13] %Input[2:6,14] %Input[2:6,15] %Input[2:7,0] %Input[2:7,1] %Input[2:7,2] %Input[2:7,3] %Input[2:7,4] %Input[2:7,5] %Input[2:7,6] %Input[2:7,7] %Input[2:7,8] %Input[2:7,9] %Input[2:7,10] %Input[2:7,11] %Input[2:7,12] %Input[2:7,13] %Input[2:7,14] %Input[2:7,15] %Input[2:8,0] %Input[2:8,1] %Input[2:8,2] %Input[2:8,3] %Input[2:8,4] %Input[2:8,5] %Input[2:8,6] %Input[2:8,7] %Input[2:8,8] %Input[2:8,9] %Input[2:8,10] %Input[2:8,11] %Input[2:8,12] %Input[2:8,13] %Input[2:8,14] %Input[2:8,15] %Input[2:9,0] %Input[2:9,1] %Input[2:9,2] %Input[2:9,3] %Input[2:9,4] %Input[2:9,5] %Input[2:9,6] %Input[2:9,7] %Input[2:9,8] %Input[2:9,9] %Input[2:9,10] %Input[2:9,11] %Input[2:9,12] %Input[2:9,13] %Input[2:9,14] %Input[2:9,15] %Input[2:10,0] %Input[2:10,1] %Input[2:10,2] %Input[2:10,3] %Input[2:10,4] %Input[2:10,5] %Input[2:10,6] %Input[2:10,7] %Input[2:10,8] %Input[2:10,9] %Input[2:10,10] %Input[2:10,11] %Input[2:10,12] %Input[2:10,13] %Input[2:10,14] %Input[2:10,15] %Input[2:11,0] %Input[2:11,1] %Input[2:11,2] %Input[2:11,3] %Input[2:11,4] %Input[2:11,5] %Input[2:11,6] %Input[2:11,7] %Input[2:11,8] %Input[2:11,9] %Input[2:11,10] %Input[2:11,11] %Input[2:11,12] %Input[2:11,13] %Input[2:11,14] %Input[2:11,15] %Input[2:12,0] %Input[2:12,1] %Input[2:12,2] %Input[2:12,3] %Input[2:12,4] %Input[2:12,5] %Input[2:12,6] %Input[2:12,7] %Input[2:12,8] %Input[2:12,9] %Input[2:12,10] %Input[2:12,11] %Input[2:12,12] %Input[2:12,13] %Input[2:12,14] %Input[2:12,15] %Input[2:13,0] %Input[2:13,1] %Input[2:13,2] %Input[2:13,3] %Input[2:13,4] %Input[2:13,5] %Input[2:13,6] %Input[2:13,7] %Input[2:13,8] %Input[2:13,9] %Input[2:13,10] %Input[2:13,11] %Input[2:13,12] %Input[2:13,13] %Input[2:13,14] %Input[2:13,15] %Input[2:14,0] %Input[2:14,1] %Input[2:14,2] %Input[2:14,3] %Input[2:14,4] %Input[2:14,5] %Input[2:14,6] %Input[2:14,7] %Input[2:14,8] %Input[2:14,9] %Input[2:14,10] %Input[2:14,11] %Input[2:14,12] %Input[2:14,13] %Input[2:14,14] %Input[2:14,15] %Input[2:15,0] %Input[2:15,1] %Input[2:15,2] %Input[2:15,3] %Input[2:15,4] %Input[2:15,5] %Input[2:15,6] %Input[2:15,7] %Input[2:15,8] %Input[2:15,9] %Input[2:15,10] %Input[2:15,11] %Input[2:15,12] %Input[2:15,13] %Input[2:15,14] %Input[2:15,15] %Emotion[2:0,0]<2:1,2> %Emotion[2:0,1] %Gender[2:0,0]<2:1,2> %Gender[2:0,1] %Identity[2:0,0]<2:1,10> %Identity[2:0,1] %Identity[2:0,2] %Identity[2:0,3] %Identity[2:0,4] %Identity[2:0,5] %Identity[2:0,6] %Identity[2:0,7] %Identity[2:0,8] %Identity[2:0,9] +Alberto_happy 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 +Alberto_sad 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 +Betty_happy 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 1 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 1 0 1 0 1 0 1 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 0 +Betty_sad 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 1 1 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 +Lisa_happy 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 1 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 0 1 0 1 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 +Lisa_sad 1 0 1 0 0 0 1 1 1 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 1 1 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 +Mark_happy 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 +Mark_sad 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 1 0 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 +Wendy_happy 0 1 1 1 0 0 1 1 1 1 0 0 1 1 1 0 0 0 0 1 0 1 0 0 0 0 1 0 1 0 0 0 0 0 1 0 1 0 0 1 1 0 0 1 0 1 0 0 0 0 0 1 0 0 1 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 0 1 0 1 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0 0 +Wendy_sad 0 1 1 1 0 0 1 1 1 1 0 0 1 1 1 0 0 0 0 1 0 1 0 0 0 0 1 1 1 0 0 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 0 0 1 0 0 0 1 1 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 1 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 +Zane_happy 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 +Zane_sad 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0 diff --git a/tensor/stats/convolve/table.go b/tensor/stats/convolve/table.go index 1381e21ef7..ebe2ea88a2 100644 --- a/tensor/stats/convolve/table.go +++ b/tensor/stats/convolve/table.go @@ -4,13 +4,7 @@ package convolve -import ( - "reflect" - - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/table" -) - +/* // SmoothTable returns a cloned table with each of the floating-point // columns in the source table smoothed over rows. // khalf is the half-width of the Gaussian smoothing kernel, @@ -34,3 +28,4 @@ func SmoothTable(src *table.Table, khalf int) *table.Table { } return dest } +*/ diff --git a/tensor/stats/glm/README.md b/tensor/stats/glm/README.md index ffe5d5703d..cb6b0c323d 100644 --- a/tensor/stats/glm/README.md +++ b/tensor/stats/glm/README.md @@ -2,7 +2,7 @@ GLM contains results and parameters for running a [general linear model](https://en.wikipedia.org/wiki/General_linear_model), which is a general form of multivariate linear regression, supporting multiple independent and dependent variables. -Make a `NewGLM` and then do `Run()` on a tensor [IndexView](../table/IndexView) with the relevant data in columns of the table. +Make a `NewGLM` and then do `Run()` on a [table](../table) with the relevant data in columns of the table. # Fitting Methods diff --git a/tensor/stats/glm/glm.go b/tensor/stats/glm/glm.go index cbac41a520..27a730b01a 100644 --- a/tensor/stats/glm/glm.go +++ b/tensor/stats/glm/glm.go @@ -18,7 +18,7 @@ import ( // linear model, which is a general form of multivariate linear // regression, supporting multiple independent and dependent // variables. Make a NewGLM and then do Run() on a tensor -// table.IndexView with the relevant data in columns of the table. +// table.Table with the relevant data in columns of the table. // Batch-mode gradient descent is used and the relevant parameters // can be altered from defaults before calling Run as needed. type GLM struct { @@ -47,8 +47,7 @@ type GLM struct { // optional names of the dependent variables, for reporting results DepNames []string - /////////////////////////////////////////// - // Parameters for the GLM model fitting: + //////// Parameters for the GLM model fitting: // ZeroOffset restricts the offset of the linear function to 0, // forcing it to pass through the origin. Otherwise, a constant offset "b" @@ -81,14 +80,13 @@ type GLM struct { // maximum number of iterations to perform MaxIters int `default:"50"` - /////////////////////////////////////////// - // Cached values from the table + //////// Cached values from the table // Table of data - Table *table.IndexView + Table *table.Table // tensor columns from table with the respective variables - IndepVars, DepVars, PredVars, ErrVars tensor.Tensor + IndepVars, DepVars, PredVars, ErrVars tensor.RowMajor // Number of independent and dependent variables NIndepVars, NDepVars int @@ -110,7 +108,8 @@ func (glm *GLM) Defaults() { func (glm *GLM) init(nIv, nDv int) { glm.NIndepVars = nIv glm.NDepVars = nDv - glm.Coeff.SetShape([]int{nDv, nIv + 1}, "DepVars", "IndepVars") + glm.Coeff.SetShapeSizes(nDv, nIv+1) + // glm.Coeff.SetNames("DepVars", "IndepVars") glm.R2 = make([]float64, nDv) glm.ObsVariance = make([]float64, nDv) glm.ErrVariance = make([]float64, nDv) @@ -122,28 +121,15 @@ func (glm *GLM) init(nIv, nDv int) { // each of the Vars args specifies a column in the table, which can have either a // single scalar value for each row, or a tensor cell with multiple values. // predVars and errVars (predicted values and error values) are optional. -func (glm *GLM) SetTable(ix *table.IndexView, indepVars, depVars, predVars, errVars string) error { - dt := ix.Table - iv, err := dt.ColumnByName(indepVars) - if err != nil { - return err - } - dv, err := dt.ColumnByName(depVars) - if err != nil { - return err - } - var pv, ev tensor.Tensor +func (glm *GLM) SetTable(dt *table.Table, indepVars, depVars, predVars, errVars string) error { + iv := dt.Column(indepVars) + dv := dt.Column(depVars) + var pv, ev *tensor.Rows if predVars != "" { - pv, err = dt.ColumnByName(predVars) - if err != nil { - return err - } + pv = dt.Column(predVars) } if errVars != "" { - ev, err = dt.ColumnByName(errVars) - if err != nil { - return err - } + ev = dt.Column(errVars) } if pv != nil && !pv.Shape().IsEqual(dv.Shape()) { return fmt.Errorf("predVars must have same shape as depVars") @@ -151,10 +137,10 @@ func (glm *GLM) SetTable(ix *table.IndexView, indepVars, depVars, predVars, errV if ev != nil && !ev.Shape().IsEqual(dv.Shape()) { return fmt.Errorf("errVars must have same shape as depVars") } - _, nIv := iv.RowCellSize() - _, nDv := dv.RowCellSize() + _, nIv := iv.Shape().RowCellSize() + _, nDv := dv.Shape().RowCellSize() glm.init(nIv, nDv) - glm.Table = ix + glm.Table = dt glm.IndepVars = iv glm.DepVars = dv glm.PredVars = pv @@ -168,17 +154,17 @@ func (glm *GLM) SetTable(ix *table.IndexView, indepVars, depVars, predVars, errV // Initial values of the coefficients, and other parameters for the regression, // should be set prior to running. func (glm *GLM) Run() { - ix := glm.Table + dt := glm.Table iv := glm.IndepVars dv := glm.DepVars pv := glm.PredVars ev := glm.ErrVars if pv == nil { - pv = dv.Clone() + pv = tensor.Clone(dv) } if ev == nil { - ev = dv.Clone() + ev = tensor.Clone(dv) } nDv := glm.NDepVars @@ -190,7 +176,7 @@ func (glm *GLM) Run() { lastItr := false sse := 0.0 prevmse := 0.0 - n := ix.Len() + n := dt.NumRows() norm := 1.0 / float64(n) lrate := norm * glm.LRate for itr := 0; itr < glm.MaxIters; itr++ { @@ -202,28 +188,28 @@ func (glm *GLM) Run() { lrate *= 0.5 } for i := 0; i < n; i++ { - row := ix.Indexes[i] + row := dt.RowIndex(i) for di := 0; di < nDv; di++ { pred := 0.0 for ii := 0; ii < nIv; ii++ { - pred += glm.Coeff.Value([]int{di, ii}) * iv.FloatRowCell(row, ii) + pred += glm.Coeff.Float(di, ii) * iv.FloatRow(row, ii) } if !glm.ZeroOffset { - pred += glm.Coeff.Value([]int{di, nIv}) + pred += glm.Coeff.Float(di, nIv) } - targ := dv.FloatRowCell(row, di) + targ := dv.FloatRow(row, di) err := targ - pred sse += err * err for ii := 0; ii < nIv; ii++ { - dc.Values[di*nCi+ii] += err * iv.FloatRowCell(row, ii) + dc.Values[di*nCi+ii] += err * iv.FloatRow(row, ii) } if !glm.ZeroOffset { dc.Values[di*nCi+nIv] += err } if lastItr { - pv.SetFloatRowCell(row, di, pred) + pv.SetFloatRow(pred, row, di) if ev != nil { - ev.SetFloatRowCell(row, di, err) + ev.SetFloatRow(err, row, di) } } } @@ -262,10 +248,13 @@ func (glm *GLM) Run() { obsMeans := make([]float64, nDv) errMeans := make([]float64, nDv) for i := 0; i < n; i++ { - row := ix.Indexes[i] + row := i + if dt.Indexes != nil { + row = dt.Indexes[i] + } for di := 0; di < nDv; di++ { - obsMeans[di] += dv.FloatRowCell(row, di) - errMeans[di] += ev.FloatRowCell(row, di) + obsMeans[di] += dv.FloatRow(row, di) + errMeans[di] += ev.FloatRow(row, di) } } for di := 0; di < nDv; di++ { @@ -275,11 +264,14 @@ func (glm *GLM) Run() { glm.ErrVariance[di] = 0 } for i := 0; i < n; i++ { - row := ix.Indexes[i] + row := i + if dt.Indexes != nil { + row = dt.Indexes[i] + } for di := 0; di < nDv; di++ { - o := dv.FloatRowCell(row, di) - obsMeans[di] + o := dv.FloatRow(row, di) - obsMeans[di] glm.ObsVariance[di] += o * o - e := ev.FloatRowCell(row, di) - errMeans[di] + e := ev.FloatRow(row, di) - errMeans[di] glm.ErrVariance[di] += e * e } } @@ -317,7 +309,7 @@ func (glm *GLM) Coeffs() string { } str += " = " for ii := 0; ii <= glm.NIndepVars; ii++ { - str += fmt.Sprintf("\t%8.6g", glm.Coeff.Value([]int{di, ii})) + str += fmt.Sprintf("\t%8.6g", glm.Coeff.Float(di, ii)) if ii < glm.NIndepVars { str += " * " if len(glm.IndepNames) > ii && glm.IndepNames[di] != "" { diff --git a/tensor/stats/histogram/histogram.go b/tensor/stats/histogram/histogram.go index 3cf7e73c3a..7f2d093c57 100644 --- a/tensor/stats/histogram/histogram.go +++ b/tensor/stats/histogram/histogram.go @@ -46,56 +46,11 @@ func F64Table(dt *table.Table, vals []float64, nBins int, min, max float64) { dt.AddFloat64Column("Value") dt.AddFloat64Column("Count") dt.SetNumRows(nBins) - ct := dt.Columns[1].(*tensor.Float64) + ct := dt.Columns.Values[1].(*tensor.Float64) F64(&ct.Values, vals, nBins, min, max) inc := (max - min) / float64(nBins) - vls := dt.Columns[0].(*tensor.Float64).Values + vls := dt.Columns.Values[0].(*tensor.Float64).Values for i := 0; i < nBins; i++ { vls[i] = math32.Truncate64(min+float64(i)*inc, 4) } } - -////////////////////////////////////////////////////// -// float32 - -// F32 generates a histogram of counts of values within given -// number of bins and min / max range. hist vals is sized to nBins. -// if value is < min or > max it is ignored. -func F32(hist *[]float32, vals []float32, nBins int, min, max float32) { - *hist = slicesx.SetLength(*hist, nBins) - h := *hist - // 0.1.2.3 = 3-0 = 4 bins - inc := (max - min) / float32(nBins) - for i := 0; i < nBins; i++ { - h[i] = 0 - } - for _, v := range vals { - if v < min || v > max { - continue - } - bin := int((v - min) / inc) - if bin >= nBins { - bin = nBins - 1 - } - h[bin] += 1 - } -} - -// F32Table generates an table with a histogram of counts of values within given -// number of bins and min / max range. The table has columns: Value, Count -// if value is < min or > max it is ignored. -// The Value column represents the min value for each bin, with the max being -// the value of the next bin, or the max if at the end. -func F32Table(dt *table.Table, vals []float32, nBins int, min, max float32) { - dt.DeleteAll() - dt.AddFloat32Column("Value") - dt.AddFloat32Column("Count") - dt.SetNumRows(nBins) - ct := dt.Columns[1].(*tensor.Float32) - F32(&ct.Values, vals, nBins, min, max) - inc := (max - min) / float32(nBins) - vls := dt.Columns[0].(*tensor.Float32).Values - for i := 0; i < nBins; i++ { - vls[i] = math32.Truncate(min+float32(i)*inc, 4) - } -} diff --git a/tensor/stats/histogram/histogram_test.go b/tensor/stats/histogram/histogram_test.go index e2c31cee5e..4806692268 100644 --- a/tensor/stats/histogram/histogram_test.go +++ b/tensor/stats/histogram/histogram_test.go @@ -7,30 +7,9 @@ package histogram import ( "testing" - "cogentcore.org/core/tensor/table" "github.com/stretchr/testify/assert" ) -func TestHistogram32(t *testing.T) { - vals := []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} - ex := []float32{4, 3, 4} - res := []float32{} - - F32(&res, vals, 3, 0, 1) - - assert.Equal(t, ex, res) - - exvals := []float32{0, 0.3333, 0.6667} - dt := table.NewTable() - F32Table(dt, vals, 3, 0, 1) - for ri, v := range ex { - vv := float32(dt.Float("Value", ri)) - cv := float32(dt.Float("Count", ri)) - assert.Equal(t, exvals[ri], vv) - assert.Equal(t, v, cv) - } -} - func TestHistogram64(t *testing.T) { vals := []float64{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} ex := []float64{4, 3, 4} @@ -40,13 +19,13 @@ func TestHistogram64(t *testing.T) { assert.Equal(t, ex, res) - exvals := []float64{0, 0.3333, 0.6667} - dt := table.NewTable() - F64Table(dt, vals, 3, 0, 1) - for ri, v := range ex { - vv := dt.Float("Value", ri) - cv := dt.Float("Count", ri) - assert.Equal(t, exvals[ri], vv) - assert.Equal(t, v, cv) - } + // exvals := []float64{0, 0.3333, 0.6667} + // dt := table.New() + // F64Table(dt, vals, 3, 0, 1) + // for ri, v := range ex { + // vv := dt.Float("Value", ri) + // cv := dt.Float("Count", ri) + // assert.Equal(t, exvals[ri], vv) + // assert.Equal(t, v, cv) + // } } diff --git a/tensor/stats/metric/README.md b/tensor/stats/metric/README.md index 3a09f6d033..7d1174182c 100644 --- a/tensor/stats/metric/README.md +++ b/tensor/stats/metric/README.md @@ -1,7 +1,57 @@ # metric -`metric` provides various similarity / distance metrics for comparing floating-point vectors. All functions have 32 and 64 bit variants, and skip NaN's (often used for missing) and will panic if the lengths of the two slices are unequal (no error return). +`metric` provides various similarity / distance metrics for comparing two tensors, operating on the `tensor.Tensor` standard data representation, using this standard function: +```Go +type MetricFunc func(a, b, out tensor.Tensor) error +``` -The signatures of all such metric functions are identical, captured as types: `metric.Func32` and `metric.Func64` so that other functions that use a metric can take a pointer to any such function. +The metric functions always operate on the outermost _row_ dimension, and it is up to the caller to reshape the tensors to accomplish the desired results. The two tensors must have the same shape. + +* To obtain a single summary metric across all values, use `tensor.As1D`. + +* For `RowMajor` data that is naturally organized as a single outer _rows_ dimension with the remaining inner dimensions comprising the _cells_, the results are the metric for each such cell computed across the outer rows dimension. For the `L2Norm` metric for example, each cell has the difference for that cell value across all the rows between the two tensors. See [Matrix functions](#matrix-functions) below for a function that computes the distances _between each cell pattern and all the others_, as a distance or similarity matrix. + +* Use `tensor.NewRowCellsView` to reshape any tensor into a 2D rows x cells shape, with the cells starting at a given dimension. Thus, any number of outer dimensions can be collapsed into the outer row dimension, and the remaining dimensions become the cells. + +## Metrics + +### Value _increases_ with increasing distance (i.e., difference metric) + +* `L2Norm`: the square root of the sum of squares differences between tensor values. +* `SumSquares`: the sum of squares differences between tensor values. +* `Abs`or `L2Norm`: the sum of the absolute value of differences between tensor values. +* `Hamming`: the sum of 1s for every element that is different, i.e., "city block" distance. +* `L2NormBinTol`: the `L2Norm` square root of the sum of squares differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0. +* `SumSquaresBinTol`: the `SumSquares` differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0. +* `InvCosine`: is 1-`Cosine`, which is useful to convert it to an Increasing metric where more different vectors have larger metric values. +* `InvCorrelation`: is 1-`Correlation`, which is useful to convert it to an Increasing metric where more different vectors have larger metric values. +* `CrossEntropy`: is a standard measure of the difference between two probabilty distributions, reflecting the additional entropy (uncertainty) associated with measuring probabilities under distribution b when in fact they come from distribution a. It is also the entropy of a plus the divergence between a from b, using Kullback-Leibler (KL) divergence. It is computed as: a * log(a/b) + (1-a) * log(1-a/1-b). + +### Value _decreases_ with increasing distance (i.e., similarity metric) + +* `DotProduct`: the sum of the co-products of the tensor values. +* `Covariance`: the co-variance between two vectors, i.e., the mean of the co-product of each vector element minus the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))]. +* `Correlation`: the standardized `Covariance` in the range (-1..1), computed as the mean of the co-product of each vector element minus the mean of that vector, normalized by the product of their standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). Equivalent to the `Cosine` of mean-normalized vectors. +* `Cosine`: the high-dimensional angle between two vectors, in range (-1..1) as the normalized `DotProduct`: inner product / sqrt(ssA * ssB). See also `Correlation`. + +Here is general info about these functions: + +The output must be a `tensor.Values` tensor, and it is automatically shaped to hold the stat value(s) for the "cells" in higher-dimensional tensors, and a single scalar value for a 1D input tensor. + +All metric functions skip over `NaN`'s, as a missing value. + +Metric functions cannot be computed in parallel, e.g., using VectorizeThreaded or GPU, due to shared writing to the same output values. Special implementations are required if that is needed. + +# Matrix functions + +* `Matrix` computes a distance / similarity matrix using a metric function, operating on the n-dimensional sub-space patterns on a given tensor (i.e., a row-wise list of patterns). The result is a square rows x rows matrix where each cell is the metric value for the pattern at the given row. The diagonal contains the self-similarity metric. + +* `CrossMatrix` is like `Matrix` except it compares two separate lists of patterns. + +* `CovarianceMatrix` computes the _covariance matrix_ for row-wise lists of patterns, where the result is a square matrix of cells x cells size ("cells" is number of elements in the patterns per row), and each value represents the extent to which value of a given cell covaries across the rows of the tensor with the value of another cell. For example, if the rows represent time, then the covariance matrix represents the extent to which the patterns tend to move in the same way over time. + + See [matrix](../../matrix) for `EigSym` and `SVD` functions that compute the "principal components" (PCA) of covariance, in terms of the _eigenvectors_ and corresponding _eigenvalues_ of this matrix. The eigenvector (component) with the largest eigenvalue is the "direction" in n-dimensional pattern space along which there is the greatest variance in the patterns across the rows. + + There is also a `matrix.ProjectOnMatrixColumn` convenience function for projecting data along a vector extracted from a matrix, which allows you to project data along an eigenvector from the PCA or SVD functions. By doing this projection along the strongest 2 eigenvectors (those with the largest eigenvalues), you can visualize high-dimensional data in a 2D plot, which typically reveals important aspects of the structure of the underlying high-dimensional data, which is otherwise hard to see given the difficulty in visualizing high-dimensional spaces. diff --git a/tensor/stats/metric/abs.go b/tensor/stats/metric/abs.go deleted file mode 100644 index 7b0dbfe18b..0000000000 --- a/tensor/stats/metric/abs.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package metric - -import ( - "math" - - "cogentcore.org/core/math32" -) - -/////////////////////////////////////////// -// Abs - -// Abs32 computes the sum of absolute value of differences (L1 Norm). -// Skips NaN's and panics if lengths are not equal. -func Abs32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - ss += math32.Abs(av - bv) - } - return ss -} - -// Abs64 computes the sum of absolute value of differences (L1 Norm). -// Skips NaN's and panics if lengths are not equal. -func Abs64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - ss += math.Abs(av - bv) - } - return ss -} - -/////////////////////////////////////////// -// Hamming - -// Hamming32 computes the sum of 1's for every element that is different -// (city block). -// Skips NaN's and panics if lengths are not equal. -func Hamming32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - if av != bv { - ss += 1 - } - } - return ss -} - -// Hamming64 computes the sum of absolute value of differences (L1 Norm). -// Skips NaN's and panics if lengths are not equal. -func Hamming64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - if av != bv { - ss += 1 - } - } - return ss -} diff --git a/tensor/stats/metric/doc.go b/tensor/stats/metric/doc.go index 724f01572b..2375767852 100644 --- a/tensor/stats/metric/doc.go +++ b/tensor/stats/metric/doc.go @@ -3,13 +3,6 @@ // license that can be found in the LICENSE file. /* -Package metric provides various similarity / distance metrics for comparing -floating-point vectors. -All functions have 32 and 64 bit variants, and skip NaN's (often used for missing) -and will panic if the lengths of the two slices are unequal (no error return). - -The signatures of all such metric functions are identical, captured as types: -metric.Func32 and metric.Func64 so that other functions that use a metric -can take a pointer to any such function. +Package metric provides various similarity / distance metrics for comparing tensors, operating on the tensor.Tensor standard data representation. */ package metric diff --git a/tensor/stats/metric/enumgen.go b/tensor/stats/metric/enumgen.go index e99276abeb..c9802a2353 100644 --- a/tensor/stats/metric/enumgen.go +++ b/tensor/stats/metric/enumgen.go @@ -6,45 +6,43 @@ import ( "cogentcore.org/core/enums" ) -var _StdMetricsValues = []StdMetrics{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} +var _MetricsValues = []Metrics{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} -// StdMetricsN is the highest valid value for type StdMetrics, plus one. -const StdMetricsN StdMetrics = 13 +// MetricsN is the highest valid value for type Metrics, plus one. +const MetricsN Metrics = 13 -var _StdMetricsValueMap = map[string]StdMetrics{`Euclidean`: 0, `SumSquares`: 1, `Abs`: 2, `Hamming`: 3, `EuclideanBinTol`: 4, `SumSquaresBinTol`: 5, `InvCosine`: 6, `InvCorrelation`: 7, `CrossEntropy`: 8, `InnerProduct`: 9, `Covariance`: 10, `Correlation`: 11, `Cosine`: 12} +var _MetricsValueMap = map[string]Metrics{`L2Norm`: 0, `SumSquares`: 1, `L1Norm`: 2, `Hamming`: 3, `L2NormBinTol`: 4, `SumSquaresBinTol`: 5, `InvCosine`: 6, `InvCorrelation`: 7, `CrossEntropy`: 8, `DotProduct`: 9, `Covariance`: 10, `Correlation`: 11, `Cosine`: 12} -var _StdMetricsDescMap = map[StdMetrics]string{0: ``, 1: ``, 2: ``, 3: ``, 4: ``, 5: ``, 6: `InvCosine is 1-Cosine -- useful to convert into an Increasing metric`, 7: `InvCorrelation is 1-Correlation -- useful to convert into an Increasing metric`, 8: ``, 9: `Everything below here is !Increasing -- larger = closer, not farther`, 10: ``, 11: ``, 12: ``} +var _MetricsDescMap = map[Metrics]string{0: `L2Norm is the square root of the sum of squares differences between tensor values, aka the L2 Norm.`, 1: `SumSquares is the sum of squares differences between tensor values.`, 2: `L1Norm is the sum of the absolute value of differences between tensor values, the L1 Norm.`, 3: `Hamming is the sum of 1s for every element that is different, i.e., "city block" distance.`, 4: `L2NormBinTol is the [L2Norm] square root of the sum of squares differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0.`, 5: `SumSquaresBinTol is the [SumSquares] differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0.`, 6: `InvCosine is 1-[Cosine], which is useful to convert it to an Increasing metric where more different vectors have larger metric values.`, 7: `InvCorrelation is 1-[Correlation], which is useful to convert it to an Increasing metric where more different vectors have larger metric values.`, 8: `CrossEntropy is a standard measure of the difference between two probabilty distributions, reflecting the additional entropy (uncertainty) associated with measuring probabilities under distribution b when in fact they come from distribution a. It is also the entropy of a plus the divergence between a from b, using Kullback-Leibler (KL) divergence. It is computed as: a * log(a/b) + (1-a) * log(1-a/1-b).`, 9: `DotProduct is the sum of the co-products of the tensor values.`, 10: `Covariance is co-variance between two vectors, i.e., the mean of the co-product of each vector element minus the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].`, 11: `Correlation is the standardized [Covariance] in the range (-1..1), computed as the mean of the co-product of each vector element minus the mean of that vector, normalized by the product of their standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). Equivalent to the [Cosine] of mean-normalized vectors.`, 12: `Cosine is high-dimensional angle between two vectors, in range (-1..1) as the normalized [DotProduct]: inner product / sqrt(ssA * ssB). See also [Correlation].`} -var _StdMetricsMap = map[StdMetrics]string{0: `Euclidean`, 1: `SumSquares`, 2: `Abs`, 3: `Hamming`, 4: `EuclideanBinTol`, 5: `SumSquaresBinTol`, 6: `InvCosine`, 7: `InvCorrelation`, 8: `CrossEntropy`, 9: `InnerProduct`, 10: `Covariance`, 11: `Correlation`, 12: `Cosine`} +var _MetricsMap = map[Metrics]string{0: `L2Norm`, 1: `SumSquares`, 2: `L1Norm`, 3: `Hamming`, 4: `L2NormBinTol`, 5: `SumSquaresBinTol`, 6: `InvCosine`, 7: `InvCorrelation`, 8: `CrossEntropy`, 9: `DotProduct`, 10: `Covariance`, 11: `Correlation`, 12: `Cosine`} -// String returns the string representation of this StdMetrics value. -func (i StdMetrics) String() string { return enums.String(i, _StdMetricsMap) } +// String returns the string representation of this Metrics value. +func (i Metrics) String() string { return enums.String(i, _MetricsMap) } -// SetString sets the StdMetrics value from its string representation, +// SetString sets the Metrics value from its string representation, // and returns an error if the string is invalid. -func (i *StdMetrics) SetString(s string) error { - return enums.SetString(i, s, _StdMetricsValueMap, "StdMetrics") +func (i *Metrics) SetString(s string) error { + return enums.SetString(i, s, _MetricsValueMap, "Metrics") } -// Int64 returns the StdMetrics value as an int64. -func (i StdMetrics) Int64() int64 { return int64(i) } +// Int64 returns the Metrics value as an int64. +func (i Metrics) Int64() int64 { return int64(i) } -// SetInt64 sets the StdMetrics value from an int64. -func (i *StdMetrics) SetInt64(in int64) { *i = StdMetrics(in) } +// SetInt64 sets the Metrics value from an int64. +func (i *Metrics) SetInt64(in int64) { *i = Metrics(in) } -// Desc returns the description of the StdMetrics value. -func (i StdMetrics) Desc() string { return enums.Desc(i, _StdMetricsDescMap) } +// Desc returns the description of the Metrics value. +func (i Metrics) Desc() string { return enums.Desc(i, _MetricsDescMap) } -// StdMetricsValues returns all possible values for the type StdMetrics. -func StdMetricsValues() []StdMetrics { return _StdMetricsValues } +// MetricsValues returns all possible values for the type Metrics. +func MetricsValues() []Metrics { return _MetricsValues } -// Values returns all possible values for the type StdMetrics. -func (i StdMetrics) Values() []enums.Enum { return enums.Values(_StdMetricsValues) } +// Values returns all possible values for the type Metrics. +func (i Metrics) Values() []enums.Enum { return enums.Values(_MetricsValues) } // MarshalText implements the [encoding.TextMarshaler] interface. -func (i StdMetrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil } +func (i Metrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil } // UnmarshalText implements the [encoding.TextUnmarshaler] interface. -func (i *StdMetrics) UnmarshalText(text []byte) error { - return enums.UnmarshalText(i, text, "StdMetrics") -} +func (i *Metrics) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Metrics") } diff --git a/tensor/stats/metric/funcs.go b/tensor/stats/metric/funcs.go new file mode 100644 index 0000000000..d40601602f --- /dev/null +++ b/tensor/stats/metric/funcs.go @@ -0,0 +1,539 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metric + +import ( + "math" + + "cogentcore.org/core/math32" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/stats/stats" +) + +// MetricFunc is the function signature for a metric function, +// which is computed over the outermost row dimension and the +// output is the shape of the remaining inner cells (a scalar for 1D inputs). +// Use [tensor.As1D], [tensor.NewRowCellsView], [tensor.Cells1D] etc +// to reshape and reslice the data as needed. +// All metric functions skip over NaN's, as a missing value, +// and use the min of the length of the two tensors. +// Metric functions cannot be computed in parallel, +// e.g., using VectorizeThreaded or GPU, due to shared writing +// to the same output values. Special implementations are required +// if that is needed. +type MetricFunc = func(a, b tensor.Tensor) tensor.Values + +// MetricOutFunc is the function signature for a metric function, +// that takes output values as the final argument. See [MetricFunc]. +// This version is for computationally demanding cases and saves +// reallocation of output. +type MetricOutFunc = func(a, b tensor.Tensor, out tensor.Values) error + +// SumSquaresScaleOut64 computes the sum of squares differences between tensor values, +// returning scale and ss factors aggregated separately for better numerical stability, per BLAS. +func SumSquaresScaleOut64(a, b tensor.Tensor) (scale64, ss64 *tensor.Float64, err error) { + if err = tensor.MustBeSameShape(a, b); err != nil { + return + } + scale64, ss64 = Vectorize2Out64(a, b, 0, 1, func(a, b, scale, ss float64) (float64, float64) { + if math.IsNaN(a) || math.IsNaN(b) { + return scale, ss + } + d := a - b + if d == 0 { + return scale, ss + } + absxi := math.Abs(d) + if scale < absxi { + ss = 1 + ss*(scale/absxi)*(scale/absxi) + scale = absxi + } else { + ss = ss + (absxi/scale)*(absxi/scale) + } + return scale, ss + }) + return +} + +// SumSquaresOut64 computes the sum of squares differences between tensor values, +// and returns the Float64 output values for use in subsequent computations. +func SumSquaresOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) { + scale64, ss64, err := SumSquaresScaleOut64(a, b) + if err != nil { + return nil, err + } + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + scale := scale64.Float1D(i) + ss := ss64.Float1D(i) + v := 0.0 + if math.IsInf(scale, 1) { + v = math.Inf(1) + } else { + v = scale * scale * ss + } + scale64.SetFloat1D(v, i) + out.SetFloat1D(v, i) + } + return scale64, err +} + +// SumSquaresOut computes the sum of squares differences between tensor values, +// See [MetricOutFunc] for general information. +func SumSquaresOut(a, b tensor.Tensor, out tensor.Values) error { + _, err := SumSquaresOut64(a, b, out) + return err +} + +// SumSquares computes the sum of squares differences between tensor values, +// See [MetricFunc] for general information. +func SumSquares(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(SumSquaresOut, a, b) +} + +// L2NormOut computes the L2 Norm: square root of the sum of squares +// differences between tensor values, aka the Euclidean distance. +func L2NormOut(a, b tensor.Tensor, out tensor.Values) error { + scale64, ss64, err := SumSquaresScaleOut64(a, b) + if err != nil { + return err + } + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + scale := scale64.Float1D(i) + ss := ss64.Float1D(i) + v := 0.0 + if math.IsInf(scale, 1) { + v = math.Inf(1) + } else { + v = scale * math.Sqrt(ss) + } + scale64.SetFloat1D(v, i) + out.SetFloat1D(v, i) + } + return nil +} + +// L2Norm computes the L2 Norm: square root of the sum of squares +// differences between tensor values, aka the Euclidean distance. +// See [MetricFunc] for general information. +func L2Norm(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(L2NormOut, a, b) +} + +// L1NormOut computes the sum of the absolute value of differences between the +// tensor values, the L1 Norm. +// See [MetricOutFunc] for general information. +func L1NormOut(a, b tensor.Tensor, out tensor.Values) error { + if err := tensor.MustBeSameShape(a, b); err != nil { + return err + } + VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 { + if math.IsNaN(a) || math.IsNaN(b) { + return agg + } + return agg + math.Abs(a-b) + }) + return nil +} + +// L1Norm computes the sum of the absolute value of differences between the +// tensor values, the L1 Norm. +// See [MetricFunc] for general information. +func L1Norm(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(L1NormOut, a, b) +} + +// HammingOut computes the sum of 1s for every element that is different, +// i.e., "city block" distance. +// See [MetricOutFunc] for general information. +func HammingOut(a, b tensor.Tensor, out tensor.Values) error { + if err := tensor.MustBeSameShape(a, b); err != nil { + return err + } + VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 { + if math.IsNaN(a) || math.IsNaN(b) { + return agg + } + if a != b { + agg += 1 + } + return agg + }) + return nil +} + +// Hamming computes the sum of 1s for every element that is different, +// i.e., "city block" distance. +// See [MetricFunc] for general information. +func Hamming(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(HammingOut, a, b) +} + +// SumSquaresBinTolScaleOut64 computes the sum of squares differences between tensor values, +// with binary tolerance: differences < 0.5 are thresholded to 0. +// returning scale and ss factors aggregated separately for better numerical stability, per BLAS. +func SumSquaresBinTolScaleOut64(a, b tensor.Tensor) (scale64, ss64 *tensor.Float64, err error) { + if err = tensor.MustBeSameShape(a, b); err != nil { + return + } + scale64, ss64 = Vectorize2Out64(a, b, 0, 1, func(a, b, scale, ss float64) (float64, float64) { + if math.IsNaN(a) || math.IsNaN(b) { + return scale, ss + } + d := a - b + if math.Abs(d) < 0.5 { + return scale, ss + } + absxi := math.Abs(d) + if scale < absxi { + ss = 1 + ss*(scale/absxi)*(scale/absxi) + scale = absxi + } else { + ss = ss + (absxi/scale)*(absxi/scale) + } + return scale, ss + }) + return +} + +// L2NormBinTolOut computes the L2 Norm square root of the sum of squares +// differences between tensor values (aka Euclidean distance), with binary tolerance: +// differences < 0.5 are thresholded to 0. +func L2NormBinTolOut(a, b tensor.Tensor, out tensor.Values) error { + scale64, ss64, err := SumSquaresBinTolScaleOut64(a, b) + if err != nil { + return err + } + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + scale := scale64.Float1D(i) + ss := ss64.Float1D(i) + v := 0.0 + if math.IsInf(scale, 1) { + v = math.Inf(1) + } else { + v = scale * math.Sqrt(ss) + } + scale64.SetFloat1D(v, i) + out.SetFloat1D(v, i) + } + return nil +} + +// L2NormBinTol computes the L2 Norm square root of the sum of squares +// differences between tensor values (aka Euclidean distance), with binary tolerance: +// differences < 0.5 are thresholded to 0. +// See [MetricFunc] for general information. +func L2NormBinTol(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(L2NormBinTolOut, a, b) +} + +// SumSquaresBinTolOut computes the sum of squares differences between tensor values, +// with binary tolerance: differences < 0.5 are thresholded to 0. +func SumSquaresBinTolOut(a, b tensor.Tensor, out tensor.Values) error { + scale64, ss64, err := SumSquaresBinTolScaleOut64(a, b) + if err != nil { + return err + } + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + scale := scale64.Float1D(i) + ss := ss64.Float1D(i) + v := 0.0 + if math.IsInf(scale, 1) { + v = math.Inf(1) + } else { + v = scale * scale * ss + } + scale64.SetFloat1D(v, i) + out.SetFloat1D(v, i) + } + return nil +} + +// SumSquaresBinTol computes the sum of squares differences between tensor values, +// with binary tolerance: differences < 0.5 are thresholded to 0. +// See [MetricFunc] for general information. +func SumSquaresBinTol(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(SumSquaresBinTolOut, a, b) +} + +// CrossEntropyOut is a standard measure of the difference between two +// probabilty distributions, reflecting the additional entropy (uncertainty) associated +// with measuring probabilities under distribution b when in fact they come from +// distribution a. It is also the entropy of a plus the divergence between a from b, +// using Kullback-Leibler (KL) divergence. It is computed as: +// a * log(a/b) + (1-a) * log(1-a/1-b). +// See [MetricOutFunc] for general information. +func CrossEntropyOut(a, b tensor.Tensor, out tensor.Values) error { + if err := tensor.MustBeSameShape(a, b); err != nil { + return err + } + VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 { + if math.IsNaN(a) || math.IsNaN(b) { + return agg + } + b = math32.Clamp(b, 0.000001, 0.999999) + if a >= 1.0 { + agg += -math.Log(b) + } else if a <= 0.0 { + agg += -math.Log(1.0 - b) + } else { + agg += a*math.Log(a/b) + (1-a)*math.Log((1-a)/(1-b)) + } + return agg + }) + return nil +} + +// CrossEntropy is a standard measure of the difference between two +// probabilty distributions, reflecting the additional entropy (uncertainty) associated +// with measuring probabilities under distribution b when in fact they come from +// distribution a. It is also the entropy of a plus the divergence between a from b, +// using Kullback-Leibler (KL) divergence. It is computed as: +// a * log(a/b) + (1-a) * log(1-a/1-b). +// See [MetricFunc] for general information. +func CrossEntropy(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(CrossEntropyOut, a, b) +} + +// DotProductOut computes the sum of the element-wise products of the +// two tensors (aka the inner product). +// See [MetricOutFunc] for general information. +func DotProductOut(a, b tensor.Tensor, out tensor.Values) error { + if err := tensor.MustBeSameShape(a, b); err != nil { + return err + } + VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 { + if math.IsNaN(a) || math.IsNaN(b) { + return agg + } + return agg + a*b + }) + return nil +} + +// DotProductOut computes the sum of the element-wise products of the +// two tensors (aka the inner product). +// See [MetricFunc] for general information. +func DotProduct(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(DotProductOut, a, b) +} + +// CovarianceOut computes the co-variance between two vectors, +// i.e., the mean of the co-product of each vector element minus +// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))]. +func CovarianceOut(a, b tensor.Tensor, out tensor.Values) error { + if err := tensor.MustBeSameShape(a, b); err != nil { + return err + } + amean, acount := stats.MeanOut64(a, out) + bmean, _ := stats.MeanOut64(b, out) + cov64 := VectorizePreOut64(a, b, out, 0, amean, bmean, func(a, b, am, bm, agg float64) float64 { + if math.IsNaN(a) || math.IsNaN(b) { + return agg + } + return agg + (a-am)*(b-bm) + }) + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + c := acount.Float1D(i) + if c == 0 { + continue + } + cov := cov64.Float1D(i) / c + out.SetFloat1D(cov, i) + } + return nil +} + +// Covariance computes the co-variance between two vectors, +// i.e., the mean of the co-product of each vector element minus +// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))]. +// See [MetricFunc] for general information. +func Covariance(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(CovarianceOut, a, b) +} + +// CorrelationOut64 computes the correlation between two vectors, +// in range (-1..1) as the mean of the co-product of each vector +// element minus the mean of that vector, normalized by the product of their +// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). +// (i.e., the standardized covariance). +// Equivalent to the cosine of mean-normalized vectors. +// Returns the Float64 output values for subsequent use. +func CorrelationOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) { + if err := tensor.MustBeSameShape(a, b); err != nil { + return nil, err + } + amean, _ := stats.MeanOut64(a, out) + bmean, _ := stats.MeanOut64(b, out) + ss64, avar64, bvar64 := VectorizePre3Out64(a, b, 0, 0, 0, amean, bmean, func(a, b, am, bm, ss, avar, bvar float64) (float64, float64, float64) { + if math.IsNaN(a) || math.IsNaN(b) { + return ss, avar, bvar + } + ad := a - am + bd := b - bm + ss += ad * bd // between + avar += ad * ad // within + bvar += bd * bd + return ss, avar, bvar + }) + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + ss := ss64.Float1D(i) + vp := math.Sqrt(avar64.Float1D(i) * bvar64.Float1D(i)) + if vp > 0 { + ss /= vp + } + ss64.SetFloat1D(ss, i) + out.SetFloat1D(ss, i) + } + return ss64, nil +} + +// CorrelationOut computes the correlation between two vectors, +// in range (-1..1) as the mean of the co-product of each vector +// element minus the mean of that vector, normalized by the product of their +// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). +// (i.e., the standardized [Covariance]). +// Equivalent to the [Cosine] of mean-normalized vectors. +func CorrelationOut(a, b tensor.Tensor, out tensor.Values) error { + _, err := CorrelationOut64(a, b, out) + return err +} + +// Correlation computes the correlation between two vectors, +// in range (-1..1) as the mean of the co-product of each vector +// element minus the mean of that vector, normalized by the product of their +// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). +// (i.e., the standardized [Covariance]). +// Equivalent to the [Cosine] of mean-normalized vectors. +// See [MetricFunc] for general information. +func Correlation(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(CorrelationOut, a, b) +} + +// InvCorrelationOut computes 1 minus the correlation between two vectors, +// in range (-1..1) as the mean of the co-product of each vector +// element minus the mean of that vector, normalized by the product of their +// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). +// (i.e., the standardized covariance). +// Equivalent to the [Cosine] of mean-normalized vectors. +// This is useful for a difference measure instead of similarity, +// where more different vectors have larger metric values. +func InvCorrelationOut(a, b tensor.Tensor, out tensor.Values) error { + cor64, err := CorrelationOut64(a, b, out) + if err != nil { + return err + } + nsub := out.Len() + for i := range nsub { + cor := cor64.Float1D(i) + out.SetFloat1D(1-cor, i) + } + return nil +} + +// InvCorrelation computes 1 minus the correlation between two vectors, +// in range (-1..1) as the mean of the co-product of each vector +// element minus the mean of that vector, normalized by the product of their +// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). +// (i.e., the standardized covariance). +// Equivalent to the [Cosine] of mean-normalized vectors. +// This is useful for a difference measure instead of similarity, +// where more different vectors have larger metric values. +// See [MetricFunc] for general information. +func InvCorrelation(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(InvCorrelationOut, a, b) +} + +// CosineOut64 computes the high-dimensional angle between two vectors, +// in range (-1..1) as the normalized [Dot]: +// dot product / sqrt(ssA * ssB). See also [Correlation]. +func CosineOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) { + if err := tensor.MustBeSameShape(a, b); err != nil { + return nil, err + } + ss64, avar64, bvar64 := Vectorize3Out64(a, b, 0, 0, 0, func(a, b, ss, avar, bvar float64) (float64, float64, float64) { + if math.IsNaN(a) || math.IsNaN(b) { + return ss, avar, bvar + } + ss += a * b + avar += a * a + bvar += b * b + return ss, avar, bvar + }) + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + ss := ss64.Float1D(i) + vp := math.Sqrt(avar64.Float1D(i) * bvar64.Float1D(i)) + if vp > 0 { + ss /= vp + } + ss64.SetFloat1D(ss, i) + out.SetFloat1D(ss, i) + } + return ss64, nil +} + +// CosineOut computes the high-dimensional angle between two vectors, +// in range (-1..1) as the normalized dot product: +// dot product / sqrt(ssA * ssB). See also [Correlation] +func CosineOut(a, b tensor.Tensor, out tensor.Values) error { + _, err := CosineOut64(a, b, out) + return err +} + +// Cosine computes the high-dimensional angle between two vectors, +// in range (-1..1) as the normalized dot product: +// dot product / sqrt(ssA * ssB). See also [Correlation] +// See [MetricFunc] for general information. +func Cosine(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(CosineOut, a, b) +} + +// InvCosineOut computes 1 minus the cosine between two vectors, +// in range (-1..1) as the normalized dot product: +// dot product / sqrt(ssA * ssB). +// This is useful for a difference measure instead of similarity, +// where more different vectors have larger metric values. +func InvCosineOut(a, b tensor.Tensor, out tensor.Values) error { + cos64, err := CosineOut64(a, b, out) + if err != nil { + return err + } + nsub := out.Len() + for i := range nsub { + cos := cos64.Float1D(i) + out.SetFloat1D(1-cos, i) + } + return nil +} + +// InvCosine computes 1 minus the cosine between two vectors, +// in range (-1..1) as the normalized dot product: +// dot product / sqrt(ssA * ssB). +// This is useful for a difference measure instead of similarity, +// where more different vectors have larger metric values. +// See [MetricFunc] for general information. +func InvCosine(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(InvCosineOut, a, b) +} diff --git a/tensor/stats/metric/matrix.go b/tensor/stats/metric/matrix.go new file mode 100644 index 0000000000..47d0c4afea --- /dev/null +++ b/tensor/stats/metric/matrix.go @@ -0,0 +1,183 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metric + +import ( + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/matrix" +) + +// MatrixOut computes the rows x rows square distance / similarity matrix +// between the patterns for each row of the given higher dimensional input tensor, +// which must have at least 2 dimensions: the outermost rows, +// and within that, 1+dimensional patterns (cells). Use [tensor.NewRowCellsView] +// to organize data into the desired split between a 1D outermost Row dimension +// and the remaining Cells dimension. +// The metric function must have the [MetricFunc] signature. +// The results fill in the elements of the output matrix, which is symmetric, +// and only the lower triangular part is computed, with results copied +// to the upper triangular region, for maximum efficiency. +func MatrixOut(fun any, in tensor.Tensor, out tensor.Values) error { + mfun, err := AsMetricFunc(fun) + if err != nil { + return err + } + rows, cells := in.Shape().RowCellSize() + if rows == 0 || cells == 0 { + return nil + } + out.SetShapeSizes(rows, rows) + coords := matrix.TriLIndicies(rows) + nc := coords.DimSize(0) + // note: flops estimating 3 per item on average -- different for different metrics. + tensor.VectorizeThreaded(cells*3, func(tsr ...tensor.Tensor) int { return nc }, + func(idx int, tsr ...tensor.Tensor) { + cx := coords.Int(idx, 0) + cy := coords.Int(idx, 1) + sa := tensor.Cells1D(tsr[0], cx) + sb := tensor.Cells1D(tsr[0], cy) + mout := mfun(sa, sb) + tsr[1].SetFloat(mout.Float1D(0), cx, cy) + }, in, out) + for idx := range nc { // copy to upper + cx := coords.Int(idx, 0) + cy := coords.Int(idx, 1) + if cx == cy { // exclude diag + continue + } + out.SetFloat(out.Float(cx, cy), cy, cx) + } + return nil +} + +// Matrix computes the rows x rows square distance / similarity matrix +// between the patterns for each row of the given higher dimensional input tensor, +// which must have at least 2 dimensions: the outermost rows, +// and within that, 1+dimensional patterns (cells). Use [tensor.NewRowCellsView] +// to organize data into the desired split between a 1D outermost Row dimension +// and the remaining Cells dimension. +// The metric function must have the [MetricFunc] signature. +// The results fill in the elements of the output matrix, which is symmetric, +// and only the lower triangular part is computed, with results copied +// to the upper triangular region, for maximum efficiency. +func Matrix(fun any, in tensor.Tensor) tensor.Values { + return tensor.CallOut1Gen1(MatrixOut, fun, in) +} + +// CrossMatrixOut computes the distance / similarity matrix between +// two different sets of patterns in the two input tensors, where +// the patterns are in the sub-space cells of the tensors, +// which must have at least 2 dimensions: the outermost rows, +// and within that, 1+dimensional patterns that the given distance metric +// function is applied to, with the results filling in the cells of the output matrix. +// The metric function must have the [MetricFunc] signature. +// The rows of the output matrix are the rows of the first input tensor, +// and the columns of the output are the rows of the second input tensor. +func CrossMatrixOut(fun any, a, b tensor.Tensor, out tensor.Values) error { + mfun, err := AsMetricFunc(fun) + if err != nil { + return err + } + arows, acells := a.Shape().RowCellSize() + if arows == 0 || acells == 0 { + return nil + } + brows, bcells := b.Shape().RowCellSize() + if brows == 0 || bcells == 0 { + return nil + } + out.SetShapeSizes(arows, brows) + // note: flops estimating 3 per item on average -- different for different metrics. + flops := min(acells, bcells) * 3 + nc := arows * brows + tensor.VectorizeThreaded(flops, func(tsr ...tensor.Tensor) int { return nc }, + func(idx int, tsr ...tensor.Tensor) { + ar := idx / brows + br := idx % brows + sa := tensor.Cells1D(tsr[0], ar) + sb := tensor.Cells1D(tsr[1], br) + mout := mfun(sa, sb) + tsr[2].SetFloat(mout.Float1D(0), ar, br) + }, a, b, out) + return nil +} + +// CrossMatrix computes the distance / similarity matrix between +// two different sets of patterns in the two input tensors, where +// the patterns are in the sub-space cells of the tensors, +// which must have at least 2 dimensions: the outermost rows, +// and within that, 1+dimensional patterns that the given distance metric +// function is applied to, with the results filling in the cells of the output matrix. +// The metric function must have the [MetricFunc] signature. +// The rows of the output matrix are the rows of the first input tensor, +// and the columns of the output are the rows of the second input tensor. +func CrossMatrix(fun any, a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2Gen1(CrossMatrixOut, fun, a, b) +} + +// CovarianceMatrixOut generates the cells x cells square covariance matrix +// for all per-row cells of the given higher dimensional input tensor, +// which must have at least 2 dimensions: the outermost rows, +// and within that, 1+dimensional patterns (cells). +// Each value in the resulting matrix represents the extent to which the +// value of a given cell covaries across the rows of the tensor with the +// value of another cell. +// Uses the given metric function, typically [Covariance] or [Correlation], +// The metric function must have the [MetricFunc] signature. +// Use Covariance if vars have similar overall scaling, +// which is typical in neural network models, and use +// Correlation if they are on very different scales, because it effectively rescales). +// The resulting matrix can be used as the input to PCA or SVD eigenvalue decomposition. +func CovarianceMatrixOut(fun any, in tensor.Tensor, out tensor.Values) error { + mfun, err := AsMetricFunc(fun) + if err != nil { + return err + } + rows, cells := in.Shape().RowCellSize() + if rows == 0 || cells == 0 { + return nil + } + out.SetShapeSizes(cells, cells) + flatvw := tensor.NewReshaped(in, rows, cells) + + coords := matrix.TriLIndicies(cells) + nc := coords.DimSize(0) + // note: flops estimating 3 per item on average -- different for different metrics. + tensor.VectorizeThreaded(rows*3, func(tsr ...tensor.Tensor) int { return nc }, + func(idx int, tsr ...tensor.Tensor) { + cx := coords.Int(idx, 0) + cy := coords.Int(idx, 1) + av := tensor.Reslice(tsr[0], tensor.FullAxis, cx) + bv := tensor.Reslice(tsr[0], tensor.FullAxis, cy) + mout := mfun(av, bv) + tsr[1].SetFloat(mout.Float1D(0), cx, cy) + }, flatvw, out) + for idx := range nc { // copy to upper + cx := coords.Int(idx, 0) + cy := coords.Int(idx, 1) + if cx == cy { // exclude diag + continue + } + out.SetFloat(out.Float(cx, cy), cy, cx) + } + return nil +} + +// CovarianceMatrix generates the cells x cells square covariance matrix +// for all per-row cells of the given higher dimensional input tensor, +// which must have at least 2 dimensions: the outermost rows, +// and within that, 1+dimensional patterns (cells). +// Each value in the resulting matrix represents the extent to which the +// value of a given cell covaries across the rows of the tensor with the +// value of another cell. +// Uses the given metric function, typically [Covariance] or [Correlation], +// The metric function must have the [MetricFunc] signature. +// Use Covariance if vars have similar overall scaling, +// which is typical in neural network models, and use +// Correlation if they are on very different scales, because it effectively rescales). +// The resulting matrix can be used as the input to PCA or SVD eigenvalue decomposition. +func CovarianceMatrix(fun any, in tensor.Tensor) tensor.Values { + return tensor.CallOut1Gen1(CovarianceMatrixOut, fun, in) +} diff --git a/tensor/stats/metric/matrix_test.go b/tensor/stats/metric/matrix_test.go new file mode 100644 index 0000000000..a64d49bbcd --- /dev/null +++ b/tensor/stats/metric/matrix_test.go @@ -0,0 +1,79 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metric + +import ( + "fmt" + "testing" + + "cogentcore.org/core/base/tolassert" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "github.com/stretchr/testify/assert" +) + +func TestCovarIris(t *testing.T) { + // note: these results are verified against this example: + // https://plot.ly/ipython-notebooks/principal-component-analysis/ + + dt := table.New() + dt.AddFloat64Column("data", 4) + dt.AddStringColumn("class") + err := dt.OpenCSV("testdata/iris.data", tensor.Comma) + if err != nil { + t.Error(err) + } + data := dt.Column("data") + covar := tensor.NewFloat64() + err = CovarianceMatrixOut(Correlation, data, covar) + assert.NoError(t, err) + // fmt.Printf("covar: %s\n", covar.String()) + // tensor.SaveCSV(covar, "testdata/iris-covar.tsv", tensor.Tab) + + cv := []float64{1, -0.10936924995064935, 0.8717541573048719, 0.8179536333691635, + -0.10936924995064935, 1, -0.4205160964011548, -0.3565440896138057, + 0.8717541573048719, -0.4205160964011548, 1, 0.9627570970509667, + 0.8179536333691635, -0.3565440896138057, 0.9627570970509667, 1} + + tolassert.EqualTolSlice(t, cv, covar.Values, 1.0e-8) +} + +func runBenchCovar(b *testing.B, n int, thread bool) { + if thread { + tensor.ThreadingThreshold = 1 + } else { + tensor.ThreadingThreshold = 100_000_000 + } + nrows := 10 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, nrows*n*n+1), nrows, n, n)) + ov := tensor.NewFloat64(nrows, n, n) + b.ResetTimer() + for range b.N { + CovarianceMatrixOut(Correlation, av, ov) + } +} + +// to run this benchmark, do: +// go test -bench BenchmarkCovar -count 10 >bench.txt +// go install golang.org/x/perf/cmd/benchstat@latest +// benchstat -row /n -col .name bench.txt + +var ns = []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40} + +func BenchmarkCovarThreaded(b *testing.B) { + for _, n := range ns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchCovar(b, n, true) + }) + } +} + +func BenchmarkCovarSingle(b *testing.B) { + for _, n := range ns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchCovar(b, n, false) + }) + } +} diff --git a/tensor/stats/metric/metric_test.go b/tensor/stats/metric/metric_test.go index 6cd11ac98d..e6c479d412 100644 --- a/tensor/stats/metric/metric_test.go +++ b/tensor/stats/metric/metric_test.go @@ -5,79 +5,248 @@ package metric import ( + "fmt" "math" "testing" - "cogentcore.org/core/math32" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "github.com/stretchr/testify/assert" ) -func TestAll(t *testing.T) { +func TestFuncs(t *testing.T) { a64 := []float64{.5, .2, .1, .7, math.NaN(), .5} - b64 := []float64{.2, .5, .1, .7, 0, .2} + b64 := []float64{.2, .9, .1, .7, 0, .2} - a32 := []float32{.5, .2, .1, .7, math32.NaN(), .5} - b32 := []float32{.2, .5, .1, .7, 0, .2} + results := []float64{math.Sqrt(0.67), 0.67, 1.3, 3, 0.7, 0.49, 1 - 0.7319115529256469, 1 - 0.11189084777289171, 1.8090248566170337, 0.88, 0.008, 0.11189084777289171, 0.7319115529256469} - ss := SumSquares64(a64, b64) - if ss != 0.27 { - t.Errorf("SumSquares64: %g\n", ss) - } - ss32 := SumSquares32(a32, b32) - if ss32 != float32(ss) { - t.Errorf("SumSquares32: %g\n", ss32) - } + tol := 1.0e-8 - ec := Euclidean64(a64, b64) - if math.Abs(ec-math.Sqrt(0.27)) > 1.0e-10 { - t.Errorf("Euclidean64: %g vs. %g\n", ec, math.Sqrt(0.27)) - } - ec32 := Euclidean32(a32, b32) - if ec32 != float32(ec) { - t.Errorf("Euclidean32: %g\n", ec32) + atsr := tensor.NewNumberFromValues(a64...) + btsr := tensor.NewNumberFromValues(b64...) + out := tensor.NewFloat64(1) + + L2NormOut(atsr, btsr, out) + assert.InDelta(t, results[MetricL2Norm], out.Values[0], tol) + + SumSquaresOut(atsr, btsr, out) + assert.InDelta(t, results[MetricSumSquares], out.Values[0], tol) + + L2NormBinTolOut(atsr, btsr, out) + assert.InDelta(t, results[MetricL2NormBinTol], out.Values[0], tol) + + L1NormOut(atsr, btsr, out) + assert.InDelta(t, results[MetricL1Norm], out.Values[0], tol) + + HammingOut(atsr, btsr, out) + assert.Equal(t, results[MetricHamming], out.Values[0]) + + SumSquaresBinTolOut(atsr, btsr, out) + assert.InDelta(t, results[MetricSumSquaresBinTol], out.Values[0], tol) + + CovarianceOut(atsr, btsr, out) + assert.InDelta(t, results[MetricCovariance], out.Values[0], tol) + + CorrelationOut(atsr, btsr, out) + assert.InDelta(t, results[MetricCorrelation], out.Values[0], tol) + + InvCorrelationOut(atsr, btsr, out) + assert.InDelta(t, results[MetricInvCorrelation], out.Values[0], tol) + + CrossEntropyOut(atsr, btsr, out) + assert.InDelta(t, results[MetricCrossEntropy], out.Values[0], tol) + + DotProductOut(atsr, btsr, out) + assert.InDelta(t, results[MetricDotProduct], out.Values[0], tol) + + CosineOut(atsr, btsr, out) + assert.InDelta(t, results[MetricCosine], out.Values[0], tol) + + InvCosineOut(atsr, btsr, out) + assert.InDelta(t, results[MetricInvCosine], out.Values[0], tol) + + for met := MetricL2Norm; met < MetricsN; met++ { + out := met.Call(atsr, btsr) + assert.InDelta(t, results[met], out.Float1D(0), tol) } +} - cv := Covariance64(a64, b64) - if cv != 0.023999999999999994 { - t.Errorf("Covariance64: %g\n", cv) +func TestMatrix(t *testing.T) { + + simres := []float64{0, 3.464101552963257, 8.83176040649414, 9.273618698120117, 8.717798233032227, 9.380831718444824, 4.690415859222412, 5.830951690673828, 8.124038696289062, 8.5440034866333, 5.291502475738525, 6.324555397033691} + + dt := table.New() + err := dt.OpenCSV("../cluster/testdata/faces.dat", tensor.Tab) + assert.NoError(t, err) + in := dt.Column("Input") + out := tensor.NewFloat64() + err = MatrixOut(L2Norm, in, out) + assert.NoError(t, err) + // fmt.Println(out.Tensor) + for i, v := range simres { + assert.InDelta(t, v, out.Float1D(i), 1.0e-8) } - cv32 := Covariance32(a32, b32) - if cv32 != float32(cv) { - t.Errorf("Covariance32: %g\n", cv32) +} + +func runBenchFuncs(b *testing.B, n int, fun Metrics) { + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + bv := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + for range b.N { + fun.Call(av, bv) } +} - cr := Correlation64(a64, b64) - if cr != 0.47311118871909136 { - t.Errorf("Correlation64: %g\n", cr) +// 375 ns/op = fastest that DotProduct could be. +func BenchmarkFuncMulBaseline(b *testing.B) { + n := 1000 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + bv := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + s := float64(0) + for range b.N { + for i := range n { + s += av.Values[i] * bv.Values[i] + } } - cr32 := Correlation32(a32, b32) - if cr32 != 0.47311115 { - t.Errorf("Correlation32: %g\n", cr32) +} + +func runClosure(av, bv *tensor.Float64, fun func(a, b, agg float64) float64) float64 { + // fun := func(a, b, agg float64) float64 { // note: it can inline closure if in same fun + // return agg + a*b + // } + n := 1000 + s := float64(0) + for i := range n { + s = fun(av.Values[i], bv.Values[i], s) // note: Float1D here no extra cost } + return s +} - cs := Cosine64(a64, b64) - if cs != 0.861061697819235 { - t.Errorf("Cosine64: %g\n", cs) +// 1465 ns/op = ~4x penalty for cosure +func BenchmarkFuncMulClosure(b *testing.B) { + n := 1000 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + bv := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + for range b.N { + runClosure(av, bv, func(a, b, agg float64) float64 { + return agg + a*b + }) } - cs32 := Cosine32(a32, b32) - if cs32 != 0.86106175 { - t.Errorf("Cosine32: %g\n", cs32) +} + +func runClosureInterface(av, bv tensor.Tensor, fun func(a, b, agg float64) float64) float64 { + n := 1000 + s := float64(0) + for i := range n { + s = fun(av.Float1D(i), bv.Float1D(i), s) } + return s +} - ab := Abs64(a64, b64) - if ab != 0.8999999999999999 { - t.Errorf("Abs64: %g\n", ab) +// 3665 ns/op = going through the Tensor interface = another ~2x +func BenchmarkFuncMulClosureInterface(b *testing.B) { + n := 1000 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + bv := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + for range b.N { + runClosureInterface(av, bv, func(a, b, agg float64) float64 { + return agg + a*b + }) } - ab32 := Abs32(a32, b32) - if ab32 != 0.90000004 { - t.Errorf("Abs32: %g\n", ab32) +} + +// original pre-optimization was: 8027 ns/op = 21x slower than the MulBaseline! +func BenchmarkDotProductOut(b *testing.B) { + n := 1000 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + bv := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + ov := tensor.NewFloat64(1) + b.ResetTimer() + for range b.N { + DotProductOut(av, bv, ov) } +} + +// to run this benchmark, do: +// go test -bench BenchmarkFuncs -count 10 >bench.txt +// go install golang.org/x/perf/cmd/benchstat@latest +// benchstat -row /met -col .name bench.txt + +var fns = []int{10, 20, 50, 100, 500, 1000, 10000} + +// after 12/2/2024 optimizations +// goos: darwin +// goarch: arm64 +// pkg: cogentcore.org/core/tensor/stats/metric +// │ Funcs │ +// │ sec/op │ +// L2Norm 1.853µ ± 1% +// SumSquares 1.878µ ± 1% +// L1Norm 1.686µ ± 1% +// Hamming 1.798µ ± 1% +// L2NormBinTol 1.906µ ± 0% +// SumSquaresBinTol 1.912µ ± 0% +// InvCosine 2.421µ ± 0% +// InvCorrelation 6.379µ ± 1% +// CrossEntropy 5.876µ ± 0% +// DotProduct 1.792µ ± 0% +// Covariance 5.914µ ± 0% +// Correlation 6.437µ ± 0% +// Cosine 2.451µ ± 0% +// geomean 2.777µ + +// prior to optimization: +// │ Funcs │ +// │ sec/op │ +// L1Norm 8.283µ ± 0% +// DotProduct 8.299µ ± 0% +// L2Norm 8.457µ ± 1% +// SumSquares 8.483µ ± 1% +// L2NormBinTol 8.466µ ± 0% +// SumSquaresBinTol 8.470µ ± 0% +// Hamming 8.556µ ± 0% +// CrossEntropy 12.84µ ± 0% +// Cosine 13.91µ ± 0% +// InvCosine 14.43µ ± 0% +// Covariance 39.47µ ± 0% +// Correlation 47.15µ ± 0% +// InvCorrelation 45.48µ ± 0% +// geomean 13.80µ - hm := Hamming64(a64, b64) - if hm != 3 { - t.Errorf("Hamming64: %g\n", hm) +// BenchmarkFuncMulBaseline: 376.7 ns/op +// BenchmarkFuncMulBaselineClosureArg: 1464 ns/op + +func BenchmarkFuncs(b *testing.B) { + for met := MetricL2Norm; met < MetricsN; met++ { + b.Run(fmt.Sprintf("met=%s", met.String()), func(b *testing.B) { + runBenchFuncs(b, 1000, met) + }) } - hm32 := Hamming32(a32, b32) - if hm32 != 3 { - t.Errorf("Hamming32: %g\n", hm32) +} + +func runBenchNs(b *testing.B, fun Metrics) { + for _, n := range fns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchFuncs(b, n, fun) + }) } } + +func BenchmarkNsL1Norm(b *testing.B) { + runBenchNs(b, MetricL1Norm) +} + +func BenchmarkNsCosine(b *testing.B) { + runBenchNs(b, MetricCosine) +} + +func BenchmarkNsCovariance(b *testing.B) { + runBenchNs(b, MetricCovariance) +} + +func BenchmarkNsCorrelation(b *testing.B) { + runBenchNs(b, MetricCorrelation) +} diff --git a/tensor/stats/metric/metrics.go b/tensor/stats/metric/metrics.go index 876be155c8..b71d1c19de 100644 --- a/tensor/stats/metric/metrics.go +++ b/tensor/stats/metric/metrics.go @@ -6,112 +6,139 @@ package metric -// Func32 is a distance / similarity metric operating on slices of float32 numbers -type Func32 func(a, b []float32) float32 +import ( + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" +) -// Func64 is a distance / similarity metric operating on slices of float64 numbers -type Func64 func(a, b []float64) float64 +func init() { + tensor.AddFunc(MetricL2Norm.FuncName(), L2Norm) + tensor.AddFunc(MetricSumSquares.FuncName(), SumSquares) + tensor.AddFunc(MetricL1Norm.FuncName(), L1Norm) + tensor.AddFunc(MetricHamming.FuncName(), Hamming) + tensor.AddFunc(MetricL2NormBinTol.FuncName(), L2NormBinTol) + tensor.AddFunc(MetricSumSquaresBinTol.FuncName(), SumSquaresBinTol) + tensor.AddFunc(MetricInvCosine.FuncName(), InvCosine) + tensor.AddFunc(MetricInvCorrelation.FuncName(), InvCorrelation) + tensor.AddFunc(MetricDotProduct.FuncName(), DotProduct) + tensor.AddFunc(MetricCrossEntropy.FuncName(), CrossEntropy) + tensor.AddFunc(MetricCovariance.FuncName(), Covariance) + tensor.AddFunc(MetricCorrelation.FuncName(), Correlation) + tensor.AddFunc(MetricCosine.FuncName(), Cosine) +} -// StdMetrics are standard metric functions -type StdMetrics int32 //enums:enum +// Metrics are standard metric functions +type Metrics int32 //enums:enum -trim-prefix Metric const ( - Euclidean StdMetrics = iota - SumSquares - Abs - Hamming + // L2Norm is the square root of the sum of squares differences + // between tensor values, aka the L2 Norm. + MetricL2Norm Metrics = iota + + // SumSquares is the sum of squares differences between tensor values. + MetricSumSquares + + // L1Norm is the sum of the absolute value of differences + // between tensor values, the L1 Norm. + MetricL1Norm + + // Hamming is the sum of 1s for every element that is different, + // i.e., "city block" distance. + MetricHamming + + // L2NormBinTol is the [L2Norm] square root of the sum of squares + // differences between tensor values, with binary tolerance: + // differences < 0.5 are thresholded to 0. + MetricL2NormBinTol + + // SumSquaresBinTol is the [SumSquares] differences between tensor values, + // with binary tolerance: differences < 0.5 are thresholded to 0. + MetricSumSquaresBinTol + + // InvCosine is 1-[Cosine], which is useful to convert it + // to an Increasing metric where more different vectors have larger metric values. + MetricInvCosine - EuclideanBinTol - SumSquaresBinTol + // InvCorrelation is 1-[Correlation], which is useful to convert it + // to an Increasing metric where more different vectors have larger metric values. + MetricInvCorrelation - // InvCosine is 1-Cosine -- useful to convert into an Increasing metric - InvCosine + // CrossEntropy is a standard measure of the difference between two + // probabilty distributions, reflecting the additional entropy (uncertainty) associated + // with measuring probabilities under distribution b when in fact they come from + // distribution a. It is also the entropy of a plus the divergence between a from b, + // using Kullback-Leibler (KL) divergence. It is computed as: + // a * log(a/b) + (1-a) * log(1-a/1-b). + MetricCrossEntropy - // InvCorrelation is 1-Correlation -- useful to convert into an Increasing metric - InvCorrelation + //////// Everything below here is !Increasing -- larger = closer, not farther - CrossEntropy + // DotProduct is the sum of the co-products of the tensor values. + MetricDotProduct - // Everything below here is !Increasing -- larger = closer, not farther - InnerProduct - Covariance - Correlation - Cosine + // Covariance is co-variance between two vectors, + // i.e., the mean of the co-product of each vector element minus + // the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))]. + MetricCovariance + + // Correlation is the standardized [Covariance] in the range (-1..1), + // computed as the mean of the co-product of each vector + // element minus the mean of that vector, normalized by the product of their + // standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). + // Equivalent to the [Cosine] of mean-normalized vectors. + MetricCorrelation + + // Cosine is high-dimensional angle between two vectors, + // in range (-1..1) as the normalized [DotProduct]: + // inner product / sqrt(ssA * ssB). See also [Correlation]. + MetricCosine ) +// FuncName returns the package-qualified function name to use +// in tensor.Call to call this function. +func (m Metrics) FuncName() string { + return "metric." + m.String() +} + +// Func returns function for given metric. +func (m Metrics) Func() MetricFunc { + fn := errors.Log1(tensor.FuncByName(m.FuncName())) + return fn.Fun.(MetricFunc) +} + +// Call calls a standard Metrics enum function on given tensors. +// Output results are in the out tensor. +func (m Metrics) Call(a, b tensor.Tensor) tensor.Values { + return m.Func()(a, b) +} + // Increasing returns true if the distance metric is such that metric -// values increase as a function of distance (e.g., Euclidean) +// values increase as a function of distance (e.g., L2Norm) // and false if metric values decrease as a function of distance // (e.g., Cosine, Correlation) -func Increasing(std StdMetrics) bool { - if std >= InnerProduct { +func (m Metrics) Increasing() bool { + if m >= MetricDotProduct { return false } return true } -// StdFunc32 returns a standard metric function as specified -func StdFunc32(std StdMetrics) Func32 { - switch std { - case Euclidean: - return Euclidean32 - case SumSquares: - return SumSquares32 - case Abs: - return Abs32 - case Hamming: - return Hamming32 - case EuclideanBinTol: - return EuclideanBinTol32 - case SumSquaresBinTol: - return SumSquaresBinTol32 - case InvCorrelation: - return InvCorrelation32 - case InvCosine: - return InvCosine32 - case CrossEntropy: - return CrossEntropy32 - case InnerProduct: - return InnerProduct32 - case Covariance: - return Covariance32 - case Correlation: - return Correlation32 - case Cosine: - return Cosine32 +// AsMetricFunc returns given function as a [MetricFunc] function, +// or an error if it does not fit that signature. +func AsMetricFunc(fun any) (MetricFunc, error) { + mfun, ok := fun.(MetricFunc) + if !ok { + return nil, errors.New("metric.AsMetricFunc: function does not fit the MetricFunc signature") } - return nil + return mfun, nil } -// StdFunc64 returns a standard metric function as specified -func StdFunc64(std StdMetrics) Func64 { - switch std { - case Euclidean: - return Euclidean64 - case SumSquares: - return SumSquares64 - case Abs: - return Abs64 - case Hamming: - return Hamming64 - case EuclideanBinTol: - return EuclideanBinTol64 - case SumSquaresBinTol: - return SumSquaresBinTol64 - case InvCorrelation: - return InvCorrelation64 - case InvCosine: - return InvCosine64 - case CrossEntropy: - return CrossEntropy64 - case InnerProduct: - return InnerProduct64 - case Covariance: - return Covariance64 - case Correlation: - return Correlation64 - case Cosine: - return Cosine64 +// AsMetricOutFunc returns given function as a [MetricFunc] function, +// or an error if it does not fit that signature. +func AsMetricOutFunc(fun any) (MetricOutFunc, error) { + mfun, ok := fun.(MetricOutFunc) + if !ok { + return nil, errors.New("metric.AsMetricOutFunc: function does not fit the MetricOutFunc signature") } - return nil + return mfun, nil } diff --git a/tensor/stats/metric/misc.go b/tensor/stats/metric/misc.go new file mode 100644 index 0000000000..cbe831104d --- /dev/null +++ b/tensor/stats/metric/misc.go @@ -0,0 +1,55 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metric + +import ( + "math" + + "cogentcore.org/core/tensor" +) + +// ClosestRow returns the closest fit between probe pattern and patterns in +// a "vocabulary" tensor with outermost row dimension, using given metric +// function, which must fit the MetricFunc signature. +// The metric *must have the Increasing property*, i.e., larger = further. +// Output is a 1D tensor with 2 elements: the row index and metric value for that row. +// Note: this does _not_ use any existing Indexes for the probe, +// but does for the vocab, and the returned index is the logical index +// into any existing Indexes. +func ClosestRow(fun any, probe, vocab tensor.Tensor) tensor.Values { + return tensor.CallOut2Gen1(ClosestRowOut, fun, probe, vocab) +} + +// ClosestRowOut returns the closest fit between probe pattern and patterns in +// a "vocabulary" tensor with outermost row dimension, using given metric +// function, which must fit the MetricFunc signature. +// The metric *must have the Increasing property*, i.e., larger = further. +// Output is a 1D tensor with 2 elements: the row index and metric value for that row. +// Note: this does _not_ use any existing Indexes for the probe, +// but does for the vocab, and the returned index is the logical index +// into any existing Indexes. +func ClosestRowOut(fun any, probe, vocab tensor.Tensor, out tensor.Values) error { + out.SetShapeSizes(2) + mfun, err := AsMetricFunc(fun) + if err != nil { + return err + } + rows, _ := vocab.Shape().RowCellSize() + mi := -1 + mind := math.MaxFloat64 + pr1d := tensor.As1D(probe) + for ri := range rows { + sub := tensor.Cells1D(vocab, ri) + mout := mfun(pr1d, sub) + d := mout.Float1D(0) + if d < mind { + mi = ri + mind = d + } + } + out.SetFloat1D(float64(mi), 0) + out.SetFloat1D(mind, 1) + return nil +} diff --git a/tensor/stats/metric/prob.go b/tensor/stats/metric/prob.go deleted file mode 100644 index 7749a1b1fd..0000000000 --- a/tensor/stats/metric/prob.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package metric - -import ( - "math" - - "cogentcore.org/core/math32" -) - -/////////////////////////////////////////// -// CrossEntropy - -// CrossEntropy32 computes cross-entropy between the two vectors. -// Skips NaN's and panics if lengths are not equal. -func CrossEntropy32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - bv = math32.Max(bv, 0.000001) - bv = math32.Min(bv, 0.999999) - if av >= 1.0 { - ss += -math32.Log(bv) - } else if av <= 0.0 { - ss += -math32.Log(1.0 - bv) - } else { - ss += av*math32.Log(av/bv) + (1-av)*math32.Log((1-av)/(1-bv)) - } - } - return ss -} - -// CrossEntropy64 computes the cross-entropy between the two vectors. -// Skips NaN's and panics if lengths are not equal. -func CrossEntropy64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - bv = math.Max(bv, 0.000001) - bv = math.Min(bv, 0.999999) - if av >= 1.0 { - ss += -math.Log(bv) - } else if av <= 0.0 { - ss += -math.Log(1.0 - bv) - } else { - ss += av*math.Log(av/bv) + (1-av)*math.Log((1-av)/(1-bv)) - } - } - return ss -} diff --git a/tensor/stats/metric/squares.go b/tensor/stats/metric/squares.go deleted file mode 100644 index 61576ba56a..0000000000 --- a/tensor/stats/metric/squares.go +++ /dev/null @@ -1,606 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package metric - -import ( - "math" - - "cogentcore.org/core/math32" - "cogentcore.org/core/tensor/stats/stats" -) - -/////////////////////////////////////////// -// SumSquares - -// SumSquares32 computes the sum-of-squares distance between two vectors. -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -func SumSquares32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math32.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float32 = 0 - sumSquares float32 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - absxi := math32.Abs(av - bv) - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math32.IsInf(scale, 1) { - return math32.Inf(1) - } - return scale * scale * sumSquares -} - -// SumSquares64 computes the sum-of-squares distance between two vectors. -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -func SumSquares64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float64 = 0 - sumSquares float64 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math.IsNaN(av) || math.IsNaN(bv) { - continue - } - absxi := math.Abs(av - bv) - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * scale * sumSquares -} - -/////////////////////////////////////////// -// SumSquaresBinTol - -// SumSquaresBinTol32 computes the sum-of-squares distance between two vectors. -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -// BinTol version uses binary tolerance for 0-1 valued-vectors where -// abs diff < .5 counts as 0 error (i.e., closer than not). -func SumSquaresBinTol32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math32.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float32 = 0 - sumSquares float32 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - absxi := math32.Abs(av - bv) - if absxi < 0.5 { - continue - } - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math32.IsInf(scale, 1) { - return math32.Inf(1) - } - return scale * scale * sumSquares -} - -// SumSquaresBinTol64 computes the sum-of-squares distance between two vectors. -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -// BinTol version uses binary tolerance for 0-1 valued-vectors where -// abs diff < .5 counts as 0 error (i.e., closer than not). -func SumSquaresBinTol64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float64 = 0 - sumSquares float64 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math.IsNaN(av) || math.IsNaN(bv) { - continue - } - absxi := math.Abs(av - bv) - if absxi < 0.5 { - continue - } - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * scale * sumSquares -} - -/////////////////////////////////////////// -// Euclidean - -// Euclidean32 computes the square-root of sum-of-squares distance -// between two vectors (aka the L2 norm). -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -func Euclidean32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math32.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float32 = 0 - sumSquares float32 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - absxi := math32.Abs(av - bv) - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math32.IsInf(scale, 1) { - return math32.Inf(1) - } - return scale * math32.Sqrt(sumSquares) -} - -// Euclidean64 computes the square-root of sum-of-squares distance -// between two vectors (aka the L2 norm). -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -func Euclidean64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float64 = 0 - sumSquares float64 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math.IsNaN(av) || math.IsNaN(bv) { - continue - } - absxi := math.Abs(av - bv) - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * math.Sqrt(sumSquares) -} - -/////////////////////////////////////////// -// EuclideanBinTol - -// EuclideanBinTol32 computes the square-root of sum-of-squares distance -// between two vectors (aka the L2 norm). -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -// BinTol version uses binary tolerance for 0-1 valued-vectors where -// abs diff < .5 counts as 0 error (i.e., closer than not). -func EuclideanBinTol32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math32.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float32 = 0 - sumSquares float32 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - absxi := math32.Abs(av - bv) - if absxi < 0.5 { - continue - } - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math32.IsInf(scale, 1) { - return math32.Inf(1) - } - return scale * math32.Sqrt(sumSquares) -} - -// EuclideanBinTol64 computes the square-root of sum-of-squares distance -// between two vectors (aka the L2 norm). -// Skips NaN's and panics if lengths are not equal. -// Uses optimized algorithm from BLAS that avoids numerical overflow. -// BinTol version uses binary tolerance for 0-1 valued-vectors where -// abs diff < .5 counts as 0 error (i.e., closer than not). -func EuclideanBinTol64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - n := len(a) - if n < 2 { - if n == 1 { - return math.Abs(a[0] - b[0]) - } - return 0 - } - var ( - scale float64 = 0 - sumSquares float64 = 1 - ) - for i, av := range a { - bv := b[i] - if av == bv || math.IsNaN(av) || math.IsNaN(bv) { - continue - } - absxi := math.Abs(av - bv) - if absxi < 0.5 { - continue - } - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * math.Sqrt(sumSquares) -} - -/////////////////////////////////////////// -// Covariance - -// Covariance32 computes the mean of the co-product of each vector element minus -// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))] -// Skips NaN's and panics if lengths are not equal. -func Covariance32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - am := stats.Mean32(a) - bm := stats.Mean32(b) - n := 0 - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - ss += (av - am) * (bv - bm) - n++ - } - if n > 0 { - ss /= float32(n) - } - return ss -} - -// Covariance64 computes the mean of the co-product of each vector element minus -// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))] -// Skips NaN's and panics if lengths are not equal. -func Covariance64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - am := stats.Mean64(a) - bm := stats.Mean64(b) - n := 0 - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - ss += (av - am) * (bv - bm) - n++ - } - if n > 0 { - ss /= float64(n) - } - return ss -} - -/////////////////////////////////////////// -// Correlation - -// Correlation32 computes the vector similarity in range (-1..1) as the -// mean of the co-product of each vector element minus the mean of that vector, -// normalized by the product of their standard deviations: -// cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). -// (i.e., the standardized covariance) -- equivalent to the cosine of mean-normalized -// vectors. -// Skips NaN's and panics if lengths are not equal. -func Correlation32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - am := stats.Mean32(a) - bm := stats.Mean32(b) - var avar, bvar float32 - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - ad := av - am - bd := bv - bm - ss += ad * bd // between - avar += ad * ad // within - bvar += bd * bd - } - vp := math32.Sqrt(avar * bvar) - if vp > 0 { - ss /= vp - } - return ss -} - -// Correlation64 computes the vector similarity in range (-1..1) as the -// mean of the co-product of each vector element minus the mean of that vector, -// normalized by the product of their standard deviations: -// cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). -// (i.e., the standardized covariance) -- equivalent to the cosine of mean-normalized -// vectors. -// Skips NaN's and panics if lengths are not equal. -func Correlation64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - am := stats.Mean64(a) - bm := stats.Mean64(b) - var avar, bvar float64 - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - ad := av - am - bd := bv - bm - ss += ad * bd // between - avar += ad * ad // within - bvar += bd * bd - } - vp := math.Sqrt(avar * bvar) - if vp > 0 { - ss /= vp - } - return ss -} - -/////////////////////////////////////////// -// InnerProduct - -// InnerProduct32 computes the sum of the element-wise product of the two vectors. -// Skips NaN's and panics if lengths are not equal. -func InnerProduct32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - ss += av * bv - } - return ss -} - -// InnerProduct64 computes the mean of the co-product of each vector element minus -// the mean of that vector, normalized by the product of their standard deviations: -// cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). -// (i.e., the standardized covariance) -- equivalent to the cosine of mean-normalized -// vectors. -// Skips NaN's and panics if lengths are not equal. -func InnerProduct64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - ss += av * bv - } - return ss -} - -/////////////////////////////////////////// -// Cosine - -// Cosine32 computes the cosine of the angle between two vectors (-1..1), -// as the normalized inner product: inner product / sqrt(ssA * ssB). -// If vectors are mean-normalized = Correlation. -// Skips NaN's and panics if lengths are not equal. -func Cosine32(a, b []float32) float32 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float32(0) - var ass, bss float32 - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - ss += av * bv // between - ass += av * av // within - bss += bv * bv - } - vp := math32.Sqrt(ass * bss) - if vp > 0 { - ss /= vp - } - return ss -} - -// Cosine32 computes the cosine of the angle between two vectors (-1..1), -// as the normalized inner product: inner product / sqrt(ssA * ssB). -// If vectors are mean-normalized = Correlation. -// Skips NaN's and panics if lengths are not equal. -func Cosine64(a, b []float64) float64 { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - ss := float64(0) - var ass, bss float64 - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - ss += av * bv // between - ass += av * av // within - bss += bv * bv - } - vp := math.Sqrt(ass * bss) - if vp > 0 { - ss /= vp - } - return ss -} - -/////////////////////////////////////////// -// InvCosine - -// InvCosine32 computes 1 - cosine of the angle between two vectors (-1..1), -// as the normalized inner product: inner product / sqrt(ssA * ssB). -// If vectors are mean-normalized = Correlation. -// Skips NaN's and panics if lengths are not equal. -func InvCosine32(a, b []float32) float32 { - return 1 - Cosine32(a, b) -} - -// InvCosine32 computes 1 - cosine of the angle between two vectors (-1..1), -// as the normalized inner product: inner product / sqrt(ssA * ssB). -// If vectors are mean-normalized = Correlation. -// Skips NaN's and panics if lengths are not equal. -func InvCosine64(a, b []float64) float64 { - return 1 - Cosine64(a, b) -} - -/////////////////////////////////////////// -// InvCorrelation - -// InvCorrelation32 computes 1 - the vector similarity in range (-1..1) as the -// mean of the co-product of each vector element minus the mean of that vector, -// normalized by the product of their standard deviations: -// cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). -// (i.e., the standardized covariance) -- equivalent to the cosine of mean-normalized -// vectors. -// Skips NaN's and panics if lengths are not equal. -func InvCorrelation32(a, b []float32) float32 { - return 1 - Correlation32(a, b) -} - -// InvCorrelation64 computes 1 - the vector similarity in range (-1..1) as the -// mean of the co-product of each vector element minus the mean of that vector, -// normalized by the product of their standard deviations: -// cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). -// (i.e., the standardized covariance) -- equivalent to the cosine of mean-normalized -// vectors. -// Skips NaN's and panics if lengths are not equal. -func InvCorrelation64(a, b []float64) float64 { - return 1 - Correlation64(a, b) -} diff --git a/tensor/stats/metric/tensor.go b/tensor/stats/metric/tensor.go deleted file mode 100644 index 6d9ae20455..0000000000 --- a/tensor/stats/metric/tensor.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package metric - -import ( - "math" - - "cogentcore.org/core/tensor" -) - -// ClosestRow32 returns the closest fit between probe pattern and patterns in -// an tensor with float32 data where the outer-most dimension is assumed to be a row -// (e.g., as a column in an table), using the given metric function, -// *which must have the Increasing property* -- i.e., larger = further. -// returns the row and metric value for that row. -// Col cell sizes must match size of probe (panics if not). -func ClosestRow32(probe tensor.Tensor, col tensor.Tensor, mfun Func32) (int, float32) { - pr := probe.(*tensor.Float32) - cl := col.(*tensor.Float32) - rows := col.Shape().DimSize(0) - csz := col.Len() / rows - if csz != probe.Len() { - panic("metric.ClosestRow32: probe size != cell size of tensor column!\n") - } - ci := -1 - minv := float32(math.MaxFloat32) - for ri := 0; ri < rows; ri++ { - st := ri * csz - rvals := cl.Values[st : st+csz] - v := mfun(pr.Values, rvals) - if v < minv { - ci = ri - minv = v - } - } - return ci, minv -} - -// ClosestRow64 returns the closest fit between probe pattern and patterns in -// a tensor with float64 data where the outer-most dimension is assumed to be a row -// (e.g., as a column in an table), using the given metric function, -// *which must have the Increasing property* -- i.e., larger = further. -// returns the row and metric value for that row. -// Col cell sizes must match size of probe (panics if not). -func ClosestRow64(probe tensor.Tensor, col tensor.Tensor, mfun Func64) (int, float64) { - pr := probe.(*tensor.Float64) - cl := col.(*tensor.Float64) - rows := col.DimSize(0) - csz := col.Len() / rows - if csz != probe.Len() { - panic("metric.ClosestRow64: probe size != cell size of tensor column!\n") - } - ci := -1 - minv := math.MaxFloat64 - for ri := 0; ri < rows; ri++ { - st := ri * csz - rvals := cl.Values[st : st+csz] - v := mfun(pr.Values, rvals) - if v < minv { - ci = ri - minv = v - } - } - return ci, minv -} - -// ClosestRow32Py returns the closest fit between probe pattern and patterns in -// an tensor.Float32 where the outer-most dimension is assumed to be a row -// (e.g., as a column in an table), using the given metric function, -// *which must have the Increasing property* -- i.e., larger = further. -// returns the row and metric value for that row. -// Col cell sizes must match size of probe (panics if not). -// Py version is for Python, returns a slice with row, cor, takes std metric -func ClosestRow32Py(probe tensor.Tensor, col tensor.Tensor, std StdMetrics) []float32 { - row, cor := ClosestRow32(probe, col, StdFunc32(std)) - return []float32{float32(row), cor} -} - -// ClosestRow64Py returns the closest fit between probe pattern and patterns in -// an tensor.Tensor where the outer-most dimension is assumed to be a row -// (e.g., as a column in an table), using the given metric function, -// *which must have the Increasing property* -- i.e., larger = further. -// returns the row and metric value for that row. -// Col cell sizes must match size of probe (panics if not). -// Optimized for tensor.Float64 but works for any tensor. -// Py version is for Python, returns a slice with row, cor, takes std metric -func ClosestRow64Py(probe tensor.Tensor, col tensor.Tensor, std StdMetrics) []float64 { - row, cor := ClosestRow64(probe, col, StdFunc64(std)) - return []float64{float64(row), cor} -} diff --git a/tensor/stats/metric/testdata/iris.data b/tensor/stats/metric/testdata/iris.data new file mode 100644 index 0000000000..a3490e0e07 --- /dev/null +++ b/tensor/stats/metric/testdata/iris.data @@ -0,0 +1,150 @@ +5.1,3.5,1.4,0.2,Iris-setosa +4.9,3.0,1.4,0.2,Iris-setosa +4.7,3.2,1.3,0.2,Iris-setosa +4.6,3.1,1.5,0.2,Iris-setosa +5.0,3.6,1.4,0.2,Iris-setosa +5.4,3.9,1.7,0.4,Iris-setosa +4.6,3.4,1.4,0.3,Iris-setosa +5.0,3.4,1.5,0.2,Iris-setosa +4.4,2.9,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.4,3.7,1.5,0.2,Iris-setosa +4.8,3.4,1.6,0.2,Iris-setosa +4.8,3.0,1.4,0.1,Iris-setosa +4.3,3.0,1.1,0.1,Iris-setosa +5.8,4.0,1.2,0.2,Iris-setosa +5.7,4.4,1.5,0.4,Iris-setosa +5.4,3.9,1.3,0.4,Iris-setosa +5.1,3.5,1.4,0.3,Iris-setosa +5.7,3.8,1.7,0.3,Iris-setosa +5.1,3.8,1.5,0.3,Iris-setosa +5.4,3.4,1.7,0.2,Iris-setosa +5.1,3.7,1.5,0.4,Iris-setosa +4.6,3.6,1.0,0.2,Iris-setosa +5.1,3.3,1.7,0.5,Iris-setosa +4.8,3.4,1.9,0.2,Iris-setosa +5.0,3.0,1.6,0.2,Iris-setosa +5.0,3.4,1.6,0.4,Iris-setosa +5.2,3.5,1.5,0.2,Iris-setosa +5.2,3.4,1.4,0.2,Iris-setosa +4.7,3.2,1.6,0.2,Iris-setosa +4.8,3.1,1.6,0.2,Iris-setosa +5.4,3.4,1.5,0.4,Iris-setosa +5.2,4.1,1.5,0.1,Iris-setosa +5.5,4.2,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.0,3.2,1.2,0.2,Iris-setosa +5.5,3.5,1.3,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +4.4,3.0,1.3,0.2,Iris-setosa +5.1,3.4,1.5,0.2,Iris-setosa +5.0,3.5,1.3,0.3,Iris-setosa +4.5,2.3,1.3,0.3,Iris-setosa +4.4,3.2,1.3,0.2,Iris-setosa +5.0,3.5,1.6,0.6,Iris-setosa +5.1,3.8,1.9,0.4,Iris-setosa +4.8,3.0,1.4,0.3,Iris-setosa +5.1,3.8,1.6,0.2,Iris-setosa +4.6,3.2,1.4,0.2,Iris-setosa +5.3,3.7,1.5,0.2,Iris-setosa +5.0,3.3,1.4,0.2,Iris-setosa +7.0,3.2,4.7,1.4,Iris-versicolor +6.4,3.2,4.5,1.5,Iris-versicolor +6.9,3.1,4.9,1.5,Iris-versicolor +5.5,2.3,4.0,1.3,Iris-versicolor +6.5,2.8,4.6,1.5,Iris-versicolor +5.7,2.8,4.5,1.3,Iris-versicolor +6.3,3.3,4.7,1.6,Iris-versicolor +4.9,2.4,3.3,1.0,Iris-versicolor +6.6,2.9,4.6,1.3,Iris-versicolor +5.2,2.7,3.9,1.4,Iris-versicolor +5.0,2.0,3.5,1.0,Iris-versicolor +5.9,3.0,4.2,1.5,Iris-versicolor +6.0,2.2,4.0,1.0,Iris-versicolor +6.1,2.9,4.7,1.4,Iris-versicolor +5.6,2.9,3.6,1.3,Iris-versicolor +6.7,3.1,4.4,1.4,Iris-versicolor +5.6,3.0,4.5,1.5,Iris-versicolor +5.8,2.7,4.1,1.0,Iris-versicolor +6.2,2.2,4.5,1.5,Iris-versicolor +5.6,2.5,3.9,1.1,Iris-versicolor +5.9,3.2,4.8,1.8,Iris-versicolor +6.1,2.8,4.0,1.3,Iris-versicolor +6.3,2.5,4.9,1.5,Iris-versicolor +6.1,2.8,4.7,1.2,Iris-versicolor +6.4,2.9,4.3,1.3,Iris-versicolor +6.6,3.0,4.4,1.4,Iris-versicolor +6.8,2.8,4.8,1.4,Iris-versicolor +6.7,3.0,5.0,1.7,Iris-versicolor +6.0,2.9,4.5,1.5,Iris-versicolor +5.7,2.6,3.5,1.0,Iris-versicolor +5.5,2.4,3.8,1.1,Iris-versicolor +5.5,2.4,3.7,1.0,Iris-versicolor +5.8,2.7,3.9,1.2,Iris-versicolor +6.0,2.7,5.1,1.6,Iris-versicolor +5.4,3.0,4.5,1.5,Iris-versicolor +6.0,3.4,4.5,1.6,Iris-versicolor +6.7,3.1,4.7,1.5,Iris-versicolor +6.3,2.3,4.4,1.3,Iris-versicolor +5.6,3.0,4.1,1.3,Iris-versicolor +5.5,2.5,4.0,1.3,Iris-versicolor +5.5,2.6,4.4,1.2,Iris-versicolor +6.1,3.0,4.6,1.4,Iris-versicolor +5.8,2.6,4.0,1.2,Iris-versicolor +5.0,2.3,3.3,1.0,Iris-versicolor +5.6,2.7,4.2,1.3,Iris-versicolor +5.7,3.0,4.2,1.2,Iris-versicolor +5.7,2.9,4.2,1.3,Iris-versicolor +6.2,2.9,4.3,1.3,Iris-versicolor +5.1,2.5,3.0,1.1,Iris-versicolor +5.7,2.8,4.1,1.3,Iris-versicolor +6.3,3.3,6.0,2.5,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +7.1,3.0,5.9,2.1,Iris-virginica +6.3,2.9,5.6,1.8,Iris-virginica +6.5,3.0,5.8,2.2,Iris-virginica +7.6,3.0,6.6,2.1,Iris-virginica +4.9,2.5,4.5,1.7,Iris-virginica +7.3,2.9,6.3,1.8,Iris-virginica +6.7,2.5,5.8,1.8,Iris-virginica +7.2,3.6,6.1,2.5,Iris-virginica +6.5,3.2,5.1,2.0,Iris-virginica +6.4,2.7,5.3,1.9,Iris-virginica +6.8,3.0,5.5,2.1,Iris-virginica +5.7,2.5,5.0,2.0,Iris-virginica +5.8,2.8,5.1,2.4,Iris-virginica +6.4,3.2,5.3,2.3,Iris-virginica +6.5,3.0,5.5,1.8,Iris-virginica +7.7,3.8,6.7,2.2,Iris-virginica +7.7,2.6,6.9,2.3,Iris-virginica +6.0,2.2,5.0,1.5,Iris-virginica +6.9,3.2,5.7,2.3,Iris-virginica +5.6,2.8,4.9,2.0,Iris-virginica +7.7,2.8,6.7,2.0,Iris-virginica +6.3,2.7,4.9,1.8,Iris-virginica +6.7,3.3,5.7,2.1,Iris-virginica +7.2,3.2,6.0,1.8,Iris-virginica +6.2,2.8,4.8,1.8,Iris-virginica +6.1,3.0,4.9,1.8,Iris-virginica +6.4,2.8,5.6,2.1,Iris-virginica +7.2,3.0,5.8,1.6,Iris-virginica +7.4,2.8,6.1,1.9,Iris-virginica +7.9,3.8,6.4,2.0,Iris-virginica +6.4,2.8,5.6,2.2,Iris-virginica +6.3,2.8,5.1,1.5,Iris-virginica +6.1,2.6,5.6,1.4,Iris-virginica +7.7,3.0,6.1,2.3,Iris-virginica +6.3,3.4,5.6,2.4,Iris-virginica +6.4,3.1,5.5,1.8,Iris-virginica +6.0,3.0,4.8,1.8,Iris-virginica +6.9,3.1,5.4,2.1,Iris-virginica +6.7,3.1,5.6,2.4,Iris-virginica +6.9,3.1,5.1,2.3,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +6.8,3.2,5.9,2.3,Iris-virginica +6.7,3.3,5.7,2.5,Iris-virginica +6.7,3.0,5.2,2.3,Iris-virginica +6.3,2.5,5.0,1.9,Iris-virginica +6.5,3.0,5.2,2.0,Iris-virginica +6.2,3.4,5.4,2.3,Iris-virginica +5.9,3.0,5.1,1.8,Iris-virginica diff --git a/tensor/stats/metric/tol.go b/tensor/stats/metric/tol.go deleted file mode 100644 index a99058561f..0000000000 --- a/tensor/stats/metric/tol.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package metric - -import ( - "math" - - "cogentcore.org/core/math32" -) - -/////////////////////////////////////////// -// Tolerance - -// Tolerance32 sets a = b for any element where |a-b| <= tol. -// This can be called prior to any metric function. -func Tolerance32(a, b []float32, tol float32) { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - for i, av := range a { - bv := b[i] - if math32.IsNaN(av) || math32.IsNaN(bv) { - continue - } - if math32.Abs(av-bv) <= tol { - a[i] = bv - } - } -} - -// Tolerance64 sets a = b for any element where |a-b| <= tol. -// This can be called prior to any metric function. -func Tolerance64(a, b []float64, tol float64) { - if len(a) != len(b) { - panic("metric: slice lengths do not match") - } - for i, av := range a { - bv := b[i] - if math.IsNaN(av) || math.IsNaN(bv) { - continue - } - if math.Abs(av-bv) <= tol { - a[i] = bv - } - } -} diff --git a/tensor/stats/metric/vec.go b/tensor/stats/metric/vec.go new file mode 100644 index 0000000000..57a4fb2e51 --- /dev/null +++ b/tensor/stats/metric/vec.go @@ -0,0 +1,824 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package metric + +import ( + "cogentcore.org/core/tensor" +) + +// VectorizeOut64 is the general compute function for metric. +// This version makes a Float64 output tensor for aggregating +// and computing values, and then copies the results back to the +// original output. This allows metric functions to operate directly +// on integer valued inputs and produce sensible results. +// It returns the Float64 output tensor for further processing as needed. +// a and b are already enforced to be the same shape. +func VectorizeOut64(a, b tensor.Tensor, out tensor.Values, ini float64, fun func(a, b, agg float64) float64) *tensor.Float64 { + rows, cells := a.Shape().RowCellSize() + o64 := tensor.NewFloat64(cells) + if rows <= 0 { + return o64 + } + if cells == 1 { + out.SetShapeSizes(1) + agg := ini + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), agg) + } + default: + for i := range rows { + agg = fun(x.Float1D(i), b.Float1D(i), agg) + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), agg) + } + default: + for i := range rows { + agg = fun(x.Float1D(i), b.Float1D(i), agg) + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(a.Float1D(i), y.Float1D(i), agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(a.Float1D(i), y.Float1D(i), agg) + } + default: + for i := range rows { + agg = fun(a.Float1D(i), b.Float1D(i), agg) + } + } + } + o64.SetFloat1D(agg, 0) + out.SetFloat1D(agg, 0) + return o64 + } + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + for i := range cells { + o64.SetFloat1D(ini, i) + } + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j) + } + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j) + } + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(a.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j) + } + } + } + } + for j := range cells { + out.SetFloat1D(o64.Float1D(j), j) + } + return o64 +} + +// VectorizePreOut64 is a version of [VectorizeOut64] that takes additional +// tensor.Float64 inputs of pre-computed values, e.g., the means of each output cell. +func VectorizePreOut64(a, b tensor.Tensor, out tensor.Values, ini float64, preA, preB *tensor.Float64, fun func(a, b, preA, preB, agg float64) float64) *tensor.Float64 { + rows, cells := a.Shape().RowCellSize() + o64 := tensor.NewFloat64(cells) + if rows <= 0 { + return o64 + } + if cells == 1 { + out.SetShapeSizes(1) + agg := ini + prevA := preA.Float1D(0) + prevB := preB.Float1D(0) + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg) + } + default: + for i := range rows { + agg = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, agg) + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg) + } + default: + for i := range rows { + agg = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, agg) + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, agg) + } + default: + for i := range rows { + agg = fun(a.Float1D(i), b.Float1D(i), prevA, prevB, agg) + } + } + } + o64.SetFloat1D(agg, 0) + out.SetFloat1D(agg, 0) + return o64 + } + osz := tensor.CellsSize(a.ShapeSizes()) + out.SetShapeSizes(osz...) + for j := range cells { + o64.SetFloat1D(ini, j) + } + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + o64.SetFloat1D(fun(a.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j) + } + } + } + } + for i := range cells { + out.SetFloat1D(o64.Float1D(i), i) + } + return o64 +} + +// Vectorize2Out64 is a version of [VectorizeOut64] that separately aggregates +// two output values, x and y as tensor.Float64. +func Vectorize2Out64(a, b tensor.Tensor, iniX, iniY float64, fun func(a, b, ox, oy float64) (float64, float64)) (ox64, oy64 *tensor.Float64) { + rows, cells := a.Shape().RowCellSize() + ox64 = tensor.NewFloat64(cells) + oy64 = tensor.NewFloat64(cells) + if rows <= 0 { + return + } + if cells == 1 { + ox := iniX + oy := iniY + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy) + } + case *tensor.Float32: + for i := range rows { + ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy) + } + default: + for i := range rows { + ox, oy = fun(x.Float1D(i), b.Float1D(i), ox, oy) + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy) + } + case *tensor.Float32: + for i := range rows { + ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy) + } + default: + for i := range rows { + ox, oy = fun(x.Float1D(i), b.Float1D(i), ox, oy) + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy = fun(a.Float1D(i), y.Float1D(i), ox, oy) + } + case *tensor.Float32: + for i := range rows { + ox, oy = fun(a.Float1D(i), y.Float1D(i), ox, oy) + } + default: + for i := range rows { + ox, oy = fun(a.Float1D(i), b.Float1D(i), ox, oy) + } + } + } + ox64.SetFloat1D(ox, 0) + oy64.SetFloat1D(oy, 0) + return + } + for j := range cells { + ox64.SetFloat1D(iniX, j) + oy64.SetFloat1D(iniY, j) + } + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy := fun(a.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + } + } + return +} + +// Vectorize3Out64 is a version of [VectorizeOut64] that has 3 outputs instead of 1. +func Vectorize3Out64(a, b tensor.Tensor, iniX, iniY, iniZ float64, fun func(a, b, ox, oy, oz float64) (float64, float64, float64)) (ox64, oy64, oz64 *tensor.Float64) { + rows, cells := a.Shape().RowCellSize() + ox64 = tensor.NewFloat64(cells) + oy64 = tensor.NewFloat64(cells) + oz64 = tensor.NewFloat64(cells) + if rows <= 0 { + return + } + if cells == 1 { + ox := iniX + oy := iniY + oz := iniZ + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz) + } + case *tensor.Float32: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz) + } + default: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), ox, oy, oz) + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz) + } + case *tensor.Float32: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz) + } + default: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), ox, oy, oz) + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), ox, oy, oz) + } + case *tensor.Float32: + for i := range rows { + ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), ox, oy, oz) + } + default: + for i := range rows { + ox, oy, oz = fun(a.Float1D(i), b.Float1D(i), ox, oy, oz) + } + } + } + ox64.SetFloat1D(ox, 0) + oy64.SetFloat1D(oy, 0) + oz64.SetFloat1D(oz, 0) + return + } + for j := range cells { + ox64.SetFloat1D(iniX, j) + oy64.SetFloat1D(iniY, j) + oz64.SetFloat1D(iniZ, j) + } + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(a.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + } + } + return +} + +// VectorizePre3Out64 is a version of [VectorizePreOut64] that takes additional +// tensor.Float64 inputs of pre-computed values, e.g., the means of each output cell, +// and has 3 outputs instead of 1. +func VectorizePre3Out64(a, b tensor.Tensor, iniX, iniY, iniZ float64, preA, preB *tensor.Float64, fun func(a, b, preA, preB, ox, oy, oz float64) (float64, float64, float64)) (ox64, oy64, oz64 *tensor.Float64) { + rows, cells := a.Shape().RowCellSize() + ox64 = tensor.NewFloat64(cells) + oy64 = tensor.NewFloat64(cells) + oz64 = tensor.NewFloat64(cells) + if rows <= 0 { + return + } + if cells == 1 { + ox := iniX + oy := iniY + oz := iniZ + prevA := preA.Float1D(0) + prevB := preB.Float1D(0) + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz) + } + case *tensor.Float32: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz) + } + default: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz) + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz) + } + case *tensor.Float32: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz) + } + default: + for i := range rows { + ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz) + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz) + } + case *tensor.Float32: + for i := range rows { + ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz) + } + default: + for i := range rows { + ox, oy, oz = fun(a.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz) + } + } + } + ox64.SetFloat1D(ox, 0) + oy64.SetFloat1D(oy, 0) + oz64.SetFloat1D(oz, 0) + return + } + for j := range cells { + ox64.SetFloat1D(iniX, j) + oy64.SetFloat1D(iniY, j) + oz64.SetFloat1D(iniZ, j) + } + switch x := a.(type) { + case *tensor.Float64: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + } + case *tensor.Float32: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + } + default: + switch y := b.(type) { + case *tensor.Float64: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + case *tensor.Float32: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + default: + for i := range rows { + si := i * cells + for j := range cells { + ox, oy, oz := fun(a.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + oz64.SetFloat1D(oz, j) + } + } + } + } + return +} diff --git a/tensor/stats/norm/README.md b/tensor/stats/norm/README.md deleted file mode 100644 index ca41bdf01f..0000000000 --- a/tensor/stats/norm/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# norm - -`norm` provides normalization of vector and tensor values. The basic functions operate on either `[]float32` or `[]float64` data, with Tensor versions using those, only for Float32 and Float64 tensors. - -* DivNorm does divisive normalization of elements -* SubNorm does subtractive normalization of elements -* ZScore subtracts the mean and divides by the standard deviation -* Abs performs absolute-value on all elements (e.g., use prior to [stats](../stats) to produce Mean of Abs vals etc). - - diff --git a/tensor/stats/norm/abs.go b/tensor/stats/norm/abs.go deleted file mode 100644 index 699152382d..0000000000 --- a/tensor/stats/norm/abs.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package norm - -import ( - "fmt" - "log/slog" - "math" - - "cogentcore.org/core/math32" - "cogentcore.org/core/tensor" -) - -// Abs32 applies the Abs function to each element in given slice -func Abs32(a []float32) { - for i, av := range a { - if math32.IsNaN(av) { - continue - } - a[i] = math32.Abs(av) - } -} - -// Abs64 applies the Abs function to each element in given slice -func Abs64(a []float64) { - for i, av := range a { - if math.IsNaN(av) { - continue - } - a[i] = math.Abs(av) - } -} - -func FloatOnlyError() error { - err := fmt.Errorf("Only float32 or float64 data types supported") - slog.Error(err.Error()) - return err -} - -// AbsTensor applies the Abs function to each element in given tensor, -// for float32 and float64 data types. -func AbsTensor(a tensor.Tensor) { - switch tsr := a.(type) { - case *tensor.Float32: - Abs32(tsr.Values) - case *tensor.Float64: - Abs64(tsr.Values) - default: - FloatOnlyError() - } -} diff --git a/tensor/stats/norm/doc.go b/tensor/stats/norm/doc.go deleted file mode 100644 index c3b769433e..0000000000 --- a/tensor/stats/norm/doc.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -Package norm provides normalization and norm metric computations -e.g., L2 = sqrt of sum of squares of a vector. - -DivNorm does divisive normalization of elements -SubNorm does subtractive normalization of elements -ZScore subtracts the mean and divides by the standard deviation -*/ -package norm diff --git a/tensor/stats/norm/norm.go b/tensor/stats/norm/norm.go deleted file mode 100644 index 1090e2fa53..0000000000 --- a/tensor/stats/norm/norm.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package norm - -//go:generate core generate - -import ( - "math" - - "cogentcore.org/core/math32" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/stats" -) - -// FloatFunc applies given functions to float tensor data, which is either Float32 or Float64 -func FloatFunc(tsr tensor.Tensor, nfunc32 Func32, nfunc64 Func64, stIdx, nIdx int, ffunc32 func(a []float32, fun Func32), ffunc64 func(a []float64, fun Func64)) { - switch tt := tsr.(type) { - case *tensor.Float32: - vals := tt.Values - if nIdx > 0 { - vals = vals[stIdx : stIdx+nIdx] - } - ffunc32(vals, nfunc32) - case *tensor.Float64: - vals := tt.Values - if nIdx > 0 { - vals = vals[stIdx : stIdx+nIdx] - } - ffunc64(vals, nfunc64) - default: - FloatOnlyError() - } -} - -/////////////////////////////////////////// -// DivNorm - -// DivNorm32 does divisive normalization by given norm function -// i.e., it divides each element by the norm value computed from nfunc. -func DivNorm32(a []float32, nfunc Func32) { - nv := nfunc(a) - if nv != 0 { - MultVector32(a, 1/nv) - } -} - -// DivNorm64 does divisive normalization by given norm function -// i.e., it divides each element by the norm value computed from nfunc. -func DivNorm64(a []float64, nfunc Func64) { - nv := nfunc(a) - if nv != 0 { - MultVec64(a, 1/nv) - } -} - -/////////////////////////////////////////// -// SubNorm - -// SubNorm32 does subtractive normalization by given norm function -// i.e., it subtracts norm computed by given function from each element. -func SubNorm32(a []float32, nfunc Func32) { - nv := nfunc(a) - AddVector32(a, -nv) -} - -// SubNorm64 does subtractive normalization by given norm function -// i.e., it subtracts norm computed by given function from each element. -func SubNorm64(a []float64, nfunc Func64) { - nv := nfunc(a) - AddVec64(a, -nv) -} - -/////////////////////////////////////////// -// ZScore - -// ZScore32 subtracts the mean and divides by the standard deviation -func ZScore32(a []float32) { - SubNorm32(a, stats.Mean32) - DivNorm32(a, stats.Std32) -} - -// ZScore64 subtracts the mean and divides by the standard deviation -func ZScore64(a []float64) { - SubNorm64(a, stats.Mean64) - DivNorm64(a, stats.Std64) -} - -/////////////////////////////////////////// -// Unit - -// Unit32 subtracts the min and divides by the max, so that values are in 0-1 unit range -func Unit32(a []float32) { - SubNorm32(a, stats.Min32) - DivNorm32(a, stats.Max32) -} - -// Unit64 subtracts the min and divides by the max, so that values are in 0-1 unit range -func Unit64(a []float64) { - SubNorm64(a, stats.Min64) - DivNorm64(a, stats.Max64) -} - -/////////////////////////////////////////// -// MultVec - -// MultVector32 multiplies vector elements by scalar -func MultVector32(a []float32, val float32) { - for i, av := range a { - if math32.IsNaN(av) { - continue - } - a[i] *= val - } -} - -// MultVec64 multiplies vector elements by scalar -func MultVec64(a []float64, val float64) { - for i, av := range a { - if math.IsNaN(av) { - continue - } - a[i] *= val - } -} - -/////////////////////////////////////////// -// AddVec - -// AddVector32 adds scalar to vector -func AddVector32(a []float32, val float32) { - for i, av := range a { - if math32.IsNaN(av) { - continue - } - a[i] += val - } -} - -// AddVec64 adds scalar to vector -func AddVec64(a []float64, val float64) { - for i, av := range a { - if math.IsNaN(av) { - continue - } - a[i] += val - } -} - -/////////////////////////////////////////// -// Thresh - -// Thresh32 thresholds the values of the vector -- anything above the high threshold is set -// to the high value, and everything below the low threshold is set to the low value. -func Thresh32(a []float32, hi bool, hiThr float32, lo bool, loThr float32) { - for i, av := range a { - if math32.IsNaN(av) { - continue - } - if hi && av > hiThr { - a[i] = hiThr - } - if lo && av < loThr { - a[i] = loThr - } - } -} - -// Thresh64 thresholds the values of the vector -- anything above the high threshold is set -// to the high value, and everything below the low threshold is set to the low value. -func Thresh64(a []float64, hi bool, hiThr float64, lo bool, loThr float64) { - for i, av := range a { - if math.IsNaN(av) { - continue - } - if hi && av > hiThr { - a[i] = hiThr - } - if lo && av < loThr { - a[i] = loThr - } - } -} - -/////////////////////////////////////////// -// Binarize - -// Binarize32 turns vector into binary-valued, by setting anything >= the threshold -// to the high value, and everything below to the low value. -func Binarize32(a []float32, thr, hiVal, loVal float32) { - for i, av := range a { - if math32.IsNaN(av) { - continue - } - if av >= thr { - a[i] = hiVal - } else { - a[i] = loVal - } - } -} - -// Binarize64 turns vector into binary-valued, by setting anything >= the threshold -// to the high value, and everything below to the low value. -func Binarize64(a []float64, thr, hiVal, loVal float64) { - for i, av := range a { - if math.IsNaN(av) { - continue - } - if av >= thr { - a[i] = hiVal - } else { - a[i] = loVal - } - } -} - -// Func32 is a norm function operating on slice of float32 numbers -type Func32 func(a []float32) float32 - -// Func64 is a norm function operating on slices of float64 numbers -type Func64 func(a []float64) float64 diff --git a/tensor/stats/norm/norm_test.go b/tensor/stats/norm/norm_test.go deleted file mode 100644 index b9b420e8de..0000000000 --- a/tensor/stats/norm/norm_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package norm - -import ( - "testing" - - "cogentcore.org/core/base/tolassert" - "cogentcore.org/core/tensor" - "github.com/stretchr/testify/assert" -) - -func TestNorm32(t *testing.T) { - vals := []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} - - zn := []float32{-1.5075567, -1.2060454, -0.90453404, -0.60302263, -0.30151132, 0, 0.3015114, 0.60302263, 0.90453404, 1.2060453, 1.5075567} - nvals := make([]float32, len(vals)) - copy(nvals, vals) - ZScore32(nvals) - assert.Equal(t, zn, nvals) - - copy(nvals, vals) - Unit32(nvals) - assert.Equal(t, vals, nvals) - - tn := []float32{0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.8, 0.8} - copy(nvals, vals) - Thresh32(nvals, true, 0.8, true, 0.2) - assert.Equal(t, tn, nvals) - - bn := []float32{0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1} - copy(nvals, vals) - Binarize32(nvals, 0.5, 1.0, 0.0) - assert.Equal(t, bn, nvals) - - tsr := tensor.New[float32]([]int{11}).(*tensor.Float32) - copy(tsr.Values, vals) - TensorZScore(tsr, 0) - tolassert.EqualTolSlice(t, zn, tsr.Values, 1.0e-6) - - copy(tsr.Values, vals) - TensorUnit(tsr, 0) - tolassert.EqualTolSlice(t, vals, tsr.Values, 1.0e-6) - -} - -func TestNorm64(t *testing.T) { - vals := []float64{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} - - zn := []float64{-1.507556722888818, -1.2060453783110545, -0.9045340337332908, -0.6030226891555273, -0.3015113445777635, 0, 0.3015113445777635, 0.603022689155527, 0.904534033733291, 1.2060453783110545, 1.507556722888818} - nvals := make([]float64, len(vals)) - copy(nvals, vals) - ZScore64(nvals) - assert.Equal(t, zn, nvals) - - copy(nvals, vals) - Unit64(nvals) - assert.Equal(t, vals, nvals) - - tn := []float64{0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.8, 0.8} - copy(nvals, vals) - Thresh64(nvals, true, 0.8, true, 0.2) - assert.Equal(t, tn, nvals) - - bn := []float64{0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1} - copy(nvals, vals) - Binarize64(nvals, 0.5, 1.0, 0.0) - assert.Equal(t, bn, nvals) - - tsr := tensor.New[float64]([]int{11}).(*tensor.Float64) - copy(tsr.Values, vals) - TensorZScore(tsr, 0) - tolassert.EqualTolSlice(t, zn, tsr.Values, 1.0e-6) - - copy(tsr.Values, vals) - TensorUnit(tsr, 0) - tolassert.EqualTolSlice(t, vals, tsr.Values, 1.0e-6) - -} diff --git a/tensor/stats/norm/tensor.go b/tensor/stats/norm/tensor.go deleted file mode 100644 index ffc00db8f7..0000000000 --- a/tensor/stats/norm/tensor.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package norm - -import ( - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/stats" -) - -/////////////////////////////////////////// -// DivNorm - -// TensorDivNorm does divisive normalization by given norm function -// computed on the first ndim dims of the tensor, where 0 = all values, -// 1 = norm each of the sub-dimensions under the first outer-most dimension etc. -// ndim must be < NumDims() if not 0. -func TensorDivNorm(tsr tensor.Tensor, ndim int, nfunc32 Func32, nfunc64 Func64) { - if ndim == 0 { - FloatFunc(tsr, nfunc32, nfunc64, 0, 0, DivNorm32, DivNorm64) - } - if ndim >= tsr.NumDims() { - panic("norm.TensorSubNorm32: number of dims must be < NumDims()") - } - sln := 1 - ln := tsr.Len() - for i := 0; i < ndim; i++ { - sln *= tsr.Shape().DimSize(i) - } - dln := ln / sln - for sl := 0; sl < sln; sl++ { - st := sl * dln - FloatFunc(tsr, nfunc32, nfunc64, st, dln, DivNorm32, DivNorm64) - } -} - -/////////////////////////////////////////// -// SubNorm - -// TensorSubNorm does subtractive normalization by given norm function -// computed on the first ndim dims of the tensor, where 0 = all values, -// 1 = norm each of the sub-dimensions under the first outer-most dimension etc. -// ndim must be < NumDims() if not 0 (panics). -func TensorSubNorm(tsr tensor.Tensor, ndim int, nfunc32 Func32, nfunc64 Func64) { - if ndim == 0 { - FloatFunc(tsr, nfunc32, nfunc64, 0, 0, SubNorm32, SubNorm64) - } - if ndim >= tsr.NumDims() { - panic("norm.TensorSubNorm32: number of dims must be < NumDims()") - } - sln := 1 - ln := tsr.Len() - for i := 0; i < ndim; i++ { - sln *= tsr.Shape().DimSize(i) - } - dln := ln / sln - for sl := 0; sl < sln; sl++ { - st := sl * dln - FloatFunc(tsr, nfunc32, nfunc64, st, dln, SubNorm32, SubNorm64) - } -} - -// TensorZScore subtracts the mean and divides by the standard deviation -// computed on the first ndim dims of the tensor, where 0 = all values, -// 1 = norm each of the sub-dimensions under the first outer-most dimension etc. -// ndim must be < NumDims() if not 0 (panics). -// must be a float32 or float64 tensor -func TensorZScore(tsr tensor.Tensor, ndim int) { - TensorSubNorm(tsr, ndim, stats.Mean32, stats.Mean64) - TensorDivNorm(tsr, ndim, stats.Std32, stats.Std64) -} - -// TensorUnit subtracts the min and divides by the max, so that values are in 0-1 unit range -// computed on the first ndim dims of the tensor, where 0 = all values, -// 1 = norm each of the sub-dimensions under the first outer-most dimension etc. -// ndim must be < NumDims() if not 0 (panics). -// must be a float32 or float64 tensor -func TensorUnit(tsr tensor.Tensor, ndim int) { - TensorSubNorm(tsr, ndim, stats.Min32, stats.Min64) - TensorDivNorm(tsr, ndim, stats.Max32, stats.Max64) -} diff --git a/tensor/stats/pca/README.md b/tensor/stats/pca/README.md deleted file mode 100644 index 5f3f7b4d63..0000000000 --- a/tensor/stats/pca/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# pca - -This performs principal component's analysis and associated covariance matrix computations, operating on `table.Table` or `tensor.Tensor` data, using the [gonum](https://github.com/gonum/gonum) matrix interface. - -There is support for the SVD version, which is much faster and produces the same results, with options for how much information to compute trading off with compute time. - - diff --git a/tensor/stats/pca/covar.go b/tensor/stats/pca/covar.go deleted file mode 100644 index 5aa1e9f0d0..0000000000 --- a/tensor/stats/pca/covar.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pca - -import ( - "fmt" - - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" -) - -// CovarTableColumn generates a covariance matrix from given column name -// in given IndexView of an table.Table, and given metric function -// (typically Covariance or Correlation -- use Covar if vars have similar -// overall scaling, which is typical in neural network models, and use -// Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix. -func CovarTableColumn(cmat tensor.Tensor, ix *table.IndexView, column string, mfun metric.Func64) error { - col, err := ix.Table.ColumnByName(column) - if err != nil { - return err - } - rows := ix.Len() - nd := col.NumDims() - if nd < 2 || rows == 0 { - return fmt.Errorf("pca.CovarTableColumn: must have 2 or more dims and rows != 0") - } - ln := col.Len() - sz := ln / col.DimSize(0) // size of cell - - cshp := []int{sz, sz} - cmat.SetShape(cshp) - - av := make([]float64, rows) - bv := make([]float64, rows) - sdim := []int{0, 0} - for ai := 0; ai < sz; ai++ { - sdim[0] = ai - TableColumnRowsVec(av, ix, col, ai) - for bi := 0; bi <= ai; bi++ { // lower diag - sdim[1] = bi - TableColumnRowsVec(bv, ix, col, bi) - cv := mfun(av, bv) - cmat.SetFloat(sdim, cv) - } - } - // now fill in upper diagonal with values from lower diagonal - // note: assumes symmetric distance function - fdim := []int{0, 0} - for ai := 0; ai < sz; ai++ { - sdim[0] = ai - fdim[1] = ai - for bi := ai + 1; bi < sz; bi++ { // upper diag - fdim[0] = bi - sdim[1] = bi - cv := cmat.Float(fdim) - cmat.SetFloat(sdim, cv) - } - } - - if nm, has := ix.Table.MetaData["name"]; has { - cmat.SetMetaData("name", nm+"_"+column) - } else { - cmat.SetMetaData("name", column) - } - if ds, has := ix.Table.MetaData["desc"]; has { - cmat.SetMetaData("desc", ds) - } - return nil -} - -// CovarTensor generates a covariance matrix from given tensor.Tensor, -// where the outer-most dimension is rows, and all other dimensions within that -// are covaried against each other, using given metric function -// (typically Covariance or Correlation -- use Covar if vars have similar -// overall scaling, which is typical in neural network models, and use -// Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix. -func CovarTensor(cmat tensor.Tensor, tsr tensor.Tensor, mfun metric.Func64) error { - rows := tsr.DimSize(0) - nd := tsr.NumDims() - if nd < 2 || rows == 0 { - return fmt.Errorf("pca.CovarTensor: must have 2 or more dims and rows != 0") - } - ln := tsr.Len() - sz := ln / rows - - cshp := []int{sz, sz} - cmat.SetShape(cshp) - - av := make([]float64, rows) - bv := make([]float64, rows) - sdim := []int{0, 0} - for ai := 0; ai < sz; ai++ { - sdim[0] = ai - TensorRowsVec(av, tsr, ai) - for bi := 0; bi <= ai; bi++ { // lower diag - sdim[1] = bi - TensorRowsVec(bv, tsr, bi) - cv := mfun(av, bv) - cmat.SetFloat(sdim, cv) - } - } - // now fill in upper diagonal with values from lower diagonal - // note: assumes symmetric distance function - fdim := []int{0, 0} - for ai := 0; ai < sz; ai++ { - sdim[0] = ai - fdim[1] = ai - for bi := ai + 1; bi < sz; bi++ { // upper diag - fdim[0] = bi - sdim[1] = bi - cv := cmat.Float(fdim) - cmat.SetFloat(sdim, cv) - } - } - - if nm, has := tsr.MetaData("name"); has { - cmat.SetMetaData("name", nm+"Covar") - } else { - cmat.SetMetaData("name", "Covar") - } - if ds, has := tsr.MetaData("desc"); has { - cmat.SetMetaData("desc", ds) - } - return nil -} - -// TableColumnRowsVec extracts row-wise vector from given cell index into vec. -// vec must be of size ix.Len() -- number of rows -func TableColumnRowsVec(vec []float64, ix *table.IndexView, col tensor.Tensor, cidx int) { - rows := ix.Len() - ln := col.Len() - sz := ln / col.DimSize(0) // size of cell - for ri := 0; ri < rows; ri++ { - coff := ix.Indexes[ri]*sz + cidx - vec[ri] = col.Float1D(coff) - } -} - -// TensorRowsVec extracts row-wise vector from given cell index into vec. -// vec must be of size tsr.DimSize(0) -- number of rows -func TensorRowsVec(vec []float64, tsr tensor.Tensor, cidx int) { - rows := tsr.DimSize(0) - ln := tsr.Len() - sz := ln / rows - for ri := 0; ri < rows; ri++ { - coff := ri*sz + cidx - vec[ri] = tsr.Float1D(coff) - } -} - -// CovarTableColumnStd generates a covariance matrix from given column name -// in given IndexView of an table.Table, and given metric function -// (typically Covariance or Correlation -- use Covar if vars have similar -// overall scaling, which is typical in neural network models, and use -// Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix. -// This Std version is usable e.g., in Python where the func cannot be passed. -func CovarTableColumnStd(cmat tensor.Tensor, ix *table.IndexView, column string, met metric.StdMetrics) error { - return CovarTableColumn(cmat, ix, column, metric.StdFunc64(met)) -} - -// CovarTensorStd generates a covariance matrix from given tensor.Tensor, -// where the outer-most dimension is rows, and all other dimensions within that -// are covaried against each other, using given metric function -// (typically Covariance or Correlation -- use Covar if vars have similar -// overall scaling, which is typical in neural network models, and use -// Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix. -// This Std version is usable e.g., in Python where the func cannot be passed. -func CovarTensorStd(cmat tensor.Tensor, tsr tensor.Tensor, met metric.StdMetrics) error { - return CovarTensor(cmat, tsr, metric.StdFunc64(met)) -} diff --git a/tensor/stats/pca/pca.go b/tensor/stats/pca/pca.go deleted file mode 100644 index b69b57e1a3..0000000000 --- a/tensor/stats/pca/pca.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pca - -//go:generate core generate - -import ( - "fmt" - - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" - "gonum.org/v1/gonum/mat" -) - -// PCA computes the eigenvalue decomposition of a square similarity matrix, -// typically generated using the correlation metric. -type PCA struct { - - // the covariance matrix computed on original data, which is then eigen-factored - Covar tensor.Tensor `display:"no-inline"` - - // the eigenvectors, in same size as Covar - each eigenvector is a column in this 2D square matrix, ordered *lowest* to *highest* across the columns -- i.e., maximum eigenvector is the last column - Vectors tensor.Tensor `display:"no-inline"` - - // the eigenvalues, ordered *lowest* to *highest* - Values []float64 `display:"no-inline"` -} - -func (pa *PCA) Init() { - pa.Covar = &tensor.Float64{} - pa.Vectors = &tensor.Float64{} - pa.Values = nil -} - -// TableColumn is a convenience method that computes a covariance matrix -// on given column of table and then performs the PCA on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -func (pa *PCA) TableColumn(ix *table.IndexView, column string, mfun metric.Func64) error { - if pa.Covar == nil { - pa.Init() - } - err := CovarTableColumn(pa.Covar, ix, column, mfun) - if err != nil { - return err - } - return pa.PCA() -} - -// Tensor is a convenience method that computes a covariance matrix -// on given tensor and then performs the PCA on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -func (pa *PCA) Tensor(tsr tensor.Tensor, mfun metric.Func64) error { - if pa.Covar == nil { - pa.Init() - } - err := CovarTensor(pa.Covar, tsr, mfun) - if err != nil { - return err - } - return pa.PCA() -} - -// TableColumnStd is a convenience method that computes a covariance matrix -// on given column of table and then performs the PCA on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is a Std metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -// This Std version is usable e.g., in Python where the func cannot be passed. -func (pa *PCA) TableColumnStd(ix *table.IndexView, column string, met metric.StdMetrics) error { - return pa.TableColumn(ix, column, metric.StdFunc64(met)) -} - -// TensorStd is a convenience method that computes a covariance matrix -// on given tensor and then performs the PCA on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is Std metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the PCA eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -// This Std version is usable e.g., in Python where the func cannot be passed. -func (pa *PCA) TensorStd(tsr tensor.Tensor, met metric.StdMetrics) error { - return pa.Tensor(tsr, metric.StdFunc64(met)) -} - -// PCA performs the eigen decomposition of the existing Covar matrix. -// Vectors and Values fields contain the results. -func (pa *PCA) PCA() error { - if pa.Covar == nil || pa.Covar.NumDims() != 2 { - return fmt.Errorf("pca.PCA: Covar matrix is nil or not 2D") - } - var eig mat.EigenSym - // note: MUST be a Float64 otherwise doesn't have Symmetric function - ok := eig.Factorize(pa.Covar.(*tensor.Float64), true) - if !ok { - return fmt.Errorf("gonum EigenSym Factorize failed") - } - if pa.Vectors == nil { - pa.Vectors = &tensor.Float64{} - } - var ev mat.Dense - eig.VectorsTo(&ev) - tensor.CopyDense(pa.Vectors, &ev) - nr := pa.Vectors.DimSize(0) - if len(pa.Values) != nr { - pa.Values = make([]float64, nr) - } - eig.Values(pa.Values) - return nil -} - -// ProjectColumn projects values from the given column of given table (via IndexView) -// onto the idx'th eigenvector (0 = largest eigenvalue, 1 = next, etc). -// Must have already called PCA() method. -func (pa *PCA) ProjectColumn(vals *[]float64, ix *table.IndexView, column string, idx int) error { - col, err := ix.Table.ColumnByName(column) - if err != nil { - return err - } - if pa.Vectors == nil { - return fmt.Errorf("PCA.ProjectColumn Vectors are nil -- must call PCA first") - } - nr := pa.Vectors.DimSize(0) - if idx >= nr { - return fmt.Errorf("PCA.ProjectColumn eigenvector index > rank of matrix") - } - cvec := make([]float64, nr) - eidx := nr - 1 - idx // eigens in reverse order - vec := pa.Vectors.(*tensor.Float64) - for ri := 0; ri < nr; ri++ { - cvec[ri] = vec.Value([]int{ri, eidx}) // vecs are in columns, reverse magnitude order - } - rows := ix.Len() - if len(*vals) != rows { - *vals = make([]float64, rows) - } - ln := col.Len() - sz := ln / col.DimSize(0) // size of cell - if sz != nr { - return fmt.Errorf("PCA.ProjectColumn column cell size != pca eigenvectors") - } - rdim := []int{0} - for row := 0; row < rows; row++ { - sum := 0.0 - rdim[0] = ix.Indexes[row] - rt := col.SubSpace(rdim) - for ci := 0; ci < sz; ci++ { - sum += cvec[ci] * rt.Float1D(ci) - } - (*vals)[row] = sum - } - return nil -} - -// ProjectColumnToTable projects values from the given column of given table (via IndexView) -// onto the given set of eigenvectors (idxs, 0 = largest eigenvalue, 1 = next, etc), -// and stores results along with labels from column labNm into results table. -// Must have already called PCA() method. -func (pa *PCA) ProjectColumnToTable(projections *table.Table, ix *table.IndexView, column, labNm string, idxs []int) error { - _, err := ix.Table.ColumnByName(column) - if err != nil { - return err - } - if pa.Vectors == nil { - return fmt.Errorf("PCA.ProjectColumn Vectors are nil -- must call PCA first") - } - rows := ix.Len() - projections.DeleteAll() - pcolSt := 0 - if labNm != "" { - projections.AddStringColumn(labNm) - pcolSt = 1 - } - for _, idx := range idxs { - projections.AddFloat64Column(fmt.Sprintf("Projection%v", idx)) - } - projections.SetNumRows(rows) - - for ii, idx := range idxs { - pcol := projections.Columns[pcolSt+ii].(*tensor.Float64) - pa.ProjectColumn(&pcol.Values, ix, column, idx) - } - - if labNm != "" { - lcol, err := ix.Table.ColumnByName(labNm) - if err == nil { - plcol := projections.Columns[0] - for row := 0; row < rows; row++ { - plcol.SetString1D(row, lcol.String1D(row)) - } - } - } - return nil -} diff --git a/tensor/stats/pca/pca_test.go b/tensor/stats/pca/pca_test.go deleted file mode 100644 index fb62388c8b..0000000000 --- a/tensor/stats/pca/pca_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pca - -import ( - "fmt" - "math" - "testing" - - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" -) - -func TestPCAIris(t *testing.T) { - // note: these results are verified against this example: - // https://plot.ly/ipython-notebooks/principal-component-analysis/ - - dt := table.NewTable() - dt.AddFloat64TensorColumn("data", []int{4}) - dt.AddStringColumn("class") - err := dt.OpenCSV("testdata/iris.data", table.Comma) - if err != nil { - t.Error(err) - } - ix := table.NewIndexView(dt) - pc := &PCA{} - // pc.TableColumn(ix, "data", metric.Covariance64) - // fmt.Printf("covar: %v\n", pc.Covar) - err = pc.TableColumn(ix, "data", metric.Correlation64) - if err != nil { - t.Error(err) - } - // fmt.Printf("correl: %v\n", pc.Covar) - // fmt.Printf("correl vec: %v\n", pc.Vectors) - // fmt.Printf("correl val: %v\n", pc.Values) - - errtol := 1.0e-9 - corvals := []float64{0.020607707235624825, 0.14735327830509573, 0.9212209307072254, 2.910818083752054} - for i, v := range pc.Values { - dif := math.Abs(corvals[i] - v) - if dif > errtol { - err = fmt.Errorf("eigenvalue: %v differs from correct: %v was: %v", i, corvals[i], v) - t.Error(err) - } - } - - prjt := &table.Table{} - err = pc.ProjectColumnToTable(prjt, ix, "data", "class", []int{0, 1}) - if err != nil { - t.Error(err) - } - // prjt.SaveCSV("test_data/projection01.csv", table.Comma, true) -} diff --git a/tensor/stats/pca/svd.go b/tensor/stats/pca/svd.go deleted file mode 100644 index d94ab60041..0000000000 --- a/tensor/stats/pca/svd.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pca - -import ( - "fmt" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" - "gonum.org/v1/gonum/mat" -) - -// SVD computes the eigenvalue decomposition of a square similarity matrix, -// typically generated using the correlation metric. -type SVD struct { - - // type of SVD to run: SVDNone is the most efficient if you only need the values which are always computed. Otherwise, SVDThin is the next most efficient for getting approximate vectors - Kind mat.SVDKind - - // condition value -- minimum normalized eigenvalue to return in values - Cond float64 `default:"0.01"` - - // the rank (count) of singular values greater than Cond - Rank int - - // the covariance matrix computed on original data, which is then eigen-factored - Covar tensor.Tensor `display:"no-inline"` - - // the eigenvectors, in same size as Covar - each eigenvector is a column in this 2D square matrix, ordered *lowest* to *highest* across the columns -- i.e., maximum eigenvector is the last column - Vectors tensor.Tensor `display:"no-inline"` - - // the eigenvalues, ordered *lowest* to *highest* - Values []float64 `display:"no-inline"` -} - -func (svd *SVD) Init() { - svd.Kind = mat.SVDNone - svd.Cond = 0.01 - svd.Covar = &tensor.Float64{} - svd.Vectors = &tensor.Float64{} - svd.Values = nil -} - -// TableColumn is a convenience method that computes a covariance matrix -// on given column of table and then performs the SVD on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the SVD eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -func (svd *SVD) TableColumn(ix *table.IndexView, column string, mfun metric.Func64) error { - if svd.Covar == nil { - svd.Init() - } - err := CovarTableColumn(svd.Covar, ix, column, mfun) - if err != nil { - return err - } - return svd.SVD() -} - -// Tensor is a convenience method that computes a covariance matrix -// on given tensor and then performs the SVD on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the SVD eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -func (svd *SVD) Tensor(tsr tensor.Tensor, mfun metric.Func64) error { - if svd.Covar == nil { - svd.Init() - } - err := CovarTensor(svd.Covar, tsr, mfun) - if err != nil { - return err - } - return svd.SVD() -} - -// TableColumnStd is a convenience method that computes a covariance matrix -// on given column of table and then performs the SVD on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is a Std metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the SVD eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -// This Std version is usable e.g., in Python where the func cannot be passed. -func (svd *SVD) TableColumnStd(ix *table.IndexView, column string, met metric.StdMetrics) error { - return svd.TableColumn(ix, column, metric.StdFunc64(met)) -} - -// TensorStd is a convenience method that computes a covariance matrix -// on given tensor and then performs the SVD on the resulting matrix. -// If no error occurs, the results can be read out from Vectors and Values -// or used in Projection methods. -// mfun is Std metric function, typically Covariance or Correlation -- use Covar -// if vars have similar overall scaling, which is typical in neural network models, -// and use Correl if they are on very different scales -- Correl effectively rescales). -// A Covariance matrix computes the *row-wise* vector similarities for each -// pairwise combination of column cells -- i.e., the extent to which each -// cell co-varies in its value with each other cell across the rows of the table. -// This is the input to the SVD eigenvalue decomposition of the resulting -// covariance matrix, which extracts the eigenvectors as directions with maximal -// variance in this matrix. -// This Std version is usable e.g., in Python where the func cannot be passed. -func (svd *SVD) TensorStd(tsr tensor.Tensor, met metric.StdMetrics) error { - return svd.Tensor(tsr, metric.StdFunc64(met)) -} - -// SVD performs the eigen decomposition of the existing Covar matrix. -// Vectors and Values fields contain the results. -func (svd *SVD) SVD() error { - if svd.Covar == nil || svd.Covar.NumDims() != 2 { - return fmt.Errorf("svd.SVD: Covar matrix is nil or not 2D") - } - var eig mat.SVD - // note: MUST be a Float64 otherwise doesn't have Symmetric function - ok := eig.Factorize(svd.Covar, svd.Kind) - if !ok { - return fmt.Errorf("gonum SVD Factorize failed") - } - if svd.Kind > mat.SVDNone { - if svd.Vectors == nil { - svd.Vectors = &tensor.Float64{} - } - var ev mat.Dense - eig.UTo(&ev) - tensor.CopyDense(svd.Vectors, &ev) - } - nr := svd.Covar.DimSize(0) - if len(svd.Values) != nr { - svd.Values = make([]float64, nr) - } - eig.Values(svd.Values) - svd.Rank = eig.Rank(svd.Cond) - return nil -} - -// ProjectColumn projects values from the given column of given table (via IndexView) -// onto the idx'th eigenvector (0 = largest eigenvalue, 1 = next, etc). -// Must have already called SVD() method. -func (svd *SVD) ProjectColumn(vals *[]float64, ix *table.IndexView, column string, idx int) error { - col, err := ix.Table.ColumnByName(column) - if err != nil { - return err - } - if svd.Vectors == nil || svd.Vectors.Len() == 0 { - return fmt.Errorf("SVD.ProjectColumn Vectors are nil: must call SVD first, with Kind = mat.SVDFull so that the vectors are returned") - } - nr := svd.Vectors.DimSize(0) - if idx >= nr { - return fmt.Errorf("SVD.ProjectColumn eigenvector index > rank of matrix") - } - cvec := make([]float64, nr) - // eidx := nr - 1 - idx // eigens in reverse order - vec := svd.Vectors.(*tensor.Float64) - for ri := 0; ri < nr; ri++ { - cvec[ri] = vec.Value([]int{ri, idx}) // vecs are in columns, reverse magnitude order - } - rows := ix.Len() - if len(*vals) != rows { - *vals = make([]float64, rows) - } - ln := col.Len() - sz := ln / col.DimSize(0) // size of cell - if sz != nr { - return fmt.Errorf("SVD.ProjectColumn column cell size != svd eigenvectors") - } - rdim := []int{0} - for row := 0; row < rows; row++ { - sum := 0.0 - rdim[0] = ix.Indexes[row] - rt := col.SubSpace(rdim) - for ci := 0; ci < sz; ci++ { - sum += cvec[ci] * rt.Float1D(ci) - } - (*vals)[row] = sum - } - return nil -} - -// ProjectColumnToTable projects values from the given column of given table (via IndexView) -// onto the given set of eigenvectors (idxs, 0 = largest eigenvalue, 1 = next, etc), -// and stores results along with labels from column labNm into results table. -// Must have already called SVD() method. -func (svd *SVD) ProjectColumnToTable(projections *table.Table, ix *table.IndexView, column, labNm string, idxs []int) error { - _, err := ix.Table.ColumnByName(column) - if errors.Log(err) != nil { - return err - } - if svd.Vectors == nil { - return fmt.Errorf("SVD.ProjectColumn Vectors are nil -- must call SVD first") - } - rows := ix.Len() - projections.DeleteAll() - pcolSt := 0 - if labNm != "" { - projections.AddStringColumn(labNm) - pcolSt = 1 - } - for _, idx := range idxs { - projections.AddFloat64Column(fmt.Sprintf("Projection%v", idx)) - } - projections.SetNumRows(rows) - - for ii, idx := range idxs { - pcol := projections.Columns[pcolSt+ii].(*tensor.Float64) - svd.ProjectColumn(&pcol.Values, ix, column, idx) - } - - if labNm != "" { - lcol, err := ix.Table.ColumnByName(labNm) - if errors.Log(err) == nil { - plcol := projections.Columns[0] - for row := 0; row < rows; row++ { - plcol.SetString1D(row, lcol.String1D(row)) - } - } - } - return nil -} diff --git a/tensor/stats/pca/svd_test.go b/tensor/stats/pca/svd_test.go deleted file mode 100644 index f1caad8b7c..0000000000 --- a/tensor/stats/pca/svd_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pca - -import ( - "fmt" - "math" - "testing" - - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" - "gonum.org/v1/gonum/mat" -) - -func TestSVDIris(t *testing.T) { - // note: these results are verified against this example: - // https://plot.ly/ipython-notebooks/principal-component-analysis/ - - dt := table.NewTable() - dt.AddFloat64TensorColumn("data", []int{4}) - dt.AddStringColumn("class") - err := dt.OpenCSV("testdata/iris.data", table.Comma) - if err != nil { - t.Error(err) - } - ix := table.NewIndexView(dt) - pc := &SVD{} - pc.Init() - pc.Kind = mat.SVDFull - // pc.TableColumn(ix, "data", metric.Covariance64) - // fmt.Printf("covar: %v\n", pc.Covar) - err = pc.TableColumn(ix, "data", metric.Correlation64) - if err != nil { - t.Error(err) - } - // fmt.Printf("correl: %v\n", pc.Covar) - // fmt.Printf("correl vec: %v\n", pc.Vectors) - // fmt.Printf("correl val: %v\n", pc.Values) - - errtol := 1.0e-9 - corvals := []float64{2.910818083752054, 0.9212209307072254, 0.14735327830509573, 0.020607707235624825} - for i, v := range pc.Values { - dif := math.Abs(corvals[i] - v) - if dif > errtol { - err = fmt.Errorf("eigenvalue: %v differs from correct: %v was: %v", i, corvals[i], v) - t.Error(err) - } - } - - prjt := &table.Table{} - err = pc.ProjectColumnToTable(prjt, ix, "data", "class", []int{0, 1}) - if err != nil { - t.Error(err) - } - // prjt.SaveCSV("test_data/svd_projection01.csv", table.Comma, true) -} diff --git a/tensor/stats/pca/testdata/projection01.csv b/tensor/stats/pca/testdata/projection01.csv deleted file mode 100644 index 49ef2cc2de..0000000000 --- a/tensor/stats/pca/testdata/projection01.csv +++ /dev/null @@ -1,151 +0,0 @@ -_H:,$class,#Prjn0,#Prjn1 -_D:,Iris-setosa,2.669230878293515,5.180887223993903 -_D:,Iris-setosa,2.6964340118689534,4.643645304250262 -_D:,Iris-setosa,2.4811633041648684,4.752183452725602 -_D:,Iris-setosa,2.5715124347750256,4.6266149223441255 -_D:,Iris-setosa,2.5906582247213548,5.236211037073636 -_D:,Iris-setosa,3.0080988099460617,5.682216917525971 -_D:,Iris-setosa,2.490941664609344,4.90871396981208 -_D:,Iris-setosa,2.7014546083439073,5.053209215928301 -_D:,Iris-setosa,2.461583693196517,4.364930473160547 -_D:,Iris-setosa,2.6716628159090594,4.731768854441222 -_D:,Iris-setosa,2.83139678191279,5.479803509512478 -_D:,Iris-setosa,2.6551056848221406,4.980855020942431 -_D:,Iris-setosa,2.5876357448399223,4.599871891007371 -_D:,Iris-setosa,2.152073732956798,4.4073842762800135 -_D:,Iris-setosa,2.786962753802378,5.900069370044279 -_D:,Iris-setosa,2.91688203729186,6.252471718236359 -_D:,Iris-setosa,2.7755972077070026,5.673779006789473 -_D:,Iris-setosa,2.72579198328178,5.187428800901795 -_D:,Iris-setosa,3.134584682611489,5.694815200208339 -_D:,Iris-setosa,2.7049109092473627,5.4672052268301075 -_D:,Iris-setosa,3.0266540576265015,5.206355516636538 -_D:,Iris-setosa,2.787807505767021,5.381191154323272 -_D:,Iris-setosa,2.1492079743192316,5.078845780997149 -_D:,Iris-setosa,3.065961378000392,5.0217290889404955 -_D:,Iris-setosa,2.829481886501435,4.987183453994805 -_D:,Iris-setosa,2.8649219750292487,4.685096095953507 -_D:,Iris-setosa,2.8727022188802023,5.068401847428211 -_D:,Iris-setosa,2.7795934408940464,5.220228538013025 -_D:,Iris-setosa,2.7478035318656753,5.12556341091417 -_D:,Iris-setosa,2.6555395058441627,4.758511885777976 -_D:,Iris-setosa,2.734112159416322,4.703188072698243 -_D:,Iris-setosa,3.023525466483502,5.215219715084075 -_D:,Iris-setosa,2.565019386717418,5.769020857593508 -_D:,Iris-setosa,2.6938310857368215,5.977704115236996 -_D:,Iris-setosa,2.6716628159090594,4.731768854441222 -_D:,Iris-setosa,2.579749389727401,4.8617694840464685 -_D:,Iris-setosa,2.8200541258968146,5.327705091649766 -_D:,Iris-setosa,2.6716628159090594,4.731768854441222 -_D:,Iris-setosa,2.3771228011053585,4.455376644891153 -_D:,Iris-setosa,2.7536917703846737,5.090441052263297 -_D:,Iris-setosa,2.615429420681249,5.148087486882674 -_D:,Iris-setosa,2.670269508854147,3.8512605122309354 -_D:,Iris-setosa,2.3244518180425704,4.640487943720612 -_D:,Iris-setosa,2.959488937325338,5.174040650658726 -_D:,Iris-setosa,2.993973616474687,5.482184714474499 -_D:,Iris-setosa,2.700757954816452,4.612955044823157 -_D:,Iris-setosa,2.706475204818863,5.462773127606339 -_D:,Iris-setosa,2.487051542683867,4.7170610940747295 -_D:,Iris-setosa,2.7791596198720234,5.44257167317748 -_D:,Iris-setosa,2.669664699315537,4.958544088829447 -_D:,Iris-versicolor,6.337614909993668,5.758736852585482 -_D:,Iris-versicolor,5.964502241617807,5.537668456115144 -_D:,Iris-versicolor,6.48452514559209,5.639709899111897 -_D:,Iris-versicolor,5.327637994258105,4.345950542131197 -_D:,Iris-versicolor,6.180206770343914,5.206787172475347 -_D:,Iris-versicolor,5.5910618634814915,4.893739850295462 -_D:,Iris-versicolor,6.058741494153441,5.603752801471018 -_D:,Iris-versicolor,4.411318411598967,4.180724099023395 -_D:,Iris-versicolor,6.092986230876757,5.323491504409287 -_D:,Iris-versicolor,5.064020246438732,4.608909730008894 -_D:,Iris-versicolor,4.685148340884838,3.8519522930677232 -_D:,Iris-versicolor,5.581611212797471,5.160069542558326 -_D:,Iris-versicolor,5.445475981028535,4.419929343667774 -_D:,Iris-versicolor,5.946486926220956,5.1459833773263215 -_D:,Iris-versicolor,4.989360604871448,4.930078364218072 -_D:,Iris-versicolor,6.03286271372347,5.548157261113388 -_D:,Iris-versicolor,5.599275928354467,5.054702466605709 -_D:,Iris-versicolor,5.267449599849797,4.810353395755553 -_D:,Iris-versicolor,6.123382832850215,4.537648289297856 -_D:,Iris-versicolor,5.155956562699788,4.553101045795742 -_D:,Iris-versicolor,6.0473759480580656,5.377462438216211 -_D:,Iris-versicolor,5.509383508845732,5.032119807214825 -_D:,Iris-versicolor,6.329115122535858,4.8609849846135385 -_D:,Iris-versicolor,5.859700207775821,5.040344574095806 -_D:,Iris-versicolor,5.81413570511593,5.242699398686921 -_D:,Iris-versicolor,6.006961043214098,5.4183697753636615 -_D:,Iris-versicolor,6.396607952597476,5.316160059940694 -_D:,Iris-versicolor,6.577633923578247,5.487883208527084 -_D:,Iris-versicolor,5.834560068048925,5.111074162530968 -_D:,Iris-versicolor,4.892795525981836,4.667909043901078 -_D:,Iris-versicolor,5.071929491630652,4.421204082361892 -_D:,Iris-versicolor,4.957242986082623,4.412553027769875 -_D:,Iris-versicolor,5.264321008706797,4.8192175942030895 -_D:,Iris-versicolor,6.292544559458566,4.94516130671415 -_D:,Iris-versicolor,5.4948016042729355,4.980238793935716 -_D:,Iris-versicolor,5.75944371538022,5.580393986512508 -_D:,Iris-versicolor,6.263800020391029,5.561027271073654 -_D:,Iris-versicolor,5.978036892823293,4.652243143547671 -_D:,Iris-versicolor,5.253652116138878,5.033181402053425 -_D:,Iris-versicolor,5.274967011195318,4.531061840960656 -_D:,Iris-versicolor,5.424572016914717,4.625513824203992 -_D:,Iris-versicolor,5.862026034129797,5.236429549056926 -_D:,Iris-versicolor,5.348781900797956,4.728771422472485 -_D:,Iris-versicolor,4.489891065171127,4.1254002859436625 -_D:,Iris-versicolor,5.3907839912928255,4.757623931493361 -_D:,Iris-versicolor,5.307453573751144,5.065981139164654 -_D:,Iris-versicolor,5.390350170270803,4.979967066657817 -_D:,Iris-versicolor,5.709661381034398,5.168235726016927 -_D:,Iris-versicolor,4.3716421474580756,4.347956564963636 -_D:,Iris-versicolor,5.358560261242432,4.885301939558963 -_D:,Iris-virginica,7.323421646324768,5.690050203535673 -_D:,Iris-virginica,6.357753550341829,4.890322364767835 -_D:,Iris-virginica,7.535955596732253,5.6819621606557655 -_D:,Iris-virginica,6.80033427529343,5.265598656785007 -_D:,Iris-virginica,7.2209683289161575,5.463003241869551 -_D:,Iris-virginica,8.204019210854437,5.882887686119622 -_D:,Iris-virginica,5.478415461702605,4.34438451900287 -_D:,Iris-virginica,7.729583699619444,5.652683363923848 -_D:,Iris-virginica,7.230875690701601,5.048522359834326 -_D:,Iris-virginica,7.772675030657245,6.30491315647896 -_D:,Iris-virginica,6.648297331958487,5.620265043094353 -_D:,Iris-virginica,6.7874273237059555,5.1179323381460655 -_D:,Iris-virginica,7.146742508370896,5.561828740914276 -_D:,Iris-virginica,6.356623075792351,4.672411328827146 -_D:,Iris-virginica,6.614223583751759,5.015585898722028 -_D:,Iris-virginica,6.881994286002045,5.606876892851283 -_D:,Iris-virginica,6.820347707283804,5.430508501185606 -_D:,Iris-virginica,8.160258946192082,6.669215772364471 -_D:,Iris-virginica,8.649096750676604,5.56930851166386 -_D:,Iris-virginica,6.309535511567507,4.473732005048484 -_D:,Iris-virginica,7.375681698444933,5.801473985262766 -_D:,Iris-virginica,6.167254038597639,4.910736963052213 -_D:,Iris-virginica,8.310491651529492,5.7305761244013915 -_D:,Iris-virginica,6.446127454437865,5.065721014166677 -_D:,Iris-virginica,7.131749672855478,5.806482808191717 -_D:,Iris-virginica,7.423963861305202,5.8867900427806665 -_D:,Iris-virginica,6.30942940030594,5.1189353495622845 -_D:,Iris-virginica,6.262646655762151,5.2689242897408715 -_D:,Iris-virginica,7.048590243830385,5.229899574428954 -_D:,Iris-virginica,7.247261833271931,5.6843766347671725 -_D:,Iris-virginica,7.748466657060339,5.59968217238376 -_D:,Iris-virginica,7.9772348586177895,6.724267858166305 -_D:,Iris-virginica,7.10515134881865,5.236441151336846 -_D:,Iris-virginica,6.366359449061206,5.142870888225977 -_D:,Iris-virginica,6.548622005853021,4.887301728239255 -_D:,Iris-virginica,8.07875158007291,5.92265528784978 -_D:,Iris-virginica,7.00802344756605,5.767626365306011 -_D:,Iris-virginica,6.7417750537116445,5.4858323142653385 -_D:,Iris-virginica,6.15228409316162,5.2295829757217485 -_D:,Iris-virginica,7.114518778320504,5.689506748979877 -_D:,Iris-virginica,7.295978570323296,5.638886762401811 -_D:,Iris-virginica,7.0532647866177385,5.6962614697432885 -_D:,Iris-virginica,6.357753550341829,4.890322364767835 -_D:,Iris-virginica,7.439695337523697,5.768461104296018 -_D:,Iris-virginica,7.3579940928085374,5.832649115823288 -_D:,Iris-virginica,7.0332513546273665,5.5313516253426895 -_D:,Iris-virginica,6.613484943048682,4.889260769929234 -_D:,Iris-virginica,6.75909371558104,5.437263221949017 -_D:,Iris-virginica,6.782974379417489,5.719633996694872 -_D:,Iris-virginica,6.274423132800148,5.198679572439127 diff --git a/tensor/stats/pca/testdata/svd_projection01.csv b/tensor/stats/pca/testdata/svd_projection01.csv deleted file mode 100644 index bef23d7481..0000000000 --- a/tensor/stats/pca/testdata/svd_projection01.csv +++ /dev/null @@ -1,151 +0,0 @@ -$class,#Prjn0,#Prjn1 -Iris-setosa,-2.6692308782935164,-5.180887223993902 -Iris-setosa,-2.6964340118689543,-4.643645304250262 -Iris-setosa,-2.4811633041648697,-4.752183452725602 -Iris-setosa,-2.5715124347750264,-4.626614922344125 -Iris-setosa,-2.5906582247213565,-5.236211037073635 -Iris-setosa,-3.0080988099460635,-5.682216917525971 -Iris-setosa,-2.4909416646093447,-4.908713969812081 -Iris-setosa,-2.701454608343909,-5.053209215928301 -Iris-setosa,-2.461583693196518,-4.364930473160547 -Iris-setosa,-2.6716628159090603,-4.731768854441222 -Iris-setosa,-2.8313967819127916,-5.479803509512477 -Iris-setosa,-2.655105684822142,-4.980855020942431 -Iris-setosa,-2.5876357448399236,-4.599871891007371 -Iris-setosa,-2.152073732956799,-4.4073842762800135 -Iris-setosa,-2.7869627538023796,-5.900069370044279 -Iris-setosa,-2.9168820372918622,-6.25247171823636 -Iris-setosa,-2.7755972077070044,-5.673779006789473 -Iris-setosa,-2.7257919832817814,-5.187428800901794 -Iris-setosa,-3.1345846826114907,-5.694815200208339 -Iris-setosa,-2.7049109092473644,-5.467205226830108 -Iris-setosa,-3.0266540576265033,-5.206355516636537 -Iris-setosa,-2.7878075057670233,-5.381191154323272 -Iris-setosa,-2.149207974319233,-5.078845780997149 -Iris-setosa,-3.0659613780003934,-5.0217290889404955 -Iris-setosa,-2.829481886501436,-4.987183453994805 -Iris-setosa,-2.86492197502925,-4.685096095953507 -Iris-setosa,-2.872702218880204,-5.068401847428211 -Iris-setosa,-2.7795934408940473,-5.220228538013024 -Iris-setosa,-2.747803531865676,-5.12556341091417 -Iris-setosa,-2.655539505844164,-4.758511885777976 -Iris-setosa,-2.734112159416324,-4.703188072698243 -Iris-setosa,-3.0235254664835036,-5.215219715084074 -Iris-setosa,-2.565019386717419,-5.769020857593508 -Iris-setosa,-2.6938310857368224,-5.977704115236996 -Iris-setosa,-2.6716628159090603,-4.731768854441222 -Iris-setosa,-2.579749389727403,-4.8617694840464685 -Iris-setosa,-2.820054125896816,-5.327705091649766 -Iris-setosa,-2.6716628159090603,-4.731768854441222 -Iris-setosa,-2.3771228011053593,-4.455376644891153 -Iris-setosa,-2.753691770384675,-5.090441052263298 -Iris-setosa,-2.615429420681251,-5.148087486882674 -Iris-setosa,-2.6702695088541484,-3.851260512230936 -Iris-setosa,-2.3244518180425717,-4.640487943720612 -Iris-setosa,-2.95948893732534,-5.174040650658726 -Iris-setosa,-2.9939736164746886,-5.4821847144745 -Iris-setosa,-2.7007579548164533,-4.612955044823157 -Iris-setosa,-2.7064752048188643,-5.46277312760634 -Iris-setosa,-2.4870515426838677,-4.7170610940747295 -Iris-setosa,-2.779159619872025,-5.44257167317748 -Iris-setosa,-2.6696646993155384,-4.958544088829447 -Iris-versicolor,-6.33761490999367,-5.758736852585481 -Iris-versicolor,-5.96450224161781,-5.537668456115145 -Iris-versicolor,-6.484525145592094,-5.639709899111898 -Iris-versicolor,-5.327637994258107,-4.345950542131198 -Iris-versicolor,-6.180206770343917,-5.206787172475347 -Iris-versicolor,-5.591061863481493,-4.8937398502954625 -Iris-versicolor,-6.058741494153444,-5.6037528014710185 -Iris-versicolor,-4.411318411598969,-4.180724099023395 -Iris-versicolor,-6.092986230876758,-5.323491504409287 -Iris-versicolor,-5.064020246438733,-4.608909730008894 -Iris-versicolor,-4.68514834088484,-3.8519522930677237 -Iris-versicolor,-5.581611212797473,-5.160069542558327 -Iris-versicolor,-5.445475981028536,-4.419929343667774 -Iris-versicolor,-5.946486926220958,-5.1459833773263215 -Iris-versicolor,-4.98936060487145,-4.930078364218073 -Iris-versicolor,-6.032862713723472,-5.548157261113388 -Iris-versicolor,-5.599275928354468,-5.05470246660571 -Iris-versicolor,-5.267449599849799,-4.810353395755553 -Iris-versicolor,-6.123382832850217,-4.537648289297856 -Iris-versicolor,-5.15595656269979,-4.553101045795744 -Iris-versicolor,-6.047375948058068,-5.377462438216212 -Iris-versicolor,-5.509383508845733,-5.032119807214826 -Iris-versicolor,-6.32911512253586,-4.860984984613539 -Iris-versicolor,-5.8597002077758225,-5.040344574095807 -Iris-versicolor,-5.814135705115932,-5.242699398686921 -Iris-versicolor,-6.006961043214099,-5.418369775363661 -Iris-versicolor,-6.396607952597479,-5.316160059940693 -Iris-versicolor,-6.577633923578249,-5.4878832085270846 -Iris-versicolor,-5.834560068048927,-5.111074162530969 -Iris-versicolor,-4.892795525981839,-4.667909043901078 -Iris-versicolor,-5.0719294916306525,-4.421204082361892 -Iris-versicolor,-4.9572429860826235,-4.412553027769875 -Iris-versicolor,-5.2643210087067995,-4.8192175942030895 -Iris-versicolor,-6.292544559458569,-4.94516130671415 -Iris-versicolor,-5.494801604272937,-4.980238793935717 -Iris-versicolor,-5.759443715380222,-5.580393986512508 -Iris-versicolor,-6.263800020391032,-5.561027271073655 -Iris-versicolor,-5.978036892823295,-4.652243143547671 -Iris-versicolor,-5.25365211613888,-5.033181402053426 -Iris-versicolor,-5.274967011195319,-4.531061840960657 -Iris-versicolor,-5.42457201691472,-4.625513824203992 -Iris-versicolor,-5.862026034129799,-5.2364295490569255 -Iris-versicolor,-5.348781900797959,-4.728771422472485 -Iris-versicolor,-4.489891065171128,-4.1254002859436625 -Iris-versicolor,-5.390783991292826,-4.757623931493362 -Iris-versicolor,-5.307453573751146,-5.065981139164655 -Iris-versicolor,-5.390350170270805,-4.979967066657818 -Iris-versicolor,-5.7096613810344,-5.168235726016928 -Iris-versicolor,-4.371642147458076,-4.347956564963637 -Iris-versicolor,-5.358560261242434,-4.885301939558964 -Iris-virginica,-7.32342164632477,-5.690050203535675 -Iris-virginica,-6.357753550341831,-4.890322364767834 -Iris-virginica,-7.535955596732257,-5.681962160655765 -Iris-virginica,-6.800334275293433,-5.265598656785008 -Iris-virginica,-7.22096832891616,-5.463003241869552 -Iris-virginica,-8.204019210854439,-5.882887686119622 -Iris-virginica,-5.478415461702606,-4.344384519002871 -Iris-virginica,-7.729583699619447,-5.652683363923849 -Iris-virginica,-7.230875690701603,-5.048522359834328 -Iris-virginica,-7.772675030657247,-6.30491315647896 -Iris-virginica,-6.648297331958489,-5.620265043094353 -Iris-virginica,-6.787427323705957,-5.117932338146065 -Iris-virginica,-7.146742508370899,-5.561828740914276 -Iris-virginica,-6.356623075792353,-4.672411328827147 -Iris-virginica,-6.614223583751762,-5.015585898722027 -Iris-virginica,-6.881994286002047,-5.606876892851284 -Iris-virginica,-6.820347707283807,-5.4305085011856065 -Iris-virginica,-8.160258946192084,-6.669215772364471 -Iris-virginica,-8.649096750676605,-5.56930851166386 -Iris-virginica,-6.309535511567509,-4.473732005048484 -Iris-virginica,-7.375681698444936,-5.801473985262767 -Iris-virginica,-6.16725403859764,-4.910736963052214 -Iris-virginica,-8.310491651529494,-5.730576124401391 -Iris-virginica,-6.4461274544378675,-5.065721014166677 -Iris-virginica,-7.131749672855481,-5.806482808191716 -Iris-virginica,-7.423963861305205,-5.886790042780667 -Iris-virginica,-6.309429400305942,-5.1189353495622845 -Iris-virginica,-6.262646655762153,-5.2689242897408715 -Iris-virginica,-7.048590243830388,-5.229899574428954 -Iris-virginica,-7.247261833271932,-5.684376634767173 -Iris-virginica,-7.748466657060342,-5.599682172383759 -Iris-virginica,-7.977234858617792,-6.724267858166306 -Iris-virginica,-7.105151348818652,-5.236441151336847 -Iris-virginica,-6.366359449061208,-5.142870888225977 -Iris-virginica,-6.548622005853022,-4.887301728239255 -Iris-virginica,-8.078751580072911,-5.922655287849781 -Iris-virginica,-7.0080234475660514,-5.767626365306012 -Iris-virginica,-6.741775053711646,-5.485832314265339 -Iris-virginica,-6.152284093161622,-5.229582975721749 -Iris-virginica,-7.114518778320507,-5.689506748979878 -Iris-virginica,-7.295978570323298,-5.638886762401812 -Iris-virginica,-7.053264786617741,-5.696261469743289 -Iris-virginica,-6.357753550341831,-4.890322364767834 -Iris-virginica,-7.4396953375237,-5.768461104296019 -Iris-virginica,-7.35799409280854,-5.832649115823288 -Iris-virginica,-7.033251354627368,-5.53135162534269 -Iris-virginica,-6.613484943048684,-4.889260769929234 -Iris-virginica,-6.7590937155810416,-5.437263221949018 -Iris-virginica,-6.7829743794174915,-5.719633996694874 -Iris-virginica,-6.274423132800151,-5.1986795724391275 diff --git a/tensor/stats/simat/README.md b/tensor/stats/simat/README.md deleted file mode 100644 index 6d2f2ee7f7..0000000000 --- a/tensor/stats/simat/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# simat - -`simat` provides similarity / distance matrix functions that create a `SimMat` matrix from Tensor or Table data. Any metric function defined in metric package (or user-created) can be used. - -The SimMat contains the Tensor of the similarity matrix values, and labels for the Rows and Columns. - -The `etview` package provides a `SimMatGrid` widget that displays the SimMat with the labels. - diff --git a/tensor/stats/simat/doc.go b/tensor/stats/simat/doc.go deleted file mode 100644 index fcc4953bac..0000000000 --- a/tensor/stats/simat/doc.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -Package simat provides similarity / distance matrix functions that create -a SimMat matrix from Tensor or Table data. Any metric function defined -in metric package (or user-created) can be used. - -The SimMat contains the Tensor of the similarity matrix values, and -labels for the Rows and Columns. - -The etview package provides a SimMatGrid widget that displays the SimMat -with the labels. -*/ -package simat diff --git a/tensor/stats/simat/simat.go b/tensor/stats/simat/simat.go deleted file mode 100644 index 49f4f06b0d..0000000000 --- a/tensor/stats/simat/simat.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package simat - -import ( - "fmt" - - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" -) - -// SimMat is a similarity / distance matrix with additional row and column -// labels for display purposes. -type SimMat struct { - - // the similarity / distance matrix (typically an tensor.Float64) - Mat tensor.Tensor - - // labels for the rows -- blank rows trigger generation of grouping lines - Rows []string - - // labels for the columns -- blank columns trigger generation of grouping lines - Columns []string -} - -// NewSimMat returns a new SimMat similarity matrix -func NewSimMat() *SimMat { - return &SimMat{} -} - -// Init initializes SimMat with default Matrix and nil rows, cols -func (smat *SimMat) Init() { - smat.Mat = &tensor.Float64{} - smat.Mat.SetMetaData("grid-fill", "1") // best for sim mats -- can override later if need to - smat.Rows = nil - smat.Columns = nil -} - -// TableColumnStd generates a similarity / distance matrix from given column name -// in given IndexView of an table.Table, and given standard metric function. -// if labNm is not empty, uses given column name for labels, which if blankRepeat -// is true are filtered so that any sequentially repeated labels are blank. -// This Std version is usable e.g., in Python where the func cannot be passed. -func (smat *SimMat) TableColumnStd(ix *table.IndexView, column, labNm string, blankRepeat bool, met metric.StdMetrics) error { - return smat.TableColumn(ix, column, labNm, blankRepeat, metric.StdFunc64(met)) -} - -// TableColumn generates a similarity / distance matrix from given column name -// in given IndexView of an table.Table, and given metric function. -// if labNm is not empty, uses given column name for labels, which if blankRepeat -// is true are filtered so that any sequentially repeated labels are blank. -func (smat *SimMat) TableColumn(ix *table.IndexView, column, labNm string, blankRepeat bool, mfun metric.Func64) error { - col, err := ix.Table.ColumnByName(column) - if err != nil { - return err - } - smat.Init() - sm := smat.Mat - - rows := ix.Len() - nd := col.NumDims() - if nd < 2 || rows == 0 { - return fmt.Errorf("simat.Tensor: must have 2 or more dims and rows != 0") - } - ln := col.Len() - sz := ln / col.DimSize(0) // size of cell - - sshp := []int{rows, rows} - sm.SetShape(sshp) - - av := make([]float64, sz) - bv := make([]float64, sz) - ardim := []int{0} - brdim := []int{0} - sdim := []int{0, 0} - for ai := 0; ai < rows; ai++ { - ardim[0] = ix.Indexes[ai] - sdim[0] = ai - ar := col.SubSpace(ardim) - ar.Floats(&av) - for bi := 0; bi <= ai; bi++ { // lower diag - brdim[0] = ix.Indexes[bi] - sdim[1] = bi - br := col.SubSpace(brdim) - br.Floats(&bv) - sv := mfun(av, bv) - sm.SetFloat(sdim, sv) - } - } - // now fill in upper diagonal with values from lower diagonal - // note: assumes symmetric distance function - fdim := []int{0, 0} - for ai := 0; ai < rows; ai++ { - sdim[0] = ai - fdim[1] = ai - for bi := ai + 1; bi < rows; bi++ { // upper diag - fdim[0] = bi - sdim[1] = bi - sv := sm.Float(fdim) - sm.SetFloat(sdim, sv) - } - } - - if nm, has := ix.Table.MetaData["name"]; has { - sm.SetMetaData("name", nm+"_"+column) - } else { - sm.SetMetaData("name", column) - } - if ds, has := ix.Table.MetaData["desc"]; has { - sm.SetMetaData("desc", ds) - } - - if labNm == "" { - return nil - } - lc, err := ix.Table.ColumnByName(labNm) - if err != nil { - return err - } - smat.Rows = make([]string, rows) - last := "" - for r := 0; r < rows; r++ { - lbl := lc.String1D(ix.Indexes[r]) - if blankRepeat && lbl == last { - continue - } - smat.Rows[r] = lbl - last = lbl - } - smat.Columns = smat.Rows // identical - return nil -} - -// BlankRepeat returns string slice with any sequentially repeated strings blanked out -func BlankRepeat(str []string) []string { - sz := len(str) - br := make([]string, sz) - last := "" - for r, s := range str { - if s == last { - continue - } - br[r] = s - last = s - } - return br -} diff --git a/tensor/stats/simat/simat_test.go b/tensor/stats/simat/simat_test.go deleted file mode 100644 index 103f7ae9a6..0000000000 --- a/tensor/stats/simat/simat_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package simat - -import ( - "testing" - - "cogentcore.org/core/tensor/stats/metric" - "cogentcore.org/core/tensor/table" - - "github.com/stretchr/testify/assert" -) - -var simres = `Tensor: [12, 12] -[0]: 0 3.4641016151377544 8.831760866327848 9.273618495495704 8.717797887081348 9.38083151964686 4.69041575982343 5.830951894845301 8.12403840463596 8.54400374531753 5.291502622129181 6.324555320336759 -[1]: 3.4641016151377544 0 9.38083151964686 8.717797887081348 9.273618495495704 8.831760866327848 5.830951894845301 4.69041575982343 8.717797887081348 7.937253933193772 6.324555320336759 5.291502622129181 -[2]: 8.831760866327848 9.38083151964686 0 3.4641016151377544 4.242640687119285 5.0990195135927845 9.38083151964686 9.899494936611665 4.47213595499958 5.744562646538029 9.38083151964686 9.899494936611665 -[3]: 9.273618495495704 8.717797887081348 3.4641016151377544 0 5.477225575051661 3.7416573867739413 9.797958971132712 9.273618495495704 5.656854249492381 4.58257569495584 9.797958971132712 9.273618495495704 -[4]: 8.717797887081348 9.273618495495704 4.242640687119285 5.477225575051661 0 4 8.831760866327848 9.38083151964686 4.242640687119285 5.5677643628300215 8.831760866327848 9.38083151964686 -[5]: 9.38083151964686 8.831760866327848 5.0990195135927845 3.7416573867739413 4 0 9.486832980505138 8.94427190999916 5.830951894845301 4.795831523312719 9.486832980505138 8.94427190999916 -[6]: 4.69041575982343 5.830951894845301 9.38083151964686 9.797958971132712 8.831760866327848 9.486832980505138 0 3.4641016151377544 9.16515138991168 9.539392014169456 4.242640687119285 5.477225575051661 -[7]: 5.830951894845301 4.69041575982343 9.899494936611665 9.273618495495704 9.38083151964686 8.94427190999916 3.4641016151377544 0 9.695359714832659 9 5.477225575051661 4.242640687119285 -[8]: 8.12403840463596 8.717797887081348 4.47213595499958 5.656854249492381 4.242640687119285 5.830951894845301 9.16515138991168 9.695359714832659 0 3.605551275463989 9.16515138991168 9.695359714832659 -[9]: 8.54400374531753 7.937253933193772 5.744562646538029 4.58257569495584 5.5677643628300215 4.795831523312719 9.539392014169456 9 3.605551275463989 0 9.539392014169456 9 -[10]: 5.291502622129181 6.324555320336759 9.38083151964686 9.797958971132712 8.831760866327848 9.486832980505138 4.242640687119285 5.477225575051661 9.16515138991168 9.539392014169456 0 3.4641016151377544 -[11]: 6.324555320336759 5.291502622129181 9.899494936611665 9.273618495495704 9.38083151964686 8.94427190999916 5.477225575051661 4.242640687119285 9.695359714832659 9 3.4641016151377544 0 -` - -func TestSimMat(t *testing.T) { - dt := &table.Table{} - err := dt.OpenCSV("../clust/testdata/faces.dat", table.Tab) - if err != nil { - t.Error(err) - } - ix := table.NewIndexView(dt) - smat := &SimMat{} - smat.TableColumn(ix, "Input", "Name", false, metric.Euclidean64) - - // fmt.Println(smat.Mat) - assert.Equal(t, simres, smat.Mat.String()) -} diff --git a/tensor/stats/simat/tensor.go b/tensor/stats/simat/tensor.go deleted file mode 100644 index fc6d2c0e8a..0000000000 --- a/tensor/stats/simat/tensor.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package simat - -import ( - "fmt" - - "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/metric" -) - -// Tensor computes a similarity / distance matrix on tensor -// using given metric function. Outer-most dimension ("rows") is -// used as "indexical" dimension and all other dimensions within that -// are compared. -// Results go in smat which is ensured to have proper square 2D shape -// (rows * rows). -func Tensor(smat tensor.Tensor, a tensor.Tensor, mfun metric.Func64) error { - rows := a.DimSize(0) - nd := a.NumDims() - if nd < 2 || rows == 0 { - return fmt.Errorf("simat.Tensor: must have 2 or more dims and rows != 0") - } - ln := a.Len() - sz := ln / rows - - sshp := []int{rows, rows} - smat.SetShape(sshp) - - av := make([]float64, sz) - bv := make([]float64, sz) - ardim := []int{0} - brdim := []int{0} - sdim := []int{0, 0} - for ai := 0; ai < rows; ai++ { - ardim[0] = ai - sdim[0] = ai - ar := a.SubSpace(ardim) - ar.Floats(&av) - for bi := 0; bi <= ai; bi++ { // lower diag - brdim[0] = bi - sdim[1] = bi - br := a.SubSpace(brdim) - br.Floats(&bv) - sv := mfun(av, bv) - smat.SetFloat(sdim, sv) - } - } - // now fill in upper diagonal with values from lower diagonal - // note: assumes symmetric distance function - fdim := []int{0, 0} - for ai := 0; ai < rows; ai++ { - sdim[0] = ai - fdim[1] = ai - for bi := ai + 1; bi < rows; bi++ { // upper diag - fdim[0] = bi - sdim[1] = bi - sv := smat.Float(fdim) - smat.SetFloat(sdim, sv) - } - } - return nil -} - -// Tensors computes a similarity / distance matrix on two tensors -// using given metric function. Outer-most dimension ("rows") is -// used as "indexical" dimension and all other dimensions within that -// are compared. Resulting reduced 2D shape of two tensors must be -// the same (returns error if not). -// Rows of smat = a, cols = b -func Tensors(smat tensor.Tensor, a, b tensor.Tensor, mfun metric.Func64) error { - arows := a.DimSize(0) - and := a.NumDims() - brows := b.DimSize(0) - bnd := b.NumDims() - if and < 2 || bnd < 2 || arows == 0 || brows == 0 { - return fmt.Errorf("simat.Tensors: must have 2 or more dims and rows != 0") - } - alen := a.Len() - asz := alen / arows - blen := b.Len() - bsz := blen / brows - if asz != bsz { - return fmt.Errorf("simat.Tensors: size of inner dimensions must be same") - } - - sshp := []int{arows, brows} - smat.SetShape(sshp, "a", "b") - - av := make([]float64, asz) - bv := make([]float64, bsz) - ardim := []int{0} - brdim := []int{0} - sdim := []int{0, 0} - for ai := 0; ai < arows; ai++ { - ardim[0] = ai - sdim[0] = ai - ar := a.SubSpace(ardim) - ar.Floats(&av) - for bi := 0; bi < brows; bi++ { - brdim[0] = bi - sdim[1] = bi - br := b.SubSpace(brdim) - br.Floats(&bv) - sv := mfun(av, bv) - smat.SetFloat(sdim, sv) - } - } - return nil -} - -// TensorStd computes a similarity / distance matrix on tensor -// using given Std metric function. Outer-most dimension ("rows") is -// used as "indexical" dimension and all other dimensions within that -// are compared. -// Results go in smat which is ensured to have proper square 2D shape -// (rows * rows). -// This Std version is usable e.g., in Python where the func cannot be passed. -func TensorStd(smat tensor.Tensor, a tensor.Tensor, met metric.StdMetrics) error { - return Tensor(smat, a, metric.StdFunc64(met)) -} - -// TensorsStd computes a similarity / distance matrix on two tensors -// using given Std metric function. Outer-most dimension ("rows") is -// used as "indexical" dimension and all other dimensions within that -// are compared. Resulting reduced 2D shape of two tensors must be -// the same (returns error if not). -// Rows of smat = a, cols = b -// This Std version is usable e.g., in Python where the func cannot be passed. -func TensorsStd(smat tensor.Tensor, a, b tensor.Tensor, met metric.StdMetrics) error { - return Tensors(smat, a, b, metric.StdFunc64(met)) -} diff --git a/tensor/stats/split/README.md b/tensor/stats/split/README.md deleted file mode 100644 index 6d6cdfde63..0000000000 --- a/tensor/stats/split/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# split - -`split` provides `GroupBy`, `Agg`, `Permute` and other functions that create and populate Splits of `table.Table` data. These are powerful tools for quickly summarizing and analyzing data. - - diff --git a/tensor/stats/split/agg.go b/tensor/stats/split/agg.go deleted file mode 100644 index cea87595be..0000000000 --- a/tensor/stats/split/agg.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package split - -import ( - "fmt" - - "cogentcore.org/core/tensor/stats/stats" - "cogentcore.org/core/tensor/table" -) - -// AggIndex performs aggregation using given standard statistic (e.g., Mean) across -// all splits, and returns the SplitAgg container of the results, which are also -// stored in the Splits. Column is specified by index. -func AggIndex(spl *table.Splits, colIndex int, stat stats.Stats) *table.SplitAgg { - ag := spl.AddAgg(stat.String(), colIndex) - for _, sp := range spl.Splits { - agv := stats.StatIndex(sp, colIndex, stat) - ag.Aggs = append(ag.Aggs, agv) - } - return ag -} - -// AggColumn performs aggregation using given standard statistic (e.g., Mean) across -// all splits, and returns the SplitAgg container of the results, which are also -// stored in the Splits. Column is specified by name; returns error for bad column name. -func AggColumn(spl *table.Splits, column string, stat stats.Stats) (*table.SplitAgg, error) { - dt := spl.Table() - if dt == nil { - return nil, fmt.Errorf("split.AggTry: No splits to aggregate over") - } - colIndex, err := dt.ColumnIndex(column) - if err != nil { - return nil, err - } - return AggIndex(spl, colIndex, stat), nil -} - -// AggAllNumericColumns performs aggregation using given standard aggregation function across -// all splits, for all number-valued columns in the table. -func AggAllNumericColumns(spl *table.Splits, stat stats.Stats) { - dt := spl.Table() - for ci, cl := range dt.Columns { - if cl.IsString() { - continue - } - AggIndex(spl, ci, stat) - } -} - -/////////////////////////////////////////////////// -// Desc - -// DescIndex performs aggregation using standard statistics across -// all splits, and stores results in the Splits. Column is specified by index. -func DescIndex(spl *table.Splits, colIndex int) { - dt := spl.Table() - if dt == nil { - return - } - col := dt.Columns[colIndex] - sts := stats.DescStats - if col.NumDims() > 1 { // nd cannot do qiles - sts = stats.DescStatsND - } - for _, st := range sts { - AggIndex(spl, colIndex, st) - } -} - -// DescColumn performs aggregation using standard statistics across -// all splits, and stores results in the Splits. -// Column is specified by name; returns error for bad column name. -func DescColumn(spl *table.Splits, column string) error { - dt := spl.Table() - if dt == nil { - return fmt.Errorf("split.DescTry: No splits to aggregate over") - } - colIndex, err := dt.ColumnIndex(column) - if err != nil { - return err - } - DescIndex(spl, colIndex) - return nil -} diff --git a/tensor/stats/split/agg_test.go b/tensor/stats/split/agg_test.go deleted file mode 100644 index e9fae284c7..0000000000 --- a/tensor/stats/split/agg_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package split - -import ( - "testing" - - "cogentcore.org/core/tensor/stats/stats" - "cogentcore.org/core/tensor/table" - - "github.com/stretchr/testify/assert" -) - -func TestAgg(t *testing.T) { - dt := table.NewTable().SetNumRows(4) - dt.AddStringColumn("Group") - dt.AddFloat32Column("Value") - for i := 0; i < dt.Rows; i++ { - gp := "A" - if i >= 2 { - gp = "B" - } - dt.SetString("Group", i, gp) - dt.SetFloat("Value", i, float64(i)) - } - ix := table.NewIndexView(dt) - spl := GroupBy(ix, "Group") - assert.Equal(t, 2, len(spl.Splits)) - - AggColumn(spl, "Value", stats.Mean) - - st := spl.AggsToTable(table.ColumnNameOnly) - assert.Equal(t, 0.5, st.Float("Value", 0)) - assert.Equal(t, 2.5, st.Float("Value", 1)) - assert.Equal(t, "A", st.StringValue("Group", 0)) - assert.Equal(t, "B", st.StringValue("Group", 1)) -} - -func TestAggEmpty(t *testing.T) { - dt := table.NewTable().SetNumRows(4) - dt.AddStringColumn("Group") - dt.AddFloat32Column("Value") - for i := 0; i < dt.Rows; i++ { - gp := "A" - if i >= 2 { - gp = "B" - } - dt.SetString("Group", i, gp) - dt.SetFloat("Value", i, float64(i)) - } - ix := table.NewIndexView(dt) - ix.Filter(func(et *table.Table, row int) bool { - return false // exclude all - }) - spl := GroupBy(ix, "Group") - assert.Equal(t, 1, len(spl.Splits)) - - AggColumn(spl, "Value", stats.Mean) - - st := spl.AggsToTable(table.ColumnNameOnly) - if st == nil { - t.Error("AggsToTable should not be nil!") - } -} diff --git a/tensor/stats/split/doc.go b/tensor/stats/split/doc.go deleted file mode 100644 index 5af6a8cb87..0000000000 --- a/tensor/stats/split/doc.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -Package split provides GroupBy, Agg, Permute and other functions that -create and populate Splits of table.Table data. These are powerful -tools for quickly summarizing and analyzing data. -*/ -package split diff --git a/tensor/stats/split/group.go b/tensor/stats/split/group.go deleted file mode 100644 index 3a8259fdef..0000000000 --- a/tensor/stats/split/group.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package split - -//go:generate core generate - -import ( - "log" - "slices" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/tensor/table" -) - -// All returns a single "split" with all of the rows in given view -// useful for leveraging the aggregation management functions in splits -func All(ix *table.IndexView) *table.Splits { - spl := &table.Splits{} - spl.Levels = []string{"All"} - spl.New(ix.Table, []string{"All"}, ix.Indexes...) - return spl -} - -// GroupByIndex returns a new Splits set based on the groups of values -// across the given set of column indexes. -// Uses a stable sort on columns, so ordering of other dimensions is preserved. -func GroupByIndex(ix *table.IndexView, colIndexes []int) *table.Splits { - nc := len(colIndexes) - if nc == 0 || ix.Table == nil { - return nil - } - if ix.Table.ColumnNames == nil { - log.Println("split.GroupBy: Table does not have any column names -- will not work") - return nil - } - spl := &table.Splits{} - spl.Levels = make([]string, nc) - for i, ci := range colIndexes { - spl.Levels[i] = ix.Table.ColumnNames[ci] - } - srt := ix.Clone() - srt.SortStableColumns(colIndexes, true) // important for consistency - lstValues := make([]string, nc) - curValues := make([]string, nc) - var curIx *table.IndexView - for _, rw := range srt.Indexes { - diff := false - for i, ci := range colIndexes { - cl := ix.Table.Columns[ci] - cv := cl.String1D(rw) - curValues[i] = cv - if cv != lstValues[i] { - diff = true - } - } - if diff || curIx == nil { - curIx = spl.New(ix.Table, curValues, rw) - copy(lstValues, curValues) - } else { - curIx.AddIndex(rw) - } - } - if spl.Len() == 0 { // prevent crashing from subsequent ops: add an empty split - spl.New(ix.Table, curValues) // no rows added here - } - return spl -} - -// GroupBy returns a new Splits set based on the groups of values -// across the given set of column names. -// Uses a stable sort on columns, so ordering of other dimensions is preserved. -func GroupBy(ix *table.IndexView, columns ...string) *table.Splits { - return GroupByIndex(ix, errors.Log1(ix.Table.ColumnIndexesByNames(columns...))) -} - -// GroupByFunc returns a new Splits set based on the given function -// which returns value(s) to group on for each row of the table. -// The function should always return the same number of values -- if -// it doesn't behavior is undefined. -// Uses a stable sort on columns, so ordering of other dimensions is preserved. -func GroupByFunc(ix *table.IndexView, fun func(row int) []string) *table.Splits { - if ix.Table == nil { - return nil - } - - // save function values - funvals := make(map[int][]string, ix.Len()) - nv := 0 // number of valeus - for _, rw := range ix.Indexes { - sv := fun(rw) - if nv == 0 { - nv = len(sv) - } - funvals[rw] = slices.Clone(sv) - } - - srt := ix.Clone() - srt.SortStable(func(et *table.Table, i, j int) bool { // sort based on given function values - fvi := funvals[i] - fvj := funvals[j] - for fi := 0; fi < nv; fi++ { - if fvi[fi] < fvj[fi] { - return true - } else if fvi[fi] > fvj[fi] { - return false - } - } - return false - }) - - // now do our usual grouping operation - spl := &table.Splits{} - lstValues := make([]string, nv) - var curIx *table.IndexView - for _, rw := range srt.Indexes { - curValues := funvals[rw] - diff := (curIx == nil) - if !diff { - for fi := 0; fi < nv; fi++ { - if lstValues[fi] != curValues[fi] { - diff = true - break - } - } - } - if diff { - curIx = spl.New(ix.Table, curValues, rw) - copy(lstValues, curValues) - } else { - curIx.AddIndex(rw) - } - } - return spl -} diff --git a/tensor/stats/split/random.go b/tensor/stats/split/random.go deleted file mode 100644 index 4099d22e18..0000000000 --- a/tensor/stats/split/random.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package split - -import ( - "fmt" - "math" - - "cogentcore.org/core/tensor/table" - "gonum.org/v1/gonum/floats" -) - -// Permuted generates permuted random splits of table rows, using given list of probabilities, -// which will be normalized to sum to 1 (error returned if sum = 0) -// names are optional names for each split (e.g., Train, Test) which will be -// used to label the Values of the resulting Splits. -func Permuted(ix *table.IndexView, probs []float64, names []string) (*table.Splits, error) { - if ix == nil || ix.Len() == 0 { - return nil, fmt.Errorf("split.Random table is nil / empty") - } - np := len(probs) - if len(names) > 0 && len(names) != np { - return nil, fmt.Errorf("split.Random names not same len as probs") - } - sum := floats.Sum(probs) - if sum == 0 { - return nil, fmt.Errorf("split.Random probs sum to 0") - } - nr := ix.Len() - ns := make([]int, np) - cum := 0 - fnr := float64(nr) - for i, p := range probs { - p /= sum - per := int(math.Round(p * fnr)) - if cum+per > nr { - per = nr - cum - if per <= 0 { - break - } - } - ns[i] = per - cum += per - } - spl := &table.Splits{} - perm := ix.Clone() - perm.Permuted() - cum = 0 - spl.SetLevels("permuted") - for i, n := range ns { - nm := "" - if names != nil { - nm = names[i] - } else { - nm = fmt.Sprintf("p=%v", probs[i]/sum) - } - spl.New(ix.Table, []string{nm}, perm.Indexes[cum:cum+n]...) - cum += n - } - return spl, nil -} diff --git a/tensor/stats/split/random_test.go b/tensor/stats/split/random_test.go deleted file mode 100644 index 507ebd77d5..0000000000 --- a/tensor/stats/split/random_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package split - -import ( - "testing" - - "cogentcore.org/core/tensor/table" - - "github.com/stretchr/testify/assert" -) - -func TestPermuted(t *testing.T) { - dt := table.NewTable().SetNumRows(25) - dt.AddStringColumn("Name") - dt.AddFloat32TensorColumn("Input", []int{5, 5}, "Y", "X") - dt.AddFloat32TensorColumn("Output", []int{5, 5}, "Y", "X") - ix := table.NewIndexView(dt) - spl, err := Permuted(ix, []float64{.5, .5}, nil) - if err != nil { - t.Error(err) - } - // for i, sp := range spl.Splits { - // fmt.Printf("split: %v name: %v len: %v idxs: %v\n", i, spl.Values[i], len(sp.Indexes), sp.Indexes) - // } - assert.Equal(t, 2, len(spl.Splits)) - assert.Contains(t, []int{12, 13}, len(spl.Splits[0].Indexes)) - assert.Contains(t, []int{12, 13}, len(spl.Splits[1].Indexes)) - - spl, err = Permuted(ix, []float64{.25, .5, .25}, []string{"test", "train", "validate"}) - if err != nil { - t.Error(err) - } - // for i, sp := range spl.Splits { - // fmt.Printf("split: %v name: %v len: %v idxs: %v\n", i, spl.Values[i], len(sp.Indexes), sp.Indexes) - // } - assert.Equal(t, 3, len(spl.Splits)) - assert.Equal(t, 6, len(spl.Splits[0].Indexes)) - assert.Equal(t, 13, len(spl.Splits[1].Indexes)) - assert.Equal(t, 6, len(spl.Splits[2].Indexes)) -} diff --git a/tensor/stats/stats/README.md b/tensor/stats/stats/README.md index 5225e42383..0e76fc6428 100644 --- a/tensor/stats/stats/README.md +++ b/tensor/stats/stats/README.md @@ -1,12 +1,26 @@ # stats -The `stats` package provides standard statistic computations operating over floating-point data (both 32 and 64 bit) in the following formats. Each statistic returns a single scalar value summarizing the data in a different way. Some formats also support multi-dimensional tensor data, returning a summary stat for each tensor value, using the outer-most ("row-wise") dimension to summarize over. +The `stats` package provides standard statistic computations operating on the `tensor.Tensor` standard data representation, using this standard function: +```Go +type StatsFunc func(in, out tensor.Tensor) error +``` +n +The stats functions always operate on the outermost _row_ dimension, and it is up to the caller to reshape the tensor to accomplish the desired results. -* `[]float32` and `[]float64` slices, as e.g., `Mean32` and `Mean64`, skipping any `NaN` values as missing data. +* To obtain a single summary statistic across all values, use `tensor.As1D`. -* `tensor.Float32`, `tensor.Float64` using the underlying `Values` slice, and other generic `Tensor` using the `Floats` interface (less efficient). +* For `RowMajor` data that is naturally organized as a single outer _rows_ dimension with the remaining inner dimensions comprising the _cells_, the results are the statistic for each such cell computed across the outer rows dimension. For the `Mean` statistic for example, each cell contains the average of that cell across all the rows. -* `table.IndexView` indexed views of `table.Table` data, with `*Column` functions (e.g., `MeanColumn`) using names to specify columns, and `*Index` versions operating on column indexes. Also available for this type are `CountIf*`, `PctIf*`, `PropIf*` functions that return count, percentage, or propoprtion of values according to given function. +* Use `tensor.NewRowCellsView` to reshape any tensor into a 2D rows x cells shape, with the cells starting at a given dimension. Thus, any number of outer dimensions can be collapsed into the outer row dimension, and the remaining dimensions become the cells. + +By contrast, the [NumPy Statistics](https://numpy.org/doc/stable/reference/generated/numpy.mean.html#numpy.mean) functions take an `axis` dimension to compute over, but passing such arguments via the universal function calling api for tensors introduces complications, so it is simpler to just have a single designated behavior and reshape the data to achieve the desired results. + +All stats are registered in the `tensor.Funcs` global list (for use in Goal), and can be called through the `Stats` enum e.g.: +```Go +stats.Mean.Call(in, out) +``` + +All stats functions skip over `NaN`s as a missing value, so they are equivalent to the `nanmean` etc versions in NumPy. ## Stats @@ -14,23 +28,79 @@ The following statistics are supported (per the `Stats` enum in `stats.go`): * `Count`: count of number of elements * `Sum`: sum of elements +* `L1Norm`: L1 Norm: sum of absolute values * `Prod`: product of elements * `Min`: minimum value -* `Max`: max maximum value -* `MinAbs`: minimum absolute value -* `MaxAbs`: maximum absolute value -* `Mean`: mean mean value +* `Max`: maximum value +* `MinAbs`: minimum of absolute values +* `MaxAbs`: maximum of absolute values +* `Mean`: mean value * `Var`: sample variance (squared diffs from mean, divided by n-1) * `Std`: sample standard deviation (sqrt of Var) * `Sem`: sample standard error of the mean (Std divided by sqrt(n)) -* `L1Norm`: L1 Norm: sum of absolute values * `SumSq`: sum of squared element values * `L2Norm`: L2 Norm: square-root of sum-of-squares * `VarPop`: population variance (squared diffs from mean, divided by n) * `StdPop`: population standard deviation (sqrt of VarPop) * `SemPop`: population standard error of the mean (StdPop divided by sqrt(n)) -* `Median`: middle value in sorted ordering (only for IndexView) -* `Q1`: Q1 first quartile = 25%ile value = .25 quantile value (only for IndexView) -* `Q3`: Q3 third quartile = 75%ile value = .75 quantile value (only for IndexView) - +* `Median`: middle value in sorted ordering (uses a `Rows` view) +* `Q1`: Q1 first quartile = 25%ile value = .25 quantile value (uses `Rows`) +* `Q3`: Q3 third quartile = 75%ile value = .75 quantile value (uses `Rows`) + +Here is the general info associated with these function calls: + +The output must be a `tensor.Values` tensor, and it is automatically shaped to hold the stat value(s) for the "cells" in higher-dimensional tensors, and a single scalar value for a 1D input tensor. + +Stats functions cannot be computed in parallel, e.g., using VectorizeThreaded or GPU, due to shared writing to the same output values. Special implementations are required if that is needed. + +## Normalization functions + +The stats package also has the following standard normalization functions for transforming data into standard ranges in various ways: + +* `UnitNorm` subtracts `min` and divides by resulting `max` to normalize to 0..1 unit range. +* `ZScore` subtracts the mean and divides by the standard deviation. +* `Clamp` enforces min, max range, clamping values to those bounds if they exceed them. +* `Binarize` sets all values below a given threshold to 0, and those above to 1. + +## Groups + +The `Groups` function (and `TableGroups` convenience function for `table.Table` columns) creates lists of indexes for each unique value in a 1D tensor, and `GroupStats` calls a stats function on those groups, thereby creating a "pivot table" that summarizes data in terms of the groups present within it. The data is stored in a [tensorfs](../tensorfs) data filesystem, which can be visualized and further manipulated. + +For example, with this data: +``` +Person Score Time +Alia 40 8 +Alia 30 12 +Ben 20 10 +Ben 10 12 +``` +The `Groups` function called on the `Person` column would create the following `tensorfs` structure: +``` +Groups + Person + Alia: [0,1] // int tensor + Ben: [2,3] + // other groups here if passed +``` +Then the `GroupStats` function operating on this `tensorfs` directory, using the `Score` and `Time` data and the `Mean` stat, followed by a second call with the `Sem` stat, would produce: +``` +Stats + Person + Person: [Alia,Ben] // string tensor of group values of Person + Score + Mean: [35, 15] // float64 tensor of means + Sem: [5, 5] + Time + Mean: [10, 11] + Sem: [1, 0.5] + // other groups here.. +``` + +The `Person` directory can be turned directly into a `table.Table` and plotted or otherwise used, in the `tensorfs` system and associated `databrowser`. + +See the [examples/planets](../examples/planets) example for an interactive exploration of data on exoplanets using the `Groups` functions. + +## Vectorize functions + +See [vec.go](vec.go) for corresponding `tensor.Vectorize` functions that are used in performing the computations. These cannot be parallelized directly due to shared writing to output accumulators, and other ordering constraints. If needed, special atomic-locking or other such techniques would be required. diff --git a/tensor/stats/stats/desc.go b/tensor/stats/stats/desc.go deleted file mode 100644 index e6795cef8c..0000000000 --- a/tensor/stats/stats/desc.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "cogentcore.org/core/tensor/table" -) - -// DescStats are all the standard stats -var DescStats = []Stats{Count, Mean, Std, Sem, Min, Max, Q1, Median, Q3} - -// DescStatsND are all the standard stats for n-dimensional (n > 1) data -- cannot do quantiles -var DescStatsND = []Stats{Count, Mean, Std, Sem, Min, Max} - -// DescAll returns a table of standard descriptive stats for -// all numeric columns in given table, operating over all non-Null, non-NaN elements -// in each column. -func DescAll(ix *table.IndexView) *table.Table { - st := ix.Table - nAgg := len(DescStats) - dt := table.NewTable().SetNumRows(nAgg) - dt.AddStringColumn("Stat") - for ci := range st.Columns { - col := st.Columns[ci] - if col.IsString() { - continue - } - dt.AddFloat64TensorColumn(st.ColumnNames[ci], col.Shape().Sizes[1:], col.Shape().Names[1:]...) - } - dtnm := dt.Columns[0] - dtci := 1 - qs := []float64{.25, .5, .75} - sq := len(DescStatsND) - for ci := range st.Columns { - col := st.Columns[ci] - if col.IsString() { - continue - } - _, csz := col.RowCellSize() - dtst := dt.Columns[dtci] - for i, styp := range DescStatsND { - ag := StatIndex(ix, ci, styp) - si := i * csz - for j := 0; j < csz; j++ { - dtst.SetFloat1D(si+j, ag[j]) - } - if dtci == 1 { - dtnm.SetString1D(i, styp.String()) - } - } - if col.NumDims() == 1 { - qvs := QuantilesIndex(ix, ci, qs) - for i, qv := range qvs { - dtst.SetFloat1D(sq+i, qv) - dtnm.SetString1D(sq+i, DescStats[sq+i].String()) - } - } - dtci++ - } - return dt -} - -// DescIndex returns a table of standard descriptive aggregates -// of non-Null, non-NaN elements in given IndexView indexed view of an -// table.Table, for given column index. -func DescIndex(ix *table.IndexView, colIndex int) *table.Table { - st := ix.Table - col := st.Columns[colIndex] - stats := DescStats - if col.NumDims() > 1 { // nd cannot do qiles - stats = DescStatsND - } - nAgg := len(stats) - dt := table.NewTable().SetNumRows(nAgg) - dt.AddStringColumn("Stat") - dt.AddFloat64TensorColumn(st.ColumnNames[colIndex], col.Shape().Sizes[1:], col.Shape().Names[1:]...) - dtnm := dt.Columns[0] - dtst := dt.Columns[1] - _, csz := col.RowCellSize() - for i, styp := range DescStatsND { - ag := StatIndex(ix, colIndex, styp) - si := i * csz - for j := 0; j < csz; j++ { - dtst.SetFloat1D(si+j, ag[j]) - } - dtnm.SetString1D(i, styp.String()) - } - if col.NumDims() == 1 { - sq := len(DescStatsND) - qs := []float64{.25, .5, .75} - qvs := QuantilesIndex(ix, colIndex, qs) - for i, qv := range qvs { - dtst.SetFloat1D(sq+i, qv) - dtnm.SetString1D(sq+i, DescStats[sq+i].String()) - } - } - return dt -} - -// DescColumn returns a table of standard descriptive stats -// of non-NaN elements in given IndexView indexed view of an -// table.Table, for given column name. -// If name not found, returns error message. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func DescColumn(ix *table.IndexView, column string) (*table.Table, error) { - colIndex, err := ix.Table.ColumnIndex(column) - if err != nil { - return nil, err - } - return DescIndex(ix, colIndex), nil -} diff --git a/tensor/stats/stats/describe.go b/tensor/stats/stats/describe.go new file mode 100644 index 0000000000..e8f047d744 --- /dev/null +++ b/tensor/stats/stats/describe.go @@ -0,0 +1,61 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stats + +import ( + "strconv" + + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" +) + +// DescriptiveStats are the standard descriptive stats used in Describe function. +// Cannot apply the final 3 sort-based stats to higher-dimensional data. +var DescriptiveStats = []Stats{StatCount, StatMean, StatStd, StatSem, StatMin, StatMax, StatQ1, StatMedian, StatQ3} + +// Describe adds standard descriptive statistics for given tensor +// to the given [tensorfs] directory, adding a directory for each tensor +// and result tensor stats for each result. +// This is an easy way to provide a comprehensive description of data. +// The [DescriptiveStats] list is: [Count], [Mean], [Std], [Sem], +// [Min], [Max], [Q1], [Median], [Q3] +func Describe(dir *tensorfs.Node, tsrs ...tensor.Tensor) { + dd := dir.Dir("Describe") + for i, tsr := range tsrs { + nr := tsr.DimSize(0) + if nr == 0 { + continue + } + nm := metadata.Name(tsr) + if nm == "" { + nm = strconv.Itoa(i) + } + td := dd.Dir(nm) + for _, st := range DescriptiveStats { + stnm := st.String() + sv := tensorfs.Scalar[float64](td, stnm) + stout := st.Call(tsr) + sv.CopyFrom(stout) + } + } +} + +// DescribeTable runs [Describe] on given columns in table. +func DescribeTable(dir *tensorfs.Node, dt *table.Table, columns ...string) { + Describe(dir, dt.ColumnList(columns...)...) +} + +// DescribeTableAll runs [Describe] on all numeric columns in given table. +func DescribeTableAll(dir *tensorfs.Node, dt *table.Table) { + var cols []string + for i, cl := range dt.Columns.Values { + if !cl.IsString() { + cols = append(cols, dt.ColumnName(i)) + } + } + Describe(dir, dt.ColumnList(cols...)...) +} diff --git a/tensor/stats/stats/doc.go b/tensor/stats/stats/doc.go index e65f1ce9da..bb7b3ed36a 100644 --- a/tensor/stats/stats/doc.go +++ b/tensor/stats/stats/doc.go @@ -3,14 +3,6 @@ // license that can be found in the LICENSE file. /* -Package agg provides aggregation functions operating on IndexView indexed views -of table.Table data, along with standard AggFunc functions that can be used -at any level of aggregation from tensor on up. - -The main functions use names to specify columns, and *Index and *Try versions -are available that operate on column indexes and return errors, respectively. - -See tsragg package for functions that operate directly on a tensor.Tensor -without the indexview indirection. +Package stats provides standard statistic computations operating on the `tensor.Tensor` standard data representation. */ package stats diff --git a/tensor/stats/stats/enumgen.go b/tensor/stats/stats/enumgen.go index 7a760c668e..1c5691e379 100644 --- a/tensor/stats/stats/enumgen.go +++ b/tensor/stats/stats/enumgen.go @@ -11,11 +11,11 @@ var _StatsValues = []Stats{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // StatsN is the highest valid value for type Stats, plus one. const StatsN Stats = 20 -var _StatsValueMap = map[string]Stats{`Count`: 0, `Sum`: 1, `Prod`: 2, `Min`: 3, `Max`: 4, `MinAbs`: 5, `MaxAbs`: 6, `Mean`: 7, `Var`: 8, `Std`: 9, `Sem`: 10, `L1Norm`: 11, `SumSq`: 12, `L2Norm`: 13, `VarPop`: 14, `StdPop`: 15, `SemPop`: 16, `Median`: 17, `Q1`: 18, `Q3`: 19} +var _StatsValueMap = map[string]Stats{`Count`: 0, `Sum`: 1, `L1Norm`: 2, `Prod`: 3, `Min`: 4, `Max`: 5, `MinAbs`: 6, `MaxAbs`: 7, `Mean`: 8, `Var`: 9, `Std`: 10, `Sem`: 11, `SumSq`: 12, `L2Norm`: 13, `VarPop`: 14, `StdPop`: 15, `SemPop`: 16, `Median`: 17, `Q1`: 18, `Q3`: 19} -var _StatsDescMap = map[Stats]string{0: `count of number of elements`, 1: `sum of elements`, 2: `product of elements`, 3: `minimum value`, 4: `max maximum value`, 5: `minimum absolute value`, 6: `maximum absolute value`, 7: `mean mean value`, 8: `sample variance (squared diffs from mean, divided by n-1)`, 9: `sample standard deviation (sqrt of Var)`, 10: `sample standard error of the mean (Std divided by sqrt(n))`, 11: `L1 Norm: sum of absolute values`, 12: `sum of squared values`, 13: `L2 Norm: square-root of sum-of-squares`, 14: `population variance (squared diffs from mean, divided by n)`, 15: `population standard deviation (sqrt of VarPop)`, 16: `population standard error of the mean (StdPop divided by sqrt(n))`, 17: `middle value in sorted ordering`, 18: `Q1 first quartile = 25%ile value = .25 quantile value`, 19: `Q3 third quartile = 75%ile value = .75 quantile value`} +var _StatsDescMap = map[Stats]string{0: `count of number of elements.`, 1: `sum of elements.`, 2: `L1 Norm: sum of absolute values of elements.`, 3: `product of elements.`, 4: `minimum value.`, 5: `maximum value.`, 6: `minimum of absolute values.`, 7: `maximum of absolute values.`, 8: `mean value = sum / count.`, 9: `sample variance (squared deviations from mean, divided by n-1).`, 10: `sample standard deviation (sqrt of Var).`, 11: `sample standard error of the mean (Std divided by sqrt(n)).`, 12: `sum of squared values.`, 13: `L2 Norm: square-root of sum-of-squares.`, 14: `population variance (squared diffs from mean, divided by n).`, 15: `population standard deviation (sqrt of VarPop).`, 16: `population standard error of the mean (StdPop divided by sqrt(n)).`, 17: `middle value in sorted ordering.`, 18: `Q1 first quartile = 25%ile value = .25 quantile value.`, 19: `Q3 third quartile = 75%ile value = .75 quantile value.`} -var _StatsMap = map[Stats]string{0: `Count`, 1: `Sum`, 2: `Prod`, 3: `Min`, 4: `Max`, 5: `MinAbs`, 6: `MaxAbs`, 7: `Mean`, 8: `Var`, 9: `Std`, 10: `Sem`, 11: `L1Norm`, 12: `SumSq`, 13: `L2Norm`, 14: `VarPop`, 15: `StdPop`, 16: `SemPop`, 17: `Median`, 18: `Q1`, 19: `Q3`} +var _StatsMap = map[Stats]string{0: `Count`, 1: `Sum`, 2: `L1Norm`, 3: `Prod`, 4: `Min`, 5: `Max`, 6: `MinAbs`, 7: `MaxAbs`, 8: `Mean`, 9: `Var`, 10: `Std`, 11: `Sem`, 12: `SumSq`, 13: `L2Norm`, 14: `VarPop`, 15: `StdPop`, 16: `SemPop`, 17: `Median`, 18: `Q1`, 19: `Q3`} // String returns the string representation of this Stats value. func (i Stats) String() string { return enums.String(i, _StatsMap) } diff --git a/tensor/stats/stats/floats.go b/tensor/stats/stats/floats.go deleted file mode 100644 index dff87d1152..0000000000 --- a/tensor/stats/stats/floats.go +++ /dev/null @@ -1,730 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "math" - - "cogentcore.org/core/math32" -) - -// Stat32 returns statistic according to given Stats type applied -// to all non-NaN elements in given slice of float32 -func Stat32(a []float32, stat Stats) float32 { - switch stat { - case Count: - return Count32(a) - case Sum: - return Sum32(a) - case Prod: - return Prod32(a) - case Min: - return Min32(a) - case Max: - return Max32(a) - case MinAbs: - return MinAbs32(a) - case MaxAbs: - return MaxAbs32(a) - case Mean: - return Mean32(a) - case Var: - return Var32(a) - case Std: - return Std32(a) - case Sem: - return Sem32(a) - case L1Norm: - return L1Norm32(a) - case SumSq: - return SumSq32(a) - case L2Norm: - return L2Norm32(a) - case VarPop: - return VarPop32(a) - case StdPop: - return StdPop32(a) - case SemPop: - return SemPop32(a) - // case Median: - // return Median32(a) - // case Q1: - // return Q132(a) - // case Q3: - // return Q332(a) - } - return 0 -} - -// Stat64 returns statistic according to given Stats type applied -// to all non-NaN elements in given slice of float64 -func Stat64(a []float64, stat Stats) float64 { - switch stat { - case Count: - return Count64(a) - case Sum: - return Sum64(a) - case Prod: - return Prod64(a) - case Min: - return Min64(a) - case Max: - return Max64(a) - case MinAbs: - return MinAbs64(a) - case MaxAbs: - return MaxAbs64(a) - case Mean: - return Mean64(a) - case Var: - return Var64(a) - case Std: - return Std64(a) - case Sem: - return Sem64(a) - case L1Norm: - return L1Norm64(a) - case SumSq: - return SumSq64(a) - case L2Norm: - return L2Norm64(a) - case VarPop: - return VarPop64(a) - case StdPop: - return StdPop64(a) - case SemPop: - return SemPop64(a) - // case Median: - // return Median64(a) - // case Q1: - // return Q164(a) - // case Q3: - // return Q364(a) - } - return 0 -} - -/////////////////////////////////////////// -// Count - -// Count32 computes the number of non-NaN vector values. -// Skips NaN's -func Count32(a []float32) float32 { - n := 0 - for _, av := range a { - if math32.IsNaN(av) { - continue - } - n++ - } - return float32(n) -} - -// Count64 computes the number of non-NaN vector values. -// Skips NaN's -func Count64(a []float64) float64 { - n := 0 - for _, av := range a { - if math.IsNaN(av) { - continue - } - n++ - } - return float64(n) -} - -/////////////////////////////////////////// -// Sum - -// Sum32 computes the sum of vector values. -// Skips NaN's -func Sum32(a []float32) float32 { - s := float32(0) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - s += av - } - return s -} - -// Sum64 computes the sum of vector values. -// Skips NaN's -func Sum64(a []float64) float64 { - s := float64(0) - for _, av := range a { - if math.IsNaN(av) { - continue - } - s += av - } - return s -} - -/////////////////////////////////////////// -// Prod - -// Prod32 computes the product of vector values. -// Skips NaN's -func Prod32(a []float32) float32 { - s := float32(1) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - s *= av - } - return s -} - -// Prod64 computes the product of vector values. -// Skips NaN's -func Prod64(a []float64) float64 { - s := float64(1) - for _, av := range a { - if math.IsNaN(av) { - continue - } - s *= av - } - return s -} - -/////////////////////////////////////////// -// Min - -// Min32 computes the max over vector values. -// Skips NaN's -func Min32(a []float32) float32 { - m := float32(math.MaxFloat32) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - m = math32.Min(m, av) - } - return m -} - -// MinIndex32 computes the min over vector values, and returns index of min as well -// Skips NaN's -func MinIndex32(a []float32) (float32, int) { - m := float32(math.MaxFloat32) - mi := -1 - for i, av := range a { - if math32.IsNaN(av) { - continue - } - if av < m { - m = av - mi = i - } - } - return m, mi -} - -// Min64 computes the max over vector values. -// Skips NaN's -func Min64(a []float64) float64 { - m := float64(math.MaxFloat64) - for _, av := range a { - if math.IsNaN(av) { - continue - } - m = math.Min(m, av) - } - return m -} - -// MinIndex64 computes the min over vector values, and returns index of min as well -// Skips NaN's -func MinIndex64(a []float64) (float64, int) { - m := float64(math.MaxFloat64) - mi := -1 - for i, av := range a { - if math.IsNaN(av) { - continue - } - if av < m { - m = av - mi = i - } - } - return m, mi -} - -/////////////////////////////////////////// -// Max - -// Max32 computes the max over vector values. -// Skips NaN's -func Max32(a []float32) float32 { - m := float32(-math.MaxFloat32) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - m = math32.Max(m, av) - } - return m -} - -// MaxIndex32 computes the max over vector values, and returns index of max as well -// Skips NaN's -func MaxIndex32(a []float32) (float32, int) { - m := float32(-math.MaxFloat32) - mi := -1 - for i, av := range a { - if math32.IsNaN(av) { - continue - } - if av > m { - m = av - mi = i - } - } - return m, mi -} - -// Max64 computes the max over vector values. -// Skips NaN's -func Max64(a []float64) float64 { - m := float64(-math.MaxFloat64) - for _, av := range a { - if math.IsNaN(av) { - continue - } - m = math.Max(m, av) - } - return m -} - -// MaxIndex64 computes the max over vector values, and returns index of max as well -// Skips NaN's -func MaxIndex64(a []float64) (float64, int) { - m := float64(-math.MaxFloat64) - mi := -1 - for i, av := range a { - if math.IsNaN(av) { - continue - } - if av > m { - m = av - mi = i - } - } - return m, mi -} - -/////////////////////////////////////////// -// MinAbs - -// MinAbs32 computes the max of absolute value over vector values. -// Skips NaN's -func MinAbs32(a []float32) float32 { - m := float32(math.MaxFloat32) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - m = math32.Min(m, math32.Abs(av)) - } - return m -} - -// MinAbs64 computes the max over vector values. -// Skips NaN's -func MinAbs64(a []float64) float64 { - m := float64(math.MaxFloat64) - for _, av := range a { - if math.IsNaN(av) { - continue - } - m = math.Min(m, math.Abs(av)) - } - return m -} - -/////////////////////////////////////////// -// MaxAbs - -// MaxAbs32 computes the max of absolute value over vector values. -// Skips NaN's -func MaxAbs32(a []float32) float32 { - m := float32(0) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - m = math32.Max(m, math32.Abs(av)) - } - return m -} - -// MaxAbs64 computes the max over vector values. -// Skips NaN's -func MaxAbs64(a []float64) float64 { - m := float64(0) - for _, av := range a { - if math.IsNaN(av) { - continue - } - m = math.Max(m, math.Abs(av)) - } - return m -} - -/////////////////////////////////////////// -// Mean - -// Mean32 computes the mean of the vector (sum / N). -// Skips NaN's -func Mean32(a []float32) float32 { - s := float32(0) - n := 0 - for _, av := range a { - if math32.IsNaN(av) { - continue - } - s += av - n++ - } - if n > 0 { - s /= float32(n) - } - return s -} - -// Mean64 computes the mean of the vector (sum / N). -// Skips NaN's -func Mean64(a []float64) float64 { - s := float64(0) - n := 0 - for _, av := range a { - if math.IsNaN(av) { - continue - } - s += av - n++ - } - if n > 0 { - s /= float64(n) - } - return s -} - -/////////////////////////////////////////// -// Var - -// Var32 returns the sample variance of non-NaN elements. -func Var32(a []float32) float32 { - mean := Mean32(a) - n := 0 - s := float32(0) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - dv := av - mean - s += dv * dv - n++ - } - if n > 1 { - s /= float32(n - 1) - } - return s -} - -// Var64 returns the sample variance of non-NaN elements. -func Var64(a []float64) float64 { - mean := Mean64(a) - n := 0 - s := float64(0) - for _, av := range a { - if math.IsNaN(av) { - continue - } - dv := av - mean - s += dv * dv - n++ - } - if n > 1 { - s /= float64(n - 1) - } - return s -} - -/////////////////////////////////////////// -// Std - -// Std32 returns the sample standard deviation of non-NaN elements in vector. -func Std32(a []float32) float32 { - return math32.Sqrt(Var32(a)) -} - -// Std64 returns the sample standard deviation of non-NaN elements in vector. -func Std64(a []float64) float64 { - return math.Sqrt(Var64(a)) -} - -/////////////////////////////////////////// -// Sem - -// Sem32 returns the sample standard error of the mean of non-NaN elements in vector. -func Sem32(a []float32) float32 { - cnt := Count32(a) - if cnt < 2 { - return 0 - } - return Std32(a) / math32.Sqrt(cnt) -} - -// Sem64 returns the sample standard error of the mean of non-NaN elements in vector. -func Sem64(a []float64) float64 { - cnt := Count64(a) - if cnt < 2 { - return 0 - } - return Std64(a) / math.Sqrt(cnt) -} - -/////////////////////////////////////////// -// L1Norm - -// L1Norm32 computes the sum of absolute values (L1 Norm). -// Skips NaN's -func L1Norm32(a []float32) float32 { - ss := float32(0) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - ss += math32.Abs(av) - } - return ss -} - -// L1Norm64 computes the sum of absolute values (L1 Norm). -// Skips NaN's -func L1Norm64(a []float64) float64 { - ss := float64(0) - for _, av := range a { - if math.IsNaN(av) { - continue - } - ss += math.Abs(av) - } - return ss -} - -/////////////////////////////////////////// -// SumSquares - -// SumSq32 computes the sum-of-squares of vector. -// Skips NaN's. Uses optimized algorithm from BLAS that avoids numerical overflow. -func SumSq32(a []float32) float32 { - n := len(a) - if n < 2 { - if n == 1 { - return math32.Abs(a[0]) - } - return 0 - } - var ( - scale float32 = 0 - sumSquares float32 = 1 - ) - for _, v := range a { - if v == 0 || math32.IsNaN(v) { - continue - } - absxi := math32.Abs(v) - if scale < absxi { - sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - sumSquares = sumSquares + (absxi/scale)*(absxi/scale) - } - } - if math32.IsInf(scale, 1) { - return math32.Inf(1) - } - return scale * scale * sumSquares -} - -// SumSq64 computes the sum-of-squares of vector. -// Skips NaN's. Uses optimized algorithm from BLAS that avoids numerical overflow. -func SumSq64(a []float64) float64 { - n := len(a) - if n < 2 { - if n == 1 { - return math.Abs(a[0]) - } - return 0 - } - var ( - scale float64 = 0 - ss float64 = 1 - ) - for _, v := range a { - if v == 0 || math.IsNaN(v) { - continue - } - absxi := math.Abs(v) - if scale < absxi { - ss = 1 + ss*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - ss = ss + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * scale * ss -} - -/////////////////////////////////////////// -// L2Norm - -// L2Norm32 computes the square-root of sum-of-squares of vector, i.e., the L2 norm. -// Skips NaN's. Uses optimized algorithm from BLAS that avoids numerical overflow. -func L2Norm32(a []float32) float32 { - n := len(a) - if n < 2 { - if n == 1 { - return math32.Abs(a[0]) - } - return 0 - } - var ( - scale float32 = 0 - ss float32 = 1 - ) - for _, v := range a { - if v == 0 || math32.IsNaN(v) { - continue - } - absxi := math32.Abs(v) - if scale < absxi { - ss = 1 + ss*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - ss = ss + (absxi/scale)*(absxi/scale) - } - } - if math32.IsInf(scale, 1) { - return math32.Inf(1) - } - return scale * math32.Sqrt(ss) -} - -// L2Norm64 computes the square-root of sum-of-squares of vector, i.e., the L2 norm. -// Skips NaN's. Uses optimized algorithm from BLAS that avoids numerical overflow. -func L2Norm64(a []float64) float64 { - n := len(a) - if n < 2 { - if n == 1 { - return math.Abs(a[0]) - } - return 0 - } - var ( - scale float64 = 0 - ss float64 = 1 - ) - for _, v := range a { - if v == 0 || math.IsNaN(v) { - continue - } - absxi := math.Abs(v) - if scale < absxi { - ss = 1 + ss*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - ss = ss + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * math.Sqrt(ss) -} - -/////////////////////////////////////////// -// VarPop - -// VarPop32 returns the population variance of non-NaN elements. -func VarPop32(a []float32) float32 { - mean := Mean32(a) - n := 0 - s := float32(0) - for _, av := range a { - if math32.IsNaN(av) { - continue - } - dv := av - mean - s += dv * dv - n++ - } - if n > 0 { - s /= float32(n) - } - return s -} - -// VarPop64 returns the population variance of non-NaN elements. -func VarPop64(a []float64) float64 { - mean := Mean64(a) - n := 0 - s := float64(0) - for _, av := range a { - if math.IsNaN(av) { - continue - } - dv := av - mean - s += dv * dv - n++ - } - if n > 0 { - s /= float64(n) - } - return s -} - -/////////////////////////////////////////// -// StdPop - -// StdPop32 returns the population standard deviation of non-NaN elements in vector. -func StdPop32(a []float32) float32 { - return math32.Sqrt(VarPop32(a)) -} - -// StdPop64 returns the population standard deviation of non-NaN elements in vector. -func StdPop64(a []float64) float64 { - return math.Sqrt(VarPop64(a)) -} - -/////////////////////////////////////////// -// SemPop - -// SemPop32 returns the population standard error of the mean of non-NaN elements in vector. -func SemPop32(a []float32) float32 { - cnt := Count32(a) - if cnt < 2 { - return 0 - } - return StdPop32(a) / math32.Sqrt(cnt) -} - -// SemPop64 returns the population standard error of the mean of non-NaN elements in vector. -func SemPop64(a []float64) float64 { - cnt := Count64(a) - if cnt < 2 { - return 0 - } - return StdPop64(a) / math.Sqrt(cnt) -} diff --git a/tensor/stats/stats/floats_test.go b/tensor/stats/stats/floats_test.go deleted file mode 100644 index 2f8dc16d47..0000000000 --- a/tensor/stats/stats/floats_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "math" - "testing" - - "cogentcore.org/core/base/tolassert" - "cogentcore.org/core/math32" - "github.com/stretchr/testify/assert" -) - -func TestStats32(t *testing.T) { - vals := []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} - - results := []float32{11, 5.5, 0, 0, 1, 0, 1, 0.5, 0.11, math32.Sqrt(0.11), math32.Sqrt(0.11) / math32.Sqrt(11), 5.5, 3.85, math32.Sqrt(3.85), 0.1, math32.Sqrt(0.1), math32.Sqrt(0.1) / math32.Sqrt(11)} - - assert.Equal(t, results[Count], Count32(vals)) - assert.Equal(t, results[Sum], Sum32(vals)) - assert.Equal(t, results[Prod], Prod32(vals)) - assert.Equal(t, results[Min], Min32(vals)) - assert.Equal(t, results[Max], Max32(vals)) - assert.Equal(t, results[MinAbs], MinAbs32(vals)) - assert.Equal(t, results[MaxAbs], MaxAbs32(vals)) - assert.Equal(t, results[Mean], Mean32(vals)) - assert.Equal(t, results[Var], Var32(vals)) - assert.Equal(t, results[Std], Std32(vals)) - assert.Equal(t, results[Sem], Sem32(vals)) - assert.Equal(t, results[L1Norm], L1Norm32(vals)) - tolassert.EqualTol(t, results[SumSq], SumSq32(vals), 1.0e-6) - tolassert.EqualTol(t, results[L2Norm], L2Norm32(vals), 1.0e-6) - assert.Equal(t, results[VarPop], VarPop32(vals)) - assert.Equal(t, results[StdPop], StdPop32(vals)) - assert.Equal(t, results[SemPop], SemPop32(vals)) - - for stat := Count; stat <= SemPop; stat++ { - tolassert.EqualTol(t, results[stat], Stat32(vals, stat), 1.0e-6) - } -} - -func TestStats64(t *testing.T) { - vals := []float64{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} - - results := []float64{11, 5.5, 0, 0, 1, 0, 1, 0.5, 0.11, math.Sqrt(0.11), math.Sqrt(0.11) / math.Sqrt(11), 5.5, 3.85, math.Sqrt(3.85), 0.1, math.Sqrt(0.1), math.Sqrt(0.1) / math.Sqrt(11)} - - assert.Equal(t, results[Count], Count64(vals)) - assert.Equal(t, results[Sum], Sum64(vals)) - assert.Equal(t, results[Prod], Prod64(vals)) - assert.Equal(t, results[Min], Min64(vals)) - assert.Equal(t, results[Max], Max64(vals)) - assert.Equal(t, results[MinAbs], MinAbs64(vals)) - assert.Equal(t, results[MaxAbs], MaxAbs64(vals)) - assert.Equal(t, results[Mean], Mean64(vals)) - tolassert.EqualTol(t, results[Var], Var64(vals), 1.0e-8) - tolassert.EqualTol(t, results[Std], Std64(vals), 1.0e-8) - tolassert.EqualTol(t, results[Sem], Sem64(vals), 1.0e-8) - assert.Equal(t, results[L1Norm], L1Norm64(vals)) - tolassert.EqualTol(t, results[SumSq], SumSq64(vals), 1.0e-8) - tolassert.EqualTol(t, results[L2Norm], L2Norm64(vals), 1.0e-8) - assert.Equal(t, results[VarPop], VarPop64(vals)) - assert.Equal(t, results[StdPop], StdPop64(vals)) - assert.Equal(t, results[SemPop], SemPop64(vals)) - - for stat := Count; stat <= SemPop; stat++ { - tolassert.EqualTol(t, results[stat], Stat64(vals, stat), 1.0e-8) - } -} diff --git a/tensor/stats/stats/funcs.go b/tensor/stats/stats/funcs.go index 76f23dacef..b0d11fe785 100644 --- a/tensor/stats/stats/funcs.go +++ b/tensor/stats/stats/funcs.go @@ -4,62 +4,553 @@ package stats -import "math" +import ( + "math" -// These are standard StatFunc functions that can operate on tensor.Tensor -// or table.Table, using float64 values + "cogentcore.org/core/tensor" +) -// StatFunc is an statistic function that incrementally updates agg -// aggregation value from each element in the tensor in turn. -// Returns new agg value that will be passed into next item as agg. -type StatFunc func(idx int, val float64, agg float64) float64 +// StatsFunc is the function signature for a stats function that +// returns a new output vector. This can be less efficient for repeated +// computations where the output can be re-used: see [StatsOutFunc]. +// But this version can be directly chained with other function calls. +// Function is computed over the outermost row dimension and the +// output is the shape of the remaining inner cells (a scalar for 1D inputs). +// Use [tensor.As1D], [tensor.NewRowCellsView], [tensor.Cells1D] etc +// to reshape and reslice the data as needed. +// All stats functions skip over NaN's, as a missing value. +// Stats functions cannot be computed in parallel, +// e.g., using VectorizeThreaded or GPU, due to shared writing +// to the same output values. Special implementations are required +// if that is needed. +type StatsFunc = func(in tensor.Tensor) tensor.Values -// CountFunc is an StatFunc that computes number of elements (non-Null, non-NaN) -// Use 0 as initial value. -func CountFunc(idx int, val float64, agg float64) float64 { - return agg + 1 +// StatsOutFunc is the function signature for a stats function, +// that takes output values as final argument. See [StatsFunc] +// This version is for computationally demanding cases and saves +// reallocation of output. +type StatsOutFunc = func(in tensor.Tensor, out tensor.Values) error + +// CountOut64 computes the count of non-NaN tensor values, +// and returns the Float64 output values for subsequent use. +func CountOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 { + return VectorizeOut64(in, out, 0, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return agg + 1 + }) +} + +// Count computes the count of non-NaN tensor values. +// See [StatsFunc] for general information. +func Count(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(CountOut, in) +} + +// CountOut computes the count of non-NaN tensor values. +// See [StatsOutFunc] for general information. +func CountOut(in tensor.Tensor, out tensor.Values) error { + CountOut64(in, out) + return nil +} + +// SumOut64 computes the sum of tensor values, +// and returns the Float64 output values for subsequent use. +func SumOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 { + return VectorizeOut64(in, out, 0, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return agg + val + }) +} + +// SumOut computes the sum of tensor values. +// See [StatsOutFunc] for general information. +func SumOut(in tensor.Tensor, out tensor.Values) error { + SumOut64(in, out) + return nil +} + +// Sum computes the sum of tensor values. +// See [StatsFunc] for general information. +func Sum(in tensor.Tensor) tensor.Values { + out := tensor.NewOfType(in.DataType()) + SumOut64(in, out) + return out +} + +// L1Norm computes the sum of absolute-value-of tensor values. +// See [StatsFunc] for general information. +func L1Norm(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(L1NormOut, in) +} + +// L1NormOut computes the sum of absolute-value-of tensor values. +// See [StatsFunc] for general information. +func L1NormOut(in tensor.Tensor, out tensor.Values) error { + VectorizeOut64(in, out, 0, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return agg + math.Abs(val) + }) + return nil +} + +// Prod computes the product of tensor values. +// See [StatsFunc] for general information. +func Prod(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(ProdOut, in) +} + +// ProdOut computes the product of tensor values. +// See [StatsOutFunc] for general information. +func ProdOut(in tensor.Tensor, out tensor.Values) error { + VectorizeOut64(in, out, 1, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return agg * val + }) + return nil +} + +// Min computes the min of tensor values. +// See [StatsFunc] for general information. +func Min(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(MinOut, in) +} + +// MinOut computes the min of tensor values. +// See [StatsOutFunc] for general information. +func MinOut(in tensor.Tensor, out tensor.Values) error { + VectorizeOut64(in, out, math.MaxFloat64, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return math.Min(agg, val) + }) + return nil +} + +// Max computes the max of tensor values. +// See [StatsFunc] for general information. +func Max(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(MaxOut, in) +} + +// MaxOut computes the max of tensor values. +// See [StatsOutFunc] for general information. +func MaxOut(in tensor.Tensor, out tensor.Values) error { + VectorizeOut64(in, out, -math.MaxFloat64, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return math.Max(agg, val) + }) + return nil +} + +// MinAbs computes the min of absolute-value-of tensor values. +// See [StatsFunc] for general information. +func MinAbs(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(MinAbsOut, in) +} + +// MinAbsOut computes the min of absolute-value-of tensor values. +// See [StatsOutFunc] for general information. +func MinAbsOut(in tensor.Tensor, out tensor.Values) error { + VectorizeOut64(in, out, math.MaxFloat64, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return math.Min(agg, math.Abs(val)) + }) + return nil +} + +// MaxAbs computes the max of absolute-value-of tensor values. +// See [StatsFunc] for general information. +func MaxAbs(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(MaxAbsOut, in) +} + +// MaxAbsOut computes the max of absolute-value-of tensor values. +// See [StatsOutFunc] for general information. +func MaxAbsOut(in tensor.Tensor, out tensor.Values) error { + VectorizeOut64(in, out, -math.MaxFloat64, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return math.Max(agg, math.Abs(val)) + }) + return nil +} + +// MeanOut64 computes the mean of tensor values, +// and returns the Float64 output values for subsequent use. +func MeanOut64(in tensor.Tensor, out tensor.Values) (mean64, count64 *tensor.Float64) { + var sum64 *tensor.Float64 + sum64, count64 = Vectorize2Out64(in, 0, 0, func(val, sum, count float64) (float64, float64) { + if math.IsNaN(val) { + return sum, count + } + count += 1 + sum += val + return sum, count + }) + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + c := count64.Float1D(i) + if c == 0 { + continue + } + mean := sum64.Float1D(i) / c + sum64.SetFloat1D(mean, i) + out.SetFloat1D(mean, i) + } + return sum64, count64 +} + +// Mean computes the mean of tensor values. +// See [StatsFunc] for general information. +func Mean(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(MeanOut, in) +} + +// MeanOut computes the mean of tensor values. +// See [StatsOutFunc] for general information. +func MeanOut(in tensor.Tensor, out tensor.Values) error { + MeanOut64(in, out) + return nil +} + +// SumSqDevOut64 computes the sum of squared mean deviates of tensor values, +// and returns the Float64 output values for subsequent use. +func SumSqDevOut64(in tensor.Tensor, out tensor.Values) (ssd64, mean64, count64 *tensor.Float64) { + mean64, count64 = MeanOut64(in, out) + ssd64 = VectorizePreOut64(in, out, 0, mean64, func(val, mean, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + dv := val - mean + return agg + dv*dv + }) + return +} + +// VarOut64 computes the sample variance of tensor values, +// and returns the Float64 output values for subsequent use. +func VarOut64(in tensor.Tensor, out tensor.Values) (var64, mean64, count64 *tensor.Float64) { + var64, mean64, count64 = SumSqDevOut64(in, out) + nsub := out.Len() + for i := range nsub { + c := count64.Float1D(i) + if c < 2 { + continue + } + vr := var64.Float1D(i) / (c - 1) + var64.SetFloat1D(vr, i) + out.SetFloat1D(vr, i) + } + return +} + +// Var computes the sample variance of tensor values. +// Squared deviations from mean, divided by n-1. See also [VarPopFunc]. +// See [StatsFunc] for general information. +func Var(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(VarOut, in) +} + +// VarOut computes the sample variance of tensor values. +// Squared deviations from mean, divided by n-1. See also [VarPopFunc]. +// See [StatsOutFunc] for general information. +func VarOut(in tensor.Tensor, out tensor.Values) error { + VarOut64(in, out) + return nil +} + +// StdOut64 computes the sample standard deviation of tensor values. +// and returns the Float64 output values for subsequent use. +func StdOut64(in tensor.Tensor, out tensor.Values) (std64, mean64, count64 *tensor.Float64) { + std64, mean64, count64 = VarOut64(in, out) + nsub := out.Len() + for i := range nsub { + std := math.Sqrt(std64.Float1D(i)) + std64.SetFloat1D(std, i) + out.SetFloat1D(std, i) + } + return +} + +// Std computes the sample standard deviation of tensor values. +// Sqrt of variance from [VarFunc]. See also [StdPopFunc]. +// See [StatsFunc] for general information. +func Std(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(StdOut, in) +} + +// StdOut computes the sample standard deviation of tensor values. +// Sqrt of variance from [VarFunc]. See also [StdPopFunc]. +// See [StatsOutFunc] for general information. +func StdOut(in tensor.Tensor, out tensor.Values) error { + StdOut64(in, out) + return nil +} + +// Sem computes the sample standard error of the mean of tensor values. +// Standard deviation [StdFunc] / sqrt(n). See also [SemPopFunc]. +// See [StatsFunc] for general information. +func Sem(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(SemOut, in) +} + +// SemOut computes the sample standard error of the mean of tensor values. +// Standard deviation [StdFunc] / sqrt(n). See also [SemPopFunc]. +// See [StatsOutFunc] for general information. +func SemOut(in tensor.Tensor, out tensor.Values) error { + var64, _, count64 := VarOut64(in, out) + nsub := out.Len() + for i := range nsub { + c := count64.Float1D(i) + if c < 2 { + out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i) + } else { + out.SetFloat1D(math.Sqrt(var64.Float1D(i))/math.Sqrt(c), i) + } + } + return nil +} + +// VarPopOut64 computes the population variance of tensor values. +// and returns the Float64 output values for subsequent use. +func VarPopOut64(in tensor.Tensor, out tensor.Values) (var64, mean64, count64 *tensor.Float64) { + var64, mean64, count64 = SumSqDevOut64(in, out) + nsub := out.Len() + for i := range nsub { + c := count64.Float1D(i) + if c == 0 { + continue + } + var64.SetFloat1D(var64.Float1D(i)/c, i) + out.SetFloat1D(var64.Float1D(i), i) + } + return +} + +// VarPop computes the population variance of tensor values. +// Squared deviations from mean, divided by n. See also [VarFunc]. +// See [StatsFunc] for general information. +func VarPop(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(VarPopOut, in) +} + +// VarPopOut computes the population variance of tensor values. +// Squared deviations from mean, divided by n. See also [VarFunc]. +// See [StatsOutFunc] for general information. +func VarPopOut(in tensor.Tensor, out tensor.Values) error { + VarPopOut64(in, out) + return nil +} + +// StdPop computes the population standard deviation of tensor values. +// Sqrt of variance from [VarPopFunc]. See also [StdFunc]. +// See [StatsFunc] for general information. +func StdPop(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(StdPopOut, in) +} + +// StdPopOut computes the population standard deviation of tensor values. +// Sqrt of variance from [VarPopFunc]. See also [StdFunc]. +// See [StatsOutFunc] for general information. +func StdPopOut(in tensor.Tensor, out tensor.Values) error { + var64, _, _ := VarPopOut64(in, out) + nsub := out.Len() + for i := range nsub { + out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i) + } + return nil +} + +// SemPop computes the population standard error of the mean of tensor values. +// Standard deviation [StdPopFunc] / sqrt(n). See also [SemFunc]. +// See [StatsFunc] for general information. +func SemPop(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(SemPopOut, in) +} + +// SemPopOut computes the population standard error of the mean of tensor values. +// Standard deviation [StdPopFunc] / sqrt(n). See also [SemFunc]. +// See [StatsOutFunc] for general information. +func SemPopOut(in tensor.Tensor, out tensor.Values) error { + var64, _, count64 := VarPopOut64(in, out) + nsub := out.Len() + for i := range nsub { + c := count64.Float1D(i) + if c < 2 { + out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i) + } else { + out.SetFloat1D(math.Sqrt(var64.Float1D(i))/math.Sqrt(c), i) + } + } + return nil +} + +// SumSqScaleOut64 is a helper for sum-of-squares, returning scale and ss +// factors aggregated separately for better numerical stability, per BLAS. +// Returns the Float64 output values for subsequent use. +func SumSqScaleOut64(in tensor.Tensor) (scale64, ss64 *tensor.Float64) { + scale64, ss64 = Vectorize2Out64(in, 0, 1, func(val, scale, ss float64) (float64, float64) { + if math.IsNaN(val) || val == 0 { + return scale, ss + } + absxi := math.Abs(val) + if scale < absxi { + ss = 1 + ss*(scale/absxi)*(scale/absxi) + scale = absxi + } else { + ss = ss + (absxi/scale)*(absxi/scale) + } + return scale, ss + }) + return } -// SumFunc is an StatFunc that computes a sum aggregate. -// use 0 as initial value. -func SumFunc(idx int, val float64, agg float64) float64 { - return agg + val +// SumSqOut64 computes the sum of squares of tensor values, +// and returns the Float64 output values for subsequent use. +func SumSqOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 { + scale64, ss64 := SumSqScaleOut64(in) + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + scale := scale64.Float1D(i) + ss := ss64.Float1D(i) + v := 0.0 + if math.IsInf(scale, 1) { + v = math.Inf(1) + } else { + v = scale * scale * ss + } + scale64.SetFloat1D(v, i) + out.SetFloat1D(v, i) + } + return scale64 } -// Prodfunc is an StatFunc that computes a product aggregate. -// use 1 as initial value. -func ProdFunc(idx int, val float64, agg float64) float64 { - return agg * val +// SumSq computes the sum of squares of tensor values, +// See [StatsFunc] for general information. +func SumSq(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(SumSqOut, in) } -// MinFunc is an StatFunc that computes a min aggregate. -// use math.MaxFloat64 for initial agg value. -func MinFunc(idx int, val float64, agg float64) float64 { - return math.Min(agg, val) +// SumSqOut computes the sum of squares of tensor values, +// See [StatsOutFunc] for general information. +func SumSqOut(in tensor.Tensor, out tensor.Values) error { + SumSqOut64(in, out) + return nil } -// MaxFunc is an StatFunc that computes a max aggregate. -// use -math.MaxFloat64 for initial agg value. -func MaxFunc(idx int, val float64, agg float64) float64 { - return math.Max(agg, val) +// L2NormOut64 computes the square root of the sum of squares of tensor values, +// known as the L2 norm, and returns the Float64 output values for +// use in subsequent computations. +func L2NormOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 { + scale64, ss64 := SumSqScaleOut64(in) + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + nsub := out.Len() + for i := range nsub { + scale := scale64.Float1D(i) + ss := ss64.Float1D(i) + v := 0.0 + if math.IsInf(scale, 1) { + v = math.Inf(1) + } else { + v = scale * math.Sqrt(ss) + } + scale64.SetFloat1D(v, i) + out.SetFloat1D(v, i) + } + return scale64 } -// MinAbsFunc is an StatFunc that computes a min aggregate. -// use math.MaxFloat64 for initial agg value. -func MinAbsFunc(idx int, val float64, agg float64) float64 { - return math.Min(agg, math.Abs(val)) +// L2Norm computes the square root of the sum of squares of tensor values, +// known as the L2 norm. +// See [StatsFunc] for general information. +func L2Norm(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(L2NormOut, in) } -// MaxAbsFunc is an StatFunc that computes a max aggregate. -// use -math.MaxFloat64 for initial agg value. -func MaxAbsFunc(idx int, val float64, agg float64) float64 { - return math.Max(agg, math.Abs(val)) +// L2NormOut computes the square root of the sum of squares of tensor values, +// known as the L2 norm. +// See [StatsOutFunc] for general information. +func L2NormOut(in tensor.Tensor, out tensor.Values) error { + L2NormOut64(in, out) + return nil } -// L1NormFunc is an StatFunc that computes the L1 norm: sum of absolute values -// use 0 as initial value. -func L1NormFunc(idx int, val float64, agg float64) float64 { - return agg + math.Abs(val) +// First returns the first tensor value(s), as a stats function, +// for the starting point in a naturally-ordered set of data. +// See [StatsFunc] for general information. +func First(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(FirstOut, in) } -// Note: SumSq is not numerically stable for large N in simple func form. +// FirstOut returns the first tensor value(s), as a stats function, +// for the starting point in a naturally-ordered set of data. +// See [StatsOutFunc] for general information. +func FirstOut(in tensor.Tensor, out tensor.Values) error { + rows, cells := in.Shape().RowCellSize() + if cells == 1 { + out.SetShapeSizes(1) + if rows > 0 { + out.SetFloat1D(in.Float1D(0), 0) + } + return nil + } + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + if rows == 0 { + return nil + } + for i := range cells { + out.SetFloat1D(in.Float1D(i), i) + } + return nil +} + +// Final returns the final tensor value(s), as a stats function, +// for the ending point in a naturally-ordered set of data. +// See [StatsFunc] for general information. +func Final(in tensor.Tensor) tensor.Values { + return tensor.CallOut1(FinalOut, in) +} + +// FinalOut returns the first tensor value(s), as a stats function, +// for the ending point in a naturally-ordered set of data. +// See [StatsOutFunc] for general information. +func FinalOut(in tensor.Tensor, out tensor.Values) error { + rows, cells := in.Shape().RowCellSize() + if cells == 1 { + out.SetShapeSizes(1) + if rows > 0 { + out.SetFloat1D(in.Float1D(rows-1), 0) + } + return nil + } + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + if rows == 0 { + return nil + } + st := (rows - 1) * cells + for i := range cells { + out.SetFloat1D(in.Float1D(st+i), i) + } + return nil +} diff --git a/tensor/stats/stats/group.go b/tensor/stats/stats/group.go new file mode 100644 index 0000000000..bde11f40cf --- /dev/null +++ b/tensor/stats/stats/group.go @@ -0,0 +1,215 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stats + +import ( + "strconv" + "strings" + + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" +) + +// Groups generates indexes for each unique value in each of the given tensors. +// One can then use the resulting indexes for the [tensor.Rows] indexes to +// perform computations restricted to grouped subsets of data, as in the +// [GroupStats] function. See [GroupCombined] for function that makes a +// "Combined" Group that has a unique group for each _combination_ of +// the separate, independent groups created by this function. +// It creates subdirectories in a "Groups" directory within given [tensorfs], +// for each tensor passed in here, using the metadata Name property for +// names (index if empty). +// Within each subdirectory there are int tensors for each unique 1D +// row-wise value of elements in the input tensor, named as the string +// representation of the value, where the int tensor contains a list of +// row-wise indexes corresponding to the source rows having that value. +// Note that these indexes are directly in terms of the underlying [Tensor] data +// rows, indirected through any existing indexes on the inputs, so that +// the results can be used directly as Indexes into the corresponding tensor data. +// Uses a stable sort on columns, so ordering of other dimensions is preserved. +func Groups(dir *tensorfs.Node, tsrs ...tensor.Tensor) error { + gd := dir.Dir("Groups") + makeIdxs := func(dir *tensorfs.Node, srt *tensor.Rows, val string, start, r int) { + n := r - start + it := tensorfs.Value[int](dir, val, n) + for j := range n { + it.SetIntRow(srt.Indexes[start+j], j, 0) // key to indirect through sort indexes + } + } + + for i, tsr := range tsrs { + nr := tsr.DimSize(0) + if nr == 0 { + continue + } + nm := metadata.Name(tsr) + if nm == "" { + nm = strconv.Itoa(i) + } + td := gd.Dir(nm) + srt := tensor.AsRows(tsr).CloneIndexes() + srt.SortStable(tensor.Ascending) + start := 0 + if tsr.IsString() { + lastVal := srt.StringRow(0, 0) + for r := range nr { + v := srt.StringRow(r, 0) + if v != lastVal { + makeIdxs(td, srt, lastVal, start, r) + start = r + lastVal = v + } + } + if start != nr-1 { + makeIdxs(td, srt, lastVal, start, nr) + } + } else { + lastVal := srt.FloatRow(0, 0) + for r := range nr { + v := srt.FloatRow(r, 0) + if v != lastVal { + makeIdxs(td, srt, tensor.Float64ToString(lastVal), start, r) + start = r + lastVal = v + } + } + if start != nr-1 { + makeIdxs(td, srt, tensor.Float64ToString(lastVal), start, nr) + } + } + } + return nil +} + +// TableGroups runs [Groups] on the given columns from given [table.Table]. +func TableGroups(dir *tensorfs.Node, dt *table.Table, columns ...string) error { + dv := table.NewView(dt) + // important for consistency across columns, to do full outer product sort first. + dv.SortColumns(tensor.Ascending, tensor.StableSort, columns...) + return Groups(dir, dv.ColumnList(columns...)...) +} + +// GroupAll copies all indexes from the first given tensor, +// into an "All/All" tensor in the given [tensorfs], which can then +// be used with [GroupStats] to generate summary statistics across +// all the data. See [Groups] for more general documentation. +func GroupAll(dir *tensorfs.Node, tsrs ...tensor.Tensor) error { + gd := dir.Dir("Groups") + tsr := tensor.AsRows(tsrs[0]) + nr := tsr.NumRows() + if nr == 0 { + return nil + } + td := gd.Dir("All") + it := tensorfs.Value[int](td, "All", nr) + for j := range nr { + it.SetIntRow(tsr.RowIndex(j), j, 0) // key to indirect through any existing indexes + } + return nil +} + +// todo: GroupCombined + +// GroupStats computes the given stats function on the unique grouped indexes +// produced by the [Groups] function, in the given [tensorfs] directory, +// applied to each of the tensors passed here. +// It creates a "Stats" subdirectory in given directory, with +// subdirectories with the name of each value tensor (if it does not +// yet exist), and then creates a subdirectory within that +// for the statistic name. Within that statistic directory, it creates +// a String tensor with the unique values of each source [Groups] tensor, +// and a aligned Float64 tensor with the statistics results for each such +// unique group value. See the README.md file for a diagram of the results. +func GroupStats(dir *tensorfs.Node, stat Stats, tsrs ...tensor.Tensor) error { + gd := dir.Dir("Groups") + sd := dir.Dir("Stats") + stnm := StripPackage(stat.String()) + groups, _ := gd.Nodes() + for _, gp := range groups { + gpnm := gp.Name() + ggd := gd.Dir(gpnm) + vals := ggd.ValuesFunc(nil) + nv := len(vals) + if nv == 0 { + continue + } + sgd := sd.Dir(gpnm) + gv := sgd.Node(gpnm) + if gv == nil { + gtsr := tensorfs.Value[string](sgd, gpnm, nv) + for i, v := range vals { + gtsr.SetStringRow(metadata.Name(v), i, 0) + } + } + for _, tsr := range tsrs { + vd := sgd.Dir(metadata.Name(tsr)) + sv := tensorfs.Value[float64](vd, stnm, nv) + for i, v := range vals { + idx := tensor.AsIntSlice(v) + sg := tensor.NewRows(tsr.AsValues(), idx...) + stout := stat.Call(sg) + sv.SetFloatRow(stout.Float1D(0), i, 0) + } + } + } + return nil +} + +// TableGroupStats runs [GroupStats] using standard [Stats] +// on the given columns from given [table.Table]. +func TableGroupStats(dir *tensorfs.Node, stat Stats, dt *table.Table, columns ...string) error { + return GroupStats(dir, stat, dt.ColumnList(columns...)...) +} + +// GroupDescribe runs standard descriptive statistics on given tensor data +// using [GroupStats] function, with [DescriptiveStats] list of stats. +func GroupDescribe(dir *tensorfs.Node, tsrs ...tensor.Tensor) error { + for _, st := range DescriptiveStats { + err := GroupStats(dir, st, tsrs...) + if err != nil { + return err + } + } + return nil +} + +// TableGroupDescribe runs [GroupDescribe] on the given columns from given [table.Table]. +func TableGroupDescribe(dir *tensorfs.Node, dt *table.Table, columns ...string) error { + return GroupDescribe(dir, dt.ColumnList(columns...)...) +} + +// GroupStatsAsTable returns the results from [GroupStats] in given directory +// as a [table.Table], using [tensorfs.DirTable] function. +func GroupStatsAsTable(dir *tensorfs.Node) *table.Table { + return tensorfs.DirTable(dir.Node("Stats"), nil) +} + +// GroupStatsAsTableNoStatName returns the results from [GroupStats] +// in given directory as a [table.Table], using [tensorfs.DirTable] function. +// Column names are updated to not include the stat name, if there is only +// one statistic such that the resulting name will still be unique. +// Otherwise, column names are Value/Stat. +func GroupStatsAsTableNoStatName(dir *tensorfs.Node) *table.Table { + dt := tensorfs.DirTable(dir.Node("Stats"), nil) + cols := make(map[string]string) + for _, nm := range dt.Columns.Keys { + vn := nm + si := strings.Index(nm, "/") + if si > 0 { + vn = nm[:si] + } + if _, exists := cols[vn]; exists { + continue + } + cols[vn] = nm + } + for k, v := range cols { + ci := dt.Columns.IndexByKey(v) + dt.Columns.RenameIndex(ci, k) + } + return dt +} diff --git a/tensor/stats/stats/group_test.go b/tensor/stats/stats/group_test.go new file mode 100644 index 0000000000..f0ddc2d3d4 --- /dev/null +++ b/tensor/stats/stats/group_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stats + +import ( + "testing" + + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorfs" + "github.com/stretchr/testify/assert" +) + +func TestGroup(t *testing.T) { + dt := table.New().SetNumRows(4) + dt.AddStringColumn("Name") + dt.AddFloat32Column("Value") + for i := range dt.NumRows() { + gp := "A" + if i >= 2 { + gp = "B" + } + dt.Column("Name").SetStringRow(gp, i, 0) + dt.Column("Value").SetFloatRow(float64(i), i, 0) + } + dir, _ := tensorfs.NewDir("Group") + err := TableGroups(dir, dt, "Name") + assert.NoError(t, err) + + ixs := dir.ValuesFunc(nil) + assert.Equal(t, []int{0, 1}, tensor.AsInt(ixs[0]).Values) + assert.Equal(t, []int{2, 3}, tensor.AsInt(ixs[1]).Values) + + err = TableGroupStats(dir, StatMean, dt, "Value") + assert.NoError(t, err) + + gdt := GroupStatsAsTableNoStatName(dir) + assert.Equal(t, 0.5, gdt.Column("Value").Float1D(0)) + assert.Equal(t, 2.5, gdt.Column("Value").Float1D(1)) + assert.Equal(t, "A", gdt.Column("Name").String1D(0)) + assert.Equal(t, "B", gdt.Column("Name").String1D(1)) +} + +/* +func TestAggEmpty(t *testing.T) { + dt := table.New().SetNumRows(4) + dt.AddStringColumn("Group") + dt.AddFloat32Column("Value") + for i := 0; i < dt.Rows; i++ { + gp := "A" + if i >= 2 { + gp = "B" + } + dt.SetString("Group", i, gp) + dt.SetFloat("Value", i, float64(i)) + } + ix := table.NewIndexed(dt) + ix.Filter(func(et *table.Table, row int) bool { + return false // exclude all + }) + spl := GroupBy(ix, "Group") + assert.Equal(t, 1, len(spl.Splits)) + + AggColumn(spl, "Value", stats.Mean) + + st := spl.AggsToTable(table.ColumnNameOnly) + if st == nil { + t.Error("AggsToTable should not be nil!") + } +} +*/ diff --git a/tensor/stats/stats/if.go b/tensor/stats/stats/if.go deleted file mode 100644 index fc29404cf8..0000000000 --- a/tensor/stats/stats/if.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "cogentcore.org/core/base/errors" - "cogentcore.org/core/tensor/table" -) - -// IfFunc is used for the *If aggregators -- counted if it returns true -type IfFunc func(idx int, val float64) bool - -/////////////////////////////////////////////////// -// CountIf - -// CountIfIndex returns the count of true return values for given IfFunc on -// non-NaN elements in given IndexView indexed view of an -// table.Table, for given column index. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func CountIfIndex(ix *table.IndexView, colIndex int, iffun IfFunc) []float64 { - return StatIndexFunc(ix, colIndex, 0, func(idx int, val float64, agg float64) float64 { - if iffun(idx, val) { - return agg + 1 - } - return agg - }) -} - -// CountIfColumn returns the count of true return values for given IfFunc on -// non-NaN elements in given IndexView indexed view of an -// table.Table, for given column name. -// If name not found, nil is returned. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func CountIfColumn(ix *table.IndexView, column string, iffun IfFunc) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return CountIfIndex(ix, colIndex, iffun) -} - -/////////////////////////////////////////////////// -// PropIf - -// PropIfIndex returns the proportion (0-1) of true return values for given IfFunc on -// non-Null, non-NaN elements in given IndexView indexed view of an -// table.Table, for given column index. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func PropIfIndex(ix *table.IndexView, colIndex int, iffun IfFunc) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - pif := CountIfIndex(ix, colIndex, iffun) - for i := range pif { - if cnt[i] > 0 { - pif[i] /= cnt[i] - } - } - return pif -} - -// PropIfColumn returns the proportion (0-1) of true return values for given IfFunc on -// non-NaN elements in given IndexView indexed view of an -// table.Table, for given column name. -// If name not found, nil is returned -- use Try version for error message. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func PropIfColumn(ix *table.IndexView, column string, iffun IfFunc) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return PropIfIndex(ix, colIndex, iffun) -} - -/////////////////////////////////////////////////// -// PctIf - -// PctIfIndex returns the percentage (0-100) of true return values for given IfFunc on -// non-Null, non-NaN elements in given IndexView indexed view of an -// table.Table, for given column index. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func PctIfIndex(ix *table.IndexView, colIndex int, iffun IfFunc) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - pif := CountIfIndex(ix, colIndex, iffun) - for i := range pif { - if cnt[i] > 0 { - pif[i] = 100.0 * (pif[i] / cnt[i]) - } - } - return pif -} - -// PctIfColumn returns the percentage (0-100) of true return values for given IfFunc on -// non-Null, non-NaN elements in given IndexView indexed view of an -// table.Table, for given column name. -// If name not found, nil is returned -- use Try version for error message. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func PctIfColumn(ix *table.IndexView, column string, iffun IfFunc) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return PctIfIndex(ix, colIndex, iffun) -} diff --git a/tensor/stats/stats/indexview.go b/tensor/stats/stats/indexview.go deleted file mode 100644 index e849f4cfba..0000000000 --- a/tensor/stats/stats/indexview.go +++ /dev/null @@ -1,761 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "math" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/tensor/table" -) - -// Every IndexView Stats method in this file follows one of these signatures: - -// IndexViewFuncIndex is a stats function operating on IndexView, taking a column index arg -type IndexViewFuncIndex func(ix *table.IndexView, colIndex int) []float64 - -// IndexViewFuncColumn is a stats function operating on IndexView, taking a column name arg -type IndexViewFuncColumn func(ix *table.IndexView, column string) []float64 - -// StatIndex returns IndexView statistic according to given Stats type applied -// to all non-NaN elements in given IndexView indexed view of -// an table.Table, for given column index. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func StatIndex(ix *table.IndexView, colIndex int, stat Stats) []float64 { - switch stat { - case Count: - return CountIndex(ix, colIndex) - case Sum: - return SumIndex(ix, colIndex) - case Prod: - return ProdIndex(ix, colIndex) - case Min: - return MinIndex(ix, colIndex) - case Max: - return MaxIndex(ix, colIndex) - case MinAbs: - return MinAbsIndex(ix, colIndex) - case MaxAbs: - return MaxAbsIndex(ix, colIndex) - case Mean: - return MeanIndex(ix, colIndex) - case Var: - return VarIndex(ix, colIndex) - case Std: - return StdIndex(ix, colIndex) - case Sem: - return SemIndex(ix, colIndex) - case L1Norm: - return L1NormIndex(ix, colIndex) - case SumSq: - return SumSqIndex(ix, colIndex) - case L2Norm: - return L2NormIndex(ix, colIndex) - case VarPop: - return VarPopIndex(ix, colIndex) - case StdPop: - return StdPopIndex(ix, colIndex) - case SemPop: - return SemPopIndex(ix, colIndex) - case Median: - return MedianIndex(ix, colIndex) - case Q1: - return Q1Index(ix, colIndex) - case Q3: - return Q3Index(ix, colIndex) - } - return nil -} - -// StatColumn returns IndexView statistic according to given Stats type applied -// to all non-NaN elements in given IndexView indexed view of -// an table.Table, for given column name. -// If name not found, returns error message. -// Return value(s) is size of column cell: 1 for scalar 1D columns -// and N for higher-dimensional columns. -func StatColumn(ix *table.IndexView, column string, stat Stats) ([]float64, error) { - colIndex, err := ix.Table.ColumnIndex(column) - if err != nil { - return nil, err - } - rv := StatIndex(ix, colIndex, stat) - return rv, nil -} - -// StatIndexFunc applies given StatFunc function to each element in the given column, -// using float64 conversions of the values. ini is the initial value for the agg variable. -// Operates independently over each cell on n-dimensional columns and returns the result as a slice -// of values per cell. -func StatIndexFunc(ix *table.IndexView, colIndex int, ini float64, fun StatFunc) []float64 { - cl := ix.Table.Columns[colIndex] - _, csz := cl.RowCellSize() - - ag := make([]float64, csz) - for i := range ag { - ag[i] = ini - } - if csz == 1 { - for _, srw := range ix.Indexes { - val := cl.Float1D(srw) - if !math.IsNaN(val) { - ag[0] = fun(srw, val, ag[0]) - } - } - } else { - for _, srw := range ix.Indexes { - si := srw * csz - for j := range ag { - val := cl.Float1D(si + j) - if !math.IsNaN(val) { - ag[j] = fun(si+j, val, ag[j]) - } - } - } - } - return ag -} - -/////////////////////////////////////////////////// -// Count - -// CountIndex returns the count of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func CountIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, 0, CountFunc) -} - -// CountColumn returns the count of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func CountColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return CountIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Sum - -// SumIndex returns the sum of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SumIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, 0, SumFunc) -} - -// SumColumn returns the sum of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SumColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return SumIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Prod - -// ProdIndex returns the product of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func ProdIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, 1, ProdFunc) -} - -// ProdColumn returns the product of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func ProdColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return ProdIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Min - -// MinIndex returns the minimum of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MinIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, math.MaxFloat64, MinFunc) -} - -// MinColumn returns the minimum of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MinColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return MinIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Max - -// MaxIndex returns the maximum of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MaxIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, -math.MaxFloat64, MaxFunc) -} - -// MaxColumn returns the maximum of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MaxColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return MaxIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// MinAbs - -// MinAbsIndex returns the minimum of abs of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MinAbsIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, math.MaxFloat64, MinAbsFunc) -} - -// MinAbsColumn returns the minimum of abs of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MinAbsColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return MinAbsIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// MaxAbs - -// MaxAbsIndex returns the maximum of abs of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MaxAbsIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, -math.MaxFloat64, MaxAbsFunc) -} - -// MaxAbsColumn returns the maximum of abs of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MaxAbsColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return MaxAbsIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Mean - -// MeanIndex returns the mean of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MeanIndex(ix *table.IndexView, colIndex int) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - mean := SumIndex(ix, colIndex) - for i := range mean { - if cnt[i] > 0 { - mean[i] /= cnt[i] - } - } - return mean -} - -// MeanColumn returns the mean of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MeanColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return MeanIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Var - -// VarIndex returns the sample variance of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Sample variance is normalized by 1/(n-1) -- see VarPop version for 1/n normalization. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func VarIndex(ix *table.IndexView, colIndex int) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - mean := SumIndex(ix, colIndex) - for i := range mean { - if cnt[i] > 0 { - mean[i] /= cnt[i] - } - } - col := ix.Table.Columns[colIndex] - _, csz := col.RowCellSize() - vr := StatIndexFunc(ix, colIndex, 0, func(idx int, val float64, agg float64) float64 { - cidx := idx % csz - dv := val - mean[cidx] - return agg + dv*dv - }) - for i := range vr { - if cnt[i] > 1 { - vr[i] /= (cnt[i] - 1) - } - } - return vr -} - -// VarColumn returns the sample variance of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// Sample variance is normalized by 1/(n-1) -- see VarPop version for 1/n normalization. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func VarColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return VarIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Std - -// StdIndex returns the sample std deviation of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Sample std deviation is normalized by 1/(n-1) -- see StdPop version for 1/n normalization. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func StdIndex(ix *table.IndexView, colIndex int) []float64 { - std := VarIndex(ix, colIndex) - for i := range std { - std[i] = math.Sqrt(std[i]) - } - return std -} - -// StdColumn returns the sample std deviation of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// Sample std deviation is normalized by 1/(n-1) -- see StdPop version for 1/n normalization. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func StdColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return StdIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Sem - -// SemIndex returns the sample standard error of the mean of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Sample sem is normalized by 1/(n-1) -- see SemPop version for 1/n normalization. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SemIndex(ix *table.IndexView, colIndex int) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - sem := StdIndex(ix, colIndex) - for i := range sem { - if cnt[i] > 0 { - sem[i] /= math.Sqrt(cnt[i]) - } - } - return sem -} - -// SemColumn returns the sample standard error of the mean of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// Sample sem is normalized by 1/(n-1) -- see SemPop version for 1/n normalization. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SemColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return SemIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// L1Norm - -// L1NormIndex returns the L1 norm (sum abs values) of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func L1NormIndex(ix *table.IndexView, colIndex int) []float64 { - return StatIndexFunc(ix, colIndex, 0, L1NormFunc) -} - -// L1NormColumn returns the L1 norm (sum abs values) of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func L1NormColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return L1NormIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// SumSq - -// SumSqIndex returns the sum-of-squares of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SumSqIndex(ix *table.IndexView, colIndex int) []float64 { - cl := ix.Table.Columns[colIndex] - _, csz := cl.RowCellSize() - - scale := make([]float64, csz) - ss := make([]float64, csz) - for i := range ss { - ss[i] = 1 - } - n := len(ix.Indexes) - if csz == 1 { - if n < 2 { - if n == 1 { - ss[0] = math.Abs(cl.Float1D(ix.Indexes[0])) - return ss - } - return scale // all 0s - } - for _, srw := range ix.Indexes { - v := cl.Float1D(srw) - absxi := math.Abs(v) - if scale[0] < absxi { - ss[0] = 1 + ss[0]*(scale[0]/absxi)*(scale[0]/absxi) - scale[0] = absxi - } else { - ss[0] = ss[0] + (absxi/scale[0])*(absxi/scale[0]) - } - } - if math.IsInf(scale[0], 1) { - ss[0] = math.Inf(1) - } else { - ss[0] = scale[0] * scale[0] * ss[0] - } - } else { - if n < 2 { - if n == 1 { - si := csz * ix.Indexes[0] - for j := range csz { - ss[j] = math.Abs(cl.Float1D(si + j)) - } - return ss - } - return scale // all 0s - } - for _, srw := range ix.Indexes { - si := srw * csz - for j := range ss { - v := cl.Float1D(si + j) - absxi := math.Abs(v) - if scale[j] < absxi { - ss[j] = 1 + ss[j]*(scale[j]/absxi)*(scale[j]/absxi) - scale[j] = absxi - } else { - ss[j] = ss[j] + (absxi/scale[j])*(absxi/scale[j]) - } - } - } - for j := range ss { - if math.IsInf(scale[j], 1) { - ss[j] = math.Inf(1) - } else { - ss[j] = scale[j] * scale[j] * ss[j] - } - } - } - return ss -} - -// SumSqColumn returns the sum-of-squares of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SumSqColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return SumSqIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// L2Norm - -// L2NormIndex returns the L2 norm (square root of sum-of-squares) of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func L2NormIndex(ix *table.IndexView, colIndex int) []float64 { - ss := SumSqIndex(ix, colIndex) - for i := range ss { - ss[i] = math.Sqrt(ss[i]) - } - return ss -} - -// L2NormColumn returns the L2 norm (square root of sum-of-squares) of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func L2NormColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return L2NormIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// VarPop - -// VarPopIndex returns the population variance of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// population variance is normalized by 1/n -- see Var version for 1/(n-1) sample normalization. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func VarPopIndex(ix *table.IndexView, colIndex int) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - mean := SumIndex(ix, colIndex) - for i := range mean { - if cnt[i] > 0 { - mean[i] /= cnt[i] - } - } - col := ix.Table.Columns[colIndex] - _, csz := col.RowCellSize() - vr := StatIndexFunc(ix, colIndex, 0, func(idx int, val float64, agg float64) float64 { - cidx := idx % csz - dv := val - mean[cidx] - return agg + dv*dv - }) - for i := range vr { - if cnt[i] > 0 { - vr[i] /= cnt[i] - } - } - return vr -} - -// VarPopColumn returns the population variance of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// population variance is normalized by 1/n -- see Var version for 1/(n-1) sample normalization. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func VarPopColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return VarPopIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// StdPop - -// StdPopIndex returns the population std deviation of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// population std dev is normalized by 1/n -- see Var version for 1/(n-1) sample normalization. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func StdPopIndex(ix *table.IndexView, colIndex int) []float64 { - std := VarPopIndex(ix, colIndex) - for i := range std { - std[i] = math.Sqrt(std[i]) - } - return std -} - -// StdPopColumn returns the population std deviation of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// population std dev is normalized by 1/n -- see Var version for 1/(n-1) sample normalization. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func StdPopColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return StdPopIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// SemPop - -// SemPopIndex returns the population standard error of the mean of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// population sem is normalized by 1/n -- see Var version for 1/(n-1) sample normalization. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SemPopIndex(ix *table.IndexView, colIndex int) []float64 { - cnt := CountIndex(ix, colIndex) - if cnt == nil { - return nil - } - sem := StdPopIndex(ix, colIndex) - for i := range sem { - if cnt[i] > 0 { - sem[i] /= math.Sqrt(cnt[i]) - } - } - return sem -} - -// SemPopColumn returns the standard error of the mean of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// population sem is normalized by 1/n -- see Var version for 1/(n-1) sample normalization. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func SemPopColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return SemPopIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Median - -// MedianIndex returns the median of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MedianIndex(ix *table.IndexView, colIndex int) []float64 { - return QuantilesIndex(ix, colIndex, []float64{.5}) -} - -// MedianColumn returns the median of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func MedianColumn(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return MedianIndex(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Q1 - -// Q1Index returns the first quartile of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func Q1Index(ix *table.IndexView, colIndex int) []float64 { - return QuantilesIndex(ix, colIndex, []float64{.25}) -} - -// Q1Column returns the first quartile of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func Q1Column(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return Q1Index(ix, colIndex) -} - -/////////////////////////////////////////////////// -// Q3 - -// Q3Index returns the third quartile of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func Q3Index(ix *table.IndexView, colIndex int) []float64 { - return QuantilesIndex(ix, colIndex, []float64{.75}) -} - -// Q3Column returns the third quartile of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned. -// Return value is size of each column cell -- 1 for scalar 1D columns -// and N for higher-dimensional columns. -func Q3Column(ix *table.IndexView, column string) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return Q3Index(ix, colIndex) -} diff --git a/tensor/stats/stats/indexview_test.go b/tensor/stats/stats/indexview_test.go deleted file mode 100644 index 31c541bdd5..0000000000 --- a/tensor/stats/stats/indexview_test.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "math" - "testing" - - "cogentcore.org/core/base/errors" - "cogentcore.org/core/base/tolassert" - "cogentcore.org/core/tensor/table" - - "github.com/stretchr/testify/assert" -) - -func TestIndexView(t *testing.T) { - dt := table.NewTable().SetNumRows(5) - dt.AddFloat64Column("data") - dt.SetFloat("data", 0, 1) - dt.SetFloat("data", 1, 2) - dt.SetFloat("data", 2, 3) - dt.SetFloat("data", 3, 4) - dt.SetFloat("data", 4, 5) - - ix := table.NewIndexView(dt) - - results := []float64{5, 15, 120, 1, 5, 1, 5, 3, 2.5, math.Sqrt(2.5), math.Sqrt(2.5) / math.Sqrt(5), - 15, 55, math.Sqrt(55), 2, math.Sqrt(2), math.Sqrt(2) / math.Sqrt(5), 3, 2, 4} - - assert.Equal(t, results[Count:Count+1], CountColumn(ix, "data")) - assert.Equal(t, results[Sum:Sum+1], SumColumn(ix, "data")) - assert.Equal(t, results[Prod:Prod+1], ProdColumn(ix, "data")) - assert.Equal(t, results[Min:Min+1], MinColumn(ix, "data")) - assert.Equal(t, results[Max:Max+1], MaxColumn(ix, "data")) - assert.Equal(t, results[MinAbs:MinAbs+1], MinAbsColumn(ix, "data")) - assert.Equal(t, results[MaxAbs:MaxAbs+1], MaxAbsColumn(ix, "data")) - assert.Equal(t, results[Mean:Mean+1], MeanColumn(ix, "data")) - assert.Equal(t, results[Var:Var+1], VarColumn(ix, "data")) - assert.Equal(t, results[Std:Std+1], StdColumn(ix, "data")) - assert.Equal(t, results[Sem:Sem+1], SemColumn(ix, "data")) - assert.Equal(t, results[L1Norm:L1Norm+1], L1NormColumn(ix, "data")) - tolassert.EqualTol(t, results[SumSq], SumSqColumn(ix, "data")[0], 1.0e-8) - tolassert.EqualTol(t, results[L2Norm], L2NormColumn(ix, "data")[0], 1.0e-8) - assert.Equal(t, results[VarPop:VarPop+1], VarPopColumn(ix, "data")) - assert.Equal(t, results[StdPop:StdPop+1], StdPopColumn(ix, "data")) - assert.Equal(t, results[SemPop:SemPop+1], SemPopColumn(ix, "data")) - assert.Equal(t, results[Median:Median+1], MedianColumn(ix, "data")) - assert.Equal(t, results[Q1:Q1+1], Q1Column(ix, "data")) - assert.Equal(t, results[Q3:Q3+1], Q3Column(ix, "data")) - - for _, stat := range StatsValues() { - tolassert.EqualTol(t, results[stat], errors.Log1(StatColumn(ix, "data", stat))[0], 1.0e-8) - } - - desc := DescAll(ix) - assert.Equal(t, len(DescStats), desc.Rows) - assert.Equal(t, 2, desc.NumColumns()) - - for ri, stat := range DescStats { - dv := desc.Float("data", ri) - // fmt.Println(ri, ag.String(), dv, results[ag]) - assert.Equal(t, results[stat], dv) - } - - desc, err := DescColumn(ix, "data") - if err != nil { - t.Error(err) - } - assert.Equal(t, len(DescStats), desc.Rows) - assert.Equal(t, 2, desc.NumColumns()) - for ri, stat := range DescStats { - dv := desc.Float("data", ri) - // fmt.Println(ri, ag.String(), dv, results[ag]) - assert.Equal(t, results[stat], dv) - } - - pcts := PctIfColumn(ix, "data", func(idx int, val float64) bool { - return val > 2 - }) - assert.Equal(t, []float64{60}, pcts) - - props := PropIfColumn(ix, "data", func(idx int, val float64) bool { - return val > 2 - }) - assert.Equal(t, []float64{0.6}, props) -} diff --git a/tensor/stats/stats/norm.go b/tensor/stats/stats/norm.go new file mode 100644 index 0000000000..0ecd6f61bc --- /dev/null +++ b/tensor/stats/stats/norm.go @@ -0,0 +1,92 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stats + +import ( + "cogentcore.org/core/math32" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/tmath" +) + +// ZScore computes Z-normalized values into given output tensor, +// subtracting the Mean and dividing by the standard deviation. +func ZScore(a tensor.Tensor) tensor.Values { + return tensor.CallOut1(ZScoreOut, a) +} + +// ZScore computes Z-normalized values into given output tensor, +// subtracting the Mean and dividing by the standard deviation. +func ZScoreOut(a tensor.Tensor, out tensor.Values) error { + mout := tensor.NewFloat64() + std, mean, _ := StdOut64(a, mout) + tmath.SubOut(a, mean, out) + tmath.DivOut(out, std, out) + return nil +} + +// UnitNorm computes unit normalized values into given output tensor, +// subtracting the Min value and dividing by the Max of the remaining numbers. +func UnitNorm(a tensor.Tensor) tensor.Values { + return tensor.CallOut1(UnitNormOut, a) +} + +// UnitNormOut computes unit normalized values into given output tensor, +// subtracting the Min value and dividing by the Max of the remaining numbers. +func UnitNormOut(a tensor.Tensor, out tensor.Values) error { + mout := tensor.NewFloat64() + err := MinOut(a, mout) + if err != nil { + return err + } + tmath.SubOut(a, mout, out) + MaxOut(out, mout) + tmath.DivOut(out, mout, out) + return nil +} + +// Clamp ensures that all values are within min, max limits, clamping +// values to those bounds if they exceed them. min and max args are +// treated as scalars (first value used). +func Clamp(in, minv, maxv tensor.Tensor) tensor.Values { + return tensor.CallOut3(ClampOut, in, minv, minv) +} + +// ClampOut ensures that all values are within min, max limits, clamping +// values to those bounds if they exceed them. min and max args are +// treated as scalars (first value used). +func ClampOut(in, minv, maxv tensor.Tensor, out tensor.Values) error { + tensor.SetShapeFrom(out, in) + mn := minv.Float1D(0) + mx := maxv.Float1D(0) + tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) { + tsr[1].SetFloat1D(math32.Clamp(tsr[0].Float1D(idx), mn, mx), idx) + }, in, out) + return nil +} + +// Binarize results in a binary-valued output by setting +// values >= the threshold to 1, else 0. threshold is +// treated as a scalar (first value used). +func Binarize(in, threshold tensor.Tensor) tensor.Values { + return tensor.CallOut2(BinarizeOut, in, threshold) +} + +// BinarizeOut results in a binary-valued output by setting +// values >= the threshold to 1, else 0. threshold is +// treated as a scalar (first value used). +func BinarizeOut(in, threshold tensor.Tensor, out tensor.Values) error { + tensor.SetShapeFrom(out, in) + thr := threshold.Float1D(0) + tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) { + v := tsr[0].Float1D(idx) + if v >= thr { + v = 1 + } else { + v = 0 + } + tsr[1].SetFloat1D(v, idx) + }, in, out) + return nil +} diff --git a/tensor/stats/stats/quantiles.go b/tensor/stats/stats/quantiles.go index eb191691a4..4aeb593f2f 100644 --- a/tensor/stats/stats/quantiles.go +++ b/tensor/stats/stats/quantiles.go @@ -8,66 +8,100 @@ import ( "math" "cogentcore.org/core/base/errors" - "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor" ) -// QuantilesIndex returns the given quantile(s) of non-NaN elements in given -// IndexView indexed view of an table.Table, for given column index. -// Column must be a 1d Column -- returns nil for n-dimensional columns. -// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc. Uses linear interpolation. +// Quantiles returns the given quantile(s) of non-NaN elements in given +// 1D tensor. Because sorting uses indexes, this only works for 1D case. +// If needed for a sub-space of values, that can be extracted through slicing +// and then used. Logs an error if not 1D. +// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc. +// Uses linear interpolation. // Because this requires a sort, it is more efficient to get as many quantiles // as needed in one pass. -func QuantilesIndex(ix *table.IndexView, colIndex int, qs []float64) []float64 { - nq := len(qs) - if nq == 0 { - return nil +func Quantiles(in, qs tensor.Tensor) tensor.Values { + return tensor.CallOut2(QuantilesOut, in, qs) +} + +// QuantilesOut returns the given quantile(s) of non-NaN elements in given +// 1D tensor. Because sorting uses indexes, this only works for 1D case. +// If needed for a sub-space of values, that can be extracted through slicing +// and then used. Returns and logs an error if not 1D. +// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc. +// Uses linear interpolation. +// Because this requires a sort, it is more efficient to get as many quantiles +// as needed in one pass. +func QuantilesOut(in, qs tensor.Tensor, out tensor.Values) error { + if in.NumDims() != 1 { + return errors.Log(errors.New("stats.QuantilesFunc: only 1D input tensors allowed")) + } + if qs.NumDims() != 1 { + return errors.Log(errors.New("stats.QuantilesFunc: only 1D quantile tensors allowed")) } - col := ix.Table.Columns[colIndex] - if col.NumDims() > 1 { // only valid for 1D + tensor.SetShapeFrom(out, in) + sin := tensor.AsRows(in.AsValues()) + sin.ExcludeMissing() + sin.Sort(tensor.Ascending) + sz := len(sin.Indexes) - 1 // length of our own index list + if sz <= 0 { + out.(tensor.Values).SetZeros() return nil } - rvs := make([]float64, nq) - six := ix.Clone() // leave original indexes intact - six.Filter(func(et *table.Table, row int) bool { // get rid of NaNs in this column - if math.IsNaN(col.Float1D(row)) { - return false - } - return true - }) - six.SortColumn(colIndex, true) - sz := len(six.Indexes) - 1 // length of our own index list fsz := float64(sz) - for i, q := range qs { + nq := qs.Len() + for i := range nq { + q := qs.Float1D(i) val := 0.0 qi := q * fsz lwi := math.Floor(qi) lwii := int(lwi) if lwii >= sz { - val = col.Float1D(six.Indexes[sz]) + val = sin.FloatRow(sz, 0) } else if lwii < 0 { - val = col.Float1D(six.Indexes[0]) + val = sin.FloatRow(0, 0) } else { phi := qi - lwi - lwv := col.Float1D(six.Indexes[lwii]) - hiv := col.Float1D(six.Indexes[lwii+1]) + lwv := sin.FloatRow(lwii, 0) + hiv := sin.FloatRow(lwii+1, 0) val = (1-phi)*lwv + phi*hiv } - rvs[i] = val + out.SetFloat1D(val, i) } - return rvs + return nil } -// Quantiles returns the given quantile(s) of non-Null, non-NaN elements in given -// IndexView indexed view of an table.Table, for given column name. -// If name not found, nil is returned -- use Try version for error message. -// Column must be a 1d Column -- returns nil for n-dimensional columns. -// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc. Uses linear interpolation. -// Because this requires a sort, it is more efficient to get as many quantiles -// as needed in one pass. -func Quantiles(ix *table.IndexView, column string, qs []float64) []float64 { - colIndex := errors.Log1(ix.Table.ColumnIndex(column)) - if colIndex == -1 { - return nil - } - return QuantilesIndex(ix, colIndex, qs) +// Median computes the median (50% quantile) of tensor values. +// See [StatsFunc] for general information. +func Median(in tensor.Tensor) tensor.Values { + return Quantiles(in, tensor.NewFloat64Scalar(.5)) +} + +// Q1 computes the first quantile (25%) of tensor values. +// See [StatsFunc] for general information. +func Q1(in tensor.Tensor) tensor.Values { + return Quantiles(in, tensor.NewFloat64Scalar(.25)) +} + +// Q3 computes the third quantile (75%) of tensor values. +// See [StatsFunc] for general information. +func Q3(in tensor.Tensor) tensor.Values { + return Quantiles(in, tensor.NewFloat64Scalar(.75)) +} + +// MedianOut computes the median (50% quantile) of tensor values. +// See [StatsFunc] for general information. +func MedianOut(in tensor.Tensor, out tensor.Values) error { + return QuantilesOut(in, tensor.NewFloat64Scalar(.5), out) +} + +// Q1Out computes the first quantile (25%) of tensor values. +// See [StatsFunc] for general information. +func Q1Out(in tensor.Tensor, out tensor.Values) error { + return QuantilesOut(in, tensor.NewFloat64Scalar(.25), out) +} + +// Q3Out computes the third quantile (75%) of tensor values. +// See [StatsFunc] for general information. +func Q3Out(in tensor.Tensor, out tensor.Values) error { + return QuantilesOut(in, tensor.NewFloat64Scalar(.75), out) } diff --git a/tensor/stats/stats/stats.go b/tensor/stats/stats/stats.go index a77a2a15a8..dce603414d 100644 --- a/tensor/stats/stats/stats.go +++ b/tensor/stats/stats/stats.go @@ -4,70 +4,147 @@ package stats +import ( + "strings" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" +) + //go:generate core generate +func init() { + tensor.AddFunc(StatCount.FuncName(), Count) + tensor.AddFunc(StatSum.FuncName(), Sum) + tensor.AddFunc(StatL1Norm.FuncName(), L1Norm) + tensor.AddFunc(StatProd.FuncName(), Prod) + tensor.AddFunc(StatMin.FuncName(), Min) + tensor.AddFunc(StatMax.FuncName(), Max) + tensor.AddFunc(StatMinAbs.FuncName(), MinAbs) + tensor.AddFunc(StatMaxAbs.FuncName(), MaxAbs) + tensor.AddFunc(StatMean.FuncName(), Mean) + tensor.AddFunc(StatVar.FuncName(), Var) + tensor.AddFunc(StatStd.FuncName(), Std) + tensor.AddFunc(StatSem.FuncName(), Sem) + tensor.AddFunc(StatSumSq.FuncName(), SumSq) + tensor.AddFunc(StatL2Norm.FuncName(), L2Norm) + tensor.AddFunc(StatVarPop.FuncName(), VarPop) + tensor.AddFunc(StatStdPop.FuncName(), StdPop) + tensor.AddFunc(StatSemPop.FuncName(), SemPop) + tensor.AddFunc(StatMedian.FuncName(), Median) + tensor.AddFunc(StatQ1.FuncName(), Q1) + tensor.AddFunc(StatQ3.FuncName(), Q3) + tensor.AddFunc(StatFirst.FuncName(), First) + tensor.AddFunc(StatFinal.FuncName(), Final) +} + // Stats is a list of different standard aggregation functions, which can be used // to choose an aggregation function -type Stats int32 //enums:enum +type Stats int32 //enums:enum -trim-prefix Stat const ( - // count of number of elements - Count Stats = iota + // count of number of elements. + StatCount Stats = iota - // sum of elements - Sum + // sum of elements. + StatSum - // product of elements - Prod + // L1 Norm: sum of absolute values of elements. + StatL1Norm - // minimum value - Min + // product of elements. + StatProd - // max maximum value - Max + // minimum value. + StatMin - // minimum absolute value - MinAbs + // maximum value. + StatMax - // maximum absolute value - MaxAbs + // minimum of absolute values. + StatMinAbs - // mean mean value - Mean + // maximum of absolute values. + StatMaxAbs - // sample variance (squared diffs from mean, divided by n-1) - Var + // mean value = sum / count. + StatMean - // sample standard deviation (sqrt of Var) - Std + // sample variance (squared deviations from mean, divided by n-1). + StatVar - // sample standard error of the mean (Std divided by sqrt(n)) - Sem + // sample standard deviation (sqrt of Var). + StatStd - // L1 Norm: sum of absolute values - L1Norm + // sample standard error of the mean (Std divided by sqrt(n)). + StatSem - // sum of squared values - SumSq + // sum of squared values. + StatSumSq - // L2 Norm: square-root of sum-of-squares - L2Norm + // L2 Norm: square-root of sum-of-squares. + StatL2Norm - // population variance (squared diffs from mean, divided by n) - VarPop + // population variance (squared diffs from mean, divided by n). + StatVarPop - // population standard deviation (sqrt of VarPop) - StdPop + // population standard deviation (sqrt of VarPop). + StatStdPop - // population standard error of the mean (StdPop divided by sqrt(n)) - SemPop + // population standard error of the mean (StdPop divided by sqrt(n)). + StatSemPop - // middle value in sorted ordering - Median + // middle value in sorted ordering. + StatMedian - // Q1 first quartile = 25%ile value = .25 quantile value - Q1 + // Q1 first quartile = 25%ile value = .25 quantile value. + StatQ1 - // Q3 third quartile = 75%ile value = .75 quantile value - Q3 + // Q3 third quartile = 75%ile value = .75 quantile value. + StatQ3 + + // first item in the set of data: for data with a natural ordering. + StatFirst + + // final item in the set of data: for data with a natural ordering. + StatFinal ) + +// FuncName returns the package-qualified function name to use +// in tensor.Call to call this function. +func (s Stats) FuncName() string { + return "stats." + s.String() +} + +// Func returns function for given stat. +func (s Stats) Func() StatsFunc { + fn := errors.Log1(tensor.FuncByName(s.FuncName())) + return fn.Fun.(StatsFunc) +} + +// Call calls this statistic function on given tensors. +// returning output as a newly created tensor. +func (s Stats) Call(in tensor.Tensor) tensor.Values { + return s.Func()(in) +} + +// StripPackage removes any package name from given string, +// used for naming based on FuncName() which could be custom +// or have a package prefix. +func StripPackage(name string) string { + spl := strings.Split(name, ".") + if len(spl) > 1 { + return spl[len(spl)-1] + } + return name +} + +// AsStatsFunc returns given function as a [StatsFunc] function, +// or an error if it does not fit that signature. +func AsStatsFunc(fun any) (StatsFunc, error) { + sfun, ok := fun.(StatsFunc) + if !ok { + return nil, errors.New("metric.AsStatsFunc: function does not fit the StatsFunc signature") + } + return sfun, nil +} diff --git a/tensor/stats/stats/stats_test.go b/tensor/stats/stats/stats_test.go new file mode 100644 index 0000000000..d5389c8195 --- /dev/null +++ b/tensor/stats/stats/stats_test.go @@ -0,0 +1,334 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stats + +import ( + "fmt" + "math" + "testing" + + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" +) + +func TestFuncs64(t *testing.T) { + vals := []float64{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1} + ix := tensor.NewNumberFromValues(vals...) + out := tensor.NewFloat64(1) + + results := []float64{11, 5.5, 5.5, 0, 0, 1, 0, 1, 0.5, 0.11, math.Sqrt(0.11), math.Sqrt(0.11) / math.Sqrt(11), 3.85, math.Sqrt(3.85), 0.1, math.Sqrt(0.1), math.Sqrt(0.1) / math.Sqrt(11), 0.5, 0.25, 0.75, 0, 1} + + tol := 1.0e-8 + + CountOut(ix, out) + assert.Equal(t, results[StatCount], out.Values[0]) + + SumOut(ix, out) + assert.Equal(t, results[StatSum], out.Values[0]) + + L1NormOut(ix, out) + assert.Equal(t, results[StatL1Norm], out.Values[0]) + + ProdOut(ix, out) + assert.Equal(t, results[StatProd], out.Values[0]) + + MinOut(ix, out) + assert.Equal(t, results[StatMin], out.Values[0]) + + MaxOut(ix, out) + assert.Equal(t, results[StatMax], out.Values[0]) + + MinAbsOut(ix, out) + assert.Equal(t, results[StatMinAbs], out.Values[0]) + + MaxAbsOut(ix, out) + assert.Equal(t, results[StatMaxAbs], out.Values[0]) + + MeanOut(ix, out) + assert.Equal(t, results[StatMean], out.Values[0]) + + VarOut(ix, out) + assert.InDelta(t, results[StatVar], out.Values[0], tol) + + StdOut(ix, out) + assert.InDelta(t, results[StatStd], out.Values[0], tol) + + SemOut(ix, out) + assert.InDelta(t, results[StatSem], out.Values[0], tol) + + VarPopOut(ix, out) + assert.InDelta(t, results[StatVarPop], out.Values[0], tol) + + StdPopOut(ix, out) + assert.InDelta(t, results[StatStdPop], out.Values[0], tol) + + SemPopOut(ix, out) + assert.InDelta(t, results[StatSemPop], out.Values[0], tol) + + SumSqOut(ix, out) + assert.InDelta(t, results[StatSumSq], out.Values[0], tol) + + L2NormOut(ix, out) + assert.InDelta(t, results[StatL2Norm], out.Values[0], tol) + + MedianOut(ix, out) + assert.InDelta(t, results[StatMedian], out.Values[0], tol) + + Q1Out(ix, out) + assert.InDelta(t, results[StatQ1], out.Values[0], tol) + + Q3Out(ix, out) + assert.InDelta(t, results[StatQ3], out.Values[0], tol) + + FirstOut(ix, out) + assert.InDelta(t, results[StatFirst], out.Values[0], tol) + + FinalOut(ix, out) + assert.InDelta(t, results[StatFinal], out.Values[0], tol) + + for stat := StatCount; stat < StatsN; stat++ { + out := stat.Call(ix) + assert.InDelta(t, results[stat], out.Float1D(0), tol) + } +} + +func TestFuncsInt(t *testing.T) { + vals := []int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} + tsr := tensor.NewNumberFromValues(vals...) + ix := tensor.NewRows(tsr) + out := tensor.NewInt(1) + + results := []int{11, 550, 550, 0, 0, 100, 0, 100, 50, 1100, int(math.Sqrt(1100)), int(math.Sqrt(1100) / math.Sqrt(11)), 38500, 196, 1000, int(math.Sqrt(1000)), int(math.Sqrt(1000) / math.Sqrt(11))} + + CountOut(ix, out) + assert.Equal(t, results[StatCount], out.Values[0]) + + SumOut(ix, out) + assert.Equal(t, results[StatSum], out.Values[0]) + + L1NormOut(ix, out) + assert.Equal(t, results[StatL1Norm], out.Values[0]) + + ProdOut(ix, out) + assert.Equal(t, results[StatProd], out.Values[0]) + + MinOut(ix, out) + assert.Equal(t, results[StatMin], out.Values[0]) + + MaxOut(ix, out) + assert.Equal(t, results[StatMax], out.Values[0]) + + MinAbsOut(ix, out) + assert.Equal(t, results[StatMinAbs], out.Values[0]) + + MaxAbsOut(ix, out) + assert.Equal(t, results[StatMaxAbs], out.Values[0]) + + MeanOut(ix, out) + assert.Equal(t, results[StatMean], out.Values[0]) + + VarOut(ix, out) + assert.Equal(t, results[StatVar], out.Values[0]) + + StdOut(ix, out) + assert.Equal(t, results[StatStd], out.Values[0]) + + SemOut(ix, out) + assert.Equal(t, results[StatSem], out.Values[0]) + + VarPopOut(ix, out) + assert.Equal(t, results[StatVarPop], out.Values[0]) + + StdPopOut(ix, out) + assert.Equal(t, results[StatStdPop], out.Values[0]) + + SemPopOut(ix, out) + assert.Equal(t, results[StatSemPop], out.Values[0]) + + SumSqOut(ix, out) + assert.Equal(t, results[StatSumSq], out.Values[0]) + + L2NormOut(ix, out) + assert.Equal(t, results[StatL2Norm], out.Values[0]) + + for stat := StatCount; stat <= StatSemPop; stat++ { + out := stat.Call(ix) + assert.Equal(t, results[stat], out.Int1D(0)) + } +} + +func TestFuncsCell(t *testing.T) { + vals := []float64{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9} + tsr := tensor.NewFloat32(20, 10) + + for i := range 20 { + for j := range 10 { + tsr.SetFloatRow(vals[j], i, j) + } + } + + ix := tensor.NewRows(tsr) + out := tensor.NewFloat32(20, 10) + + CountOut(ix, out) + nsub := out.Len() + for i := range nsub { + assert.Equal(t, 20.0, out.FloatRow(0, i)) + } + MeanOut(ix, out) + for i := range nsub { + assert.InDelta(t, vals[i], out.FloatRow(0, i), 1.0e-7) // lower tol, using float32 + } + VarOut(ix, out) + for i := range nsub { + assert.InDelta(t, 0.0, out.FloatRow(0, i), 1.0e-7) + } +} + +func TestNorm(t *testing.T) { + vals := []float64{-1.507556722888818, -1.2060453783110545, -0.9045340337332908, -0.6030226891555273, -0.3015113445777635, 0.1, 0.3015113445777635, 0.603022689155527, 0.904534033733291, 1.2060453783110545, 1.507556722888818, .3} + + oned := tensor.NewNumberFromValues(vals...) + oneout := oned.Clone() + + ZScoreOut(oned, oneout) + mout := tensor.NewFloat64() + std, mean, _ := StdOut64(oneout, mout) + assert.InDelta(t, 1.0, std.Float1D(0), 1.0e-6) + assert.InDelta(t, 0.0, mean.Float1D(0), 1.0e-6) + + UnitNormOut(oned, oneout) + MinOut(oneout, mout) + assert.InDelta(t, 0.0, mout.Float1D(0), 1.0e-6) + MaxOut(oneout, mout) + assert.InDelta(t, 1.0, mout.Float1D(0), 1.0e-6) + // fmt.Println(oneout) + + minv := tensor.NewFloat64Scalar(0) + maxv := tensor.NewFloat64Scalar(1) + ClampOut(oned, minv, maxv, oneout) + MinOut(oneout, mout) + assert.InDelta(t, 0.0, mout.Float1D(0), 1.0e-6) + MaxOut(oneout, mout) + assert.InDelta(t, 1.0, mout.Float1D(0), 1.0e-6) + // fmt.Println(oneout) + + thr := tensor.NewFloat64Scalar(0.5) + BinarizeOut(oned, thr, oneout) + MinOut(oneout, mout) + assert.InDelta(t, 0.0, mout.Float1D(0), 1.0e-6) + MaxOut(oneout, mout) + assert.InDelta(t, 1.0, mout.Float1D(0), 1.0e-6) + // fmt.Println(oneout) +} + +// after optimizing: 12/1/2024: also, GOEXPERIMENT=newinliner didn't make any diff +// go test -bench BenchmarkFuncs -count=1 +// goos: darwin +// goarch: arm64 +// pkg: cogentcore.org/core/tensor/stats/stats +// stat=Count-16 677764 1789 ns/op +// stat=Sum-16 668791 1809 ns/op +// stat=L1Norm-16 821071 1484 ns/op +// stat=Prod-16 703598 1706 ns/op +// stat=Min-16 182587 6564 ns/op +// stat=Max-16 181981 6577 ns/op +// stat=MinAbs-16 176342 6787 ns/op +// stat=MaxAbs-16 175491 6784 ns/op +// stat=Mean-16 592713 2014 ns/op +// stat=Var-16 330260 3620 ns/op +// stat=Std-16 329876 3625 ns/op +// stat=Sem-16 330141 3629 ns/op +// stat=SumSq-16 366603 3267 ns/op +// stat=L2Norm-16 366862 3264 ns/op +// stat=VarPop-16 330362 3617 ns/op +// stat=StdPop-16 329172 3626 ns/op +// stat=SemPop-16 331568 3631 ns/op +// stat=Median-16 116071 10316 ns/op +// stat=Q1-16 116175 10334 ns/op +// stat=Q3-16 115149 10331 ns/op + +// old: prior to optimizing 12/1/2024: +// stat=Count-16 166908 7189 ns/op +// stat=Sum-16 166287 7198 ns/op +// stat=L1Norm-16 166587 7195 ns/op +// stat=Prod-16 166029 7185 ns/op +// stat=Min-16 125803 9523 ns/op +// stat=Max-16 125067 9561 ns/op +// stat=MinAbs-16 126109 9524 ns/op +// stat=MaxAbs-16 126346 9500 ns/op +// stat=Mean-16 83302 14365 ns/op +// stat=Var-16 53138 22707 ns/op +// stat=Std-16 53073 22611 ns/op +// stat=Sem-16 52928 22611 ns/op +// stat=SumSq-16 125698 9486 ns/op +// stat=L2Norm-16 126196 9483 ns/op +// stat=VarPop-16 53010 22659 ns/op +// stat=StdPop-16 52994 22573 ns/op +// stat=SemPop-16 52897 22726 ns/op +// stat=Median-16 116223 10334 ns/op +// stat=Q1-16 115728 10431 ns/op +// stat=Q3-16 111325 10307 ns/op + +func runBenchFuncs(b *testing.B, n int, fun Stats) { + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + for range b.N { + fun.Call(av) + } +} + +func BenchmarkFuncs(b *testing.B) { + for stf := StatCount; stf < StatsN; stf++ { + b.Run(fmt.Sprintf("stat=%s", stf.String()), func(b *testing.B) { + runBenchFuncs(b, 1000, stf) + }) + } +} + +// 258.6 ns/op, vs 1809 actual +func BenchmarkSumBaseline(b *testing.B) { + n := 1000 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + for range b.N { + sum := float64(0) + for i := range n { + val := av.Float1D(i) + if math.IsNaN(val) { + continue + } + sum += val + } + _ = sum + } +} + +func runClosure(av *tensor.Float64, fun func(a, agg float64) float64) float64 { + // fun := func(a, agg float64) float64 { // note: it can inline closure if in same fun + // return agg + a + // } + n := 1000 + s := float64(0) + for i := range n { + s = fun(av.Float1D(i), s) // note: Float1D here no extra cost + } + return s +} + +// 1242 ns/op, vs 1809 actual -- mostly it is the closure +func BenchmarkSumClosure(b *testing.B) { + n := 1000 + av := tensor.AsFloat64(tensor.Reshape(tensor.NewIntRange(1, n+1), n)) + b.ResetTimer() + for range b.N { + runClosure(av, func(val, agg float64) float64 { + if math.IsNaN(val) { + return agg + } + return agg + val + }) + } +} diff --git a/tensor/stats/stats/table.go b/tensor/stats/stats/table.go index 018fb75a9a..694ca6621a 100644 --- a/tensor/stats/stats/table.go +++ b/tensor/stats/stats/table.go @@ -4,12 +4,7 @@ package stats -import ( - "reflect" - - "cogentcore.org/core/tensor/table" -) - +/* // MeanTables returns an table.Table with the mean values across all float // columns of the input tables, which must have the same columns but not // necessarily the same number of rows. @@ -58,7 +53,7 @@ func MeanTables(dts []*table.Table) *table.Table { ci := si + j cv := cl.Float1D(ci) cv += dc.Float1D(ci) - cl.SetFloat1D(ci, cv) + cl.SetFloat1D(cv, ci) } } } @@ -69,10 +64,11 @@ func MeanTables(dts []*table.Table) *table.Table { cv := cl.Float1D(ci) if rns[ri] > 0 { cv /= float64(rns[ri]) - cl.SetFloat1D(ci, cv) + cl.SetFloat1D(cv, ci) } } } } return ot } +*/ diff --git a/tensor/stats/stats/tensor.go b/tensor/stats/stats/tensor.go deleted file mode 100644 index bc33041d8e..0000000000 --- a/tensor/stats/stats/tensor.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "math" - - "cogentcore.org/core/tensor" -) - -// StatTensor returns Tensor statistic according to given Stats type applied -// to all non-NaN elements in given Tensor -func StatTensor(tsr tensor.Tensor, stat Stats) float64 { - switch stat { - case Count: - return CountTensor(tsr) - case Sum: - return SumTensor(tsr) - case Prod: - return ProdTensor(tsr) - case Min: - return MinTensor(tsr) - case Max: - return MaxTensor(tsr) - case MinAbs: - return MinAbsTensor(tsr) - case MaxAbs: - return MaxAbsTensor(tsr) - case Mean: - return MeanTensor(tsr) - case Var: - return VarTensor(tsr) - case Std: - return StdTensor(tsr) - case Sem: - return SemTensor(tsr) - case L1Norm: - return L1NormTensor(tsr) - case SumSq: - return SumSqTensor(tsr) - case L2Norm: - return L2NormTensor(tsr) - case VarPop: - return VarPopTensor(tsr) - case StdPop: - return StdPopTensor(tsr) - case SemPop: - return SemPopTensor(tsr) - // case Median: - // return MedianTensor(tsr) - // case Q1: - // return Q1Tensor(tsr) - // case Q3: - // return Q3Tensor(tsr) - } - return 0 -} - -// TensorStat applies given StatFunc function to each element in the tensor -// (automatically skips NaN elements), using float64 conversions of the values. -// ini is the initial value for the agg variable. returns final aggregate value -func TensorStat(tsr tensor.Tensor, ini float64, fun StatFunc) float64 { - ln := tsr.Len() - agg := ini - for j := 0; j < ln; j++ { - val := tsr.Float1D(j) - if !math.IsNaN(val) { - agg = fun(j, val, agg) - } - } - return agg -} - -// CountTensor returns the count of non-NaN elements in given Tensor. -func CountTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, 0, CountFunc) -} - -// SumTensor returns the sum of non-NaN elements in given Tensor. -func SumTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, 0, SumFunc) -} - -// ProdTensor returns the product of non-NaN elements in given Tensor. -func ProdTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, 1, ProdFunc) -} - -// MinTensor returns the minimum of non-NaN elements in given Tensor. -func MinTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, math.MaxFloat64, MinFunc) -} - -// MaxTensor returns the maximum of non-NaN elements in given Tensor. -func MaxTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, -math.MaxFloat64, MaxFunc) -} - -// MinAbsTensor returns the minimum of non-NaN elements in given Tensor. -func MinAbsTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, math.MaxFloat64, MinAbsFunc) -} - -// MaxAbsTensor returns the maximum of non-NaN elements in given Tensor. -func MaxAbsTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, -math.MaxFloat64, MaxAbsFunc) -} - -// MeanTensor returns the mean of non-NaN elements in given Tensor. -func MeanTensor(tsr tensor.Tensor) float64 { - cnt := CountTensor(tsr) - if cnt == 0 { - return 0 - } - return SumTensor(tsr) / cnt -} - -// VarTensor returns the sample variance of non-NaN elements in given Tensor. -func VarTensor(tsr tensor.Tensor) float64 { - cnt := CountTensor(tsr) - if cnt < 2 { - return 0 - } - mean := SumTensor(tsr) / cnt - vr := TensorStat(tsr, 0, func(idx int, val float64, agg float64) float64 { - dv := val - mean - return agg + dv*dv - }) - return vr / (cnt - 1) -} - -// StdTensor returns the sample standard deviation of non-NaN elements in given Tensor. -func StdTensor(tsr tensor.Tensor) float64 { - return math.Sqrt(VarTensor(tsr)) -} - -// SemTensor returns the sample standard error of the mean of non-NaN elements in given Tensor. -func SemTensor(tsr tensor.Tensor) float64 { - cnt := CountTensor(tsr) - if cnt < 2 { - return 0 - } - return StdTensor(tsr) / math.Sqrt(cnt) -} - -// L1NormTensor returns the L1 norm: sum of absolute values of non-NaN elements in given Tensor. -func L1NormTensor(tsr tensor.Tensor) float64 { - return TensorStat(tsr, 0, L1NormFunc) -} - -// SumSqTensor returns the sum-of-squares of non-NaN elements in given Tensor. -func SumSqTensor(tsr tensor.Tensor) float64 { - n := tsr.Len() - if n < 2 { - if n == 1 { - return math.Abs(tsr.Float1D(0)) - } - return 0 - } - var ( - scale float64 = 0 - ss float64 = 1 - ) - for j := 0; j < n; j++ { - v := tsr.Float1D(j) - if v == 0 || math.IsNaN(v) { - continue - } - absxi := math.Abs(v) - if scale < absxi { - ss = 1 + ss*(scale/absxi)*(scale/absxi) - scale = absxi - } else { - ss = ss + (absxi/scale)*(absxi/scale) - } - } - if math.IsInf(scale, 1) { - return math.Inf(1) - } - return scale * scale * ss -} - -// L2NormTensor returns the L2 norm: square root of sum-of-squared values of non-NaN elements in given Tensor. -func L2NormTensor(tsr tensor.Tensor) float64 { - return math.Sqrt(SumSqTensor(tsr)) -} - -// VarPopTensor returns the population variance of non-NaN elements in given Tensor. -func VarPopTensor(tsr tensor.Tensor) float64 { - cnt := CountTensor(tsr) - if cnt < 2 { - return 0 - } - mean := SumTensor(tsr) / cnt - vr := TensorStat(tsr, 0, func(idx int, val float64, agg float64) float64 { - dv := val - mean - return agg + dv*dv - }) - return vr / cnt -} - -// StdPopTensor returns the population standard deviation of non-NaN elements in given Tensor. -func StdPopTensor(tsr tensor.Tensor) float64 { - return math.Sqrt(VarPopTensor(tsr)) -} - -// SemPopTensor returns the population standard error of the mean of non-NaN elements in given Tensor. -func SemPopTensor(tsr tensor.Tensor) float64 { - cnt := CountTensor(tsr) - if cnt < 2 { - return 0 - } - return StdPopTensor(tsr) / math.Sqrt(cnt) -} diff --git a/tensor/stats/stats/tensor_test.go b/tensor/stats/stats/tensor_test.go deleted file mode 100644 index 663ff43f50..0000000000 --- a/tensor/stats/stats/tensor_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stats - -import ( - "math" - "testing" - - "cogentcore.org/core/base/tolassert" - "cogentcore.org/core/tensor" - "github.com/stretchr/testify/assert" -) - -func TestTsrAgg(t *testing.T) { - tsr := tensor.New[float64]([]int{5}).(*tensor.Float64) - tsr.Values = []float64{1, 2, 3, 4, 5} - - results := []float64{5, 15, 120, 1, 5, 1, 5, 3, 2.5, math.Sqrt(2.5), math.Sqrt(2.5) / math.Sqrt(5), - 15, 55, math.Sqrt(55), 2, math.Sqrt(2), math.Sqrt(2) / math.Sqrt(5), 3, 2, 4} - - assert.Equal(t, results[Count], CountTensor(tsr)) - assert.Equal(t, results[Sum], SumTensor(tsr)) - assert.Equal(t, results[Prod], ProdTensor(tsr)) - assert.Equal(t, results[Min], MinTensor(tsr)) - assert.Equal(t, results[Max], MaxTensor(tsr)) - assert.Equal(t, results[MinAbs], MinAbsTensor(tsr)) - assert.Equal(t, results[MaxAbs], MaxAbsTensor(tsr)) - assert.Equal(t, results[Mean], MeanTensor(tsr)) - assert.Equal(t, results[Var], VarTensor(tsr)) - assert.Equal(t, results[Std], StdTensor(tsr)) - assert.Equal(t, results[Sem], SemTensor(tsr)) - assert.Equal(t, results[L1Norm], L1NormTensor(tsr)) - tolassert.EqualTol(t, results[SumSq], SumSqTensor(tsr), 1.0e-8) - tolassert.EqualTol(t, results[L2Norm], L2NormTensor(tsr), 1.0e-8) - assert.Equal(t, results[VarPop], VarPopTensor(tsr)) - assert.Equal(t, results[StdPop], StdPopTensor(tsr)) - assert.Equal(t, results[SemPop], SemPopTensor(tsr)) - - for stat := Count; stat <= SemPop; stat++ { - tolassert.EqualTol(t, results[stat], StatTensor(tsr, stat), 1.0e-8) - } -} diff --git a/tensor/stats/stats/vec.go b/tensor/stats/stats/vec.go new file mode 100644 index 0000000000..9617d08c53 --- /dev/null +++ b/tensor/stats/stats/vec.go @@ -0,0 +1,197 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stats + +import ( + "cogentcore.org/core/tensor" +) + +// VectorizeOut64 is the general compute function for stats. +// This version makes a Float64 output tensor for aggregating +// and computing values, and then copies the results back to the +// original output. This allows stats functions to operate directly +// on integer valued inputs and produce sensible results. +// It returns the Float64 output tensor for further processing as needed. +func VectorizeOut64(in tensor.Tensor, out tensor.Values, ini float64, fun func(val, agg float64) float64) *tensor.Float64 { + rows, cells := in.Shape().RowCellSize() + o64 := tensor.NewFloat64(cells) + if rows <= 0 { + return o64 + } + if cells == 1 { + out.SetShapeSizes(1) + agg := ini + switch x := in.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(x.Float1D(i), agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(x.Float1D(i), agg) + } + default: + for i := range rows { + agg = fun(in.Float1D(i), agg) + } + } + o64.SetFloat1D(agg, 0) + out.SetFloat1D(agg, 0) + return o64 + } + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + for i := range cells { + o64.SetFloat1D(ini, i) + } + switch x := in.(type) { + case *tensor.Float64: + for i := range rows { + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(i*cells+j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(i*cells+j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + for j := range cells { + o64.SetFloat1D(fun(in.Float1D(i*cells+j), o64.Float1D(j)), j) + } + } + } + for j := range cells { + out.SetFloat1D(o64.Float1D(j), j) + } + return o64 +} + +// VectorizePreOut64 is a version of [VectorizeOut64] that takes an additional +// tensor.Float64 input of pre-computed values, e.g., the means of each output cell. +func VectorizePreOut64(in tensor.Tensor, out tensor.Values, ini float64, pre *tensor.Float64, fun func(val, pre, agg float64) float64) *tensor.Float64 { + rows, cells := in.Shape().RowCellSize() + o64 := tensor.NewFloat64(cells) + if rows <= 0 { + return o64 + } + if cells == 1 { + out.SetShapeSizes(1) + agg := ini + prev := pre.Float1D(0) + switch x := in.(type) { + case *tensor.Float64: + for i := range rows { + agg = fun(x.Float1D(i), prev, agg) + } + case *tensor.Float32: + for i := range rows { + agg = fun(x.Float1D(i), prev, agg) + } + default: + for i := range rows { + agg = fun(in.Float1D(i), prev, agg) + } + } + o64.SetFloat1D(agg, 0) + out.SetFloat1D(agg, 0) + return o64 + } + osz := tensor.CellsSize(in.ShapeSizes()) + out.SetShapeSizes(osz...) + for j := range cells { + o64.SetFloat1D(ini, j) + } + switch x := in.(type) { + case *tensor.Float64: + for i := range rows { + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j) + } + } + case *tensor.Float32: + for i := range rows { + for j := range cells { + o64.SetFloat1D(fun(x.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j) + } + } + default: + for i := range rows { + for j := range cells { + o64.SetFloat1D(fun(in.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j) + } + } + } + for i := range cells { + out.SetFloat1D(o64.Float1D(i), i) + } + return o64 +} + +// Vectorize2Out64 is a version of [VectorizeOut64] that separately aggregates +// two output values, x and y as tensor.Float64. +func Vectorize2Out64(in tensor.Tensor, iniX, iniY float64, fun func(val, ox, oy float64) (float64, float64)) (ox64, oy64 *tensor.Float64) { + rows, cells := in.Shape().RowCellSize() + ox64 = tensor.NewFloat64(cells) + oy64 = tensor.NewFloat64(cells) + if rows <= 0 { + return ox64, oy64 + } + if cells == 1 { + ox := iniX + oy := iniY + switch x := in.(type) { + case *tensor.Float64: + for i := range rows { + ox, oy = fun(x.Float1D(i), ox, oy) + } + case *tensor.Float32: + for i := range rows { + ox, oy = fun(x.Float1D(i), ox, oy) + } + default: + for i := range rows { + ox, oy = fun(in.Float1D(i), ox, oy) + } + } + ox64.SetFloat1D(ox, 0) + oy64.SetFloat1D(oy, 0) + return + } + for j := range cells { + ox64.SetFloat1D(iniX, j) + oy64.SetFloat1D(iniY, j) + } + switch x := in.(type) { + case *tensor.Float64: + for i := range rows { + for j := range cells { + ox, oy := fun(x.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + case *tensor.Float32: + for i := range rows { + for j := range cells { + ox, oy := fun(x.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + default: + for i := range rows { + for j := range cells { + ox, oy := fun(in.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j)) + ox64.SetFloat1D(ox, j) + oy64.SetFloat1D(oy, j) + } + } + } + return +} diff --git a/tensor/string.go b/tensor/string.go index 4a04bfbe68..f2cc576ef3 100644 --- a/tensor/string.go +++ b/tensor/string.go @@ -6,13 +6,9 @@ package tensor import ( "fmt" - "log" - "math" "strconv" - "strings" - "cogentcore.org/core/base/slicesx" - "gonum.org/v1/gonum/mat" + "cogentcore.org/core/base/errors" ) // String is a tensor of string values @@ -21,10 +17,10 @@ type String struct { } // NewString returns a new n-dimensional tensor of string values -// with the given sizes per dimension (shape), and optional dimension names. -func NewString(sizes []int, names ...string) *String { +// with the given sizes per dimension (shape). +func NewString(sizes ...int) *String { tsr := &String{} - tsr.SetShape(sizes, names...) + tsr.SetShapeSizes(sizes...) tsr.Values = make([]string, tsr.Len()) return tsr } @@ -33,7 +29,7 @@ func NewString(sizes []int, names ...string) *String { // using given shape. func NewStringShape(shape *Shape) *String { tsr := &String{} - tsr.Shp.CopyShape(shape) + tsr.shape.CopyFrom(shape) tsr.Values = make([]string, tsr.Len()) return tsr } @@ -52,156 +48,138 @@ func Float64ToString(val float64) string { return strconv.FormatFloat(val, 'g', -1, 64) } +// String satisfies the fmt.Stringer interface for string of tensor data. +func (tsr *String) String() string { + return Sprintf("", tsr, 0) +} + func (tsr *String) IsString() bool { return true } -func (tsr *String) AddScalar(i []int, val float64) float64 { - j := tsr.Shp.Offset(i) - fv := StringToFloat64(tsr.Values[j]) + val - tsr.Values[j] = Float64ToString(fv) - return fv +func (tsr *String) AsValues() Values { return tsr } + +/////// Strings + +func (tsr *String) SetString(val string, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] = val } -func (tsr *String) MulScalar(i []int, val float64) float64 { - j := tsr.Shp.Offset(i) - fv := StringToFloat64(tsr.Values[j]) * val - tsr.Values[j] = Float64ToString(fv) - return fv +func (tsr *String) String1D(i int) string { + return tsr.Values[i] } -func (tsr *String) SetString(i []int, val string) { - j := tsr.Shp.Offset(i) - tsr.Values[j] = val +func (tsr *String) SetString1D(val string, i int) { + tsr.Values[NegIndex(i, len(tsr.Values))] = val } -func (tsr String) SetString1D(off int, val string) { - tsr.Values[off] = val +func (tsr *String) StringRow(row, cell int) string { + _, sz := tsr.shape.RowCellSize() + return tsr.Values[row*sz+cell] } -func (tsr *String) SetStringRowCell(row, cell int, val string) { - _, sz := tsr.Shp.RowCellSize() +func (tsr *String) SetStringRow(val string, row, cell int) { + _, sz := tsr.shape.RowCellSize() tsr.Values[row*sz+cell] = val } -// String satisfies the fmt.Stringer interface for string of tensor data -func (tsr *String) String() string { - str := tsr.Label() - sz := len(tsr.Values) - if sz > 1000 { - return str +// AppendRowString adds a row and sets string value(s), up to number of cells. +func (tsr *String) AppendRowString(val ...string) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) } - var b strings.Builder - b.WriteString(str) - b.WriteString("\n") - oddRow := true - rows, cols, _, _ := Projection2DShape(&tsr.Shp, oddRow) - for r := 0; r < rows; r++ { - rc, _ := Projection2DCoords(&tsr.Shp, oddRow, r, 0) - b.WriteString(fmt.Sprintf("%v: ", rc)) - for c := 0; c < cols; c++ { - idx := Projection2DIndex(tsr.Shape(), oddRow, r, c) - vl := tsr.Values[idx] - b.WriteString(vl) - } - b.WriteString("\n") + nrow, sz := tsr.shape.RowCellSize() + tsr.SetNumRows(nrow + 1) + mx := min(sz, len(val)) + for i := range mx { + tsr.SetStringRow(val[i], nrow, i) } - return b.String() } -func (tsr *String) Float(i []int) float64 { - j := tsr.Shp.Offset(i) - return StringToFloat64(tsr.Values[j]) +/////// Floats + +func (tsr *String) Float(i ...int) float64 { + return StringToFloat64(tsr.Values[tsr.shape.IndexTo1D(i...)]) } -func (tsr *String) SetFloat(i []int, val float64) { - j := tsr.Shp.Offset(i) - tsr.Values[j] = Float64ToString(val) +func (tsr *String) SetFloat(val float64, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] = Float64ToString(val) } -func (tsr *String) Float1D(off int) float64 { - return StringToFloat64(tsr.Values[off]) +func (tsr *String) Float1D(i int) float64 { + return StringToFloat64(tsr.Values[NegIndex(i, len(tsr.Values))]) } -func (tsr *String) SetFloat1D(off int, val float64) { - tsr.Values[off] = Float64ToString(val) +func (tsr *String) SetFloat1D(val float64, i int) { + tsr.Values[NegIndex(i, len(tsr.Values))] = Float64ToString(val) } -func (tsr *String) FloatRowCell(row, cell int) float64 { - _, sz := tsr.Shp.RowCellSize() +func (tsr *String) FloatRow(row, cell int) float64 { + _, sz := tsr.shape.RowCellSize() return StringToFloat64(tsr.Values[row*sz+cell]) } -func (tsr *String) SetFloatRowCell(row, cell int, val float64) { - _, sz := tsr.Shp.RowCellSize() +func (tsr *String) SetFloatRow(val float64, row, cell int) { + _, sz := tsr.shape.RowCellSize() tsr.Values[row*sz+cell] = Float64ToString(val) } -// Floats sets []float64 slice of all elements in the tensor -// (length is ensured to be sufficient). -// This can be used for all of the gonum/floats methods -// for basic math, gonum/stats, etc. -func (tsr *String) Floats(flt *[]float64) { - *flt = slicesx.SetLength(*flt, len(tsr.Values)) - for i, v := range tsr.Values { - (*flt)[i] = StringToFloat64(v) +// AppendRowFloat adds a row and sets float value(s), up to number of cells. +func (tsr *String) AppendRowFloat(val ...float64) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + nrow, sz := tsr.shape.RowCellSize() + tsr.SetNumRows(nrow + 1) + mx := min(sz, len(val)) + for i := range mx { + tsr.SetFloatRow(val[i], nrow, i) } } -// SetFloats sets tensor values from a []float64 slice (copies values). -func (tsr *String) SetFloats(flt []float64) { - for i, v := range flt { - tsr.Values[i] = Float64ToString(v) - } +/////// Ints + +func (tsr *String) Int(i ...int) int { + return errors.Ignore1(strconv.Atoi(tsr.Values[tsr.shape.IndexTo1D(i...)])) } -// At is the gonum/mat.Matrix interface method for returning 2D matrix element at given -// row, column index. Assumes Row-major ordering and logs an error if NumDims < 2. -func (tsr *String) At(i, j int) float64 { - nd := tsr.NumDims() - if nd < 2 { - log.Println("tensor Dims gonum Matrix call made on Tensor with dims < 2") - return 0 - } else if nd == 2 { - return tsr.Float([]int{i, j}) - } else { - ix := make([]int, nd) - ix[nd-2] = i - ix[nd-1] = j - return tsr.Float(ix) - } +func (tsr *String) SetInt(val int, i ...int) { + tsr.Values[tsr.shape.IndexTo1D(i...)] = strconv.Itoa(val) } -// T is the gonum/mat.Matrix transpose method. -// It performs an implicit transpose by returning the receiver inside a Transpose. -func (tsr *String) T() mat.Matrix { - return mat.Transpose{tsr} +func (tsr *String) Int1D(i int) int { + return errors.Ignore1(strconv.Atoi(tsr.Values[NegIndex(i, len(tsr.Values))])) } -// Range returns the min, max (and associated indexes, -1 = no values) for the tensor. -// This is needed for display and is thus in the core api in optimized form -// Other math operations can be done using gonum/floats package. -func (tsr *String) Range() (min, max float64, minIndex, maxIndex int) { - minIndex = -1 - maxIndex = -1 - for j, vl := range tsr.Values { - fv := StringToFloat64(vl) - if math.IsNaN(fv) { - continue - } - if fv < min || minIndex < 0 { - min = fv - minIndex = j - } - if fv > max || maxIndex < 0 { - max = fv - maxIndex = j - } +func (tsr *String) SetInt1D(val int, i int) { + tsr.Values[NegIndex(i, len(tsr.Values))] = strconv.Itoa(val) +} + +func (tsr *String) IntRow(row, cell int) int { + _, sz := tsr.shape.RowCellSize() + return errors.Ignore1(strconv.Atoi(tsr.Values[row*sz+cell])) +} + +func (tsr *String) SetIntRow(val int, row, cell int) { + _, sz := tsr.shape.RowCellSize() + tsr.Values[row*sz+cell] = strconv.Itoa(val) +} + +// AppendRowInt adds a row and sets int value(s), up to number of cells. +func (tsr *String) AppendRowInt(val ...int) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + nrow, sz := tsr.shape.RowCellSize() + tsr.SetNumRows(nrow + 1) + mx := min(sz, len(val)) + for i := range mx { + tsr.SetIntRow(val[i], nrow, i) } - return } -// SetZeros is simple convenience function initialize all values to 0 +// SetZeros is a simple convenience function initialize all values to the +// zero value of the type (empty strings for string type). func (tsr *String) SetZeros() { for j := range tsr.Values { tsr.Values[j] = "" @@ -211,8 +189,8 @@ func (tsr *String) SetZeros() { // Clone clones this tensor, creating a duplicate copy of itself with its // own separate memory representation of all the values, and returns // that as a Tensor (which can be converted into the known type as needed). -func (tsr *String) Clone() Tensor { - csr := NewStringShape(&tsr.Shp) +func (tsr *String) Clone() Values { + csr := NewStringShape(&tsr.shape) copy(csr.Values, tsr.Values) return csr } @@ -220,21 +198,39 @@ func (tsr *String) Clone() Tensor { // CopyFrom copies all avail values from other tensor into this tensor, with an // optimized implementation if the other tensor is of the same type, and // otherwise it goes through appropriate standard type. -func (tsr *String) CopyFrom(frm Tensor) { +func (tsr *String) CopyFrom(frm Values) { if fsm, ok := frm.(*String); ok { copy(tsr.Values, fsm.Values) return } - sz := min(len(tsr.Values), frm.Len()) + sz := min(tsr.Len(), frm.Len()) for i := 0; i < sz; i++ { tsr.Values[i] = Float64ToString(frm.Float1D(i)) } } -// CopyShapeFrom copies just the shape from given source tensor -// calling SetShape with the shape params from source (see for more docs). -func (tsr *String) CopyShapeFrom(frm Tensor) { - tsr.SetShape(frm.Shape().Sizes, frm.Shape().Names...) +// AppendFrom appends values from other tensor into this tensor, +// which must have the same cell size as this tensor. +// It uses and optimized implementation if the other tensor +// is of the same type, and otherwise it goes through +// appropriate standard type. +func (tsr *String) AppendFrom(frm Values) error { + rows, cell := tsr.shape.RowCellSize() + frows, fcell := frm.Shape().RowCellSize() + if cell != fcell { + return fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell) + } + tsr.SetNumRows(rows + frows) + st := rows * cell + fsz := frows * fcell + if fsm, ok := frm.(*String); ok { + copy(tsr.Values[st:st+fsz], fsm.Values) + return nil + } + for i := 0; i < fsz; i++ { + tsr.Values[st+i] = Float64ToString(frm.Float1D(i)) + } + return nil } // CopyCellsFrom copies given range of values from other tensor into this tensor, @@ -242,7 +238,7 @@ func (tsr *String) CopyShapeFrom(frm Tensor) { // start = starting index on from Tensor to start copying from, and n = number of // values to copy. Uses an optimized implementation if the other tensor is // of the same type, and otherwise it goes through appropriate standard type. -func (tsr *String) CopyCellsFrom(frm Tensor, to, start, n int) { +func (tsr *String) CopyCellsFrom(frm Values, to, start, n int) { if fsm, ok := frm.(*String); ok { for i := 0; i < n; i++ { tsr.Values[to+i] = fsm.Values[start+i] @@ -260,8 +256,33 @@ func (tsr *String) CopyCellsFrom(frm Tensor, to, start, n int) { // will affect both), as its Values slice is a view onto the original (which // is why only inner-most contiguous supsaces are supported). // Use Clone() method to separate the two. -func (tsr *String) SubSpace(offs []int) Tensor { - b := tsr.subSpaceImpl(offs) +func (tsr *String) SubSpace(offs ...int) Values { + b := tsr.subSpaceImpl(offs...) rt := &String{Base: *b} return rt } + +// RowTensor is a convenience version of [RowMajor.SubSpace] to return the +// SubSpace for the outermost row dimension. [Rows] defines a version +// of this that indirects through the row indexes. +func (tsr *String) RowTensor(row int) Values { + return tsr.SubSpace(row) +} + +// SetRowTensor sets the values of the SubSpace at given row to given values. +func (tsr *String) SetRowTensor(val Values, row int) { + _, cells := tsr.shape.RowCellSize() + st := row * cells + mx := min(val.Len(), cells) + tsr.CopyCellsFrom(val, st, 0, mx) +} + +// AppendRow adds a row and sets values to given values. +func (tsr *String) AppendRow(val Values) { + if tsr.NumDims() == 0 { + tsr.SetShapeSizes(0) + } + nrow := tsr.DimSize(0) + tsr.SetNumRows(nrow + 1) + tsr.SetRowTensor(val, nrow) +} diff --git a/tensor/table/README.md b/tensor/table/README.md index e7ed404e77..55e3f1525a 100644 --- a/tensor/table/README.md +++ b/tensor/table/README.md @@ -6,130 +6,59 @@ See [examples/dataproc](examples/dataproc) for a demo of how to use this system for data analysis, paralleling the example in [Python Data Science](https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html) using pandas, to see directly how that translates into this framework. -As a general convention, it is safest, clearest, and quite fast to access columns by name instead of index (there is a map that caches the column indexes), so the base access method names generally take a column name argument, and those that take a column index have an `Index` suffix. In addition, we use the `Try` suffix for versions that return an error message. It is a bit painful for the writer of these methods but very convenient for the users. +Whereas an individual `Tensor` can only hold one data type, the `Table` allows coordinated storage and processing of heterogeneous data types, aligned by the outermost row dimension. The main `tensor` data processing functions are defined on the individual tensors (which are the universal computational element in the `tensor` system), but the coordinated row-wise indexing in the table is important for sorting or filtering a collection of data in the same way, and grouping data by a common set of "splits" for data analysis. Plotting is also driven by the table, with one column providing a shared X axis for the rest of the columns. -The following packages are included: +The `Table` mainly provides "infrastructure" methods for adding tensor columns and CSV (comma separated values, and related tab separated values, TSV) file reading and writing. Any function that can be performed on an individual column should be done using the `tensor.Rows` and `Tensor` methods directly. -* [bitslice](bitslice) is a Go slice of bytes `[]byte` that has methods for setting individual bits, as if it was a slice of bools, while being 8x more memory efficient. This is used for encoding null entries in `etensor`, and as a Tensor of bool / bits there as well, and is generally very useful for binary (boolean) data. +As a general convention, it is safest, clearest, and quite fast to access columns by name instead of index (there is a `map` from name to index), so the base access method names generally take a column name argument, and those that take a column index have an `Index` suffix. -* [etensor](etensor) is a Tensor (n-dimensional array) object. `etensor.Tensor` is an interface that applies to many different type-specific instances, such as `etensor.Float32`. A tensor is just a `etensor.Shape` plus a slice holding the specific data type. Our tensor is based directly on the [Apache Arrow](https://github.com/apache/arrow/tree/master/go) project's tensor, and it fully interoperates with it. Arrow tensors are designed to be read-only, and we needed some extra support to make our `etable.Table` work well, so we had to roll our own. Our tensors also interoperate fully with Gonum's 2D-specific Matrix type for the 2D case. +The table itself stores raw data `tensor.Tensor` values, and the `Column` (by name) and `ColumnByIndex` methods return a `tensor.Rows` with the `Indexes` pointing to the shared table-wide `Indexes` (which can be `nil` if standard sequential order is being used). -* [etable](etable) has the `etable.Table` DataTable / DataFrame object, which is useful for many different data analysis and database functions, and also for holding patterns to present to a neural network, and logs of output from the models, etc. A `etable.Table` is just a slice of `etensor.Tensor` columns, that are all aligned along the outer-most *row* dimension. Index-based indirection, which is essential for efficient Sort, Filter etc, is provided by the `etable.IndexView` type, which is an indexed view into a Table. All data processing operations are defined on the IndexView. +If you call Sort, Filter or other routines on an individual column tensor, then you can grab the updated indexes via the `IndexesFromTensor` method so that they apply to the entire table. The `SortColumn` and `FilterString` methods do this for you. -* [eplot](eplot) provides an interactive 2D plotting GUI in [GoGi](https://cogentcore.org/core/gi) for Table data, using the [gonum plot](https://github.com/gonum/plot) plotting package. You can select which columns to plot and specify various basic plot parameters. +There are also multi-column `Sort` and `Filter` methods on the Table itself. -* [tensorcore](tensorcore) provides an interactive tabular, spreadsheet-style GUI using [GoGi](https://cogentcore.org/core/gi) for viewing and editing `etable.Table` and `etable.Tensor` objects. The `tensorcore.TensorGrid` also provides a colored grid display higher-dimensional tensor data. - -* [agg](agg) provides standard aggregation functions (`Sum`, `Mean`, `Var`, `Std` etc) operating over `etable.IndexView` views of Table data. It also defines standard `AggFunc` functions such as `SumFunc` which can be used for `Agg` functions on either a Tensor or IndexView. - -* [tsragg](tsragg) provides the same agg functions as in `agg`, but operating on all the values in a given `Tensor`. Because of the indexed, row-based nature of tensors in a Table, these are not the same as the `agg` functions. - -* [split](split) supports splitting a Table into any number of indexed sub-views and aggregating over those (i.e., pivot tables), grouping, summarizing data, etc. - -* [metric](metric) provides similarity / distance metrics such as `Euclidean`, `Cosine`, or `Correlation` that operate on slices of `[]float64` or `[]float32`. - -* [simat](simat) provides similarity / distance matrix computation methods operating on `etensor.Tensor` or `etable.Table` data. The `SimMat` type holds the resulting matrix and labels for the rows and columns, which has a special `SimMatGrid` view in `etview` for visualizing labeled similarity matricies. - -* [pca](pca) provides principal-components-analysis (PCA) and covariance matrix computation functions. - -* [clust](clust) provides standard agglomerative hierarchical clustering including ability to plot results in an eplot. - -* [minmax](minmax) is home of basic Min / Max range struct, and `norm` has lots of good functions for computing standard norms and normalizing vectors. - -* [utils](utils) has various table-related utility command-line utility tools, including `etcat` which combines multiple table files into one file, including option for averaging column data. +It is very low-cost to create a new View of an existing Table, via `NewView`, as they can share the underlying `Columns` data. # Cheat Sheet -`et` is the etable pointer variable for examples below: +`dt` is the Table pointer variable for examples below: ## Table Access -Scalar columns: +Column data access: ```Go -val := et.Float("ColName", row) +// FloatRow is a method on the `tensor.Rows` returned from the `Column` method. +// This is the best method to use in general for generic 1D data access, +// as it works on any data from 1D on up (although it only samples the first value +// from higher dimensional data) . +val := dt.Column("Values").FloatRow(3) ``` ```Go -str := et.StringValue("ColName", row) +dt.Column("Name").SetStringRow(4) ``` -Tensor (higher-dimensional) columns: +To access higher-dimensional "cell" level data using a simple 1D index into the cell patterns: ```Go -tsr := et.Tensor("ColName", row) // entire tensor at cell (a row-level SubSpace of column tensor) +// FloatRow is a method on the `tensor.Rows` returned from the `Column` method. +// This is the best method to use in general for generic 1D data access, +// as it works on any data from 1D on up (although it only samples the first value +// from higher dimensional data) . +val := dt.Column("Values").FloatRow(3, 2) ``` ```Go -val := et.TensorFloat1D("ColName", row, cellidx) // idx is 1D index into cell tensor +dt.Column("Name").SetStringRow("Alia", 4, 1) ``` -## Set Table Value +todo: more -```Go -et.SetFloat("ColName", row, val) -``` +## Sorting and Filtering -```Go -et.SetString("ColName", row, str) -``` - -Tensor (higher-dimensional) columns: - -```Go -et.SetTensor("ColName", row, tsr) // set entire tensor at cell -``` - -```Go -et.SetTensorFloat1D("ColName", row, cellidx, val) // idx is 1D index into cell tensor -``` - -## Find Value(s) in Column - -Returns all rows where value matches given value, in string form (any number will convert to a string) - -```Go -rows := et.RowsByString("ColName", "value", etable.Contains, etable.IgnoreCase) -``` - -Other options are `etable.Equals` instead of `Contains` to search for an exact full string, and `etable.UseCase` if case should be used instead of ignored. - -## Index Views (Sort, Filter, etc) - -The [IndexView](https://godoc.org/github.com/goki/etable/v2/etable#IndexView) provides a list of row-wise indexes into a table, and Sorting, Filtering and Splitting all operate on this index view without changing the underlying table data, for maximum efficiency and flexibility. - -```Go -ix := etable.NewIndexView(et) // new view with all rows -``` - -### Sort - -```Go -ix.SortColumnName("Name", etable.Ascending) // etable.Ascending or etable.Descending -SortedTable := ix.NewTable() // turn an IndexView back into a new Table organized in order of indexes -``` - -or: - -```Go -nmcl := et.ColumnByName("Name") // nmcl is an etensor of the Name column, cached -ix.Sort(func(t *Table, i, j int) bool { - return nmcl.StringValue1D(i) < nmcl.StringValue1D(j) -}) -``` - -### Filter - -```Go -nmcl := et.ColumnByName("Name") // column we're filtering on -ix.Filter(func(t *Table, row int) bool { - // filter return value is for what to *keep* (=true), not exclude - // here we keep any row with a name that contains the string "in" - return strings.Contains(nmcl.StringValue1D(row), "in") -}) -``` - -### Splits ("pivot tables" etc), Aggregation +## Splits ("pivot tables" etc), Aggregation Create a table of mean values of "Data" column grouped by unique entries in "Name" column, resulting table will be called "DataMean": @@ -142,7 +71,7 @@ gps := byNm.AggsToTable(etable.AddAggName) // etable.AddAggName or etable.ColNam Describe (basic stats) all columns in a table: ```Go -ix := etable.NewIndexView(et) // new view with all rows +ix := etable.NewRows(et) // new view with all rows desc := agg.DescAll(ix) // summary stats of all columns // get value at given column name (from original table), row "Mean" mean := desc.Float("ColNm", desc.RowsByString("Agg", "Mean", etable.Equals, etable.UseCase)[0]) @@ -163,12 +92,12 @@ Here is the mapping of special header prefix characters to standard types: '#': etensor.FLOAT64, '|': etensor.INT64, '@': etensor.UINT8, -'^': etensor.BOOl, +'^': etensor.BOOL, ``` Columns that have tensor cell shapes (not just scalars) are marked as such with the *first* such column having a `` suffix indicating the shape of the *cells* in this column, e.g., `<2:5,4>` indicates a 2D cell Y=5,X=4. Each individual column is then indexed as `[ndims:x,y..]` e.g., the first would be `[2:0,0]`, then `[2:0,1]` etc. -### Example +## Example Here's a TSV file for a scalar String column (`Name`), a 2D 1x4 tensor float32 column (`Input`), and a 2D 1x2 float32 `Output` column. @@ -180,6 +109,8 @@ _D: Event_2 0 0 1 0 0 1 _D: Event_3 0 0 0 1 0 1 ``` +## Logging one row at a time + diff --git a/tensor/table/columns.go b/tensor/table/columns.go new file mode 100644 index 0000000000..811ee1b04e --- /dev/null +++ b/tensor/table/columns.go @@ -0,0 +1,89 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package table + +import ( + "cogentcore.org/core/base/keylist" + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/tensor" +) + +// Columns is the underlying column list and number of rows for Table. +// Each column is a raw [tensor.Values] tensor, and [Table] +// provides a [tensor.Rows] indexed view onto the Columns. +type Columns struct { + keylist.List[string, tensor.Values] + + // number of rows, which is enforced to be the size of the + // outermost row dimension of the column tensors. + Rows int `edit:"-"` +} + +// NewColumns returns a new Columns. +func NewColumns() *Columns { + return &Columns{} +} + +// SetNumRows sets the number of rows in the table, across all columns. +// It is safe to set this to 0. For incrementally growing tables (e.g., a log) +// it is best to first set the anticipated full size, which allocates the +// full amount of memory, and then set to 0 and grow incrementally. +func (cl *Columns) SetNumRows(rows int) *Columns { //types:add + cl.Rows = rows // can be 0 + for _, tsr := range cl.Values { + tsr.SetNumRows(rows) + } + return cl +} + +// AddColumn adds the given tensor (as a [tensor.Values]) as a column, +// returning an error and not adding if the name is not unique. +// Automatically adjusts the shape to fit the current number of rows, +// (setting Rows if this is the first column added) +// and calls the metadata SetName with column name. +func (cl *Columns) AddColumn(name string, tsr tensor.Values) error { + if cl.Len() == 0 { + cl.Rows = tsr.DimSize(0) + } + err := cl.Add(name, tsr) + if err != nil { + return err + } + tsr.SetNumRows(cl.Rows) + metadata.SetName(tsr, name) + return nil +} + +// InsertColumn inserts the given tensor as a column at given index, +// returning an error and not adding if the name is not unique. +// Automatically adjusts the shape to fit the current number of rows. +func (cl *Columns) InsertColumn(idx int, name string, tsr tensor.Values) error { + cl.Insert(idx, name, tsr) + tsr.SetNumRows(cl.Rows) + return nil +} + +// Clone returns a complete copy of this set of columns. +func (cl *Columns) Clone() *Columns { + cp := NewColumns().SetNumRows(cl.Rows) + for i, nm := range cl.Keys { + tsr := cl.Values[i] + cp.AddColumn(nm, tsr.Clone()) + } + return cl +} + +// AppendRows appends shared columns in both tables with input table rows. +func (cl *Columns) AppendRows(cl2 *Columns) { + for i, nm := range cl.Keys { + c2 := cl2.At(nm) + if c2 == nil { + continue + } + c1 := cl.Values[i] + c1.AppendFrom(c2) + } + cl.SetNumRows(cl.Rows + cl2.Rows) +} diff --git a/tensor/table/enumgen.go b/tensor/table/enumgen.go deleted file mode 100644 index c4eda161e5..0000000000 --- a/tensor/table/enumgen.go +++ /dev/null @@ -1,46 +0,0 @@ -// Code generated by "core generate"; DO NOT EDIT. - -package table - -import ( - "cogentcore.org/core/enums" -) - -var _DelimsValues = []Delims{0, 1, 2, 3} - -// DelimsN is the highest valid value for type Delims, plus one. -const DelimsN Delims = 4 - -var _DelimsValueMap = map[string]Delims{`Tab`: 0, `Comma`: 1, `Space`: 2, `Detect`: 3} - -var _DelimsDescMap = map[Delims]string{0: `Tab is the tab rune delimiter, for TSV tab separated values`, 1: `Comma is the comma rune delimiter, for CSV comma separated values`, 2: `Space is the space rune delimiter, for SSV space separated value`, 3: `Detect is used during reading a file -- reads the first line and detects tabs or commas`} - -var _DelimsMap = map[Delims]string{0: `Tab`, 1: `Comma`, 2: `Space`, 3: `Detect`} - -// String returns the string representation of this Delims value. -func (i Delims) String() string { return enums.String(i, _DelimsMap) } - -// SetString sets the Delims value from its string representation, -// and returns an error if the string is invalid. -func (i *Delims) SetString(s string) error { return enums.SetString(i, s, _DelimsValueMap, "Delims") } - -// Int64 returns the Delims value as an int64. -func (i Delims) Int64() int64 { return int64(i) } - -// SetInt64 sets the Delims value from an int64. -func (i *Delims) SetInt64(in int64) { *i = Delims(in) } - -// Desc returns the description of the Delims value. -func (i Delims) Desc() string { return enums.Desc(i, _DelimsDescMap) } - -// DelimsValues returns all possible values for the type Delims. -func DelimsValues() []Delims { return _DelimsValues } - -// Values returns all possible values for the type Delims. -func (i Delims) Values() []enums.Enum { return enums.Values(_DelimsValues) } - -// MarshalText implements the [encoding.TextMarshaler] interface. -func (i Delims) MarshalText() ([]byte, error) { return []byte(i.String()), nil } - -// UnmarshalText implements the [encoding.TextUnmarshaler] interface. -func (i *Delims) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Delims") } diff --git a/tensor/table/indexes.go b/tensor/table/indexes.go new file mode 100644 index 0000000000..c01554c8be --- /dev/null +++ b/tensor/table/indexes.go @@ -0,0 +1,250 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package table + +import ( + "math/rand" + "slices" + "sort" + + "cogentcore.org/core/tensor" +) + +// RowIndex returns the actual index into underlying tensor row based on given +// index value. If Indexes == nil, index is passed through. +func (dt *Table) RowIndex(idx int) int { + if dt.Indexes == nil { + return idx + } + return dt.Indexes[idx] +} + +// NumRows returns the number of rows, which is the number of Indexes if present, +// else actual number of [Columns.Rows]. +func (dt *Table) NumRows() int { + if dt.Indexes == nil { + return dt.Columns.Rows + } + return len(dt.Indexes) +} + +// Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor. +func (dt *Table) Sequential() { //types:add + dt.Indexes = nil +} + +// IndexesNeeded is called prior to an operation that needs actual indexes, +// e.g., Sort, Filter. If Indexes == nil, they are set to all rows, otherwise +// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure +// all rows are represented. +func (dt *Table) IndexesNeeded() { + if dt.Indexes != nil { + return + } + dt.Indexes = make([]int, dt.Columns.Rows) + for i := range dt.Indexes { + dt.Indexes[i] = i + } +} + +// IndexesFromTensor copies Indexes from the given [tensor.Rows] tensor, +// including if they are nil. This allows column-specific Sort, Filter and +// other such methods to be applied to the entire table. +func (dt *Table) IndexesFromTensor(ix *tensor.Rows) { + dt.Indexes = ix.Indexes +} + +// ValidIndexes deletes all invalid indexes from the list. +// Call this if rows (could) have been deleted from table. +func (dt *Table) ValidIndexes() { + if dt.Columns.Rows <= 0 || dt.Indexes == nil { + dt.Indexes = nil + return + } + ni := dt.NumRows() + for i := ni - 1; i >= 0; i-- { + if dt.Indexes[i] >= dt.Columns.Rows { + dt.Indexes = append(dt.Indexes[:i], dt.Indexes[i+1:]...) + } + } +} + +// Permuted sets indexes to a permuted order -- if indexes already exist +// then existing list of indexes is permuted, otherwise a new set of +// permuted indexes are generated +func (dt *Table) Permuted() { + if dt.Columns.Rows <= 0 { + dt.Indexes = nil + return + } + if dt.Indexes == nil { + dt.Indexes = rand.Perm(dt.Columns.Rows) + } else { + rand.Shuffle(len(dt.Indexes), func(i, j int) { + dt.Indexes[i], dt.Indexes[j] = dt.Indexes[j], dt.Indexes[i] + }) + } +} + +// SortColumn sorts the indexes into our Table according to values in +// given column, using either ascending or descending order, +// (use [tensor.Ascending] or [tensor.Descending] for self-documentation). +// Uses first cell of higher dimensional data. +// Returns error if column name not found. +func (dt *Table) SortColumn(columnName string, ascending bool) error { //types:add + dt.IndexesNeeded() + cl, err := dt.ColumnTry(columnName) + if err != nil { + return err + } + cl.Sort(ascending) + dt.IndexesFromTensor(cl) + return nil +} + +// SortFunc sorts the indexes into our Table using given compare function. +// The compare function operates directly on row numbers into the Table +// as these row numbers have already been projected through the indexes. +// cmp(a, b) should return a negative number when a < b, a positive +// number when a > b and zero when a == b. +func (dt *Table) SortFunc(cmp func(dt *Table, i, j int) int) { + dt.IndexesNeeded() + slices.SortFunc(dt.Indexes, func(a, b int) int { + return cmp(dt, a, b) // key point: these are already indirected through indexes!! + }) +} + +// SortStableFunc stably sorts the indexes into our Table using given compare function. +// The compare function operates directly on row numbers into the Table +// as these row numbers have already been projected through the indexes. +// cmp(a, b) should return a negative number when a < b, a positive +// number when a > b and zero when a == b. +// It is *essential* that it always returns 0 when the two are equal +// for the stable function to actually work. +func (dt *Table) SortStableFunc(cmp func(dt *Table, i, j int) int) { + dt.IndexesNeeded() + slices.SortStableFunc(dt.Indexes, func(a, b int) int { + return cmp(dt, a, b) // key point: these are already indirected through indexes!! + }) +} + +// SortColumns sorts the indexes into our Table according to values in +// given column names, using either ascending or descending order, +// (use [tensor.Ascending] or [tensor.Descending] for self-documentation, +// and optionally using a stable sort. +// Uses first cell of higher dimensional data. +func (dt *Table) SortColumns(ascending, stable bool, columns ...string) { //types:add + dt.SortColumnIndexes(ascending, stable, dt.ColumnIndexList(columns...)...) +} + +// SortColumnIndexes sorts the indexes into our Table according to values in +// given list of column indexes, using either ascending or descending order for +// all of the columns. Uses first cell of higher dimensional data. +func (dt *Table) SortColumnIndexes(ascending, stable bool, colIndexes ...int) { + dt.IndexesNeeded() + sf := dt.SortFunc + if stable { + sf = dt.SortStableFunc + } + sf(func(dt *Table, i, j int) int { + for _, ci := range colIndexes { + cl := dt.ColumnByIndex(ci).Tensor + if cl.IsString() { + v := tensor.CompareAscending(cl.StringRow(i, 0), cl.StringRow(j, 0), ascending) + if v != 0 { + return v + } + } else { + v := tensor.CompareAscending(cl.FloatRow(i, 0), cl.FloatRow(j, 0), ascending) + if v != 0 { + return v + } + } + } + return 0 + }) +} + +// SortIndexes sorts the indexes into our Table directly in +// numerical order, producing the native ordering, while preserving +// any filtering that might have occurred. +func (dt *Table) SortIndexes() { + if dt.Indexes == nil { + return + } + sort.Ints(dt.Indexes) +} + +// FilterFunc is a function used for filtering that returns +// true if Table row should be included in the current filtered +// view of the table, and false if it should be removed. +type FilterFunc func(dt *Table, row int) bool + +// Filter filters the indexes into our Table using given Filter function. +// The Filter function operates directly on row numbers into the Table +// as these row numbers have already been projected through the indexes. +func (dt *Table) Filter(filterer func(dt *Table, row int) bool) { + dt.IndexesNeeded() + sz := len(dt.Indexes) + for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering + if !filterer(dt, dt.Indexes[i]) { // delete + dt.Indexes = append(dt.Indexes[:i], dt.Indexes[i+1:]...) + } + } +} + +// FilterString filters the indexes using string values in column compared to given +// string. Includes rows with matching values unless the Exclude option is set. +// If Contains option is set, it only checks if row contains string; +// if IgnoreCase, ignores case, otherwise filtering is case sensitive. +// Uses first cell from higher dimensions. +// Returns error if column name not found. +func (dt *Table) FilterString(columnName string, str string, opts tensor.FilterOptions) error { //types:add + dt.IndexesNeeded() + cl, err := dt.ColumnTry(columnName) + if err != nil { + return err + } + cl.FilterString(str, opts) + dt.IndexesFromTensor(cl) + return nil +} + +// New returns a new table with column data organized according to +// the indexes. If Indexes are nil, a clone of the current tensor is returned +// but this function is only sensible if there is an indexed view in place. +func (dt *Table) New() *Table { + if dt.Indexes == nil { + return dt.Clone() + } + rows := len(dt.Indexes) + nt := dt.Clone() + nt.Indexes = nil + nt.SetNumRows(rows) + if rows == 0 { + return nt + } + for ci, cl := range nt.Columns.Values { + scl := dt.Columns.Values[ci] + _, csz := cl.Shape().RowCellSize() + for i, srw := range dt.Indexes { + cl.CopyCellsFrom(scl, i*csz, srw*csz, csz) + } + } + return nt +} + +// DeleteRows deletes n rows of Indexes starting at given index in the list of indexes. +// This does not affect the underlying tensor data; To create an actual in-memory +// ordering with rows deleted, use [Table.New]. +func (dt *Table) DeleteRows(at, n int) { + dt.IndexesNeeded() + dt.Indexes = append(dt.Indexes[:at], dt.Indexes[at+n:]...) +} + +// Swap switches the indexes for i and j +func (dt *Table) Swap(i, j int) { + dt.Indexes[i], dt.Indexes[j] = dt.Indexes[j], dt.Indexes[i] +} diff --git a/tensor/table/indexview.go b/tensor/table/indexview.go deleted file mode 100644 index 5a8da3e570..0000000000 --- a/tensor/table/indexview.go +++ /dev/null @@ -1,520 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package table - -import ( - "fmt" - "log" - "math/rand" - "slices" - "sort" - "strings" -) - -// LessFunc is a function used for sort comparisons that returns -// true if Table row i is less than Table row j -- these are the -// raw row numbers, which have already been projected through -// indexes when used for sorting via Indexes. -type LessFunc func(et *Table, i, j int) bool - -// Filterer is a function used for filtering that returns -// true if Table row should be included in the current filtered -// view of the table, and false if it should be removed. -type Filterer func(et *Table, row int) bool - -// IndexView is an indexed wrapper around an table.Table that provides a -// specific view onto the Table defined by the set of indexes. -// This provides an efficient way of sorting and filtering a table by only -// updating the indexes while doing nothing to the Table itself. -// To produce a table that has data actually organized according to the -// indexed order, call the NewTable method. -// IndexView views on a table can also be organized together as Splits -// of the table rows, e.g., by grouping values along a given column. -type IndexView struct { //types:add - - // Table that we are an indexed view onto - Table *Table - - // current indexes into Table - Indexes []int - - // current Less function used in sorting - lessFunc LessFunc `copier:"-" display:"-" xml:"-" json:"-"` -} - -// NewIndexView returns a new IndexView based on given table, initialized with sequential idxes -func NewIndexView(et *Table) *IndexView { - ix := &IndexView{} - ix.SetTable(et) - return ix -} - -// SetTable sets as indexes into given table with sequential initial indexes -func (ix *IndexView) SetTable(et *Table) { - ix.Table = et - ix.Sequential() -} - -// DeleteInvalid deletes all invalid indexes from the list. -// Call this if rows (could) have been deleted from table. -func (ix *IndexView) DeleteInvalid() { - if ix.Table == nil || ix.Table.Rows <= 0 { - ix.Indexes = nil - return - } - ni := ix.Len() - for i := ni - 1; i >= 0; i-- { - if ix.Indexes[i] >= ix.Table.Rows { - ix.Indexes = append(ix.Indexes[:i], ix.Indexes[i+1:]...) - } - } -} - -// Sequential sets indexes to sequential row-wise indexes into table -func (ix *IndexView) Sequential() { //types:add - if ix.Table == nil || ix.Table.Rows <= 0 { - ix.Indexes = nil - return - } - ix.Indexes = make([]int, ix.Table.Rows) - for i := range ix.Indexes { - ix.Indexes[i] = i - } -} - -// Permuted sets indexes to a permuted order -- if indexes already exist -// then existing list of indexes is permuted, otherwise a new set of -// permuted indexes are generated -func (ix *IndexView) Permuted() { - if ix.Table == nil || ix.Table.Rows <= 0 { - ix.Indexes = nil - return - } - if len(ix.Indexes) == 0 { - ix.Indexes = rand.Perm(ix.Table.Rows) - } else { - rand.Shuffle(len(ix.Indexes), func(i, j int) { - ix.Indexes[i], ix.Indexes[j] = ix.Indexes[j], ix.Indexes[i] - }) - } -} - -// AddIndex adds a new index to the list -func (ix *IndexView) AddIndex(idx int) { - ix.Indexes = append(ix.Indexes, idx) -} - -// Sort sorts the indexes into our Table using given Less function. -// The Less function operates directly on row numbers into the Table -// as these row numbers have already been projected through the indexes. -func (ix *IndexView) Sort(lessFunc func(et *Table, i, j int) bool) { - ix.lessFunc = lessFunc - sort.Sort(ix) -} - -// SortIndexes sorts the indexes into our Table directly in -// numerical order, producing the native ordering, while preserving -// any filtering that might have occurred. -func (ix *IndexView) SortIndexes() { - sort.Ints(ix.Indexes) -} - -const ( - // Ascending specifies an ascending sort direction for table Sort routines - Ascending = true - - // Descending specifies a descending sort direction for table Sort routines - Descending = false -) - -// SortColumnName sorts the indexes into our Table according to values in -// given column name, using either ascending or descending order. -// Only valid for 1-dimensional columns. -// Returns error if column name not found. -func (ix *IndexView) SortColumnName(column string, ascending bool) error { //types:add - ci, err := ix.Table.ColumnIndex(column) - if err != nil { - log.Println(err) - return err - } - ix.SortColumn(ci, ascending) - return nil -} - -// SortColumn sorts the indexes into our Table according to values in -// given column index, using either ascending or descending order. -// Only valid for 1-dimensional columns. -func (ix *IndexView) SortColumn(colIndex int, ascending bool) { - cl := ix.Table.Columns[colIndex] - if cl.IsString() { - ix.Sort(func(et *Table, i, j int) bool { - if ascending { - return cl.String1D(i) < cl.String1D(j) - } else { - return cl.String1D(i) > cl.String1D(j) - } - }) - } else { - ix.Sort(func(et *Table, i, j int) bool { - if ascending { - return cl.Float1D(i) < cl.Float1D(j) - } else { - return cl.Float1D(i) > cl.Float1D(j) - } - }) - } -} - -// SortColumnNames sorts the indexes into our Table according to values in -// given column names, using either ascending or descending order. -// Only valid for 1-dimensional columns. -// Returns error if column name not found. -func (ix *IndexView) SortColumnNames(columns []string, ascending bool) error { - nc := len(columns) - if nc == 0 { - return fmt.Errorf("table.IndexView.SortColumnNames: no column names provided") - } - cis := make([]int, nc) - for i, cn := range columns { - ci, err := ix.Table.ColumnIndex(cn) - if err != nil { - log.Println(err) - return err - } - cis[i] = ci - } - ix.SortColumns(cis, ascending) - return nil -} - -// SortColumns sorts the indexes into our Table according to values in -// given list of column indexes, using either ascending or descending order for -// all of the columns. Only valid for 1-dimensional columns. -func (ix *IndexView) SortColumns(colIndexes []int, ascending bool) { - ix.Sort(func(et *Table, i, j int) bool { - for _, ci := range colIndexes { - cl := ix.Table.Columns[ci] - if cl.IsString() { - if ascending { - if cl.String1D(i) < cl.String1D(j) { - return true - } else if cl.String1D(i) > cl.String1D(j) { - return false - } // if equal, fallthrough to next col - } else { - if cl.String1D(i) > cl.String1D(j) { - return true - } else if cl.String1D(i) < cl.String1D(j) { - return false - } // if equal, fallthrough to next col - } - } else { - if ascending { - if cl.Float1D(i) < cl.Float1D(j) { - return true - } else if cl.Float1D(i) > cl.Float1D(j) { - return false - } // if equal, fallthrough to next col - } else { - if cl.Float1D(i) > cl.Float1D(j) { - return true - } else if cl.Float1D(i) < cl.Float1D(j) { - return false - } // if equal, fallthrough to next col - } - } - } - return false - }) -} - -///////////////////////////////////////////////////////////////////////// -// Stable sorts -- sometimes essential.. - -// SortStable stably sorts the indexes into our Table using given Less function. -// The Less function operates directly on row numbers into the Table -// as these row numbers have already been projected through the indexes. -// It is *essential* that it always returns false when the two are equal -// for the stable function to actually work. -func (ix *IndexView) SortStable(lessFunc func(et *Table, i, j int) bool) { - ix.lessFunc = lessFunc - sort.Stable(ix) -} - -// SortStableColumnName sorts the indexes into our Table according to values in -// given column name, using either ascending or descending order. -// Only valid for 1-dimensional columns. -// Returns error if column name not found. -func (ix *IndexView) SortStableColumnName(column string, ascending bool) error { - ci, err := ix.Table.ColumnIndex(column) - if err != nil { - log.Println(err) - return err - } - ix.SortStableColumn(ci, ascending) - return nil -} - -// SortStableColumn sorts the indexes into our Table according to values in -// given column index, using either ascending or descending order. -// Only valid for 1-dimensional columns. -func (ix *IndexView) SortStableColumn(colIndex int, ascending bool) { - cl := ix.Table.Columns[colIndex] - if cl.IsString() { - ix.SortStable(func(et *Table, i, j int) bool { - if ascending { - return cl.String1D(i) < cl.String1D(j) - } else { - return cl.String1D(i) > cl.String1D(j) - } - }) - } else { - ix.SortStable(func(et *Table, i, j int) bool { - if ascending { - return cl.Float1D(i) < cl.Float1D(j) - } else { - return cl.Float1D(i) > cl.Float1D(j) - } - }) - } -} - -// SortStableColumnNames sorts the indexes into our Table according to values in -// given column names, using either ascending or descending order. -// Only valid for 1-dimensional columns. -// Returns error if column name not found. -func (ix *IndexView) SortStableColumnNames(columns []string, ascending bool) error { - nc := len(columns) - if nc == 0 { - return fmt.Errorf("table.IndexView.SortStableColumnNames: no column names provided") - } - cis := make([]int, nc) - for i, cn := range columns { - ci, err := ix.Table.ColumnIndex(cn) - if err != nil { - log.Println(err) - return err - } - cis[i] = ci - } - ix.SortStableColumns(cis, ascending) - return nil -} - -// SortStableColumns sorts the indexes into our Table according to values in -// given list of column indexes, using either ascending or descending order for -// all of the columns. Only valid for 1-dimensional columns. -func (ix *IndexView) SortStableColumns(colIndexes []int, ascending bool) { - ix.SortStable(func(et *Table, i, j int) bool { - for _, ci := range colIndexes { - cl := ix.Table.Columns[ci] - if cl.IsString() { - if ascending { - if cl.String1D(i) < cl.String1D(j) { - return true - } else if cl.String1D(i) > cl.String1D(j) { - return false - } // if equal, fallthrough to next col - } else { - if cl.String1D(i) > cl.String1D(j) { - return true - } else if cl.String1D(i) < cl.String1D(j) { - return false - } // if equal, fallthrough to next col - } - } else { - if ascending { - if cl.Float1D(i) < cl.Float1D(j) { - return true - } else if cl.Float1D(i) > cl.Float1D(j) { - return false - } // if equal, fallthrough to next col - } else { - if cl.Float1D(i) > cl.Float1D(j) { - return true - } else if cl.Float1D(i) < cl.Float1D(j) { - return false - } // if equal, fallthrough to next col - } - } - } - return false - }) -} - -// Filter filters the indexes into our Table using given Filter function. -// The Filter function operates directly on row numbers into the Table -// as these row numbers have already been projected through the indexes. -func (ix *IndexView) Filter(filterer func(et *Table, row int) bool) { - sz := len(ix.Indexes) - for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering - if !filterer(ix.Table, ix.Indexes[i]) { // delete - ix.Indexes = append(ix.Indexes[:i], ix.Indexes[i+1:]...) - } - } -} - -// FilterColumnName filters the indexes into our Table according to values in -// given column name, using string representation of column values. -// Includes rows with matching values unless exclude is set. -// If contains, only checks if row contains string; if ignoreCase, ignores case. -// Use named args for greater clarity. -// Only valid for 1-dimensional columns. -// Returns error if column name not found. -func (ix *IndexView) FilterColumnName(column string, str string, exclude, contains, ignoreCase bool) error { //types:add - ci, err := ix.Table.ColumnIndex(column) - if err != nil { - log.Println(err) - return err - } - ix.FilterColumn(ci, str, exclude, contains, ignoreCase) - return nil -} - -// FilterColumn sorts the indexes into our Table according to values in -// given column index, using string representation of column values. -// Includes rows with matching values unless exclude is set. -// If contains, only checks if row contains string; if ignoreCase, ignores case. -// Use named args for greater clarity. -// Only valid for 1-dimensional columns. -func (ix *IndexView) FilterColumn(colIndex int, str string, exclude, contains, ignoreCase bool) { - col := ix.Table.Columns[colIndex] - lowstr := strings.ToLower(str) - ix.Filter(func(et *Table, row int) bool { - val := col.String1D(row) - has := false - switch { - case contains && ignoreCase: - has = strings.Contains(strings.ToLower(val), lowstr) - case contains: - has = strings.Contains(val, str) - case ignoreCase: - has = strings.EqualFold(val, str) - default: - has = (val == str) - } - if exclude { - return !has - } - return has - }) -} - -// NewTable returns a new table with column data organized according to -// the indexes -func (ix *IndexView) NewTable() *Table { - rows := len(ix.Indexes) - nt := ix.Table.Clone() - nt.SetNumRows(rows) - if rows == 0 { - return nt - } - for ci := range nt.Columns { - scl := ix.Table.Columns[ci] - tcl := nt.Columns[ci] - _, csz := tcl.RowCellSize() - for i, srw := range ix.Indexes { - tcl.CopyCellsFrom(scl, i*csz, srw*csz, csz) - } - } - return nt -} - -// Clone returns a copy of the current index view with its own index memory -func (ix *IndexView) Clone() *IndexView { - nix := &IndexView{} - nix.CopyFrom(ix) - return nix -} - -// CopyFrom copies from given other IndexView (we have our own unique copy of indexes) -func (ix *IndexView) CopyFrom(oix *IndexView) { - ix.Table = oix.Table - ix.Indexes = slices.Clone(oix.Indexes) -} - -// AddRows adds n rows to end of underlying Table, and to the indexes in this view -func (ix *IndexView) AddRows(n int) { //types:add - stidx := ix.Table.Rows - ix.Table.SetNumRows(stidx + n) - for i := stidx; i < stidx+n; i++ { - ix.Indexes = append(ix.Indexes, i) - } -} - -// InsertRows adds n rows to end of underlying Table, and to the indexes starting at -// given index in this view -func (ix *IndexView) InsertRows(at, n int) { - stidx := ix.Table.Rows - ix.Table.SetNumRows(stidx + n) - nw := make([]int, n, n+len(ix.Indexes)-at) - for i := 0; i < n; i++ { - nw[i] = stidx + i - } - ix.Indexes = append(ix.Indexes[:at], append(nw, ix.Indexes[at:]...)...) -} - -// DeleteRows deletes n rows of indexes starting at given index in the list of indexes -func (ix *IndexView) DeleteRows(at, n int) { - ix.Indexes = append(ix.Indexes[:at], ix.Indexes[at+n:]...) -} - -// RowsByStringIndex returns the list of *our indexes* whose row in the table has -// given string value in given column index (de-reference our indexes to get actual row). -// if contains, only checks if row contains string; if ignoreCase, ignores case. -// Use named args for greater clarity. -func (ix *IndexView) RowsByStringIndex(colIndex int, str string, contains, ignoreCase bool) []int { - dt := ix.Table - col := dt.Columns[colIndex] - lowstr := strings.ToLower(str) - var idxs []int - for idx, srw := range ix.Indexes { - val := col.String1D(srw) - has := false - switch { - case contains && ignoreCase: - has = strings.Contains(strings.ToLower(val), lowstr) - case contains: - has = strings.Contains(val, str) - case ignoreCase: - has = strings.EqualFold(val, str) - default: - has = (val == str) - } - if has { - idxs = append(idxs, idx) - } - } - return idxs -} - -// RowsByString returns the list of *our indexes* whose row in the table has -// given string value in given column name (de-reference our indexes to get actual row). -// if contains, only checks if row contains string; if ignoreCase, ignores case. -// returns error message for invalid column name. -// Use named args for greater clarity. -func (ix *IndexView) RowsByString(column string, str string, contains, ignoreCase bool) ([]int, error) { - dt := ix.Table - ci, err := dt.ColumnIndex(column) - if err != nil { - return nil, err - } - return ix.RowsByStringIndex(ci, str, contains, ignoreCase), nil -} - -// Len returns the length of the index list -func (ix *IndexView) Len() int { - return len(ix.Indexes) -} - -// Less calls the LessFunc for sorting -func (ix *IndexView) Less(i, j int) bool { - return ix.lessFunc(ix.Table, ix.Indexes[i], ix.Indexes[j]) -} - -// Swap switches the indexes for i and j -func (ix *IndexView) Swap(i, j int) { - ix.Indexes[i], ix.Indexes[j] = ix.Indexes[j], ix.Indexes[i] -} diff --git a/tensor/table/io.go b/tensor/table/io.go index c5e0395cfd..e40b39e33d 100644 --- a/tensor/table/io.go +++ b/tensor/table/io.go @@ -11,6 +11,7 @@ import ( "io" "io/fs" "log" + "log/slog" "math" "os" "reflect" @@ -18,39 +19,10 @@ import ( "strings" "cogentcore.org/core/base/errors" - "cogentcore.org/core/core" + "cogentcore.org/core/base/fsx" "cogentcore.org/core/tensor" ) -// Delim are standard CSV delimiter options (Tab, Comma, Space) -type Delims int32 //enums:enum - -const ( - // Tab is the tab rune delimiter, for TSV tab separated values - Tab Delims = iota - - // Comma is the comma rune delimiter, for CSV comma separated values - Comma - - // Space is the space rune delimiter, for SSV space separated value - Space - - // Detect is used during reading a file -- reads the first line and detects tabs or commas - Detect -) - -func (dl Delims) Rune() rune { - switch dl { - case Tab: - return '\t' - case Comma: - return ',' - case Space: - return ' ' - } - return '\t' -} - const ( // Headers is passed to CSV methods for the headers arg, to use headers // that capture full type and tensor shape information. @@ -66,7 +38,7 @@ const ( // and tensor cell geometry of the columns, enabling full reloading // of exactly the same table format and data (recommended). // Otherwise, only the data is written. -func (dt *Table) SaveCSV(filename core.Filename, delim Delims, headers bool) error { //types:add +func (dt *Table) SaveCSV(filename fsx.Filename, delim tensor.Delims, headers bool) error { //types:add fp, err := os.Create(string(filename)) defer fp.Close() if err != nil { @@ -79,25 +51,6 @@ func (dt *Table) SaveCSV(filename core.Filename, delim Delims, headers bool) err return err } -// SaveCSV writes a table index view to a comma-separated-values (CSV) file -// (where comma = any delimiter, specified in the delim arg). -// If headers = true then generate column headers that capture the type -// and tensor cell geometry of the columns, enabling full reloading -// of exactly the same table format and data (recommended). -// Otherwise, only the data is written. -func (ix *IndexView) SaveCSV(filename core.Filename, delim Delims, headers bool) error { //types:add - fp, err := os.Create(string(filename)) - defer fp.Close() - if err != nil { - log.Println(err) - return err - } - bw := bufio.NewWriter(fp) - err = ix.WriteCSV(bw, delim, headers) - bw.Flush() - return err -} - // OpenCSV reads a table from a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg), // using the Go standard encoding/csv reader conforming to the official CSV standard. @@ -107,7 +60,7 @@ func (ix *IndexView) SaveCSV(filename core.Filename, delim Delims, headers bool) // information for tensor type and dimensionality. // If the table DOES have existing columns, then those are used robustly // for whatever information fits from each row of the file. -func (dt *Table) OpenCSV(filename core.Filename, delim Delims) error { //types:add +func (dt *Table) OpenCSV(filename fsx.Filename, delim tensor.Delims) error { //types:add fp, err := os.Open(string(filename)) if err != nil { return errors.Log(err) @@ -117,7 +70,7 @@ func (dt *Table) OpenCSV(filename core.Filename, delim Delims) error { //types:a } // OpenFS is the version of [Table.OpenCSV] that uses an [fs.FS] filesystem. -func (dt *Table) OpenFS(fsys fs.FS, filename string, delim Delims) error { +func (dt *Table) OpenFS(fsys fs.FS, filename string, delim tensor.Delims) error { fp, err := fsys.Open(filename) if err != nil { return errors.Log(err) @@ -126,28 +79,6 @@ func (dt *Table) OpenFS(fsys fs.FS, filename string, delim Delims) error { return dt.ReadCSV(bufio.NewReader(fp), delim) } -// OpenCSV reads a table idx view from a comma-separated-values (CSV) file -// (where comma = any delimiter, specified in the delim arg), -// using the Go standard encoding/csv reader conforming to the official CSV standard. -// If the table does not currently have any columns, the first row of the file -// is assumed to be headers, and columns are constructed therefrom. -// If the file was saved from table with headers, then these have full configuration -// information for tensor type and dimensionality. -// If the table DOES have existing columns, then those are used robustly -// for whatever information fits from each row of the file. -func (ix *IndexView) OpenCSV(filename core.Filename, delim Delims) error { //types:add - err := ix.Table.OpenCSV(filename, delim) - ix.Sequential() - return err -} - -// OpenFS is the version of [IndexView.OpenCSV] that uses an [fs.FS] filesystem. -func (ix *IndexView) OpenFS(fsys fs.FS, filename string, delim Delims) error { - err := ix.Table.OpenFS(fsys, filename, delim) - ix.Sequential() - return err -} - // ReadCSV reads a table from a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg), // using the Go standard encoding/csv reader conforming to the official CSV standard. @@ -157,7 +88,8 @@ func (ix *IndexView) OpenFS(fsys fs.FS, filename string, delim Delims) error { // information for tensor type and dimensionality. // If the table DOES have existing columns, then those are used robustly // for whatever information fits from each row of the file. -func (dt *Table) ReadCSV(r io.Reader, delim Delims) error { +func (dt *Table) ReadCSV(r io.Reader, delim tensor.Delims) error { + dt.Sequential() cr := csv.NewReader(r) cr.Comma = delim.Rune() rec, err := cr.ReadAll() // todo: lazy, avoid resizing @@ -165,7 +97,6 @@ func (dt *Table) ReadCSV(r io.Reader, delim Delims) error { return err } rows := len(rec) - // cols := len(rec[0]) strow := 0 if dt.NumColumns() == 0 || DetectTableHeaders(rec[0]) { dt.DeleteAll() @@ -186,26 +117,24 @@ func (dt *Table) ReadCSV(r io.Reader, delim Delims) error { // ReadCSVRow reads a record of CSV data into given row in table func (dt *Table) ReadCSVRow(rec []string, row int) { - tc := dt.NumColumns() ci := 0 if rec[0] == "_D:" { // data row ci++ } nan := math.NaN() - for j := 0; j < tc; j++ { - tsr := dt.Columns[j] - _, csz := tsr.RowCellSize() + for _, tsr := range dt.Columns.Values { + _, csz := tsr.Shape().RowCellSize() stoff := row * csz for cc := 0; cc < csz; cc++ { str := rec[ci] if !tsr.IsString() { if str == "" || str == "NaN" || str == "-NaN" || str == "Inf" || str == "-Inf" { - tsr.SetFloat1D(stoff+cc, nan) + tsr.SetFloat1D(nan, stoff+cc) } else { - tsr.SetString1D(stoff+cc, strings.TrimSpace(str)) + tsr.SetString1D(strings.TrimSpace(str), stoff+cc) } } else { - tsr.SetString1D(stoff+cc, strings.TrimSpace(str)) + tsr.SetString1D(strings.TrimSpace(str), stoff+cc) } ci++ if ci >= len(rec) { @@ -256,14 +185,14 @@ func ConfigFromTableHeaders(dt *Table, hdrs []string) error { hd = hd[:lbst] csh := ShapeFromString(dims) // new tensor starting - dt.AddTensorColumnOfType(typ, hd, csh, "Row") + dt.AddColumnOfType(hd, typ, csh...) continue } dimst = strings.Index(hd, "[") if dimst > 0 { continue } - dt.AddColumnOfType(typ, hd) + dt.AddColumnOfType(hd, typ) } return nil } @@ -358,7 +287,7 @@ func ConfigFromDataValues(dt *Table, hdrs []string, rec [][]string) error { typ = ctyp } } - dt.AddColumnOfType(typ, hd) + dt.AddColumnOfType(hd, typ) } return nil } @@ -384,37 +313,7 @@ func InferDataType(str string) reflect.Kind { return reflect.String } -////////////////////////////////////////////////////////////////////////// -// WriteCSV - -// WriteCSV writes a table to a comma-separated-values (CSV) file -// (where comma = any delimiter, specified in the delim arg). -// If headers = true then generate column headers that capture the type -// and tensor cell geometry of the columns, enabling full reloading -// of exactly the same table format and data (recommended). -// Otherwise, only the data is written. -func (dt *Table) WriteCSV(w io.Writer, delim Delims, headers bool) error { - ncol := 0 - var err error - if headers { - ncol, err = dt.WriteCSVHeaders(w, delim) - if err != nil { - log.Println(err) - return err - } - } - cw := csv.NewWriter(w) - cw.Comma = delim.Rune() - for ri := 0; ri < dt.Rows; ri++ { - err = dt.WriteCSVRowWriter(cw, ri, ncol) - if err != nil { - log.Println(err) - return err - } - } - cw.Flush() - return nil -} +//////// WriteCSV // WriteCSV writes only rows in table idx view to a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg). @@ -422,11 +321,11 @@ func (dt *Table) WriteCSV(w io.Writer, delim Delims, headers bool) error { // and tensor cell geometry of the columns, enabling full reloading // of exactly the same table format and data (recommended). // Otherwise, only the data is written. -func (ix *IndexView) WriteCSV(w io.Writer, delim Delims, headers bool) error { +func (dt *Table) WriteCSV(w io.Writer, delim tensor.Delims, headers bool) error { ncol := 0 var err error if headers { - ncol, err = ix.Table.WriteCSVHeaders(w, delim) + ncol, err = dt.WriteCSVHeaders(w, delim) if err != nil { log.Println(err) return err @@ -434,9 +333,10 @@ func (ix *IndexView) WriteCSV(w io.Writer, delim Delims, headers bool) error { } cw := csv.NewWriter(w) cw.Comma = delim.Rune() - nrow := ix.Len() - for ri := 0; ri < nrow; ri++ { - err = ix.Table.WriteCSVRowWriter(cw, ix.Indexes[ri], ncol) + nrow := dt.NumRows() + for ri := range nrow { + ix := dt.RowIndex(ri) + err = dt.WriteCSVRowWriter(cw, ix, ncol) if err != nil { log.Println(err) return err @@ -449,7 +349,7 @@ func (ix *IndexView) WriteCSV(w io.Writer, delim Delims, headers bool) error { // WriteCSVHeaders writes headers to a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg). // Returns number of columns in header -func (dt *Table) WriteCSVHeaders(w io.Writer, delim Delims) (int, error) { +func (dt *Table) WriteCSVHeaders(w io.Writer, delim tensor.Delims) (int, error) { cw := csv.NewWriter(w) cw.Comma = delim.Rune() hdrs := dt.TableHeaders() @@ -464,7 +364,7 @@ func (dt *Table) WriteCSVHeaders(w io.Writer, delim Delims) (int, error) { // WriteCSVRow writes given row to a comma-separated-values (CSV) file // (where comma = any delimiter, specified in the delim arg) -func (dt *Table) WriteCSVRow(w io.Writer, row int, delim Delims) error { +func (dt *Table) WriteCSVRow(w io.Writer, row int, delim tensor.Delims) error { cw := csv.NewWriter(w) cw.Comma = delim.Rune() err := dt.WriteCSVRowWriter(cw, row, 0) @@ -475,8 +375,8 @@ func (dt *Table) WriteCSVRow(w io.Writer, row int, delim Delims) error { // WriteCSVRowWriter uses csv.Writer to write one row func (dt *Table) WriteCSVRowWriter(cw *csv.Writer, row int, ncol int) error { prec := -1 - if ps, ok := dt.MetaData["precision"]; ok { - prec, _ = strconv.Atoi(ps) + if ps, err := tensor.Precision(dt); err == nil { + prec = ps } var rec []string if ncol > 0 { @@ -485,8 +385,7 @@ func (dt *Table) WriteCSVRowWriter(cw *csv.Writer, row int, ncol int) error { rec = make([]string, 0) } rc := 0 - for i := range dt.Columns { - tsr := dt.Columns[i] + for _, tsr := range dt.Columns.Values { nd := tsr.NumDims() if nd == 1 { vl := "" @@ -502,7 +401,7 @@ func (dt *Table) WriteCSVRowWriter(cw *csv.Writer, row int, ncol int) error { } rc++ } else { - csh := tensor.NewShape(tsr.Shape().Sizes[1:]) // cell shape + csh := tensor.NewShape(tsr.ShapeSizes()[1:]...) // cell shape tc := csh.Len() for ti := 0; ti < tc; ti++ { vl := "" @@ -528,14 +427,13 @@ func (dt *Table) WriteCSVRowWriter(cw *csv.Writer, row int, ncol int) error { // with full information about type and tensor cell dimensionality. func (dt *Table) TableHeaders() []string { hdrs := []string{} - for i := range dt.Columns { - tsr := dt.Columns[i] - nm := dt.ColumnNames[i] + for i, nm := range dt.Columns.Keys { + tsr := dt.Columns.Values[i] nm = string([]byte{TableHeaderChar(tsr.DataType())}) + nm if tsr.NumDims() == 1 { hdrs = append(hdrs, nm) } else { - csh := tensor.NewShape(tsr.Shape().Sizes[1:]) // cell shape + csh := tensor.NewShape(tsr.ShapeSizes()[1:]...) // cell shape tc := csh.Len() nd := csh.NumDims() fnm := nm + fmt.Sprintf("[%v:", nd) @@ -552,7 +450,7 @@ func (dt *Table) TableHeaders() []string { ffnm += "]" + dn + ">" hdrs = append(hdrs, ffnm) for ti := 1; ti < tc; ti++ { - idx := csh.Index(ti) + idx := csh.IndexFrom1D(ti) ffnm := fnm for di := 0; di < nd; di++ { ffnm += fmt.Sprintf("%v", idx[di]) @@ -567,3 +465,44 @@ func (dt *Table) TableHeaders() []string { } return hdrs } + +// CleanCatTSV cleans a TSV file formed by concatenating multiple files together. +// Removes redundant headers and then sorts by given set of columns. +func CleanCatTSV(filename string, sorts ...string) error { + str, err := os.ReadFile(filename) + if err != nil { + slog.Error(err.Error()) + return err + } + lns := strings.Split(string(str), "\n") + if len(lns) == 0 { + return nil + } + hdr := lns[0] + f, err := os.Create(filename) + if err != nil { + slog.Error(err.Error()) + return err + } + for i, ln := range lns { + if i > 0 && ln == hdr { + continue + } + io.WriteString(f, ln) + io.WriteString(f, "\n") + } + f.Close() + dt := New() + err = dt.OpenCSV(fsx.Filename(filename), tensor.Detect) + if err != nil { + slog.Error(err.Error()) + return err + } + dt.SortColumns(tensor.Ascending, tensor.StableSort, sorts...) + st := dt.New() + err = st.SaveCSV(fsx.Filename(filename), tensor.Tab, true) + if err != nil { + slog.Error(err.Error()) + } + return err +} diff --git a/tensor/table/io_test.go b/tensor/table/io_test.go index 925fc9a2e7..38d80d528b 100644 --- a/tensor/table/io_test.go +++ b/tensor/table/io_test.go @@ -9,13 +9,16 @@ import ( "reflect" "strings" "testing" + + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" ) func TestTableHeaders(t *testing.T) { hdrstr := `$Name %Input[2:0,0]<2:5,5> %Input[2:1,0] %Input[2:2,0] %Input[2:3,0] %Input[2:4,0] %Input[2:0,1] %Input[2:1,1] %Input[2:2,1] %Input[2:3,1] %Input[2:4,1] %Input[2:0,2] %Input[2:1,2] %Input[2:2,2] %Input[2:3,2] %Input[2:4,2] %Input[2:0,3] %Input[2:1,3] %Input[2:2,3] %Input[2:3,3] %Input[2:4,3] %Input[2:0,4] %Input[2:1,4] %Input[2:2,4] %Input[2:3,4] %Input[2:4,4] %Output[2:0,0]<2:5,5> %Output[2:1,0] %Output[2:2,0] %Output[2:3,0] %Output[2:4,0] %Output[2:0,1] %Output[2:1,1] %Output[2:2,1] %Output[2:3,1] %Output[2:4,1] %Output[2:0,2] %Output[2:1,2] %Output[2:2,2] %Output[2:3,2] %Output[2:4,2] %Output[2:0,3] %Output[2:1,3] %Output[2:2,3] %Output[2:3,3] %Output[2:4,3] %Output[2:0,4] %Output[2:1,4] %Output[2:2,4] %Output[2:3,4] %Output[2:4,4] ` hdrs := strings.Split(hdrstr, "\t") - dt := NewTable() + dt := New() err := ConfigFromHeaders(dt, hdrs, nil) if err != nil { t.Error(err) @@ -24,26 +27,29 @@ func TestTableHeaders(t *testing.T) { if dt.NumColumns() != 3 { t.Errorf("TableHeaders: len != 3\n") } - if dt.Columns[0].DataType() != reflect.String { - t.Errorf("TableHeaders: dt.Columns[0] != STRING\n") + cols := dt.Columns.Values + if cols[0].DataType() != reflect.String { + t.Errorf("TableHeaders: cols[0] != STRING\n") } - if dt.Columns[1].DataType() != reflect.Float32 { - t.Errorf("TableHeaders: dt.Columns[1] != FLOAT32\n") + if cols[1].DataType() != reflect.Float32 { + t.Errorf("TableHeaders: cols[1] != FLOAT32\n") } - if dt.Columns[2].DataType() != reflect.Float32 { - t.Errorf("TableHeaders: dt.Columns[2] != FLOAT32\n") + if cols[2].DataType() != reflect.Float32 { + t.Errorf("TableHeaders: cols[2] != FLOAT32\n") } - if dt.Columns[1].Shape().Sizes[1] != 5 { - t.Errorf("TableHeaders: dt.Columns[1].Shape().Sizes[1] != 5\n") + shsz := cols[1].ShapeSizes() + if shsz[1] != 5 { + t.Errorf("TableHeaders: cols[1].ShapeSizes[1] != 5\n") } - if dt.Columns[1].Shape().Sizes[2] != 5 { - t.Errorf("TableHeaders: dt.Columns[1].Shape().Sizes[2] != 5\n") + if shsz[2] != 5 { + t.Errorf("TableHeaders: cols[1].ShapeSizes[2] != 5\n") } - if dt.Columns[2].Shape().Sizes[1] != 5 { - t.Errorf("TableHeaders: dt.Columns[2].Shape().Sizes[1] != 5\n") + shsz = cols[2].ShapeSizes() + if shsz[1] != 5 { + t.Errorf("TableHeaders: cols[2].ShapeSizes[1] != 5\n") } - if dt.Columns[2].Shape().Sizes[2] != 5 { - t.Errorf("TableHeaders: dt.Columns[2].Shape().Sizes[2] != 5\n") + if shsz[2] != 5 { + t.Errorf("TableHeaders: cols[2].ShapeSizes[2] != 5\n") } outh := dt.TableHeaders() // fmt.Printf("headers out:\n%v\n", outh) @@ -66,12 +72,12 @@ func TestReadTableDat(t *testing.T) { if err != nil { t.Error(err) } - dt := &Table{} + dt := New() err = dt.ReadCSV(fp, '\t') // tsv if err != nil { t.Error(err) } - sc := dt.Columns + sc := dt.Columns.Values if len(sc) != 3 { t.Errorf("TableHeaders: len != 3\n") } @@ -84,16 +90,16 @@ func TestReadTableDat(t *testing.T) { if sc[2].DataType() != reflect.Float32 { t.Errorf("TableHeaders: sc[2] != FLOAT32\n") } - if sc[1].Shape().DimSize(0) != 6 { - t.Errorf("TableHeaders: sc[1].Dim[0] != 6 = %v\n", sc[1].Shape().DimSize(0)) + if sc[1].DimSize(0) != 6 { + t.Errorf("TableHeaders: sc[1].Dim[0] != 6 = %v\n", sc[1].DimSize(0)) } - if sc[1].Shape().DimSize(1) != 5 { + if sc[1].DimSize(1) != 5 { t.Errorf("TableHeaders: sc[1].Dim[1] != 5\n") } - if sc[2].Shape().DimSize(0) != 6 { - t.Errorf("TableHeaders: sc[2].Dim[0] != 6 = %v\n", sc[2].Shape().DimSize(0)) + if sc[2].DimSize(0) != 6 { + t.Errorf("TableHeaders: sc[2].Dim[0] != 6 = %v\n", sc[2].DimSize(0)) } - if sc[2].Shape().DimSize(1) != 5 { + if sc[2].DimSize(1) != 5 { t.Errorf("TableHeaders: sc[2].Dim[1] != 5\n") } fo, err := os.Create("testdata/emer_simple_lines_5x5_rec.dat") @@ -104,3 +110,41 @@ func TestReadTableDat(t *testing.T) { dt.WriteCSV(fo, '\t', Headers) } } + +func TestLog(t *testing.T) { + dt := New() + dt.AddStringColumn("Name") + dt.AddFloat64Column("Value") + err := dt.OpenLog("testdata/log.tsv", tensor.Tab) + assert.NoError(t, err) + err = dt.WriteToLog() + assert.Equal(t, ErrLogNoNewRows, err) + + dt.SetNumRows(1) + dt.Column("Name").SetString1D("test1", 0) + dt.Column("Value").SetFloat1D(42, 0) + err = dt.WriteToLog() + assert.NoError(t, err) + + dt.AddRows(2) + dt.Column("Name").SetString1D("test2", 1) + dt.Column("Value").SetFloat1D(44, 1) + dt.Column("Name").SetString1D("test3", 2) + dt.Column("Value").SetFloat1D(46, 2) + err = dt.WriteToLog() + assert.NoError(t, err) + + err = dt.WriteToLog() + assert.Equal(t, ErrLogNoNewRows, err) + + dt.SetNumRows(0) + err = dt.WriteToLog() + assert.Equal(t, ErrLogNoNewRows, err) + + dt.AddRows(1) + dt.Column("Name").SetString1D("test4", 0) + dt.Column("Value").SetFloat1D(50, 0) + err = dt.WriteToLog() + assert.NoError(t, err) + dt.CloseLog() +} diff --git a/tensor/table/log.go b/tensor/table/log.go new file mode 100644 index 0000000000..e2be87566f --- /dev/null +++ b/tensor/table/log.go @@ -0,0 +1,87 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package table + +import ( + "os" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/tensor" +) + +func setLogRow(dt *Table, row int) { + metadata.SetTo(dt, "LogRow", row) +} + +func logRow(dt *Table) int { + return errors.Ignore1(metadata.GetFrom[int](dt, "LogRow")) +} + +func setLogDelim(dt *Table, delim tensor.Delims) { + metadata.SetTo(dt, "LogDelim", delim) +} + +func logDelim(dt *Table) tensor.Delims { + return errors.Ignore1(metadata.GetFrom[tensor.Delims](dt, "LogDelim")) +} + +// OpenLog opens a log file for this table, which supports incremental +// output of table data as it is generated, using the standard [Table.SaveCSV] +// output formatting, using given delimiter between values on a line. +// Call [Table.WriteToLog] to write any new data rows to +// the open log file, and [Table.CloseLog] to close the file. +func (dt *Table) OpenLog(filename string, delim tensor.Delims) error { + f, err := os.Create(filename) + if err != nil { + return err + } + metadata.SetFile(dt, f) + setLogDelim(dt, delim) + setLogRow(dt, 0) + return nil +} + +var ( + ErrLogNoNewRows = errors.New("no new rows to write") +) + +// WriteToLog writes any accumulated rows in the table to the file +// opened by [Table.OpenLog]. A Header row is written for the first output. +// If the current number of rows is less than the last number of rows, +// all of those rows are written under the assumption that the rows +// were reset via [Table.SetNumRows]. +// Returns error for any failure, including [ErrLogNoNewRows] if +// no new rows are available to write. +func (dt *Table) WriteToLog() error { + f := metadata.File(dt) + if f == nil { + return errors.New("tensor.Table.WriteToLog: log file was not opened") + } + delim := logDelim(dt) + lrow := logRow(dt) + nr := dt.NumRows() + if nr == 0 || lrow == nr { + return ErrLogNoNewRows + } + if lrow == 0 { + dt.WriteCSVHeaders(f, delim) + } + sr := lrow + if nr < lrow { + sr = 0 + } + for r := sr; r < nr; r++ { + dt.WriteCSVRow(f, r, delim) + } + setLogRow(dt, nr) + return nil +} + +// CloseLog closes the log file opened by [Table.OpenLog]. +func (dt *Table) CloseLog() { + f := metadata.File(dt) + f.Close() +} diff --git a/tensor/table/slicetable.go b/tensor/table/slicetable.go index 7587ab6d23..61432a39cf 100644 --- a/tensor/table/slicetable.go +++ b/tensor/table/slicetable.go @@ -22,35 +22,17 @@ func NewSliceTable(st any) (*Table, error) { if eltyp.Kind() != reflect.Struct { return nil, fmt.Errorf("NewSliceTable: element type is not a struct") } - dt := NewTable() + dt := New() for i := 0; i < eltyp.NumField(); i++ { f := eltyp.Field(i) - switch f.Type.Kind() { - case reflect.Float32: - dt.AddFloat32Column(f.Name) - case reflect.Float64: - dt.AddFloat64Column(f.Name) - case reflect.String: - dt.AddStringColumn(f.Name) - } - } - - nr := npv.Len() - dt.SetNumRows(nr) - for ri := 0; ri < nr; ri++ { - for i := 0; i < eltyp.NumField(); i++ { - f := eltyp.Field(i) - switch f.Type.Kind() { - case reflect.Float32: - dt.SetFloat(f.Name, ri, float64(npv.Index(ri).Field(i).Interface().(float32))) - case reflect.Float64: - dt.SetFloat(f.Name, ri, float64(npv.Index(ri).Field(i).Interface().(float64))) - case reflect.String: - dt.SetString(f.Name, ri, npv.Index(ri).Field(i).Interface().(string)) - } + kind := f.Type.Kind() + if !reflectx.KindIsBasic(kind) { + continue } + dt.AddColumnOfType(f.Name, kind) } + UpdateSliceTable(st, dt) return dt, nil } @@ -65,13 +47,17 @@ func UpdateSliceTable(st any, dt *Table) { for ri := 0; ri < nr; ri++ { for i := 0; i < eltyp.NumField(); i++ { f := eltyp.Field(i) - switch f.Type.Kind() { - case reflect.Float32: - dt.SetFloat(f.Name, ri, float64(npv.Index(ri).Field(i).Interface().(float32))) - case reflect.Float64: - dt.SetFloat(f.Name, ri, float64(npv.Index(ri).Field(i).Interface().(float64))) - case reflect.String: - dt.SetString(f.Name, ri, npv.Index(ri).Field(i).Interface().(string)) + kind := f.Type.Kind() + if !reflectx.KindIsBasic(kind) { + continue + } + val := npv.Index(ri).Field(i).Interface() + cl := dt.Column(f.Name) + if kind == reflect.String { + cl.SetStringRow(val.(string), ri, 0) + } else { + fv, _ := reflectx.ToFloat(val) + cl.SetFloatRow(fv, ri, 0) } } } diff --git a/tensor/table/slicetable_test.go b/tensor/table/slicetable_test.go index d74948342f..c280ba68fb 100644 --- a/tensor/table/slicetable_test.go +++ b/tensor/table/slicetable_test.go @@ -26,5 +26,5 @@ func TestSliceTable(t *testing.T) { if err != nil { t.Error(err.Error()) } - assert.Equal(t, 2, dt.Rows) + assert.Equal(t, 2, dt.NumRows()) } diff --git a/tensor/table/splits.go b/tensor/table/splits.go deleted file mode 100644 index faa8ec71c3..0000000000 --- a/tensor/table/splits.go +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright (c) 2024, Cogent Core. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package table - -import ( - "fmt" - "slices" - "sort" - "strings" - - "cogentcore.org/core/base/errors" -) - -// SplitAgg contains aggregation results for splits -type SplitAgg struct { - - // the name of the aggregation operation performed, e.g., Sum, Mean, etc - Name string - - // column index on which the aggregation was performed -- results will have same shape as cells in this column - ColumnIndex int - - // aggregation results -- outer index is length of splits, inner is the length of the cell shape for the column - Aggs [][]float64 -} - -// Splits is a list of indexed views into a given Table, that represent a particular -// way of splitting up the data, e.g., whenever a given column value changes. -// -// It is functionally equivalent to the MultiIndex in python's pandas: it has multiple -// levels of indexes as listed in the Levels field, which then have corresponding -// Values for each split. These index levels can be re-ordered, and new Splits or -// IndexViews's can be created from subsets of the existing levels. The Values are -// stored simply as string values, as this is the most general type and often -// index values are labels etc. -// -// For Splits created by the splits.GroupBy function for example, each index Level is -// the column name that the data was grouped by, and the Values for each split are then -// the values of those columns. However, any arbitrary set of levels and values can -// be used, e.g., as in the splits.GroupByFunc function. -// -// Conceptually, a given Split always contains the full "outer product" of all the -// index levels -- there is one split for each unique combination of values along each -// index level. Thus, removing one level collapses across those values and moves the -// corresponding indexes into the remaining split indexes. -// -// You can Sort and Filter based on the index values directly, to reorganize the splits -// and drop particular index values, etc. -// -// Splits also maintains Aggs aggregate values for each split, which can be computed using -// standard aggregation methods over data columns, using the split.Agg* functions. -// -// The table code contains the structural methods for managing the Splits data. -// See split package for end-user methods to generate different kinds of splits, -// and perform aggregations, etc. -type Splits struct { - - // the list of index views for each split - Splits []*IndexView - - // levels of indexes used to organize the splits -- each split contains the full outer product across these index levels. for example, if the split was generated by grouping over column values, then these are the column names in order of grouping. the splits are not automatically sorted hierarchically by these levels but e.g., the GroupBy method produces that result -- use the Sort methods to explicitly sort. - Levels []string - - // the values of the index levels associated with each split. The outer dimension is the same length as Splits, and the inner dimension is the levels. - Values [][]string - - // aggregate results, one for each aggregation operation performed -- split-level data is contained within each SplitAgg struct -- deleting a split removes these aggs but adding new splits just invalidates all existing aggs (they are automatically deleted). - Aggs []*SplitAgg - - // current Less function used in sorting - lessFunc SplitsLessFunc `copier:"-" display:"-" xml:"-" json:"-"` -} - -// SplitsLessFunc is a function used for sort comparisons that returns -// true if split i is less than split j -type SplitsLessFunc func(spl *Splits, i, j int) bool - -// Len returns number of splits -func (spl *Splits) Len() int { - return len(spl.Splits) -} - -// Table returns the table from the first split (should be same for all) -// returns nil if no splits yet -func (spl *Splits) Table() *Table { - if len(spl.Splits) == 0 { - return nil - } - return spl.Splits[0].Table -} - -// New adds a new split to the list for given table, and with associated -// values, which are copied before saving into Values list, and any number of rows -// from the table associated with this split (also copied). -// Any existing Aggs are deleted by this. -func (spl *Splits) New(dt *Table, values []string, rows ...int) *IndexView { - spl.Aggs = nil - ix := &IndexView{Table: dt} - spl.Splits = append(spl.Splits, ix) - if len(rows) > 0 { - ix.Indexes = append(ix.Indexes, slices.Clone(rows)...) - } - if len(values) > 0 { - spl.Values = append(spl.Values, slices.Clone(values)) - } else { - spl.Values = append(spl.Values, nil) - } - return ix -} - -// ByValue finds split indexes by matching to split values, returns nil if not found. -// values are used in order as far as they go and any remaining values are assumed -// to match, and any empty values will match anything. Can use this to access different -// subgroups within overall set of splits. -func (spl *Splits) ByValue(values []string) []int { - var matches []int - for si, sn := range spl.Values { - sz := min(len(sn), len(values)) - match := true - for j := 0; j < sz; j++ { - if values[j] == "" { - continue - } - if values[j] != sn[j] { - match = false - break - } - } - if match { - matches = append(matches, si) - } - } - return matches -} - -// Delete deletes split at given index -- use this to coordinate deletion -// of Splits, Values, and Aggs values for given split -func (spl *Splits) Delete(idx int) { - spl.Splits = append(spl.Splits[:idx], spl.Splits[idx+1:]...) - spl.Values = append(spl.Values[:idx], spl.Values[idx+1:]...) - for _, ag := range spl.Aggs { - ag.Aggs = append(ag.Aggs[:idx], ag.Aggs[idx+1:]...) - } -} - -// Filter removes any split for which given function returns false -func (spl *Splits) Filter(fun func(idx int) bool) { - sz := len(spl.Splits) - for si := sz - 1; si >= 0; si-- { - if !fun(si) { - spl.Delete(si) - } - } -} - -// Sort sorts the splits according to the given Less function. -func (spl *Splits) Sort(lessFunc func(spl *Splits, i, j int) bool) { - spl.lessFunc = lessFunc - sort.Sort(spl) -} - -// SortLevels sorts the splits according to the current index level ordering of values -// i.e., first index level is outer sort dimension, then within that is the next, etc -func (spl *Splits) SortLevels() { - spl.Sort(func(sl *Splits, i, j int) bool { - vli := sl.Values[i] - vlj := sl.Values[j] - for k := range vli { - if vli[k] < vlj[k] { - return true - } else if vli[k] > vlj[k] { - return false - } // fallthrough - } - return false - }) -} - -// SortOrder sorts the splits according to the given ordering of index levels -// which can be a subset as well -func (spl *Splits) SortOrder(order []int) error { - if len(order) == 0 || len(order) > len(spl.Levels) { - return fmt.Errorf("table.Splits SortOrder: order length == 0 or > Levels") - } - spl.Sort(func(sl *Splits, i, j int) bool { - vli := sl.Values[i] - vlj := sl.Values[j] - for k := range order { - if vli[order[k]] < vlj[order[k]] { - return true - } else if vli[order[k]] > vlj[order[k]] { - return false - } // fallthrough - } - return false - }) - return nil -} - -// ReorderLevels re-orders the index levels according to the given new ordering indexes -// e.g., []int{1,0} will move the current level 0 to level 1, and 1 to level 0 -// no checking is done to ensure these are sensible beyond basic length test -- -// behavior undefined if so. Typically you want to call SortLevels after this. -func (spl *Splits) ReorderLevels(order []int) error { - nlev := len(spl.Levels) - if len(order) != nlev { - return fmt.Errorf("table.Splits ReorderLevels: order length != Levels") - } - old := make([]string, nlev) - copy(old, spl.Levels) - for i := range order { - spl.Levels[order[i]] = old[i] - } - for si := range spl.Values { - copy(old, spl.Values[si]) - for i := range order { - spl.Values[si][order[i]] = old[i] - } - } - return nil -} - -// ExtractLevels returns a new Splits that only has the given levels of indexes, -// in their given order, with the other levels removed and their corresponding indexes -// merged into the appropriate remaining levels. -// Any existing aggregation data is not retained in the new splits. -func (spl *Splits) ExtractLevels(levels []int) (*Splits, error) { - nlv := len(levels) - if nlv == 0 || nlv >= len(spl.Levels) { - return nil, fmt.Errorf("table.Splits ExtractLevels: levels length == 0 or >= Levels") - } - aggs := spl.Aggs - spl.Aggs = nil - ss := spl.Clone() - spl.Aggs = aggs - ss.SortOrder(levels) - // now just do the grouping by levels values - lstValues := make([]string, nlv) - curValues := make([]string, nlv) - var curIx *IndexView - nsp := len(ss.Splits) - for si := nsp - 1; si >= 0; si-- { - diff := false - for li := range levels { - vl := ss.Values[si][levels[li]] - curValues[li] = vl - if vl != lstValues[li] { - diff = true - } - } - if diff || curIx == nil { - curIx = ss.Splits[si] - copy(lstValues, curValues) - ss.Values[si] = slices.Clone(curValues) - } else { - curIx.Indexes = append(curIx.Indexes, ss.Splits[si].Indexes...) // absorb - ss.Delete(si) - } - } - ss.Levels = make([]string, nlv) - for li := range levels { - ss.Levels[li] = spl.Levels[levels[li]] - } - return ss, nil -} - -// Clone returns a cloned copy of our SplitAgg -func (sa *SplitAgg) Clone() *SplitAgg { - nsa := &SplitAgg{} - nsa.CopyFrom(sa) - return nsa -} - -// CopyFrom copies from other SplitAgg -- we get our own unique copy of everything -func (sa *SplitAgg) CopyFrom(osa *SplitAgg) { - sa.Name = osa.Name - sa.ColumnIndex = osa.ColumnIndex - nags := len(osa.Aggs) - if nags > 0 { - sa.Aggs = make([][]float64, nags) - for si := range osa.Aggs { - sa.Aggs[si] = slices.Clone(osa.Aggs[si]) - } - } -} - -// Clone returns a cloned copy of our splits -func (spl *Splits) Clone() *Splits { - nsp := &Splits{} - nsp.CopyFrom(spl) - return nsp -} - -// CopyFrom copies from other Splits -- we get our own unique copy of everything -func (spl *Splits) CopyFrom(osp *Splits) { - spl.Splits = make([]*IndexView, len(osp.Splits)) - spl.Values = make([][]string, len(osp.Values)) - for si := range osp.Splits { - spl.Splits[si] = osp.Splits[si].Clone() - spl.Values[si] = slices.Clone(osp.Values[si]) - } - spl.Levels = slices.Clone(osp.Levels) - - nag := len(osp.Aggs) - if nag > 0 { - spl.Aggs = make([]*SplitAgg, nag) - for ai := range osp.Aggs { - spl.Aggs[ai] = osp.Aggs[ai].Clone() - } - } -} - -// AddAgg adds a new set of aggregation results for the Splits -func (spl *Splits) AddAgg(name string, colIndex int) *SplitAgg { - ag := &SplitAgg{Name: name, ColumnIndex: colIndex} - spl.Aggs = append(spl.Aggs, ag) - return ag -} - -// DeleteAggs deletes all existing aggregation data -func (spl *Splits) DeleteAggs() { - spl.Aggs = nil -} - -// AggByName returns Agg results for given name, which does NOT include the -// column name, just the name given to the Agg result -// (e.g., Mean for a standard Mean agg). -// Returns error message if not found. -func (spl *Splits) AggByName(name string) (*SplitAgg, error) { - for _, ag := range spl.Aggs { - if ag.Name == name { - return ag, nil - } - } - return nil, fmt.Errorf("table.Splits AggByName: agg results named: %v not found", name) -} - -// AggByColumnName returns Agg results for given column name, -// optionally including :Name agg name appended, where Name -// is the name given to the Agg result (e.g., Mean for a standard Mean agg). -// Returns error message if not found. -func (spl *Splits) AggByColumnName(name string) (*SplitAgg, error) { - dt := spl.Table() - if dt == nil { - return nil, fmt.Errorf("table.Splits AggByColumnName: table nil") - } - nmsp := strings.Split(name, ":") - colIndex, err := dt.ColumnIndex(nmsp[0]) - if err != nil { - return nil, err - } - for _, ag := range spl.Aggs { - if ag.ColumnIndex != colIndex { - continue - } - if len(nmsp) == 2 && nmsp[1] != ag.Name { - continue - } - return ag, nil - } - return nil, fmt.Errorf("table.Splits AggByColumnName: agg results named: %v not found", name) -} - -// SetLevels sets the Levels index names -- must match actual index dimensionality -// of the Values. This is automatically done by e.g., GroupBy, but must be done -// manually if creating custom indexes. -func (spl *Splits) SetLevels(levels ...string) { - spl.Levels = levels -} - -// use these for arg to ArgsToTable* -const ( - // ColumnNameOnly means resulting agg table just has the original column name, no aggregation name - ColumnNameOnly bool = true - // AddAggName means resulting agg table columns have aggregation name appended - AddAggName = false -) - -// AggsToTable returns a Table containing this Splits' aggregate data. -// Must have Levels and Aggs all created as in the split.Agg* methods. -// if colName == ColumnNameOnly, then the name of the columns for the Table -// is just the corresponding agg column name -- otherwise it also includes -// the name of the aggregation function with a : divider (e.g., Name:Mean) -func (spl *Splits) AggsToTable(colName bool) *Table { - nsp := len(spl.Splits) - if nsp == 0 { - return nil - } - dt := spl.Splits[0].Table - st := NewTable().SetNumRows(nsp) - for _, cn := range spl.Levels { - oc, _ := dt.ColumnByName(cn) - if oc != nil { - st.AddColumnOfType(oc.DataType(), cn) - } else { - st.AddStringColumn(cn) - } - } - for _, ag := range spl.Aggs { - col := dt.Columns[ag.ColumnIndex] - an := dt.ColumnNames[ag.ColumnIndex] - if colName == AddAggName { - an += ":" + ag.Name - } - st.AddFloat64TensorColumn(an, col.Shape().Sizes[1:], col.Shape().Names[1:]...) - } - for si := range spl.Splits { - cidx := 0 - for ci := range spl.Levels { - col := st.Columns[cidx] - col.SetString1D(si, spl.Values[si][ci]) - cidx++ - } - for _, ag := range spl.Aggs { - col := st.Columns[cidx] - _, csz := col.RowCellSize() - sti := si * csz - av := ag.Aggs[si] - for j, a := range av { - col.SetFloat1D(sti+j, a) - } - cidx++ - } - } - return st -} - -// AggsToTableCopy returns a Table containing this Splits' aggregate data -// and a copy of the first row of data for each split for all non-agg cols, -// which is useful for recording other data that goes along with aggregated values. -// Must have Levels and Aggs all created as in the split.Agg* methods. -// if colName == ColumnNameOnly, then the name of the columns for the Table -// is just the corresponding agg column name -- otherwise it also includes -// the name of the aggregation function with a : divider (e.g., Name:Mean) -func (spl *Splits) AggsToTableCopy(colName bool) *Table { - nsp := len(spl.Splits) - if nsp == 0 { - return nil - } - dt := spl.Splits[0].Table - st := NewTable().SetNumRows(nsp) - exmap := make(map[string]struct{}) - for _, cn := range spl.Levels { - st.AddStringColumn(cn) - exmap[cn] = struct{}{} - } - for _, ag := range spl.Aggs { - col := dt.Columns[ag.ColumnIndex] - an := dt.ColumnNames[ag.ColumnIndex] - exmap[an] = struct{}{} - if colName == AddAggName { - an += ":" + ag.Name - } - st.AddFloat64TensorColumn(an, col.Shape().Sizes[1:], col.Shape().Names[1:]...) - } - var cpcol []string - for _, cn := range dt.ColumnNames { - if _, ok := exmap[cn]; !ok { - cpcol = append(cpcol, cn) - col := errors.Log1(dt.ColumnByName(cn)) - st.AddColumn(col.Clone(), cn) - } - } - for si, sidx := range spl.Splits { - cidx := 0 - for ci := range spl.Levels { - col := st.Columns[cidx] - col.SetString1D(si, spl.Values[si][ci]) - cidx++ - } - for _, ag := range spl.Aggs { - col := st.Columns[cidx] - _, csz := col.RowCellSize() - sti := si * csz - av := ag.Aggs[si] - for j, a := range av { - col.SetFloat1D(sti+j, a) - } - cidx++ - } - if len(sidx.Indexes) > 0 { - stidx := sidx.Indexes[0] - for _, cn := range cpcol { - st.CopyCell(cn, si, dt, cn, stidx) - } - } - } - return st -} - -// Less calls the LessFunc for sorting -func (spl *Splits) Less(i, j int) bool { - return spl.lessFunc(spl, i, j) -} - -// Swap switches the indexes for i and j -func (spl *Splits) Swap(i, j int) { - spl.Splits[i], spl.Splits[j] = spl.Splits[j], spl.Splits[i] - spl.Values[i], spl.Values[j] = spl.Values[j], spl.Values[i] - for _, ag := range spl.Aggs { - ag.Aggs[i], ag.Aggs[j] = ag.Aggs[j], ag.Aggs[i] - } -} diff --git a/tensor/table/table.go b/tensor/table/table.go index 42c507b0ae..e589617e5c 100644 --- a/tensor/table/table.go +++ b/tensor/table/table.go @@ -7,703 +7,325 @@ package table //go:generate core generate import ( - "errors" "fmt" - "log/slog" - "math" "reflect" "slices" - "strings" + "cogentcore.org/core/base/metadata" "cogentcore.org/core/tensor" ) -// Table is a table of data, with columns of tensors, -// each with the same number of Rows (outer-most dimension). +// Table is a table of Tensor columns aligned by a common outermost row dimension. +// Use the [Table.Column] (by name) and [Table.ColumnIndex] methods to obtain a +// [tensor.Rows] view of the column, using the shared [Table.Indexes] of the Table. +// Thus, a coordinated sorting and filtered view of the column data is automatically +// available for any of the tensor package functions that use [tensor.Tensor] as the one +// common data representation for all operations. +// Tensor Columns are always raw value types and support SubSpace operations on cells. type Table struct { //types:add - - // columns of data, as tensor.Tensor tensors - Columns []tensor.Tensor `display:"no-inline"` - - // the names of the columns - ColumnNames []string - - // number of rows, which is enforced to be the size of the outer-most dimension of the column tensors - Rows int `edit:"-"` - - // the map of column names to column numbers - ColumnNameMap map[string]int `display:"-"` - - // misc meta data for the table. We use lower-case key names following the struct tag convention: name = name of table; desc = description; read-only = gui is read-only; precision = n for precision to write out floats in csv. For Column-specific data, we look for ColumnName: prefix, specifically ColumnName:desc = description of the column contents, which is shown as tooltip in the tensorcore.Table, and :width for width of a column - MetaData map[string]string -} - -func NewTable(name ...string) *Table { + // Columns has the list of column tensor data for this table. + // Different tables can provide different indexed views onto the same Columns. + Columns *Columns + + // Indexes are the indexes into Tensor rows, with nil = sequential. + // Only set if order is different from default sequential order. + // These indexes are shared into the `tensor.Rows` Column values + // to provide a coordinated indexed view into the underlying data. + Indexes []int + + // Meta is misc metadata for the table. Use lower-case key names + // following the struct tag convention: + // - name string = name of table + // - doc string = documentation, description + // - read-only bool = gui is read-only + // - precision int = n for precision to write out floats in csv. + Meta metadata.Data +} + +// New returns a new Table with its own (empty) set of Columns. +// Can pass an optional name which calls metadata SetName. +func New(name ...string) *Table { dt := &Table{} + dt.Columns = NewColumns() if len(name) > 0 { - dt.SetMetaData("name", name[0]) + metadata.SetName(dt, name[0]) } return dt } -// IsValidRow returns error if the row is invalid -func (dt *Table) IsValidRow(row int) error { - if row < 0 || row >= dt.Rows { - return fmt.Errorf("table.Table IsValidRow: row %d is out of valid range [0..%d]", row, dt.Rows) +// NewView returns a new Table with its own Rows view into the +// same underlying set of Column tensor data as the source table. +// Indexes are copied from the existing table -- use Sequential +// to reset to full sequential view. +func NewView(src *Table) *Table { + dt := &Table{Columns: src.Columns} + if src.Indexes != nil { + dt.Indexes = slices.Clone(src.Indexes) } - return nil + dt.Meta.CopyFrom(src.Meta) + return dt } -// NumRows returns the number of rows -func (dt *Table) NumRows() int { return dt.Rows } - -// NumColumns returns the number of columns -func (dt *Table) NumColumns() int { return len(dt.Columns) } +// Init initializes a new empty table with [NewColumns]. +func (dt *Table) Init() { + dt.Columns = NewColumns() +} -// Column returns the tensor at given column index -func (dt *Table) Column(i int) tensor.Tensor { return dt.Columns[i] } +func (dt *Table) Metadata() *metadata.Data { return &dt.Meta } -// ColumnByName returns the tensor at given column name, with error message if not found. -// Returns nil if not found -func (dt *Table) ColumnByName(name string) (tensor.Tensor, error) { - i, ok := dt.ColumnNameMap[name] - if !ok { - return nil, fmt.Errorf("table.Table ColumnByNameTry: column named: %v not found", name) +// IsValidRow returns error if the row is invalid, if error checking is needed. +func (dt *Table) IsValidRow(row int) error { + if row < 0 || row >= dt.NumRows() { + return fmt.Errorf("table.Table IsValidRow: row %d is out of valid range [0..%d]", row, dt.NumRows()) } - return dt.Columns[i], nil + return nil } -// ColumnIndex returns the index of the given column name, -// along with an error if not found. -func (dt *Table) ColumnIndex(name string) (int, error) { - i, ok := dt.ColumnNameMap[name] - if !ok { - return 0, fmt.Errorf("table.Table ColumnIndex: column named: %v not found", name) +// NumColumns returns the number of columns. +func (dt *Table) NumColumns() int { return dt.Columns.Len() } + +// Column returns the tensor with given column name, as a [tensor.Rows] +// with the shared [Table.Indexes] from this table. It is best practice to +// access columns by name, and direct access through [Table.Columns] does not +// provide the shared table-wide Indexes. +// Returns nil if not found. +func (dt *Table) Column(name string) *tensor.Rows { + cl := dt.Columns.At(name) + if cl == nil { + return nil } - return i, nil + return tensor.NewRows(cl, dt.Indexes...) } -// ColumnIndexesByNames returns the indexes of the given column names. -// idxs have -1 if name not found. -func (dt *Table) ColumnIndexesByNames(names ...string) ([]int, error) { - nc := len(names) - if nc == 0 { - return nil, nil +// ColumnTry is a version of [Table.Column] that also returns an error +// if the column name is not found, for cases when error is needed. +func (dt *Table) ColumnTry(name string) (*tensor.Rows, error) { + cl := dt.Column(name) + if cl != nil { + return cl, nil } - var errs []error - cidx := make([]int, nc) - for i, cn := range names { - var err error - cidx[i], err = dt.ColumnIndex(cn) - if err != nil { - errs = append(errs, err) + return nil, fmt.Errorf("table.Table: Column named %q not found", name) +} + +// ColumnIndex returns the tensor at the given column index, +// as a [tensor.Rows] with the shared [Table.Indexes] from this table. +// It is best practice to instead access columns by name using [Table.Column]. +// Direct access through [Table.Columns} does not provide the shared table-wide Indexes. +// Will panic if out of range. +func (dt *Table) ColumnByIndex(idx int) *tensor.Rows { + cl := dt.Columns.Values[idx] + return tensor.NewRows(cl, dt.Indexes...) +} + +// ColumnList returns a list of tensors with given column names, +// as [tensor.Rows] with the shared [Table.Indexes] from this table. +func (dt *Table) ColumnList(names ...string) []tensor.Tensor { + list := make([]tensor.Tensor, 0, len(names)) + for _, nm := range names { + cl := dt.Column(nm) + if cl != nil { + list = append(list, cl) } } - return cidx, errors.Join(errs...) + return list } -// ColumnName returns the name of given column +// ColumnName returns the name of given column. func (dt *Table) ColumnName(i int) string { - return dt.ColumnNames[i] + return dt.Columns.Keys[i] +} + +// ColumnIndex returns the index for given column name. +func (dt *Table) ColumnIndex(name string) int { + return dt.Columns.IndexByKey(name) } -// UpdateColumnNameMap updates the column name map, returning an error -// if any of the column names are duplicates. -func (dt *Table) UpdateColumnNameMap() error { - nc := dt.NumColumns() - dt.ColumnNameMap = make(map[string]int, nc) - var errs []error - for i, nm := range dt.ColumnNames { - if _, has := dt.ColumnNameMap[nm]; has { - err := fmt.Errorf("table.Table duplicate column name: %s", nm) - slog.Warn(err.Error()) - errs = append(errs, err) - } else { - dt.ColumnNameMap[nm] = i +// ColumnIndexList returns a list of indexes to columns of given names. +func (dt *Table) ColumnIndexList(names ...string) []int { + list := make([]int, 0, len(names)) + for _, nm := range names { + ci := dt.ColumnIndex(nm) + if ci >= 0 { + list = append(list, ci) } } - if len(errs) > 0 { - return errors.Join(errs...) - } - return nil + return list } // AddColumn adds a new column to the table, of given type and column name -// (which must be unique). The cells of this column hold a single scalar value: -// see AddColumnTensor for n-dimensional cells. -func AddColumn[T string | bool | float32 | float64 | int | int32 | byte](dt *Table, name string) tensor.Tensor { - rows := max(1, dt.Rows) - tsr := tensor.New[T]([]int{rows}, "Row") - dt.AddColumn(tsr, name) +// (which must be unique). If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +func AddColumn[T tensor.DataTypes](dt *Table, name string, cellSizes ...int) tensor.Tensor { + rows := dt.Columns.Rows + sz := append([]int{rows}, cellSizes...) + tsr := tensor.New[T](sz...) + // tsr.SetNames("Row") + dt.AddColumn(name, tsr) return tsr } // InsertColumn inserts a new column to the table, of given type and column name // (which must be unique), at given index. -// The cells of this column hold a single scalar value. -func InsertColumn[T string | bool | float32 | float64 | int | int32 | byte](dt *Table, name string, idx int) tensor.Tensor { - rows := max(1, dt.Rows) - tsr := tensor.New[T]([]int{rows}, "Row") - dt.InsertColumn(tsr, name, idx) - return tsr -} - -// AddTensorColumn adds a new n-dimensional column to the table, of given type, column name -// (which must be unique), and dimensionality of each _cell_. -// An outer-most Row dimension will be added to this dimensionality to create -// the tensor column. -func AddTensorColumn[T string | bool | float32 | float64 | int | int32 | byte](dt *Table, name string, cellSizes []int, dimNames ...string) tensor.Tensor { - rows := max(1, dt.Rows) +// If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +func InsertColumn[T tensor.DataTypes](dt *Table, name string, idx int, cellSizes ...int) tensor.Tensor { + rows := dt.Columns.Rows sz := append([]int{rows}, cellSizes...) - nms := append([]string{"Row"}, dimNames...) - tsr := tensor.New[T](sz, nms...) - dt.AddColumn(tsr, name) + tsr := tensor.New[T](sz...) + // tsr.SetNames("Row") + dt.InsertColumn(idx, name, tsr) return tsr } -// AddColumn adds the given tensor as a column to the table, +// AddColumn adds the given [tensor.Values] as a column to the table, // returning an error and not adding if the name is not unique. // Automatically adjusts the shape to fit the current number of rows. -func (dt *Table) AddColumn(tsr tensor.Tensor, name string) error { - dt.ColumnNames = append(dt.ColumnNames, name) - err := dt.UpdateColumnNameMap() - if err != nil { - dt.ColumnNames = dt.ColumnNames[:len(dt.ColumnNames)-1] - return err - } - dt.Columns = append(dt.Columns, tsr) - rows := max(1, dt.Rows) - tsr.SetNumRows(rows) - return nil +func (dt *Table) AddColumn(name string, tsr tensor.Values) error { + return dt.Columns.AddColumn(name, tsr) } -// InsertColumn inserts the given tensor as a column to the table at given index, +// InsertColumn inserts the given [tensor.Values] as a column to the table at given index, // returning an error and not adding if the name is not unique. // Automatically adjusts the shape to fit the current number of rows. -func (dt *Table) InsertColumn(tsr tensor.Tensor, name string, idx int) error { - if _, has := dt.ColumnNameMap[name]; has { - err := fmt.Errorf("table.Table duplicate column name: %s", name) - slog.Warn(err.Error()) - return err - } - dt.ColumnNames = slices.Insert(dt.ColumnNames, idx, name) - dt.UpdateColumnNameMap() - dt.Columns = slices.Insert(dt.Columns, idx, tsr) - rows := max(1, dt.Rows) - tsr.SetNumRows(rows) - return nil +func (dt *Table) InsertColumn(idx int, name string, tsr tensor.Values) error { + return dt.Columns.InsertColumn(idx, name, tsr) } // AddColumnOfType adds a new scalar column to the table, of given reflect type, // column name (which must be unique), -// The cells of this column hold a single (scalar) value of given type. -// Supported types are string, bool (for [tensor.Bits]), float32, float64, int, int32, and byte. -func (dt *Table) AddColumnOfType(typ reflect.Kind, name string) tensor.Tensor { - rows := max(1, dt.Rows) - tsr := tensor.NewOfType(typ, []int{rows}, "Row") - dt.AddColumn(tsr, name) - return tsr -} - -// AddTensorColumnOfType adds a new n-dimensional column to the table, of given reflect type, -// column name (which must be unique), and dimensionality of each _cell_. -// An outer-most Row dimension will be added to this dimensionality to create -// the tensor column. -// Supported types are string, bool (for [tensor.Bits]), float32, float64, int, int32, and byte. -func (dt *Table) AddTensorColumnOfType(typ reflect.Kind, name string, cellSizes []int, dimNames ...string) tensor.Tensor { - rows := max(1, dt.Rows) +// If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +// Supported types include string, bool (for [tensor.Bool]), float32, float64, int, int32, and byte. +func (dt *Table) AddColumnOfType(name string, typ reflect.Kind, cellSizes ...int) tensor.Tensor { + rows := dt.Columns.Rows sz := append([]int{rows}, cellSizes...) - nms := append([]string{"Row"}, dimNames...) - tsr := tensor.NewOfType(typ, sz, nms...) - dt.AddColumn(tsr, name) + tsr := tensor.NewOfType(typ, sz...) + // tsr.SetNames("Row") + dt.AddColumn(name, tsr) return tsr } // AddStringColumn adds a new String column with given name. -// The cells of this column hold a single string value. -func (dt *Table) AddStringColumn(name string) *tensor.String { - return AddColumn[string](dt, name).(*tensor.String) +// If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +func (dt *Table) AddStringColumn(name string, cellSizes ...int) *tensor.String { + return AddColumn[string](dt, name, cellSizes...).(*tensor.String) } // AddFloat64Column adds a new float64 column with given name. -// The cells of this column hold a single scalar value. -func (dt *Table) AddFloat64Column(name string) *tensor.Float64 { - return AddColumn[float64](dt, name).(*tensor.Float64) -} - -// AddFloat64TensorColumn adds a new n-dimensional float64 column with given name -// and dimensionality of each _cell_. -// An outer-most Row dimension will be added to this dimensionality to create -// the tensor column. -func (dt *Table) AddFloat64TensorColumn(name string, cellSizes []int, dimNames ...string) *tensor.Float64 { - return AddTensorColumn[float64](dt, name, cellSizes, dimNames...).(*tensor.Float64) +// If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +func (dt *Table) AddFloat64Column(name string, cellSizes ...int) *tensor.Float64 { + return AddColumn[float64](dt, name, cellSizes...).(*tensor.Float64) } // AddFloat32Column adds a new float32 column with given name. -// The cells of this column hold a single scalar value. -func (dt *Table) AddFloat32Column(name string) *tensor.Float32 { - return AddColumn[float32](dt, name).(*tensor.Float32) -} - -// AddFloat32TensorColumn adds a new n-dimensional float32 column with given name -// and dimensionality of each _cell_. -// An outer-most Row dimension will be added to this dimensionality to create -// the tensor column. -func (dt *Table) AddFloat32TensorColumn(name string, cellSizes []int, dimNames ...string) *tensor.Float32 { - return AddTensorColumn[float32](dt, name, cellSizes, dimNames...).(*tensor.Float32) +// If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +func (dt *Table) AddFloat32Column(name string, cellSizes ...int) *tensor.Float32 { + return AddColumn[float32](dt, name, cellSizes...).(*tensor.Float32) } // AddIntColumn adds a new int column with given name. -// The cells of this column hold a single scalar value. -func (dt *Table) AddIntColumn(name string) *tensor.Int { - return AddColumn[int](dt, name).(*tensor.Int) -} - -// AddIntTensorColumn adds a new n-dimensional int column with given name -// and dimensionality of each _cell_. -// An outer-most Row dimension will be added to this dimensionality to create -// the tensor column. -func (dt *Table) AddIntTensorColumn(name string, cellSizes []int, dimNames ...string) *tensor.Int { - return AddTensorColumn[int](dt, name, cellSizes, dimNames...).(*tensor.Int) +// If no cellSizes are specified, it holds scalar values, +// otherwise the cells are n-dimensional tensors of given size. +func (dt *Table) AddIntColumn(name string, cellSizes ...int) *tensor.Int { + return AddColumn[int](dt, name, cellSizes...).(*tensor.Int) } // DeleteColumnName deletes column of given name. -// returns error if not found. -func (dt *Table) DeleteColumnName(name string) error { - ci, err := dt.ColumnIndex(name) - if err != nil { - return err - } - dt.DeleteColumnIndex(ci) - return nil +// returns false if not found. +func (dt *Table) DeleteColumnName(name string) bool { + return dt.Columns.DeleteByKey(name) } -// DeleteColumnIndex deletes column of given index -func (dt *Table) DeleteColumnIndex(idx int) { - dt.Columns = append(dt.Columns[:idx], dt.Columns[idx+1:]...) - dt.ColumnNames = append(dt.ColumnNames[:idx], dt.ColumnNames[idx+1:]...) - dt.UpdateColumnNameMap() +// DeleteColumnIndex deletes column within the index range [i:j]. +func (dt *Table) DeleteColumnByIndex(i, j int) { + dt.Columns.DeleteByIndex(i, j) } -// DeleteAll deletes all columns -- full reset +// DeleteAll deletes all columns, does full reset. func (dt *Table) DeleteAll() { - dt.Columns = nil - dt.ColumnNames = nil - dt.Rows = 0 - dt.ColumnNameMap = nil -} - -// AddRows adds n rows to each of the columns -func (dt *Table) AddRows(n int) { //types:add - dt.SetNumRows(dt.Rows + n) + dt.Indexes = nil + dt.Columns.Reset() +} + +// AddRows adds n rows to end of underlying Table, and to the indexes in this view. +func (dt *Table) AddRows(n int) *Table { //types:add + return dt.SetNumRows(dt.Columns.Rows + n) +} + +// InsertRows adds n rows to end of underlying Table, and to the indexes starting at +// given index in this view, providing an efficient insertion operation that only +// exists in the indexed view. To create an in-memory ordering, use [Table.New]. +func (dt *Table) InsertRows(at, n int) *Table { + dt.IndexesNeeded() + strow := dt.Columns.Rows + stidx := len(dt.Indexes) + dt.SetNumRows(strow + n) // adds n indexes to end of list + // move those indexes to at:at+n in index list + dt.Indexes = append(dt.Indexes[:at], append(dt.Indexes[stidx:], dt.Indexes[at:]...)...) + dt.Indexes = dt.Indexes[:strow+n] + return dt } -// SetNumRows sets the number of rows in the table, across all columns -// if rows = 0 then effective number of rows in tensors is 1, as this dim cannot be 0 +// SetNumRows sets the number of rows in the table, across all columns. +// If rows = 0 then effective number of rows in tensors is 1, as this dim cannot be 0. +// If indexes are in place and rows are added, indexes for the new rows are added. func (dt *Table) SetNumRows(rows int) *Table { //types:add - dt.Rows = rows // can be 0 - rows = max(1, rows) - for _, tsr := range dt.Columns { - tsr.SetNumRows(rows) + strow := dt.Columns.Rows + dt.Columns.SetNumRows(rows) + if dt.Indexes == nil { + return dt + } + if rows > strow { + for i := range rows - strow { + dt.Indexes = append(dt.Indexes, strow+i) + } + } else { + dt.ValidIndexes() } return dt } +// SetNumRowsToMax gets the current max number of rows across all the column tensors, +// and sets the number of rows to that. This will automatically pad shorter columns +// so they all have the same number of rows. If a table has columns that are not fully +// under its own control, they can change size, so this reestablishes +// a common row dimension. +func (dt *Table) SetNumRowsToMax() { + var maxRow int + for _, tsr := range dt.Columns.Values { + maxRow = max(maxRow, tsr.DimSize(0)) + } + dt.SetNumRows(maxRow) +} + // note: no really clean definition of CopyFrom -- no point of re-using existing // table -- just clone it. -// Clone returns a complete copy of this table +// Clone returns a complete copy of this table, including cloning +// the underlying Columns tensors, and the current [Table.Indexes]. +// See also [Table.New] to flatten the current indexes. func (dt *Table) Clone() *Table { - cp := NewTable().SetNumRows(dt.Rows) - cp.CopyMetaDataFrom(dt) - for i, cl := range dt.Columns { - cp.AddColumn(cl.Clone(), dt.ColumnNames[i]) + cp := &Table{} + cp.Columns = dt.Columns.Clone() + cp.Meta.CopyFrom(dt.Meta) + if dt.Indexes != nil { + cp.Indexes = slices.Clone(dt.Indexes) } return cp } -// AppendRows appends shared columns in both tables with input table rows +// AppendRows appends shared columns in both tables with input table rows. func (dt *Table) AppendRows(dt2 *Table) { - shared := false - strow := dt.Rows - for iCol := range dt.Columns { - colName := dt.ColumnName(iCol) - _, err := dt2.ColumnIndex(colName) - if err != nil { - continue - } - if !shared { - shared = true - dt.AddRows(dt2.Rows) - } - for iRow := 0; iRow < dt2.Rows; iRow++ { - dt.CopyCell(colName, iRow+strow, dt2, colName, iRow) - } - } -} - -// SetMetaData sets given meta-data key to given value, safely creating the -// map if not yet initialized. Standard Keys are: -// * name -- name of table -// * desc -- description of table -// * read-only -- makes gui read-only (inactive edits) for tensorcore.Table -// * ColumnName:* -- prefix for all column-specific meta-data -// - desc -- description of column -func (dt *Table) SetMetaData(key, val string) { - if dt.MetaData == nil { - dt.MetaData = make(map[string]string) - } - dt.MetaData[key] = val -} - -// CopyMetaDataFrom copies meta data from other table -func (dt *Table) CopyMetaDataFrom(cp *Table) { - nm := len(cp.MetaData) - if nm == 0 { + strow := dt.Columns.Rows + n := dt2.Columns.Rows + dt.Columns.AppendRows(dt2.Columns) + if dt.Indexes == nil { return } - if dt.MetaData == nil { - dt.MetaData = make(map[string]string, nm) - } - for k, v := range cp.MetaData { - dt.MetaData[k] = v + for i := range n { + dt.Indexes = append(dt.Indexes, strow+i) } } - -// Named arg values for Contains, IgnoreCase -const ( - // Contains means the string only needs to contain the target string (see Equals) - Contains bool = true - // Equals means the string must equal the target string (see Contains) - Equals = false - // IgnoreCase means that differences in case are ignored in comparing strings - IgnoreCase = true - // UseCase means that case matters when comparing strings - UseCase = false -) - -// RowsByStringIndex returns the list of rows that have given -// string value in given column index. -// if contains, only checks if row contains string; if ignoreCase, ignores case. -// Use named args for greater clarity. -func (dt *Table) RowsByStringIndex(column int, str string, contains, ignoreCase bool) []int { - col := dt.Columns[column] - lowstr := strings.ToLower(str) - var idxs []int - for i := 0; i < dt.Rows; i++ { - val := col.String1D(i) - has := false - switch { - case contains && ignoreCase: - has = strings.Contains(strings.ToLower(val), lowstr) - case contains: - has = strings.Contains(val, str) - case ignoreCase: - has = strings.EqualFold(val, str) - default: - has = (val == str) - } - if has { - idxs = append(idxs, i) - } - } - return idxs -} - -// RowsByString returns the list of rows that have given -// string value in given column name. returns nil & error if name invalid. -// if contains, only checks if row contains string; if ignoreCase, ignores case. -// Use named args for greater clarity. -func (dt *Table) RowsByString(column string, str string, contains, ignoreCase bool) ([]int, error) { - ci, err := dt.ColumnIndex(column) - if err != nil { - return nil, err - } - return dt.RowsByStringIndex(ci, str, contains, ignoreCase), nil -} - -////////////////////////////////////////////////////////////////////////////////////// -// Cell convenience access methods - -// FloatIndex returns the float64 value of cell at given column, row index -// for columns that have 1-dimensional tensors. -// Returns NaN if column is not a 1-dimensional tensor or row not valid. -func (dt *Table) FloatIndex(column, row int) float64 { - if dt.IsValidRow(row) != nil { - return math.NaN() - } - ct := dt.Columns[column] - if ct.NumDims() != 1 { - return math.NaN() - } - return ct.Float1D(row) -} - -// Float returns the float64 value of cell at given column (by name), -// row index for columns that have 1-dimensional tensors. -// Returns NaN if column is not a 1-dimensional tensor -// or col name not found, or row not valid. -func (dt *Table) Float(column string, row int) float64 { - if dt.IsValidRow(row) != nil { - return math.NaN() - } - ct, err := dt.ColumnByName(column) - if err != nil { - return math.NaN() - } - if ct.NumDims() != 1 { - return math.NaN() - } - return ct.Float1D(row) -} - -// StringIndex returns the string value of cell at given column, row index -// for columns that have 1-dimensional tensors. -// Returns "" if column is not a 1-dimensional tensor or row not valid. -func (dt *Table) StringIndex(column, row int) string { - if dt.IsValidRow(row) != nil { - return "" - } - ct := dt.Columns[column] - if ct.NumDims() != 1 { - return "" - } - return ct.String1D(row) -} - -// NOTE: String conflicts with [fmt.Stringer], so we have to use StringValue - -// StringValue returns the string value of cell at given column (by name), row index -// for columns that have 1-dimensional tensors. -// Returns "" if column is not a 1-dimensional tensor or row not valid. -func (dt *Table) StringValue(column string, row int) string { - if dt.IsValidRow(row) != nil { - return "" - } - ct, err := dt.ColumnByName(column) - if err != nil { - return "" - } - if ct.NumDims() != 1 { - return "" - } - return ct.String1D(row) -} - -// TensorIndex returns the tensor SubSpace for given column, row index -// for columns that have higher-dimensional tensors so each row is -// represented by an n-1 dimensional tensor, with the outer dimension -// being the row number. Returns nil if column is a 1-dimensional -// tensor or there is any error from the tensor.Tensor.SubSpace call. -func (dt *Table) TensorIndex(column, row int) tensor.Tensor { - if dt.IsValidRow(row) != nil { - return nil - } - ct := dt.Columns[column] - if ct.NumDims() == 1 { - return nil - } - return ct.SubSpace([]int{row}) -} - -// Tensor returns the tensor SubSpace for given column (by name), row index -// for columns that have higher-dimensional tensors so each row is -// represented by an n-1 dimensional tensor, with the outer dimension -// being the row number. Returns nil on any error. -func (dt *Table) Tensor(column string, row int) tensor.Tensor { - if dt.IsValidRow(row) != nil { - return nil - } - ct, err := dt.ColumnByName(column) - if err != nil { - return nil - } - if ct.NumDims() == 1 { - return nil - } - return ct.SubSpace([]int{row}) -} - -// TensorFloat1D returns the float value of a Tensor cell's cell at given -// 1D offset within cell, for given column (by name), row index -// for columns that have higher-dimensional tensors so each row is -// represented by an n-1 dimensional tensor, with the outer dimension -// being the row number. Returns 0 on any error. -func (dt *Table) TensorFloat1D(column string, row int, idx int) float64 { - if dt.IsValidRow(row) != nil { - return math.NaN() - } - ct, err := dt.ColumnByName(column) - if err != nil { - return math.NaN() - } - if ct.NumDims() == 1 { - return math.NaN() - } - _, sz := ct.RowCellSize() - if idx >= sz || idx < 0 { - return math.NaN() - } - off := row*sz + idx - return ct.Float1D(off) -} - -///////////////////////////////////////////////////////////////////////////////////// -// Set - -// SetFloatIndex sets the float64 value of cell at given column, row index -// for columns that have 1-dimensional tensors. -func (dt *Table) SetFloatIndex(column, row int, val float64) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ct := dt.Columns[column] - if ct.NumDims() != 1 { - return fmt.Errorf("table.Table SetFloatIndex: Column %d is a tensor, must use SetTensorFloat1D", column) - } - ct.SetFloat1D(row, val) - return nil -} - -// SetFloat sets the float64 value of cell at given column (by name), row index -// for columns that have 1-dimensional tensors. -func (dt *Table) SetFloat(column string, row int, val float64) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ct, err := dt.ColumnByName(column) - if err != nil { - return err - } - if ct.NumDims() != 1 { - return fmt.Errorf("table.Table SetFloat: Column %s is a tensor, must use SetTensorFloat1D", column) - } - ct.SetFloat1D(row, val) - return nil -} - -// SetStringIndex sets the string value of cell at given column, row index -// for columns that have 1-dimensional tensors. Returns true if set. -func (dt *Table) SetStringIndex(column, row int, val string) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ct := dt.Columns[column] - if ct.NumDims() != 1 { - return fmt.Errorf("table.Table SetStringIndex: Column %d is a tensor, must use SetTensorFloat1D", column) - } - ct.SetString1D(row, val) - return nil -} - -// SetString sets the string value of cell at given column (by name), row index -// for columns that have 1-dimensional tensors. Returns true if set. -func (dt *Table) SetString(column string, row int, val string) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ct, err := dt.ColumnByName(column) - if err != nil { - return err - } - if ct.NumDims() != 1 { - return fmt.Errorf("table.Table SetString: Column %s is a tensor, must use SetTensorFloat1D", column) - } - ct.SetString1D(row, val) - return nil -} - -// SetTensorIndex sets the tensor value of cell at given column, row index -// for columns that have n-dimensional tensors. Returns true if set. -func (dt *Table) SetTensorIndex(column, row int, val tensor.Tensor) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ct := dt.Columns[column] - _, csz := ct.RowCellSize() - st := row * csz - sz := min(csz, val.Len()) - if ct.IsString() { - for j := 0; j < sz; j++ { - ct.SetString1D(st+j, val.String1D(j)) - } - } else { - for j := 0; j < sz; j++ { - ct.SetFloat1D(st+j, val.Float1D(j)) - } - } - return nil -} - -// SetTensor sets the tensor value of cell at given column (by name), row index -// for columns that have n-dimensional tensors. Returns true if set. -func (dt *Table) SetTensor(column string, row int, val tensor.Tensor) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ci, err := dt.ColumnIndex(column) - if err != nil { - return err - } - return dt.SetTensorIndex(ci, row, val) -} - -// SetTensorFloat1D sets the tensor cell's float cell value at given 1D index within cell, -// at given column (by name), row index for columns that have n-dimensional tensors. -// Returns true if set. -func (dt *Table) SetTensorFloat1D(column string, row int, idx int, val float64) error { - if err := dt.IsValidRow(row); err != nil { - return err - } - ct, err := dt.ColumnByName(column) - if err != nil { - return err - } - _, sz := ct.RowCellSize() - if idx >= sz || idx < 0 { - return fmt.Errorf("table.Table IsValidRow: index %d is out of valid range [0..%d]", idx, sz) - } - off := row*sz + idx - ct.SetFloat1D(off, val) - return nil -} - -////////////////////////////////////////////////////////////////////////////////////// -// Copy Cell - -// CopyCell copies into cell at given column, row from cell in other table. -// It is robust to differences in type; uses destination cell type. -// Returns error if column names are invalid. -func (dt *Table) CopyCell(column string, row int, cpt *Table, cpColNm string, cpRow int) error { - ct, err := dt.ColumnByName(column) - if err != nil { - return err - } - cpct, err := cpt.ColumnByName(cpColNm) - if err != nil { - return err - } - _, sz := ct.RowCellSize() - if sz == 1 { - if ct.IsString() { - ct.SetString1D(row, cpct.String1D(cpRow)) - return nil - } - ct.SetFloat1D(row, cpct.Float1D(cpRow)) - return nil - } - _, cpsz := cpct.RowCellSize() - st := row * sz - cst := cpRow * cpsz - msz := min(sz, cpsz) - if ct.IsString() { - for j := 0; j < msz; j++ { - ct.SetString1D(st+j, cpct.String1D(cst+j)) - } - } else { - for j := 0; j < msz; j++ { - ct.SetFloat1D(st+j, cpct.Float1D(cst+j)) - } - } - return nil -} diff --git a/tensor/table/table_test.go b/tensor/table/table_test.go index dbf34cdc43..a848e3e31f 100644 --- a/tensor/table/table_test.go +++ b/tensor/table/table_test.go @@ -8,72 +8,128 @@ import ( "strconv" "testing" + "cogentcore.org/core/tensor" "github.com/stretchr/testify/assert" ) func TestAdd3DCol(t *testing.T) { - dt := NewTable() - dt.AddFloat32TensorColumn("Values", []int{11, 1, 16}) + dt := New() + dt.AddFloat32Column("Values", 11, 1, 16) - col, err := dt.ColumnByName("Values") - if err != nil { - t.Error(err) - } + col := dt.Column("Values").Tensor if col.NumDims() != 4 { t.Errorf("Add4DCol: # of dims != 4\n") } - if col.Shape().DimSize(0) != 1 { - t.Errorf("Add4DCol: dim 0 len != 1, was: %v\n", col.Shape().DimSize(0)) + if col.DimSize(0) != 0 { + t.Errorf("Add4DCol: dim 0 len != 0, was: %v\n", col.DimSize(0)) } - if col.Shape().DimSize(1) != 11 { - t.Errorf("Add4DCol: dim 0 len != 11, was: %v\n", col.Shape().DimSize(1)) + if col.DimSize(1) != 11 { + t.Errorf("Add4DCol: dim 0 len != 11, was: %v\n", col.DimSize(1)) } - if col.Shape().DimSize(2) != 1 { - t.Errorf("Add4DCol: dim 0 len != 1, was: %v\n", col.Shape().DimSize(2)) + if col.DimSize(2) != 1 { + t.Errorf("Add4DCol: dim 0 len != 1, was: %v\n", col.DimSize(2)) } - if col.Shape().DimSize(3) != 16 { + if col.DimSize(3) != 16 { t.Errorf("Add4DCol: dim 0 len != 16, was: %v\n", col.Shape().DimSize(3)) } } func NewTestTable() *Table { - dt := NewTable() + dt := New() dt.AddStringColumn("Str") dt.AddFloat64Column("Flt64") dt.AddIntColumn("Int") dt.SetNumRows(3) - for i := 0; i < dt.Rows; i++ { - dt.SetString("Str", i, strconv.Itoa(i)) - dt.SetFloat("Flt64", i, float64(i)) - dt.SetFloat("Int", i, float64(i)) + for i := range dt.NumRows() { + dt.Column("Str").SetStringRow(strconv.Itoa(i), i, 0) + dt.Column("Flt64").SetFloatRow(float64(i), i, 0) + dt.Column("Int").SetFloatRow(float64(i), i, 0) } return dt } -func TestAppendRows(t *testing.T) { +func TestAppendRowsEtc(t *testing.T) { st := NewTestTable() dt := NewTestTable() dt.AppendRows(st) dt.AppendRows(st) dt.AppendRows(st) - for j := 0; j < 3; j++ { - for i := 0; i < st.Rows; i++ { + for j := range 3 { + for i := range st.NumRows() { sr := j*3 + i - ss := st.StringValue("Str", i) - ds := dt.StringValue("Str", sr) + ss := st.Column("Str").StringRow(i, 0) + ds := dt.Column("Str").StringRow(sr, 0) assert.Equal(t, ss, ds) - sf := st.Float("Flt64", i) - df := dt.Float("Flt64", sr) + sf := st.Column("Flt64").FloatRow(i, 0) + df := dt.Column("Flt64").FloatRow(sr, 0) assert.Equal(t, sf, df) - sf = st.Float("Int", i) - df = dt.Float("Int", sr) + sf = st.Column("Int").FloatRow(i, 0) + df = dt.Column("Int").FloatRow(sr, 0) assert.Equal(t, sf, df) } } + dt.Sequential() + dt.SortColumn("Int", tensor.Descending) + assert.Equal(t, []int{2, 5, 8, 11, 1, 4, 7, 10, 0, 3, 6, 9}, dt.Indexes) + + dt.Sequential() + dt.SortColumns(tensor.Descending, true, "Int", "Flt64") + assert.Equal(t, []int{2, 5, 8, 11, 1, 4, 7, 10, 0, 3, 6, 9}, dt.Indexes) + + dt.Sequential() + dt.FilterString("Int", "1", tensor.FilterOptions{Contains: true, IgnoreCase: true}) + assert.Equal(t, []int{1, 4, 7, 10}, dt.Indexes) + + dt.Sequential() + dt.Filter(func(dt *Table, row int) bool { + return dt.Column("Flt64").FloatRow(row, 0) > 1 + }) + assert.Equal(t, []int{2, 5, 8, 11}, dt.Indexes) +} + +func TestSetNumRows(t *testing.T) { + st := NewTestTable() + dt := NewTestTable() + dt.AppendRows(st) + dt.AppendRows(st) + dt.AppendRows(st) + dt.IndexesNeeded() + dt.SetNumRows(3) + assert.Equal(t, []int{0, 1, 2}, dt.Indexes) +} + +func TestInsertDeleteRows(t *testing.T) { + dt := NewTestTable() + dt.IndexesNeeded() + dt.InsertRows(1, 2) + assert.Equal(t, []int{0, 3, 4, 1, 2}, dt.Indexes) + dt.DeleteRows(1, 2) + assert.Equal(t, []int{0, 1, 2}, dt.Indexes) +} + +func TestCells(t *testing.T) { + dt := New() + err := dt.OpenCSV("../stats/cluster/testdata/faces.dat", tensor.Tab) + assert.NoError(t, err) + in := dt.Column("Input") + for i := range 10 { + vals := make([]float32, 16) + for j := range 16 { + vals[j] = float32(in.FloatRow(i, j)) + } + // fmt.Println(s) + ss := in.Tensor.SubSpace(i).(*tensor.Float32) + // fmt.Println(ss.Values[:16]) + + cl := tensor.AsFloat32(tensor.Cells1D(in, i)) + // fmt.Println(cl.Values[:16]) + assert.Equal(t, vals, ss.Values[:16]) + assert.Equal(t, vals, cl.Values[:16]) + } } diff --git a/tensor/table/typegen.go b/tensor/table/typegen.go index 92eec5f9c5..9754001a97 100644 --- a/tensor/table/typegen.go +++ b/tensor/table/typegen.go @@ -6,6 +6,4 @@ import ( "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/table.IndexView", IDName: "index-view", Doc: "IndexView is an indexed wrapper around an table.Table that provides a\nspecific view onto the Table defined by the set of indexes.\nThis provides an efficient way of sorting and filtering a table by only\nupdating the indexes while doing nothing to the Table itself.\nTo produce a table that has data actually organized according to the\nindexed order, call the NewTable method.\nIndexView views on a table can also be organized together as Splits\nof the table rows, e.g., by grouping values along a given column.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Sequential", Doc: "Sequential sets indexes to sequential row-wise indexes into table", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "SortColumnName", Doc: "SortColumnName sorts the indexes into our Table according to values in\ngiven column name, using either ascending or descending order.\nOnly valid for 1-dimensional columns.\nReturns error if column name not found.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"column", "ascending"}, Returns: []string{"error"}}, {Name: "FilterColumnName", Doc: "FilterColumnName filters the indexes into our Table according to values in\ngiven column name, using string representation of column values.\nIncludes rows with matching values unless exclude is set.\nIf contains, only checks if row contains string; if ignoreCase, ignores case.\nUse named args for greater clarity.\nOnly valid for 1-dimensional columns.\nReturns error if column name not found.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"column", "str", "exclude", "contains", "ignoreCase"}, Returns: []string{"error"}}, {Name: "AddRows", Doc: "AddRows adds n rows to end of underlying Table, and to the indexes in this view", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"n"}}, {Name: "SaveCSV", Doc: "SaveCSV writes a table index view to a comma-separated-values (CSV) file\n(where comma = any delimiter, specified in the delim arg).\nIf headers = true then generate column headers that capture the type\nand tensor cell geometry of the columns, enabling full reloading\nof exactly the same table format and data (recommended).\nOtherwise, only the data is written.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim", "headers"}, Returns: []string{"error"}}, {Name: "OpenCSV", Doc: "OpenCSV reads a table idx view from a comma-separated-values (CSV) file\n(where comma = any delimiter, specified in the delim arg),\nusing the Go standard encoding/csv reader conforming to the official CSV standard.\nIf the table does not currently have any columns, the first row of the file\nis assumed to be headers, and columns are constructed therefrom.\nIf the file was saved from table with headers, then these have full configuration\ninformation for tensor type and dimensionality.\nIf the table DOES have existing columns, then those are used robustly\nfor whatever information fits from each row of the file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}, Returns: []string{"error"}}}, Fields: []types.Field{{Name: "Table", Doc: "Table that we are an indexed view onto"}, {Name: "Indexes", Doc: "current indexes into Table"}, {Name: "lessFunc", Doc: "current Less function used in sorting"}}}) - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/table.Table", IDName: "table", Doc: "Table is a table of data, with columns of tensors,\neach with the same number of Rows (outer-most dimension).", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "SaveCSV", Doc: "SaveCSV writes a table to a comma-separated-values (CSV) file\n(where comma = any delimiter, specified in the delim arg).\nIf headers = true then generate column headers that capture the type\nand tensor cell geometry of the columns, enabling full reloading\nof exactly the same table format and data (recommended).\nOtherwise, only the data is written.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim", "headers"}, Returns: []string{"error"}}, {Name: "OpenCSV", Doc: "OpenCSV reads a table from a comma-separated-values (CSV) file\n(where comma = any delimiter, specified in the delim arg),\nusing the Go standard encoding/csv reader conforming to the official CSV standard.\nIf the table does not currently have any columns, the first row of the file\nis assumed to be headers, and columns are constructed therefrom.\nIf the file was saved from table with headers, then these have full configuration\ninformation for tensor type and dimensionality.\nIf the table DOES have existing columns, then those are used robustly\nfor whatever information fits from each row of the file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}, Returns: []string{"error"}}, {Name: "AddRows", Doc: "AddRows adds n rows to each of the columns", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"n"}}, {Name: "SetNumRows", Doc: "SetNumRows sets the number of rows in the table, across all columns\nif rows = 0 then effective number of rows in tensors is 1, as this dim cannot be 0", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"rows"}, Returns: []string{"Table"}}}, Fields: []types.Field{{Name: "Columns", Doc: "columns of data, as tensor.Tensor tensors"}, {Name: "ColumnNames", Doc: "the names of the columns"}, {Name: "Rows", Doc: "number of rows, which is enforced to be the size of the outer-most dimension of the column tensors"}, {Name: "ColumnNameMap", Doc: "the map of column names to column numbers"}, {Name: "MetaData", Doc: "misc meta data for the table. We use lower-case key names following the struct tag convention: name = name of table; desc = description; read-only = gui is read-only; precision = n for precision to write out floats in csv. For Column-specific data, we look for ColumnName: prefix, specifically ColumnName:desc = description of the column contents, which is shown as tooltip in the tensorcore.Table, and :width for width of a column"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/table.Table", IDName: "table", Doc: "Table is a table of Tensor columns aligned by a common outermost row dimension.\nUse the [Table.Column] (by name) and [Table.ColumnIndex] methods to obtain a\n[tensor.Rows] view of the column, using the shared [Table.Indexes] of the Table.\nThus, a coordinated sorting and filtered view of the column data is automatically\navailable for any of the tensor package functions that use [tensor.Tensor] as the one\ncommon data representation for all operations.\nTensor Columns are always raw value types and support SubSpace operations on cells.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Sequential", Doc: "Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "SortColumn", Doc: "SortColumn sorts the indexes into our Table according to values in\ngiven column, using either ascending or descending order,\n(use [tensor.Ascending] or [tensor.Descending] for self-documentation).\nUses first cell of higher dimensional data.\nReturns error if column name not found.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"columnName", "ascending"}, Returns: []string{"error"}}, {Name: "SortColumns", Doc: "SortColumns sorts the indexes into our Table according to values in\ngiven column names, using either ascending or descending order,\n(use [tensor.Ascending] or [tensor.Descending] for self-documentation,\nand optionally using a stable sort.\nUses first cell of higher dimensional data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"ascending", "stable", "columns"}}, {Name: "FilterString", Doc: "FilterString filters the indexes using string values in column compared to given\nstring. Includes rows with matching values unless the Exclude option is set.\nIf Contains option is set, it only checks if row contains string;\nif IgnoreCase, ignores case, otherwise filtering is case sensitive.\nUses first cell from higher dimensions.\nReturns error if column name not found.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"columnName", "str", "opts"}, Returns: []string{"error"}}, {Name: "SaveCSV", Doc: "SaveCSV writes a table to a comma-separated-values (CSV) file\n(where comma = any delimiter, specified in the delim arg).\nIf headers = true then generate column headers that capture the type\nand tensor cell geometry of the columns, enabling full reloading\nof exactly the same table format and data (recommended).\nOtherwise, only the data is written.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim", "headers"}, Returns: []string{"error"}}, {Name: "OpenCSV", Doc: "OpenCSV reads a table from a comma-separated-values (CSV) file\n(where comma = any delimiter, specified in the delim arg),\nusing the Go standard encoding/csv reader conforming to the official CSV standard.\nIf the table does not currently have any columns, the first row of the file\nis assumed to be headers, and columns are constructed therefrom.\nIf the file was saved from table with headers, then these have full configuration\ninformation for tensor type and dimensionality.\nIf the table DOES have existing columns, then those are used robustly\nfor whatever information fits from each row of the file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}, Returns: []string{"error"}}, {Name: "AddRows", Doc: "AddRows adds n rows to end of underlying Table, and to the indexes in this view.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"n"}, Returns: []string{"Table"}}, {Name: "SetNumRows", Doc: "SetNumRows sets the number of rows in the table, across all columns.\nIf rows = 0 then effective number of rows in tensors is 1, as this dim cannot be 0.\nIf indexes are in place and rows are added, indexes for the new rows are added.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"rows"}, Returns: []string{"Table"}}}, Fields: []types.Field{{Name: "Columns", Doc: "Columns has the list of column tensor data for this table.\nDifferent tables can provide different indexed views onto the same Columns."}, {Name: "Indexes", Doc: "Indexes are the indexes into Tensor rows, with nil = sequential.\nOnly set if order is different from default sequential order.\nThese indexes are shared into the `tensor.Rows` Column values\nto provide a coordinated indexed view into the underlying data."}, {Name: "Meta", Doc: "Meta is misc metadata for the table. Use lower-case key names\nfollowing the struct tag convention:\n\t- name string = name of table\n\t- doc string = documentation, description\n\t- read-only bool = gui is read-only\n\t- precision int = n for precision to write out floats in csv."}}}) diff --git a/tensor/table/util.go b/tensor/table/util.go index a631a168cd..97a3632885 100644 --- a/tensor/table/util.go +++ b/tensor/table/util.go @@ -27,8 +27,8 @@ func (dt *Table) InsertKeyColumns(args ...string) *Table { for j := range nc { colNm := args[2*j] val := args[2*j+1] - col := tensor.NewString([]int{c.Rows}) - c.InsertColumn(col, colNm, 0) + col := tensor.NewString(c.Columns.Rows) + c.InsertColumn(0, colNm, col) for i := range col.Values { col.Values[i] = val } @@ -40,10 +40,10 @@ func (dt *Table) InsertKeyColumns(args ...string) *Table { // values in the first two columns of given format table, conventionally named // Name, Type (but names are not used), which must be of the string type. func (dt *Table) ConfigFromTable(ft *Table) error { - nmcol := ft.Columns[0] - tycol := ft.Columns[1] + nmcol := ft.ColumnByIndex(0) + tycol := ft.ColumnByIndex(1) var errs []error - for i := range ft.Rows { + for i := range ft.NumRows() { name := nmcol.String1D(i) typ := strings.ToLower(tycol.String1D(i)) kind := reflect.Float64 @@ -66,7 +66,7 @@ func (dt *Table) ConfigFromTable(ft *Table) error { err := fmt.Errorf("ConfigFromTable: type string %q not recognized", typ) errs = append(errs, err) } - dt.AddColumnOfType(kind, name) + dt.AddColumnOfType(name, kind) } return errors.Join(errs...) } diff --git a/tensor/tensor.go b/tensor/tensor.go index d0953711a0..d61e7b8532 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -10,227 +10,148 @@ import ( "fmt" "reflect" - "gonum.org/v1/gonum/mat" + "cogentcore.org/core/base/metadata" ) +// DataTypes are the primary tensor data types with specific support. +// Any numerical type can also be used. bool is represented using an +// efficient bit slice. +type DataTypes interface { + string | bool | float32 | float64 | int | int32 | uint32 | byte +} + +// MaxSprintLength is the default maximum length of a String() representation +// of a tensor, as generated by the Sprint function. Defaults to 1000. +var MaxSprintLength = 1000 + // todo: add a conversion function to copy data from Column-Major to a tensor: // It is also possible to use Column-Major order, which is used in R, Julia, and MATLAB -// where the inner-most index is first and outer-most last. +// where the inner-most index is first and outermost last. -// Tensor is the interface for n-dimensional tensors. +// Tensor is the most general interface for n-dimensional tensors. // Per C / Go / Python conventions, indexes are Row-Major, ordered from // outer to inner left-to-right, so the inner-most is right-most. -// It is implemented by the Base and Number generic types specialized -// by different concrete types: float64, float32, int, int32, byte, -// string, bits (bools). -// For float32 and float64 values, use NaN to indicate missing values. -// All of the data analysis and plot packages skip NaNs. +// It is implemented for raw [Values] with direct integer indexing +// by the [Number], [String], and [Bool] types, covering the different +// concrete types specified by [DataTypes] (see [Values] for +// additional interface methods for raw value types). +// For float32 and float64 values, use NaN to indicate missing values, +// as all of the data analysis and plot packages skip NaNs. +// View Tensor types provide different ways of viewing a source tensor, +// including [Sliced] for arbitrary slices of dimension indexes, +// [Masked] for boolean masked access of individual elements, +// and [Indexed] for arbitrary indexes of values, organized into the +// shape of the indexes, not the original source data. +// [Reshaped] provides length preserving reshaping (mostly for computational +// alignment purposes), and [Rows] provides an optimized row-indexed +// view for [table.Table] data. type Tensor interface { fmt.Stringer - mat.Matrix - // Shape returns a pointer to the shape that fully parametrizes the tensor shape + // Label satisfies the core.Labeler interface for a summary + // description of the tensor, including metadata Name if set. + Label() string + + // Metadata returns the metadata for this tensor, which can be used + // to encode name, docs, shape dimension names, plotting options, etc. + Metadata() *metadata.Data + + // Shape() returns a [Shape] representation of the tensor shape + // (dimension sizes). For tensors that present a view onto another + // tensor, this typically must be constructed. + // In general, it is better to use the specific [Tensor.ShapeSizes], + // [Tensor.ShapeSizes], [Tensor.DimSize] etc as neeed. Shape() *Shape - // Len returns the number of elements in the tensor (product of shape dimensions). + // ShapeSizes returns the sizes of each dimension as a slice of ints. + // This is the preferred access for Go code. + ShapeSizes() []int + + // Len returns the total number of elements in the tensor, + // i.e., the product of all shape dimensions. + // Len must always be such that the 1D() accessors return + // values using indexes from 0..Len()-1. Len() int // NumDims returns the total number of dimensions. NumDims() int - // DimSize returns size of given dimension + // DimSize returns size of given dimension. DimSize(dim int) int - // RowCellSize returns the size of the outer-most Row shape dimension, - // and the size of all the remaining inner dimensions (the "cell" size). - // Used for Tensors that are columns in a data table. - RowCellSize() (rows, cells int) - // DataType returns the type of the data elements in the tensor. - // Bool is returned for the Bits tensor type. + // Bool is returned for the Bool tensor type. DataType() reflect.Kind - // Sizeof returns the number of bytes contained in the Values of this tensor. - // for String types, this is just the string pointers. - Sizeof() int64 - - // Bytes returns the underlying byte representation of the tensor values. - // This is the actual underlying data, so make a copy if it can be - // unintentionally modified or retained more than for immediate use. - Bytes() []byte - - // returns true if the data type is a String. otherwise is numeric. + // IsString returns true if the data type is a String; otherwise it is numeric. IsString() bool - // Float returns the value of given index as a float64. - Float(i []int) float64 + // AsValues returns this tensor as raw [Values]. If it already is, + // it is returned directly. If it is a View tensor, the view is + // "rendered" into a fully contiguous and optimized [Values] representation + // of that view, which will be faster to access for further processing, + // and enables all the additional functionality provided by the [Values] interface. + AsValues() Values - // SetFloat sets the value of given index as a float64 - SetFloat(i []int, val float64) + //////// Floats - // NOTE: String conflicts with [fmt.Stringer], so we have to use StringValue + // Float returns the value of given n-dimensional index (matching Shape) as a float64. + Float(i ...int) float64 - // StringValue returns the value of given index as a string - StringValue(i []int) string + // SetFloat sets the value of given n-dimensional index (matching Shape) as a float64. + SetFloat(val float64, i ...int) - // SetString sets the value of given index as a string - SetString(i []int, val string) - - // Float1D returns the value of given 1-dimensional index (0-Len()-1) as a float64 + // Float1D returns the value of given 1-dimensional index (0-Len()-1) as a float64. + // If index is negative, it indexes from the end of the list (-1 = last). + // This can be somewhat expensive in wrapper views ([Rows], [Sliced]), which + // convert the flat index back into a full n-dimensional index and use that api. + // [Tensor.FloatRow] is preferred. Float1D(i int) float64 - // SetFloat1D sets the value of given 1-dimensional index (0-Len()-1) as a float64 - SetFloat1D(i int, val float64) - - // FloatRowCell returns the value at given row and cell, where row is outer-most dim, - // and cell is 1D index into remaining inner dims. For Table columns. - FloatRowCell(row, cell int) float64 + // SetFloat1D sets the value of given 1-dimensional index (0-Len()-1) as a float64. + // If index is negative, it indexes from the end of the list (-1 = last). + // This can be somewhat expensive in the commonly-used [Rows] view; + // [Tensor.SetFloatRow] is preferred. + SetFloat1D(val float64, i int) - // SetFloatRowCell sets the value at given row and cell, where row is outer-most dim, - // and cell is 1D index into remaining inner dims. For Table columns. - SetFloatRowCell(row, cell int, val float64) + //////// Strings - // Floats sets []float64 slice of all elements in the tensor - // (length is ensured to be sufficient). - // This can be used for all of the gonum/floats methods - // for basic math, gonum/stats, etc. - Floats(flt *[]float64) + // StringValue returns the value of given n-dimensional index (matching Shape) as a string. + // 'String' conflicts with [fmt.Stringer], so we have to use StringValue here. + StringValue(i ...int) string - // SetFloats sets tensor values from a []float64 slice (copies values). - SetFloats(vals []float64) + // SetString sets the value of given n-dimensional index (matching Shape) as a string. + SetString(val string, i ...int) - // String1D returns the value of given 1-dimensional index (0-Len()-1) as a string + // String1D returns the value of given 1-dimensional index (0-Len()-1) as a string. + // If index is negative, it indexes from the end of the list (-1 = last). String1D(i int) string - // SetString1D sets the value of given 1-dimensional index (0-Len()-1) as a string - SetString1D(i int, val string) - - // StringRowCell returns the value at given row and cell, where row is outer-most dim, - // and cell is 1D index into remaining inner dims. For Table columns - StringRowCell(row, cell int) string - - // SetStringRowCell sets the value at given row and cell, where row is outer-most dim, - // and cell is 1D index into remaining inner dims. For Table columns - SetStringRowCell(row, cell int, val string) - - // SubSpace returns a new tensor with innermost subspace at given - // offset(s) in outermost dimension(s) (len(offs) < NumDims). - // The new tensor points to the values of the this tensor (i.e., modifications - // will affect both), as its Values slice is a view onto the original (which - // is why only inner-most contiguous supsaces are supported). - // Use Clone() method to separate the two. - SubSpace(offs []int) Tensor - - // Range returns the min, max (and associated indexes, -1 = no values) for the tensor. - // This is needed for display and is thus in the core api in optimized form - // Other math operations can be done using gonum/floats package. - Range() (min, max float64, minIndex, maxIndex int) - - // SetZeros is simple convenience function initialize all values to 0 - SetZeros() - - // Clone clones this tensor, creating a duplicate copy of itself with its - // own separate memory representation of all the values, and returns - // that as a Tensor (which can be converted into the known type as needed). - Clone() Tensor - - // CopyFrom copies all avail values from other tensor into this tensor, with an - // optimized implementation if the other tensor is of the same type, and - // otherwise it goes through appropriate standard type. - CopyFrom(from Tensor) - - // CopyShapeFrom copies just the shape from given source tensor - // calling SetShape with the shape params from source (see for more docs). - CopyShapeFrom(from Tensor) - - // CopyCellsFrom copies given range of values from other tensor into this tensor, - // using flat 1D indexes: to = starting index in this Tensor to start copying into, - // start = starting index on from Tensor to start copying from, and n = number of - // values to copy. Uses an optimized implementation if the other tensor is - // of the same type, and otherwise it goes through appropriate standard type. - CopyCellsFrom(from Tensor, to, start, n int) - - // SetShape sets the sizes parameters of the tensor, and resizes backing storage appropriately. - // existing names will be preserved if not presented. - SetShape(sizes []int, names ...string) - - // SetNumRows sets the number of rows (outer-most dimension). - SetNumRows(rows int) - - // SetMetaData sets a key=value meta data (stored as a map[string]string). - // For TensorGrid display: top-zero=+/-, odd-row=+/-, image=+/-, - // min, max set fixed min / max values, background=color - SetMetaData(key, val string) - - // MetaData retrieves value of given key, bool = false if not set - MetaData(key string) (string, bool) - - // MetaDataMap returns the underlying map used for meta data - MetaDataMap() map[string]string - - // CopyMetaData copies meta data from given source tensor - CopyMetaData(from Tensor) -} + // SetString1D sets the value of given 1-dimensional index (0-Len()-1) as a string. + // If index is negative, it indexes from the end of the list (-1 = last). + SetString1D(val string, i int) -// New returns a new n-dimensional tensor of given value type -// with the given sizes per dimension (shape), and optional dimension names. -func New[T string | bool | float32 | float64 | int | int32 | byte](sizes []int, names ...string) Tensor { - var v T - switch any(v).(type) { - case string: - return NewString(sizes, names...) - case bool: - return NewBits(sizes, names...) - case float64: - return NewNumber[float64](sizes, names...) - case float32: - return NewNumber[float32](sizes, names...) - case int: - return NewNumber[int](sizes, names...) - case int32: - return NewNumber[int32](sizes, names...) - case byte: - return NewNumber[byte](sizes, names...) - default: - panic("tensor.New: unexpected error: type not supported") - } -} + //////// Ints -// NewOfType returns a new n-dimensional tensor of given reflect.Kind type -// with the given sizes per dimension (shape), and optional dimension names. -// Supported types are string, bool (for [Bits]), float32, float64, int, int32, and byte. -func NewOfType(typ reflect.Kind, sizes []int, names ...string) Tensor { - switch typ { - case reflect.String: - return NewString(sizes, names...) - case reflect.Bool: - return NewBits(sizes, names...) - case reflect.Float64: - return NewNumber[float64](sizes, names...) - case reflect.Float32: - return NewNumber[float32](sizes, names...) - case reflect.Int: - return NewNumber[int](sizes, names...) - case reflect.Int32: - return NewNumber[int32](sizes, names...) - case reflect.Uint8: - return NewNumber[byte](sizes, names...) - default: - panic(fmt.Sprintf("tensor.NewOfType: type not supported: %v", typ)) - } + // Int returns the value of given n-dimensional index (matching Shape) as a int. + Int(i ...int) int + + // SetInt sets the value of given n-dimensional index (matching Shape) as a int. + SetInt(val int, i ...int) + + // Int1D returns the value of given 1-dimensional index (0-Len()-1) as a int. + // If index is negative, it indexes from the end of the list (-1 = last). + Int1D(i int) int + + // SetInt1D sets the value of given 1-dimensional index (0-Len()-1) as a int. + // If index is negative, it indexes from the end of the list (-1 = last). + SetInt1D(val int, i int) } -// CopyDense copies a gonum mat.Dense matrix into given Tensor -// using standard Float64 interface -func CopyDense(to Tensor, dm *mat.Dense) { - nr, nc := dm.Dims() - to.SetShape([]int{nr, nc}) - idx := 0 - for ri := 0; ri < nr; ri++ { - for ci := 0; ci < nc; ci++ { - v := dm.At(ri, ci) - to.SetFloat1D(idx, v) - idx++ - } +// NegIndex handles negative index values as counting backward from n. +func NegIndex(i, n int) int { + if i < 0 { + return n + i } + return i } diff --git a/tensor/tensor_test.go b/tensor/tensor_test.go index 2efa6c1ba4..d77ed684d6 100644 --- a/tensor/tensor_test.go +++ b/tensor/tensor_test.go @@ -8,125 +8,595 @@ import ( "reflect" "testing" + "cogentcore.org/core/base/metadata" "github.com/stretchr/testify/assert" ) +func TestProjection2D(t *testing.T) { + shp := NewShape(5) + var nilINts []int + rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, OnedRow) + assert.Equal(t, []int{5}, rowShape.Sizes) + assert.Equal(t, []int{1}, colShape.Sizes) + assert.Equal(t, []int{0}, rowIdxs) + assert.Equal(t, nilINts, colIdxs) + + rowShape, colShape, rowIdxs, colIdxs = Projection2DDimShapes(shp, OnedColumn) + assert.Equal(t, []int{1}, rowShape.Sizes) + assert.Equal(t, []int{5}, colShape.Sizes) + assert.Equal(t, nilINts, rowIdxs) + assert.Equal(t, []int{0}, colIdxs) + + shp = NewShape(3, 4) + rowShape, colShape, rowIdxs, colIdxs = Projection2DDimShapes(shp, OnedRow) + assert.Equal(t, []int{3}, rowShape.Sizes) + assert.Equal(t, []int{4}, colShape.Sizes) + assert.Equal(t, []int{0}, rowIdxs) + assert.Equal(t, []int{1}, colIdxs) + + shp = NewShape(3, 4, 5) + rowShape, colShape, rowIdxs, colIdxs = Projection2DDimShapes(shp, OnedRow) + assert.Equal(t, []int{3, 4}, rowShape.Sizes) + assert.Equal(t, []int{5}, colShape.Sizes) + assert.Equal(t, []int{0, 1}, rowIdxs) + assert.Equal(t, []int{2}, colIdxs) + + shp = NewShape(3, 4, 5, 6) + rowShape, colShape, rowIdxs, colIdxs = Projection2DDimShapes(shp, OnedRow) + assert.Equal(t, []int{3, 5}, rowShape.Sizes) + assert.Equal(t, []int{4, 6}, colShape.Sizes) + assert.Equal(t, []int{0, 2}, rowIdxs) + assert.Equal(t, []int{1, 3}, colIdxs) + + shp = NewShape(3, 4, 5, 6, 7) + rowShape, colShape, rowIdxs, colIdxs = Projection2DDimShapes(shp, OnedRow) + assert.Equal(t, []int{3, 4, 6}, rowShape.Sizes) + assert.Equal(t, []int{5, 7}, colShape.Sizes) + assert.Equal(t, []int{0, 1, 3}, rowIdxs) + assert.Equal(t, []int{2, 4}, colIdxs) +} + +func TestPrintf(t *testing.T) { + ft := NewFloat64(4) + for x := range 4 { + ft.SetFloat(float64(x), x) + } + // fmt.Println(ft.String()) + res := `[4] 0 1 2 3 +` + assert.Equal(t, res, ft.String()) + + ft = NewFloat64(40) + for x := range 40 { + ft.SetFloat(float64(x), x) + } + // fmt.Println(ft.String()) + res = `[40] 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 + 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +` + assert.Equal(t, res, ft.String()) + + ft = NewFloat64(4, 1) + for x := range 4 { + ft.SetFloat(float64(x), x, 0) + } + // fmt.Println(ft.String()) + res = `[4 1] +[0] 0 +[1] 1 +[2] 2 +[3] 3 +` + assert.Equal(t, res, ft.String()) + + ft = NewFloat64(4, 3) + for y := range 4 { + for x := range 3 { + v := y*10 + x + ft.SetFloat(float64(v), y, x) + } + } + // fmt.Println(ft.String()) + res = `[4 3] + [0] [1] [2] +[0] 0 1 2 +[1] 10 11 12 +[2] 20 21 22 +[3] 30 31 32 +` + assert.Equal(t, res, ft.String()) + + ft = NewFloat64(4, 3, 2) + for z := range 4 { + for y := range 3 { + for x := range 2 { + v := z*100 + y*10 + x + ft.SetFloat(float64(v), z, y, x) + } + } + } + // fmt.Println(ft.String()) + res = `[4 3 2] +[r r c] [0] [1] +[0 0] 0 1 +[0 1] 10 11 +[0 2] 20 21 +[1 0] 100 101 +[1 1] 110 111 +[1 2] 120 121 +[2 0] 200 201 +[2 1] 210 211 +[2 2] 220 221 +[3 0] 300 301 +[3 1] 310 311 +[3 2] 320 321 +` + assert.Equal(t, res, ft.String()) + + ft = NewFloat64(5, 4, 3, 2) + for z := range 5 { + for y := range 4 { + for x := range 3 { + for w := range 2 { + v := z*1000 + y*100 + x*10 + w + ft.SetFloat(float64(v), z, y, x, w) + } + } + } + } + // fmt.Println(ft.String()) + res = `[5 4 3 2] +[r c r c] [0 0] [0 1] [1 0] [1 1] [2 0] [2 1] [3 0] [3 1] +[0 0] 0 1 100 101 200 201 300 301 +[0 1] 10 11 110 111 210 211 310 311 +[0 2] 20 21 120 121 220 221 320 321 +[1 0] 1000 1001 1100 1101 1200 1201 1300 1301 +[1 1] 1010 1011 1110 1111 1210 1211 1310 1311 +[1 2] 1020 1021 1120 1121 1220 1221 1320 1321 +[2 0] 2000 2001 2100 2101 2200 2201 2300 2301 +[2 1] 2010 2011 2110 2111 2210 2211 2310 2311 +[2 2] 2020 2021 2120 2121 2220 2221 2320 2321 +[3 0] 3000 3001 3100 3101 3200 3201 3300 3301 +[3 1] 3010 3011 3110 3111 3210 3211 3310 3311 +[3 2] 3020 3021 3120 3121 3220 3221 3320 3321 +[4 0] 4000 4001 4100 4101 4200 4201 4300 4301 +[4 1] 4010 4011 4110 4111 4210 4211 4310 4311 +[4 2] 4020 4021 4120 4121 4220 4221 4320 4321 +` + assert.Equal(t, res, ft.String()) + + ft = NewFloat64(6, 5, 4, 3, 2) + for z := range 6 { + for y := range 5 { + for x := range 4 { + for w := range 3 { + for q := range 2 { + v := z*10000 + y*1000 + x*100 + w*10 + q + ft.SetFloat(float64(v), z, y, x, w, q) + } + } + } + } + } + // fmt.Println(ft.String()) + res = `[6 5 4 3 2] +[r r c r c] [0 0] [0 1] [1 0] [1 1] [2 0] [2 1] [3 0] [3 1] +[0 0 0] 0 1 100 101 200 201 300 301 +[0 0 1] 10 11 110 111 210 211 310 311 +[0 0 2] 20 21 120 121 220 221 320 321 +[0 1 0] 1000 1001 1100 1101 1200 1201 1300 1301 +[0 1 1] 1010 1011 1110 1111 1210 1211 1310 1311 +[0 1 2] 1020 1021 1120 1121 1220 1221 1320 1321 +[0 2 0] 2000 2001 2100 2101 2200 2201 2300 2301 +[0 2 1] 2010 2011 2110 2111 2210 2211 2310 2311 +[0 2 2] 2020 2021 2120 2121 2220 2221 2320 2321 +[0 3 0] 3000 3001 3100 3101 3200 3201 3300 3301 +[0 3 1] 3010 3011 3110 3111 3210 3211 3310 3311 +[0 3 2] 3020 3021 3120 3121 3220 3221 3320 3321 +[0 4 0] 4000 4001 4100 4101 4200 4201 4300 4301 +[0 4 1] 4010 4011 4110 4111 4210 4211 4310 4311 +[0 4 2] 4020 4021 4120 4121 4220 4221 4320 4321 +[1 0 0] 10000 10001 10100 10101 10200 10201 10300 10301 +[1 0 1] 10010 10011 10110 10111 10210 10211 10310 10311 +[1 0 2] 10020 10021 10120 10121 10220 10221 10320 10321 +[1 1 0] 11000 11001 11100 11101 11200 11201 11300 11301 +[1 1 1] 11010 11011 11110 11111 11210 11211 11310 11311 +[1 1 2] 11020 11021 11120 11121 11220 11221 11320 11321 +[1 2 0] 12000 12001 12100 12101 12200 12201 12300 12301 +[1 2 1] 12010 12011 12110 12111 12210 12211 12310 12311 +[1 2 2] 12020 12021 12120 12121 12220 12221 12320 12321 +[1 3 0] 13000 13001 13100 13101 13200 13201 13300 13301 +[1 3 1] 13010 13011 13110 13111 13210 13211 13310 13311 +[1 3 2] 13020 13021 13120 13121 13220 13221 13320 13321 +[1 4 0] 14000 14001 14100 14101 14200 14201 14300 14301 +[1 4 1] 14010 14011 14110 14111 14210 14211 14310 14311 +[1 4 2] 14020 14021 14120 14121 14220 14221 14320 14321 +[2 0 0] 20000 20001 20100 20101 20200 20201 20300 20301 +[2 0 1] 20010 20011 20110 20111 20210 20211 20310 20311 +[2 0 2] 20020 20021 20120 20121 20220 20221 20320 20321 +[2 1 0] 21000 21001 21100 21101 21200 21201 21300 21301 +[2 1 1] 21010 21011 21110 21111 21210 21211 21310 21311 +[2 1 2] 21020 21021 21120 21121 21220 21221 21320 21321 +[2 2 0] 22000 22001 22100 22101 22200 22201 22300 22301 +[2 2 1] 22010 22011 22110 22111 22210 22211 22310 22311 +[2 2 2] 22020 22021 22120 22121 22220 22221 22320 22321 +[2 3 0] 23000 23001 23100 23101 23200 23201 23300 23301 +[2 3 1] 23010 23011 23110 23111 23210 23211 23310 23311 +[2 3 2] 23020 23021 23120 23121 23220 23221 23320 23321 +[2 4 0] 24000 24001 24100 24101 24200 24201 24300 24301 +[2 4 1] 24010 24011 24110 24111 24210 24211 24310 24311 +[2 4 2] 24020 24021 24120 24121 24220 24221 24320 24321 +[3 0 0] 30000 30001 30100 30101 30200 30201 30300 30301 +[3 0 1] 30010 30011 30110 30111 30210 30211 30310 30311 +[3 0 2] 30020 30021 30120 30121 30220 30221 30320 30321 +[3 1 0] 31000 31001 31100 31101 31200 31201 31300 31301 +[3 1 1] 31010 31011 31110 31111 31210 31211 31310 31311 +[3 1 2] 31020 31021 31120 31121 31220 31221 31320 31321 +[3 2 0] 32000 32001 32100 32101 32200 32201 32300 32301 +[3 2 1] 32010 32011 32110 32111 32210 32211 32310 32311 +[3 2 2] 32020 32021 32120 32121 32220 32221 32320 32321 +[3 3 0] 33000 33001 33100 33101 33200 33201 33300 33301 +[3 3 1] 33010 33011 33110 33111 33210 33211 33310 33311 +[3 3 2] 33020 33021 33120 33121 33220 33221 33320 33321 +[3 4 0] 34000 34001 34100 34101 34200 34201 34300 34301 +[3 4 1] 34010 34011 34110 34111 34210 34211 34310 34311 +[3 4 2] 34020 34021 34120 34121 34220 34221 34320 34321 +[4 0 0] 40000 40001 40100 40101 40200 40201 40300 40301 +[4 0 1] 40010 40011 40110 40111 40210 40211 40310 40311 +[4 0 2] 40020 40021 40120 40121 40220 40221 40320 40321 +[4 1 0] 41000 41001 41100 41101 41200 41201 41300 41301 +[4 1 1] 41010 41011 41110 41111 41210 41211 41310 41311 +[4 1 2] 41020 41021 41120 41121 41220 41221 41320 41321 +[4 2 0] 42000 42001 42100 42101 42200 42201 42300 42301 +[4 2 1] 42010 42011 42110 42111 42210 42211 42310 42311 +[4 2 2] 42020 42021 42120 42121 42220 42221 42320 42321 +[4 3 0] 43000 43001 43100 43101 43200 43201 43300 43301 +[4 3 1] 43010 43011 43110 43111 43210 43211 43310 43311 +[4 3 2] 43020 43021 43120 43121 43220 43221 43320 43321 +[4 4 0] 44000 44001 44100 44101 44200 44201 44300 44301 +[4 4 1] 44010 44011 44110 44111 44210 44211 44310 44311 +[4 4 2] 44020 44021 44120 44121 44220 44221 44320 44321 +[5 0 0] 50000 50001 50100 50101 50200 50201 50300 50301 +[5 0 1] 50010 50011 50110 50111 50210 50211 50310 50311 +[5 0 2] 50020 50021 50120 50121 50220 50221 50320 50321 +[5 1 0] 51000 51001 51100 51101 51200 51201 51300 51301 +[5 1 1] 51010 51011 51110 51111 51210 51211 51310 51311 +[5 1 2] 51020 51021 51120 51121 51220 51221 51320 51321 +[5 2 0] 52000 52001 52100 52101 52200 52201 52300 52301 +[5 2 1] 52010 52011 52110 52111 52210 52211 52310 52311 +[5 2 2] 52020 52021 52120 52121 52220 52221 52320 52321 +[5 3 0] 53000 53001 53100 53101 53200 53201 53300 53301 +[5 3 1] 53010 53011 53110 53111 53210 53211 53310 53311 +[5 3 2] 53020 53021 53120 53121 53220 53221 53320 53321 +[5 4 0] 54000 54001 54100 54101 54200 54201 54300 54301 +[5 4 1] 54010 54011 54110 54111 54210 54211 54310 54311 +[5 4 2] 54020 54021 54120 54121 54220 54221 54320 54321 +` + assert.Equal(t, res, ft.String()) +} + func TestTensorString(t *testing.T) { - shp := []int{4, 2} - nms := []string{"Row", "Vals"} - tsr := New[string](shp, nms...) + tsr := New[string](4, 2) + // tsr.SetNames("Row", "Vals") + // assert.Equal(t, []string{"Row", "Vals"}, tsr.Shape().Names) assert.Equal(t, 8, tsr.Len()) assert.Equal(t, true, tsr.IsString()) assert.Equal(t, reflect.String, tsr.DataType()) - assert.Equal(t, 2, tsr.SubSpace([]int{0}).Len()) - r, c := tsr.RowCellSize() + assert.Equal(t, 2, tsr.SubSpace(0).Len()) + r, c := tsr.Shape().RowCellSize() assert.Equal(t, 4, r) assert.Equal(t, 2, c) - tsr.SetString([]int{2, 0}, "test") - assert.Equal(t, "test", tsr.StringValue([]int{2, 0})) - tsr.SetString1D(5, "testing") - assert.Equal(t, "testing", tsr.StringValue([]int{2, 1})) + tsr.SetString("test", 2, 0) + assert.Equal(t, "test", tsr.StringValue(2, 0)) + tsr.SetString1D("testing", 5) + assert.Equal(t, "testing", tsr.StringValue(2, 1)) assert.Equal(t, "test", tsr.String1D(4)) - assert.Equal(t, "test", tsr.StringRowCell(2, 0)) - assert.Equal(t, "testing", tsr.StringRowCell(2, 1)) - assert.Equal(t, "", tsr.StringRowCell(3, 0)) + assert.Equal(t, "test", tsr.StringRow(2, 0)) + assert.Equal(t, "testing", tsr.StringRow(2, 1)) + assert.Equal(t, "", tsr.StringRow(3, 0)) cln := tsr.Clone() - assert.Equal(t, "testing", cln.StringValue([]int{2, 1})) + assert.Equal(t, "testing", cln.StringValue(2, 1)) cln.SetZeros() - assert.Equal(t, "", cln.StringValue([]int{2, 1})) - assert.Equal(t, "testing", tsr.StringValue([]int{2, 1})) + assert.Equal(t, "", cln.StringValue(2, 1)) + assert.Equal(t, "testing", tsr.StringValue(2, 1)) - tsr.SetShape([]int{2, 4}, "Vals", "Row") - assert.Equal(t, "test", tsr.StringValue([]int{1, 0})) - assert.Equal(t, "testing", tsr.StringValue([]int{1, 1})) + tsr.SetShapeSizes(2, 4) + // tsr.SetNames("Vals", "Row") + assert.Equal(t, "test", tsr.StringValue(1, 0)) + assert.Equal(t, "testing", tsr.StringValue(1, 1)) - cln.SetString1D(5, "ctesting") - cln.CopyShapeFrom(tsr) - assert.Equal(t, "ctesting", cln.StringValue([]int{1, 1})) + cln.SetString1D("ctesting", 5) + SetShapeFrom(cln, tsr) + assert.Equal(t, "ctesting", cln.StringValue(1, 1)) cln.CopyCellsFrom(tsr, 5, 4, 2) - assert.Equal(t, "test", cln.StringValue([]int{1, 1})) - assert.Equal(t, "testing", cln.StringValue([]int{1, 2})) + assert.Equal(t, "test", cln.StringValue(1, 1)) + assert.Equal(t, "testing", cln.StringValue(1, 2)) tsr.SetNumRows(5) assert.Equal(t, 20, tsr.Len()) - tsr.SetMetaData("name", "test") - nm, has := tsr.MetaData("name") + metadata.SetName(tsr, "test") + nm := metadata.Name(tsr) assert.Equal(t, "test", nm) - assert.Equal(t, true, has) - _, has = tsr.MetaData("type") - assert.Equal(t, false, has) + _, err := metadata.Get[string](*tsr.Metadata(), "type") + assert.Error(t, err) - var flt []float64 - cln.SetString1D(0, "3.14") + cln.SetString1D("3.14", 0) assert.Equal(t, 3.14, cln.Float1D(0)) - cln.Floats(&flt) - assert.Equal(t, 3.14, flt[0]) - assert.Equal(t, 0.0, flt[1]) + af := AsFloat64Slice(cln) + assert.Equal(t, 3.14, af[0]) + assert.Equal(t, 0.0, af[1]) } func TestTensorFloat64(t *testing.T) { - shp := []int{4, 2} - nms := []string{"Row", "Vals"} - tsr := New[float64](shp, nms...) + tsr := New[float64](4, 2) + // tsr.SetNames("Row") + // assert.Equal(t, []string{"Row", ""}, tsr.Shape().Names) assert.Equal(t, 8, tsr.Len()) assert.Equal(t, false, tsr.IsString()) assert.Equal(t, reflect.Float64, tsr.DataType()) - assert.Equal(t, 2, tsr.SubSpace([]int{0}).Len()) - r, c := tsr.RowCellSize() + assert.Equal(t, 2, tsr.SubSpace(0).Len()) + r, c := tsr.Shape().RowCellSize() assert.Equal(t, 4, r) assert.Equal(t, 2, c) - tsr.SetFloat([]int{2, 0}, 3.14) - assert.Equal(t, 3.14, tsr.Float([]int{2, 0})) - tsr.SetFloat1D(5, 2.17) - assert.Equal(t, 2.17, tsr.Float([]int{2, 1})) + tsr.SetFloat(3.14, 2, 0) + assert.Equal(t, 3.14, tsr.Float(2, 0)) + tsr.SetFloat1D(2.17, 5) + assert.Equal(t, 2.17, tsr.Float(2, 1)) assert.Equal(t, 3.14, tsr.Float1D(4)) - assert.Equal(t, 3.14, tsr.FloatRowCell(2, 0)) - assert.Equal(t, 2.17, tsr.FloatRowCell(2, 1)) - assert.Equal(t, 0.0, tsr.FloatRowCell(3, 0)) + assert.Equal(t, 3.14, tsr.FloatRow(2, 0)) + assert.Equal(t, 2.17, tsr.FloatRow(2, 1)) + assert.Equal(t, 0.0, tsr.FloatRow(3, 0)) cln := tsr.Clone() - assert.Equal(t, 2.17, cln.Float([]int{2, 1})) + assert.Equal(t, 2.17, cln.Float(2, 1)) cln.SetZeros() - assert.Equal(t, 0.0, cln.Float([]int{2, 1})) - assert.Equal(t, 2.17, tsr.Float([]int{2, 1})) + assert.Equal(t, 0.0, cln.Float(2, 1)) + assert.Equal(t, 2.17, tsr.Float(2, 1)) - tsr.SetShape([]int{2, 4}, "Vals", "Row") - assert.Equal(t, 3.14, tsr.Float([]int{1, 0})) - assert.Equal(t, 2.17, tsr.Float([]int{1, 1})) + tsr.SetShapeSizes(2, 4) + assert.Equal(t, 3.14, tsr.Float(1, 0)) + assert.Equal(t, 2.17, tsr.Float(1, 1)) - cln.SetFloat1D(5, 9.9) - cln.CopyShapeFrom(tsr) - assert.Equal(t, 9.9, cln.Float([]int{1, 1})) + cln.SetFloat1D(9.9, 5) + SetShapeFrom(cln, tsr) + assert.Equal(t, 9.9, cln.Float(1, 1)) cln.CopyCellsFrom(tsr, 5, 4, 2) - assert.Equal(t, 3.14, cln.Float([]int{1, 1})) - assert.Equal(t, 2.17, cln.Float([]int{1, 2})) + assert.Equal(t, 3.14, cln.Float(1, 1)) + assert.Equal(t, 2.17, cln.Float(1, 2)) tsr.SetNumRows(5) assert.Equal(t, 20, tsr.Len()) - tsr.SetMetaData("name", "test") - nm, has := tsr.MetaData("name") - assert.Equal(t, "test", nm) - assert.Equal(t, true, has) - _, has = tsr.MetaData("type") - assert.Equal(t, false, has) - - var flt []float64 - cln.SetString1D(0, "3.14") + cln.SetString1D("3.14", 0) assert.Equal(t, 3.14, cln.Float1D(0)) - cln.Floats(&flt) - assert.Equal(t, 3.14, flt[0]) - assert.Equal(t, 0.0, flt[1]) + af := AsFloat64Slice(cln) + assert.Equal(t, 3.14, af[0]) + assert.Equal(t, 0.0, af[1]) +} + +func TestSliced(t *testing.T) { + ft := NewFloat64(3, 4) + for y := range 3 { + for x := range 4 { + v := y*10 + x + ft.SetFloat(float64(v), y, x) + } + } + + res := `[3 4] + [0] [1] [2] [3] +[0] 0 1 2 3 +[1] 10 11 12 13 +[2] 20 21 22 23 +` + assert.Equal(t, res, ft.String()) + + res = `[2 2] + [0] [1] +[0] 23 22 +[1] 13 12 +` + sl := NewSliced(ft, []int{2, 1}, []int{3, 2}) + assert.Equal(t, res, sl.String()) + + vl := sl.AsValues() + assert.Equal(t, res, vl.String()) + res = `[3 1] +[0] 2 +[1] 12 +[2] 22 +` + sl2 := Reslice(ft, FullAxis, Slice{2, 3, 0}) + assert.Equal(t, res, sl2.String()) + + vl = sl2.AsValues() + assert.Equal(t, res, vl.String()) +} + +func TestMasked(t *testing.T) { + ft := NewFloat64(3, 4) + for y := range 3 { + for x := range 4 { + v := y*10 + x + ft.SetFloat(float64(v), y, x) + } + } + ms := NewMasked(ft) + + res := `[3 4] + [0] [1] [2] [3] +[0] 0 1 2 3 +[1] 10 11 12 13 +[2] 20 21 22 23 +` + assert.Equal(t, res, ms.String()) + + ms.Filter(func(tsr Tensor, idx int) bool { + val := tsr.Float1D(idx) + return int(val)%10 == 2 + }) + res = `[3 4] + [0] [1] [2] [3] +[0] NaN NaN 2 NaN +[1] NaN NaN 12 NaN +[2] NaN NaN 22 NaN +` + assert.Equal(t, res, ms.String()) + + res = `[3] 2 12 22 +` + vl := ms.AsValues() + assert.Equal(t, res, vl.String()) +} + +func TestIndexed(t *testing.T) { + ft := NewFloat64(3, 4) + for y := range 3 { + for x := range 4 { + v := y*10 + x + ft.SetFloat(float64(v), y, x) + } + } + ixs := NewIntFromValues( + 0, 1, + 0, 1, + 0, 2, + 0, 2, + 1, 1, + 1, 1, + 2, 2, + 2, 2, + ) + + ixs.SetShapeSizes(2, 2, 2, 2) + ix := NewIndexed(ft, ixs) + + res := `[2 2 2] +[r r c] [0] [1] +[0 0] 1 1 +[0 1] 2 2 +[1 0] 11 11 +[1 1] 22 22 +` + assert.Equal(t, res, ix.String()) + + vl := ix.AsValues() + assert.Equal(t, res, vl.String()) +} + +func TestReshaped(t *testing.T) { + ft := NewFloat64(3, 4) + for y := range 3 { + for x := range 4 { + v := y*10 + x + ft.SetFloat(float64(v), y, x) + } + } + + res := `[4 3] + [0] [1] [2] +[0] 0 1 2 +[1] 3 10 11 +[2] 12 13 20 +[3] 21 22 23 +` + rs := NewReshaped(ft, 4, 3) + assert.Equal(t, res, rs.String()) + + res = `[1 3 4] +[r r c] [0] [1] [2] [3] +[0 0] 0 1 2 3 +[0 1] 10 11 12 13 +[0 2] 20 21 22 23 +` + rs = NewReshaped(ft, int(NewAxis), 3, 4) + assert.Equal(t, res, rs.String()) + + res = `[12] 0 1 2 3 10 11 12 13 20 21 22 23 +` + rs = NewReshaped(ft, -1) + assert.Equal(t, res, rs.String()) + + res = `[4 3] + [0] [1] [2] +[0] 0 1 2 +[1] 3 10 11 +[2] 12 13 20 +[3] 21 22 23 +` + rs = NewReshaped(ft, 4, -1) + assert.Equal(t, res, rs.String()) + + err := rs.SetShapeSizes(5, -1) + assert.Error(t, err) + + res = `[3 4] + [0] [3] [2] [1] +[0] 0 3 12 21 +[0] 1 10 13 22 +[0] 2 11 20 23 +` + tr := Transpose(ft) + assert.Equal(t, res, tr.String()) + +} + +func TestSortFilter(t *testing.T) { + tsr := NewRows(NewFloat64(5)) + for i := range 5 { + tsr.SetFloatRow(float64(i), i, 0) + } + tsr.Sort(Ascending) + assert.Equal(t, []int{0, 1, 2, 3, 4}, tsr.Indexes) + tsr.Sort(Descending) + assert.Equal(t, []int{4, 3, 2, 1, 0}, tsr.Indexes) + + tsr.Sequential() + tsr.FilterString("1", FilterOptions{}) + assert.Equal(t, []int{1}, tsr.Indexes) + + tsr.Sequential() + tsr.FilterString("1", FilterOptions{Exclude: true}) + assert.Equal(t, []int{0, 2, 3, 4}, tsr.Indexes) +} + +func TestGrowRow(t *testing.T) { + tsr := NewFloat64(1000) + assert.Equal(t, 1000, cap(tsr.Values)) + assert.Equal(t, 1000, tsr.Len()) + tsr.SetNumRows(0) + assert.Equal(t, 1000, cap(tsr.Values)) + assert.Equal(t, 0, tsr.Len()) + tsr.SetNumRows(1) + assert.Equal(t, 1000, cap(tsr.Values)) + assert.Equal(t, 1, tsr.Len()) + + tsr2 := NewFloat64(1000, 10, 10) + assert.Equal(t, 100000, cap(tsr2.Values)) + assert.Equal(t, 100000, tsr2.Len()) + tsr2.SetNumRows(0) + assert.Equal(t, 100000, cap(tsr2.Values)) + assert.Equal(t, 0, tsr2.Len()) + tsr2.SetNumRows(1) + assert.Equal(t, 100000, cap(tsr2.Values)) + assert.Equal(t, 100, tsr2.Len()) + + bits := NewBool(1000) + assert.Equal(t, 1000, bits.Len()) + bits.SetNumRows(0) + assert.Equal(t, 0, bits.Len()) + bits.SetNumRows(1) + assert.Equal(t, 1, bits.Len()) } diff --git a/tensor/tensorcore/gridstyle.go b/tensor/tensorcore/gridstyle.go new file mode 100644 index 0000000000..5ed9b3df7a --- /dev/null +++ b/tensor/tensorcore/gridstyle.go @@ -0,0 +1,125 @@ +// Copyright (c) 2019, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorcore + +import ( + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/core" + "cogentcore.org/core/math32/minmax" +) + +// Layout are layout options for displaying tensors. +type Layout struct { //types:add --setters + + // OddRow means that even-numbered dimensions are displayed as Y*X rectangles. + // This determines along which dimension to display any remaining + // odd dimension: OddRow = true = organize vertically along row + // dimension, false = organize horizontally across column dimension. + OddRow bool + + // TopZero means that the Y=0 coordinate is displayed from the top-down; + // otherwise the Y=0 coordinate is displayed from the bottom up, + // which is typical for emergent network patterns. + TopZero bool + + // Image will display the data as a bitmap image. If a 2D tensor, then it will + // be a greyscale image. If a 3D tensor with size of either the first + // or last dim = either 3 or 4, then it is a RGB(A) color image. + Image bool +} + +// GridStyle are options for displaying tensors +type GridStyle struct { //types:add --setters + Layout + + // Range to plot + Range minmax.Range64 `display:"inline"` + + // MinMax has the actual range of data, if not using fixed Range. + MinMax minmax.F64 `display:"inline"` + + // ColorMap is the name of the color map to use in translating values to colors. + ColorMap core.ColorMapName + + // GridFill sets proportion of grid square filled by the color block: + // 1 = all, .5 = half, etc. + GridFill float32 `min:"0.1" max:"1" step:"0.1" default:"0.9,1"` + + // DimExtra is the amount of extra space to add at dimension boundaries, + // as a proportion of total grid size. + DimExtra float32 `min:"0" max:"1" step:"0.02" default:"0.1,0.3"` + + // Size sets the minimum and maximum size for grid squares. + Size minmax.F32 `display:"inline"` + + // TotalSize sets the total preferred display size along largest dimension. + // Grid squares will be sized to fit within this size, + // subject to the Size.Min / Max constraints, which have precedence. + TotalSize float32 + + // FontSize is the font size in standard point units for labels. + FontSize float32 +} + +// Defaults sets defaults for values that are at nonsensical initial values +func (gs *GridStyle) Defaults() { + gs.Range.SetMin(-1).SetMax(1) + gs.ColorMap = "ColdHot" + gs.GridFill = 0.9 + gs.DimExtra = 0.3 + gs.Size.Set(2, 32) + gs.TotalSize = 100 + gs.FontSize = 24 +} + +// NewGridStyle returns a new GridStyle with defaults. +func NewGridStyle() *GridStyle { + gs := &GridStyle{} + gs.Defaults() + return gs +} + +func (gs *GridStyle) ApplyStylersFrom(obj any) { + st := GetGridStylersFrom(obj) + if st == nil { + return + } + st.Run(gs) +} + +// GridStylers is a list of styling functions that set GridStyle properties. +// These are called in the order added. +type GridStylers []func(s *GridStyle) + +// Add Adds a styling function to the list. +func (st *GridStylers) Add(f func(s *GridStyle)) { + *st = append(*st, f) +} + +// Run runs the list of styling functions on given [GridStyle] object. +func (st *GridStylers) Run(s *GridStyle) { + for _, f := range *st { + f(s) + } +} + +// SetGridStylersTo sets the [GridStylers] into given object's [metadata]. +func SetGridStylersTo(obj any, st GridStylers) { + metadata.SetTo(obj, "GridStylers", st) +} + +// GetGridStylersFrom returns [GridStylers] from given object's [metadata]. +// Returns nil if none or no metadata. +func GetGridStylersFrom(obj any) GridStylers { + st, _ := metadata.GetFrom[GridStylers](obj, "GridStylers") + return st +} + +// AddGridStylerTo adds the given [GridStyler] function into given object's [metadata]. +func AddGridStylerTo(obj any, f func(s *GridStyle)) { + st := GetGridStylersFrom(obj) + st.Add(f) + SetGridStylersTo(obj, st) +} diff --git a/tensor/tensorcore/simatgrid.go b/tensor/tensorcore/simatgrid.go index a134c3199f..a0e2f610ad 100644 --- a/tensor/tensorcore/simatgrid.go +++ b/tensor/tensorcore/simatgrid.go @@ -4,6 +4,8 @@ package tensorcore +/* + import ( "cogentcore.org/core/colors" "cogentcore.org/core/math32" @@ -237,3 +239,5 @@ func (tg *SimMatGrid) Render() { } } } + +*/ diff --git a/tensor/tensorcore/table.go b/tensor/tensorcore/table.go index 1bfe75e996..70c189cf6f 100644 --- a/tensor/tensorcore/table.go +++ b/tensor/tensorcore/table.go @@ -34,36 +34,37 @@ import ( type Table struct { core.ListBase - // the idx view of the table that we're a view of - Table *table.IndexView `set:"-"` + // Table is the table that we're a view of. + Table *table.Table `set:"-"` - // overall display options for tensor display - TensorDisplay TensorDisplay `set:"-"` + // GridStyle has global grid display styles. GridStylers on the Table + // are applied to this on top of defaults. + GridStyle GridStyle `set:"-"` - // per column tensor display params - ColumnTensorDisplay map[int]*TensorDisplay `set:"-"` + // ColumnGridStyle has per column grid display styles. + ColumnGridStyle map[int]*GridStyle `set:"-"` - // per column blank tensor values - ColumnTensorBlank map[int]*tensor.Float64 `set:"-"` - - // number of columns in table (as of last update) - NCols int `edit:"-"` - - // current sort index + // current sort index. SortIndex int - // whether current sort order is descending + // whether current sort order is descending. SortDescending bool - // headerWidths has number of characters in each header, per visfields + // number of columns in table (as of last update). + nCols int `edit:"-"` + + // headerWidths has number of characters in each header, per visfields. headerWidths []int `copier:"-" display:"-" json:"-" xml:"-"` - // colMaxWidths records maximum width in chars of string type fields + // colMaxWidths records maximum width in chars of string type fields. colMaxWidths []int `set:"-" copier:"-" json:"-" xml:"-"` - // blank values for out-of-range rows - BlankString string - BlankFloat float64 + // blank values for out-of-range rows. + blankString string + blankFloat float64 + + // blankCells has per column blank tensor cells. + blankCells map[int]*tensor.Float64 `set:"-"` } // check for interface impl @@ -72,9 +73,9 @@ var _ core.Lister = (*Table)(nil) func (tb *Table) Init() { tb.ListBase.Init() tb.SortIndex = -1 - tb.TensorDisplay.Defaults() - tb.ColumnTensorDisplay = map[int]*TensorDisplay{} - tb.ColumnTensorBlank = map[int]*tensor.Float64{} + tb.GridStyle.Defaults() + tb.ColumnGridStyle = map[int]*GridStyle{} + tb.blankCells = map[int]*tensor.Float64{} tb.Makers.Normal[0] = func(p *tree.Plan) { // TODO: reduce redundancy with ListBase Maker svi := tb.This.(core.Lister) @@ -109,8 +110,8 @@ func (tb *Table) Init() { func (tb *Table) SliceIndex(i int) (si, vi int, invis bool) { si = tb.StartIndex + i vi = -1 - if si < len(tb.Table.Indexes) { - vi = tb.Table.Indexes[si] + if si < tb.Table.NumRows() { + vi = tb.Table.RowIndex(si) } invis = vi < 0 return @@ -130,14 +131,15 @@ func (tb *Table) StyleValue(w core.Widget, s *styles.Style, row, col int) { s.SetTextWrap(false) } -// SetTable sets the source table that we are viewing, using a sequential IndexView +// SetTable sets the source table that we are viewing, using a sequential view, // and then configures the display -func (tb *Table) SetTable(et *table.Table) *Table { - if et == nil { - return nil +func (tb *Table) SetTable(dt *table.Table) *Table { + if dt == nil { + tb.Table = nil + } else { + tb.Table = table.NewView(dt) + tb.GridStyle.ApplyStylersFrom(tb.Table) } - - tb.Table = table.NewIndexView(et) tb.This.(core.Lister).UpdateSliceSize() tb.SetSliceBase() tb.Update() @@ -160,50 +162,28 @@ func (tb *Table) AsyncUpdateTable() { tb.AsyncUnlock() } -// SetIndexView sets the source IndexView of a table (using a copy so original is not modified) -// and then configures the display -func (tb *Table) SetIndexView(ix *table.IndexView) *Table { - if ix == nil { - return tb - } - - tb.Table = ix.Clone() // always copy - - tb.This.(core.Lister).UpdateSliceSize() - tb.StartIndex = 0 - tb.VisibleRows = tb.MinRows - if !tb.IsReadOnly() { - tb.SelectedIndex = -1 - } - tb.ResetSelectedIndexes() - tb.SelectMode = false - tb.MakeIter = 0 - tb.Update() - return tb -} - func (tb *Table) UpdateSliceSize() int { - tb.Table.DeleteInvalid() // table could have changed - if tb.Table.Len() == 0 { + tb.Table.ValidIndexes() // table could have changed + if tb.Table.NumRows() == 0 { tb.Table.Sequential() } - tb.SliceSize = tb.Table.Len() - tb.NCols = tb.Table.Table.NumColumns() + tb.SliceSize = tb.Table.NumRows() + tb.nCols = tb.Table.NumColumns() return tb.SliceSize } func (tb *Table) UpdateMaxWidths() { - if len(tb.headerWidths) != tb.NCols { - tb.headerWidths = make([]int, tb.NCols) - tb.colMaxWidths = make([]int, tb.NCols) + if len(tb.headerWidths) != tb.nCols { + tb.headerWidths = make([]int, tb.nCols) + tb.colMaxWidths = make([]int, tb.nCols) } if tb.SliceSize == 0 { return } - for fli := 0; fli < tb.NCols; fli++ { + for fli := 0; fli < tb.nCols; fli++ { tb.colMaxWidths[fli] = 0 - col := tb.Table.Table.Columns[fli] + col := tb.Table.Columns.Values[fli] stsr, isstr := col.(*tensor.String) if !isstr { @@ -238,18 +218,26 @@ func (tb *Table) MakeHeader(p *tree.Plan) { w.SetText("Index") }) } - for fli := 0; fli < tb.NCols; fli++ { - field := tb.Table.Table.ColumnNames[fli] + for fli := 0; fli < tb.nCols; fli++ { + field := tb.Table.Columns.Keys[fli] tree.AddAt(p, "head-"+field, func(w *core.Button) { w.SetType(core.ButtonAction) w.Styler(func(s *styles.Style) { s.Justify.Content = styles.Start }) w.OnClick(func(e events.Event) { - tb.SortSliceAction(fli) + tb.SortColumn(fli) }) + if tb.Table.Columns.Values[fli].NumDims() > 1 { + w.AddContextMenu(func(m *core.Scene) { + core.NewButton(m).SetText("Edit grid style").SetIcon(icons.Edit). + OnClick(func(e events.Event) { + tb.EditGridStyle(fli) + }) + }) + } w.Updater(func() { - field := tb.Table.Table.ColumnNames[fli] + field := tb.Table.Columns.Keys[fli] w.SetText(field).SetTooltip(field + " (tap to sort by)") tb.headerWidths[fli] = len(field) if fli == tb.SortIndex { @@ -275,7 +263,7 @@ func (tb *Table) SliceHeader() *core.Frame { // RowWidgetNs returns number of widgets per row and offset for index label func (tb *Table) RowWidgetNs() (nWidgPerRow, idxOff int) { - nWidgPerRow = 1 + tb.NCols + nWidgPerRow = 1 + tb.nCols idxOff = 1 if !tb.ShowIndexes { nWidgPerRow -= 1 @@ -293,8 +281,8 @@ func (tb *Table) MakeRow(p *tree.Plan, i int) { tb.MakeGridIndex(p, i, si, itxt, invis) } - for fli := 0; fli < tb.NCols; fli++ { - col := tb.Table.Table.Columns[fli] + for fli := 0; fli < tb.nCols; fli++ { + col := tb.Table.Columns.Values[fli] valnm := fmt.Sprintf("value-%v.%v", fli, itxt) _, isstr := col.(*tensor.String) @@ -313,11 +301,12 @@ func (tb *Table) MakeRow(p *tree.Plan, i int) { w.AsTree().SetProperty(core.ListColProperty, fli) if !tb.IsReadOnly() { wb.OnChange(func(e events.Event) { - if si < len(tb.Table.Indexes) { + _, vi, invis := svi.SliceIndex(i) + if !invis { if isstr { - tb.Table.Table.SetStringIndex(fli, tb.Table.Indexes[si], str) + col.SetString1D(str, vi) } else { - tb.Table.Table.SetFloatIndex(fli, tb.Table.Indexes[si], fval) + col.SetFloat1D(fval, vi) } } tb.This.(core.Lister).UpdateMaxWidths() @@ -328,17 +317,17 @@ func (tb *Table) MakeRow(p *tree.Plan, i int) { _, vi, invis := svi.SliceIndex(i) if !invis { if isstr { - str = tb.Table.Table.StringIndex(fli, vi) + str = col.String1D(vi) core.Bind(&str, w) } else { - fval = tb.Table.Table.FloatIndex(fli, vi) + fval = col.Float1D(vi) core.Bind(&fval, w) } } else { if isstr { - core.Bind(tb.BlankString, w) + core.Bind(tb.blankString, w) } else { - core.Bind(tb.BlankFloat, w) + core.Bind(tb.blankFloat, w) } } wb.SetReadOnly(tb.IsReadOnly()) @@ -364,58 +353,41 @@ func (tb *Table) MakeRow(p *tree.Plan, i int) { si, vi, invis := svi.SliceIndex(i) var cell tensor.Tensor if invis { - cell = tb.ColTensorBlank(fli, col) + cell = tb.blankCell(fli, col) } else { - cell = tb.Table.Table.TensorIndex(fli, vi) + cell = col.RowTensor(vi) } wb.ValueTitle = tb.ValueTitle + "[" + strconv.Itoa(si) + "]" w.SetState(invis, states.Invisible) w.SetTensor(cell) - w.Display = *tb.GetColumnTensorDisplay(fli) + w.GridStyle = *tb.GetColumnGridStyle(fli) }) }) } } } -// ColTensorBlank returns tensor blanks for given tensor col -func (tb *Table) ColTensorBlank(cidx int, col tensor.Tensor) *tensor.Float64 { - if ctb, has := tb.ColumnTensorBlank[cidx]; has { +// blankCell returns tensor blanks for given tensor col +func (tb *Table) blankCell(cidx int, col tensor.Tensor) *tensor.Float64 { + if ctb, has := tb.blankCells[cidx]; has { return ctb } - ctb := tensor.New[float64](col.Shape().Sizes, col.Shape().Names...).(*tensor.Float64) - tb.ColumnTensorBlank[cidx] = ctb + ctb := tensor.New[float64](col.ShapeSizes()...).(*tensor.Float64) + tb.blankCells[cidx] = ctb return ctb } -// GetColumnTensorDisplay returns tensor display parameters for this column -// either the overall defaults or the per-column if set -func (tb *Table) GetColumnTensorDisplay(col int) *TensorDisplay { - if ctd, has := tb.ColumnTensorDisplay[col]; has { +// GetColumnGridStyle gets grid style for given column. +func (tb *Table) GetColumnGridStyle(col int) *GridStyle { + if ctd, has := tb.ColumnGridStyle[col]; has { return ctd } + ctd := &GridStyle{} + *ctd = tb.GridStyle if tb.Table != nil { - cl := tb.Table.Table.Columns[col] - if len(cl.MetaDataMap()) > 0 { - return tb.SetColumnTensorDisplay(col) - } - } - return &tb.TensorDisplay -} - -// SetColumnTensorDisplay sets per-column tensor display params and returns them -// if already set, just returns them -func (tb *Table) SetColumnTensorDisplay(col int) *TensorDisplay { - if ctd, has := tb.ColumnTensorDisplay[col]; has { - return ctd + cl := tb.Table.Columns.Values[col] + ctd.ApplyStylersFrom(cl) } - ctd := &TensorDisplay{} - *ctd = tb.TensorDisplay - if tb.Table != nil { - cl := tb.Table.Table.Columns[col] - ctd.FromMeta(cl) - } - tb.ColumnTensorDisplay[col] = ctd return ctd } @@ -441,13 +413,13 @@ func (tb *Table) DeleteAt(idx int) { tb.Update() } -// SortSliceAction sorts the slice for given field index -- toggles ascending -// vs. descending if already sorting on this dimension -func (tb *Table) SortSliceAction(fldIndex int) { +// SortColumn sorts the slice for given column index. +// Toggles ascending vs. descending if already sorting on this dimension. +func (tb *Table) SortColumn(fldIndex int) { sgh := tb.SliceHeader() _, idxOff := tb.RowWidgetNs() - for fli := 0; fli < tb.NCols; fli++ { + for fli := 0; fli < tb.nCols; fli++ { hdr := sgh.Child(idxOff + fli).(*core.Button) hdr.SetType(core.ButtonAction) if fli == fldIndex { @@ -463,23 +435,38 @@ func (tb *Table) SortSliceAction(fldIndex int) { if fldIndex == -1 { tb.Table.SortIndexes() } else { - tb.Table.SortColumn(tb.SortIndex, !tb.SortDescending) + tb.Table.IndexesNeeded() + col := tb.Table.ColumnByIndex(tb.SortIndex) + col.Sort(!tb.SortDescending) + tb.Table.IndexesFromTensor(col) } tb.Update() // requires full update due to sort button icon } -// TensorDisplayAction allows user to select tensor display options for column -// pass -1 for global params for the entire table -func (tb *Table) TensorDisplayAction(fldIndex int) { - ctd := &tb.TensorDisplay - if fldIndex >= 0 { - ctd = tb.SetColumnTensorDisplay(fldIndex) - } - d := core.NewBody("Tensor grid display options") - core.NewForm(d).SetStruct(ctd) - d.RunFullDialog(tb) - // tv.UpdateSliceGrid() - tb.NeedsRender() +// EditGridStyle shows an editor dialog for grid style for given column index. +func (tb *Table) EditGridStyle(col int) { + ctd := tb.GetColumnGridStyle(col) + d := core.NewBody("Tensor grid style") + core.NewForm(d).SetStruct(ctd). + OnChange(func(e events.Event) { + tb.ColumnGridStyle[col] = ctd + tb.Update() + }) + core.NewButton(d).SetText("Edit global style").SetIcon(icons.Edit). + OnClick(func(e events.Event) { + tb.EditGlobalGridStyle() + }) + d.RunWindowDialog(tb) +} + +// EditGlobalGridStyle shows an editor dialog for global grid styles. +func (tb *Table) EditGlobalGridStyle() { + d := core.NewBody("Tensor grid style") + core.NewForm(d).SetStruct(&tb.GridStyle). + OnChange(func(e events.Event) { + tb.Update() + }) + d.RunWindowDialog(tb) } func (tb *Table) HasStyler() bool { return false } @@ -489,8 +476,8 @@ func (tb *Table) StyleRow(w core.Widget, idx, fidx int) {} // SortFieldName returns the name of the field being sorted, along with :up or // :down depending on descending func (tb *Table) SortFieldName() string { - if tb.SortIndex >= 0 && tb.SortIndex < tb.NCols { - nm := tb.Table.Table.ColumnNames[tb.SortIndex] + if tb.SortIndex >= 0 && tb.SortIndex < tb.nCols { + nm := tb.Table.Columns.Keys[tb.SortIndex] if tb.SortDescending { nm += ":down" } else { @@ -509,8 +496,8 @@ func (tb *Table) SetSortFieldName(nm string) { } spnm := strings.Split(nm, ":") got := false - for fli := 0; fli < tb.NCols; fli++ { - fld := tb.Table.Table.ColumnNames[fli] + for fli := 0; fli < tb.nCols; fli++ { + fld := tb.Table.Columns.Keys[fli] if fld == spnm[0] { got = true // fmt.Println("sorting on:", fld.Name, fli, "from:", nm) @@ -543,7 +530,7 @@ func (tb *Table) RowFirstVisWidget(row int) (*core.WidgetBase, bool) { return w, true } ridx := nWidgPerRow * row - for fli := 0; fli < tb.NCols; fli++ { + for fli := 0; fli < tb.nCols; fli++ { w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget() if w.Geom.TotalBBox != (image.Rectangle{}) { return w, true @@ -563,7 +550,7 @@ func (tb *Table) RowGrabFocus(row int) *core.WidgetBase { ridx := nWidgPerRow * row lg := tb.ListGrid // first check if we already have focus - for fli := 0; fli < tb.NCols; fli++ { + for fli := 0; fli < tb.nCols; fli++ { w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget() if w.StateIs(states.Focused) || w.ContainsFocus() { return w @@ -571,7 +558,7 @@ func (tb *Table) RowGrabFocus(row int) *core.WidgetBase { } tb.InFocusGrab = true defer func() { tb.InFocusGrab = false }() - for fli := 0; fli < tb.NCols; fli++ { + for fli := 0; fli < tb.nCols; fli++ { w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget() if w.CanFocus() { w.SetFocus() @@ -581,8 +568,7 @@ func (tb *Table) RowGrabFocus(row int) *core.WidgetBase { return nil } -////////////////////////////////////////////////////// -// Header layout +//////// Header layout func (tb *Table) SizeFinal() { tb.ListBase.SizeFinal() @@ -613,24 +599,24 @@ func (tb *Table) SizeFinal() { // SelectedColumnStrings returns the string values of given column name. func (tb *Table) SelectedColumnStrings(colName string) []string { - dt := tb.Table.Table + dt := tb.Table jis := tb.SelectedIndexesList(false) if len(jis) == 0 || dt == nil { return nil } var s []string + col := dt.Column(colName) for _, i := range jis { - v := dt.StringValue(colName, i) + v := col.StringRow(i, 0) s = append(s, v) } return s } -////////////////////////////////////////////////////////////////////////////// -// Copy / Cut / Paste +//////// Copy / Cut / Paste func (tb *Table) MakeToolbar(p *tree.Plan) { - if tb.Table == nil || tb.Table.Table == nil { + if tb.Table == nil { return } tree.Add(p, func(w *core.FuncButton) { @@ -638,11 +624,11 @@ func (tb *Table) MakeToolbar(p *tree.Plan) { w.SetAfterFunc(func() { tb.Update() }) }) tree.Add(p, func(w *core.FuncButton) { - w.SetFunc(tb.Table.SortColumnName).SetText("Sort").SetIcon(icons.Sort) + w.SetFunc(tb.Table.SortColumns).SetText("Sort").SetIcon(icons.Sort) w.SetAfterFunc(func() { tb.Update() }) }) tree.Add(p, func(w *core.FuncButton) { - w.SetFunc(tb.Table.FilterColumnName).SetText("Filter").SetIcon(icons.FilterAlt) + w.SetFunc(tb.Table.FilterString).SetText("Filter").SetIcon(icons.FilterAlt) w.SetAfterFunc(func() { tb.Update() }) }) tree.Add(p, func(w *core.FuncButton) { @@ -669,16 +655,15 @@ func (tb *Table) CopySelectToMime() mimedata.Mimes { if nitms == 0 { return nil } - ix := &table.IndexView{} - ix.Table = tb.Table.Table + ix := table.NewView(tb.Table) idx := tb.SelectedIndexesList(false) // ascending iidx := make([]int, len(idx)) for i, di := range idx { - iidx[i] = tb.Table.Indexes[di] + iidx[i] = tb.Table.RowIndex(di) } ix.Indexes = iidx var b bytes.Buffer - ix.WriteCSV(&b, table.Tab, table.Headers) + ix.WriteCSV(&b, tensor.Tab, table.Headers) md := mimedata.NewTextBytes(b.Bytes()) md[0].Type = fileinfo.DataCsv return md @@ -691,7 +676,7 @@ func (tb *Table) FromMimeData(md mimedata.Mimes) [][]string { if d.Type == fileinfo.DataCsv { b := bytes.NewBuffer(d.Data) cr := csv.NewReader(b) - cr.Comma = table.Tab.Rune() + cr.Comma = tensor.Tab.Rune() rec, err := cr.ReadAll() if err != nil || len(rec) == 0 { log.Printf("Error reading CSV from clipboard: %s\n", err) @@ -709,7 +694,7 @@ func (tb *Table) PasteAssign(md mimedata.Mimes, idx int) { if len(recs) == 0 { return } - tb.Table.Table.ReadCSVRow(recs[1], tb.Table.Indexes[idx]) + tb.Table.ReadCSVRow(recs[1], tb.Table.RowIndex(idx)) tb.UpdateChange() } @@ -724,8 +709,8 @@ func (tb *Table) PasteAtIndex(md mimedata.Mimes, idx int) { tb.Table.InsertRows(idx, nr) for ri := 0; ri < nr; ri++ { rec := recs[1+ri] - rw := tb.Table.Indexes[idx+ri] - tb.Table.Table.ReadCSVRow(rec, rw) + rw := tb.Table.RowIndex(idx + ri) + tb.Table.ReadCSVRow(rec, rw) } tb.SendChange() tb.SelectIndexEvent(idx, events.SelectOne) diff --git a/tensor/tensorcore/tensoreditor.go b/tensor/tensorcore/tensoreditor.go index 65a833cc91..bf0bc901f3 100644 --- a/tensor/tensorcore/tensoreditor.go +++ b/tensor/tensorcore/tensoreditor.go @@ -12,6 +12,7 @@ import ( "cogentcore.org/core/base/fileinfo" "cogentcore.org/core/base/fileinfo/mimedata" + "cogentcore.org/core/base/fsx" "cogentcore.org/core/core" "cogentcore.org/core/events" "cogentcore.org/core/icons" @@ -19,7 +20,6 @@ import ( "cogentcore.org/core/styles/states" "cogentcore.org/core/styles/units" "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/table" "cogentcore.org/core/tree" ) @@ -31,7 +31,7 @@ type TensorEditor struct { Tensor tensor.Tensor `set:"-"` // overall layout options for tensor display - Layout TensorLayout `set:"-"` + Layout Layout `set:"-"` // number of columns in table (as of last update) NCols int `edit:"-"` @@ -52,6 +52,7 @@ var _ core.Lister = (*TensorEditor)(nil) func (tb *TensorEditor) Init() { tb.ListBase.Init() + tb.Layout.OddRow = true tb.Makers.Normal[0] = func(p *tree.Plan) { // TODO: reduce redundancy with ListBase Maker svi := tb.This.(core.Lister) svi.UpdateSliceSize() @@ -70,6 +71,9 @@ func (tb *TensorEditor) Init() { tb.UpdateMaxWidths() tb.Updater(func() { + if tb.Tensor.NumDims() == 1 { + tb.Layout.TopZero = true + } tb.UpdateStartIndex() }) @@ -335,8 +339,7 @@ func (tb *TensorEditor) RowGrabFocus(row int) *core.WidgetBase { return nil } -////////////////////////////////////////////////////// -// Header layout +/////// Header layout func (tb *TensorEditor) SizeFinal() { tb.ListBase.SizeFinal() @@ -365,14 +368,13 @@ func (tb *TensorEditor) SizeFinal() { } } -////////////////////////////////////////////////////////////////////////////// -// Copy / Cut / Paste +//////// Copy / Cut / Paste // SaveTSV writes a tensor to a tab-separated-values (TSV) file. // Outer-most dims are rows in the file, and inner-most is column -- // Reading just grabs all values and doesn't care about shape. func (tb *TensorEditor) SaveCSV(filename core.Filename) error { //types:add - return tensor.SaveCSV(tb.Tensor, filename, table.Tab.Rune()) + return tensor.SaveCSV(tb.Tensor, fsx.Filename(filename), tensor.Tab) } // OpenTSV reads a tensor from a tab-separated-values (TSV) file. @@ -380,7 +382,7 @@ func (tb *TensorEditor) SaveCSV(filename core.Filename) error { //types:add // to the official CSV standard. // Reads all values and assigns as many as fit. func (tb *TensorEditor) OpenCSV(filename core.Filename) error { //types:add - return tensor.OpenCSV(tb.Tensor, filename, table.Tab.Rune()) + return tensor.OpenCSV(tb.Tensor, fsx.Filename(filename), tensor.Tab) } func (tb *TensorEditor) MakeToolbar(p *tree.Plan) { @@ -409,7 +411,7 @@ func (tb *TensorEditor) CopySelectToMime() mimedata.Mimes { } // idx := tb.SelectedIndexesList(false) // ascending // var b bytes.Buffer - // ix.WriteCSV(&b, table.Tab, table.Headers) + // ix.WriteCSV(&b, tensor.Tab, table.Headers) // md := mimedata.NewTextBytes(b.Bytes()) // md[0].Type = fileinfo.DataCsv // return md @@ -423,7 +425,7 @@ func (tb *TensorEditor) FromMimeData(md mimedata.Mimes) [][]string { if d.Type == fileinfo.DataCsv { // b := bytes.NewBuffer(d.Data) // cr := csv.NewReader(b) - // cr.Comma = table.Tab.Rune() + // cr.Comma = tensor.Tab.Rune() // rec, err := cr.ReadAll() // if err != nil || len(rec) == 0 { // log.Printf("Error reading CSV from clipboard: %s\n", err) diff --git a/tensor/tensorcore/tensorgrid.go b/tensor/tensorcore/tensorgrid.go index a6d054a744..b9e31ec1f8 100644 --- a/tensor/tensorcore/tensorgrid.go +++ b/tensor/tensorcore/tensorgrid.go @@ -7,7 +7,6 @@ package tensorcore import ( "image/color" "log" - "strconv" "cogentcore.org/core/colors" "cogentcore.org/core/colors/colormap" @@ -22,168 +21,17 @@ import ( "cogentcore.org/core/tensor" ) -// TensorLayout are layout options for displaying tensors -type TensorLayout struct { //types:add - - // even-numbered dimensions are displayed as Y*X rectangles. - // This determines along which dimension to display any remaining - // odd dimension: OddRow = true = organize vertically along row - // dimension, false = organize horizontally across column dimension. - OddRow bool - - // if true, then the Y=0 coordinate is displayed from the top-down; - // otherwise the Y=0 coordinate is displayed from the bottom up, - // which is typical for emergent network patterns. - TopZero bool - - // display the data as a bitmap image. if a 2D tensor, then it will - // be a greyscale image. if a 3D tensor with size of either the first - // or last dim = either 3 or 4, then it is a RGB(A) color image. - Image bool -} - -// TensorDisplay are options for displaying tensors -type TensorDisplay struct { //types:add - TensorLayout - - // range to plot - Range minmax.Range64 `display:"inline"` - - // if not using fixed range, this is the actual range of data - MinMax minmax.F64 `display:"inline"` - - // the name of the color map to use in translating values to colors - ColorMap core.ColorMapName - - // what proportion of grid square should be filled by color block -- 1 = all, .5 = half, etc - GridFill float32 `min:"0.1" max:"1" step:"0.1" default:"0.9,1"` - - // amount of extra space to add at dimension boundaries, as a proportion of total grid size - DimExtra float32 `min:"0" max:"1" step:"0.02" default:"0.1,0.3"` - - // minimum size for grid squares -- they will never be smaller than this - GridMinSize float32 - - // maximum size for grid squares -- they will never be larger than this - GridMaxSize float32 - - // total preferred display size along largest dimension. - // grid squares will be sized to fit within this size, - // subject to harder GridMin / Max size constraints - TotPrefSize float32 - - // font size in standard point units for labels (e.g., SimMat) - FontSize float32 - - // our gridview, for update method - GridView *TensorGrid `copier:"-" json:"-" xml:"-" display:"-"` -} - -// Defaults sets defaults for values that are at nonsensical initial values -func (td *TensorDisplay) Defaults() { - if td.ColorMap == "" { - td.ColorMap = "ColdHot" - } - if td.Range.Max == 0 && td.Range.Min == 0 { - td.Range.SetMin(-1) - td.Range.SetMax(1) - } - if td.GridMinSize == 0 { - td.GridMinSize = 2 - } - if td.GridMaxSize == 0 { - td.GridMaxSize = 16 - } - if td.TotPrefSize == 0 { - td.TotPrefSize = 100 - } - if td.GridFill == 0 { - td.GridFill = 0.9 - td.DimExtra = 0.3 - } - if td.FontSize == 0 { - td.FontSize = 24 - } -} - -// FromMeta sets display options from Tensor meta-data -func (td *TensorDisplay) FromMeta(tsr tensor.Tensor) { - if op, has := tsr.MetaData("top-zero"); has { - if op == "+" || op == "true" { - td.TopZero = true - } - } - if op, has := tsr.MetaData("odd-row"); has { - if op == "+" || op == "true" { - td.OddRow = true - } - } - if op, has := tsr.MetaData("image"); has { - if op == "+" || op == "true" { - td.Image = true - } - } - if op, has := tsr.MetaData("min"); has { - mv, _ := strconv.ParseFloat(op, 64) - td.Range.Min = mv - } - if op, has := tsr.MetaData("max"); has { - mv, _ := strconv.ParseFloat(op, 64) - td.Range.Max = mv - } - if op, has := tsr.MetaData("fix-min"); has { - if op == "+" || op == "true" { - td.Range.FixMin = true - } else { - td.Range.FixMin = false - } - } - if op, has := tsr.MetaData("fix-max"); has { - if op == "+" || op == "true" { - td.Range.FixMax = true - } else { - td.Range.FixMax = false - } - } - if op, has := tsr.MetaData("colormap"); has { - td.ColorMap = core.ColorMapName(op) - } - if op, has := tsr.MetaData("grid-fill"); has { - mv, _ := strconv.ParseFloat(op, 32) - td.GridFill = float32(mv) - } - if op, has := tsr.MetaData("grid-min"); has { - mv, _ := strconv.ParseFloat(op, 32) - td.GridMinSize = float32(mv) - } - if op, has := tsr.MetaData("grid-max"); has { - mv, _ := strconv.ParseFloat(op, 32) - td.GridMaxSize = float32(mv) - } - if op, has := tsr.MetaData("dim-extra"); has { - mv, _ := strconv.ParseFloat(op, 32) - td.DimExtra = float32(mv) - } - if op, has := tsr.MetaData("font-size"); has { - mv, _ := strconv.ParseFloat(op, 32) - td.FontSize = float32(mv) - } -} - -//////////////////////////////////////////////////////////////////////////// -// TensorGrid - // TensorGrid is a widget that displays tensor values as a grid of colored squares. type TensorGrid struct { core.WidgetBase - // the tensor that we view + // Tensor is the tensor that we view. Tensor tensor.Tensor `set:"-"` - // display options - Display TensorDisplay + // GridStyle has grid display style properties. + GridStyle GridStyle - // the actual colormap + // ColorMap is the colormap displayed (based on) ColorMap *colormap.Map } @@ -196,8 +44,7 @@ func (tg *TensorGrid) SetWidgetValue(value any) error { func (tg *TensorGrid) Init() { tg.WidgetBase.Init() - tg.Display.GridView = tg - tg.Display.Defaults() + tg.GridStyle.Defaults() tg.Styler(func(s *styles.Style) { s.SetAbilities(true, abilities.DoubleClickable) ms := tg.MinSize() @@ -206,11 +53,11 @@ func (tg *TensorGrid) Init() { }) tg.OnDoubleClick(func(e events.Event) { - tg.OpenTensorEditor() + tg.TensorEditor() }) tg.AddContextMenu(func(m *core.Scene) { - core.NewFuncButton(m).SetFunc(tg.OpenTensorEditor).SetIcon(icons.Edit) - core.NewFuncButton(m).SetFunc(tg.EditSettings).SetIcon(icons.Edit) + core.NewFuncButton(m).SetFunc(tg.TensorEditor).SetIcon(icons.Edit) + core.NewFuncButton(m).SetFunc(tg.EditStyle).SetIcon(icons.Edit) }) } @@ -222,14 +69,14 @@ func (tg *TensorGrid) SetTensor(tsr tensor.Tensor) *TensorGrid { } tg.Tensor = tsr if tg.Tensor != nil { - tg.Display.FromMeta(tg.Tensor) + tg.GridStyle.ApplyStylersFrom(tg.Tensor) } return tg } -// OpenTensorEditor pulls up a TensorEditor of our tensor -func (tg *TensorGrid) OpenTensorEditor() { //types:add - d := core.NewBody("Tensor Editor") +// TensorEditor pulls up a TensorEditor of our tensor +func (tg *TensorGrid) TensorEditor() { //types:add + d := core.NewBody("Tensor editor") tb := core.NewToolbar(d) te := NewTensorEditor(d).SetTensor(tg.Tensor) te.OnChange(func(e events.Event) { @@ -239,9 +86,9 @@ func (tg *TensorGrid) OpenTensorEditor() { //types:add d.RunWindowDialog(tg) } -func (tg *TensorGrid) EditSettings() { //types:add - d := core.NewBody("Tensor Grid Display Options") - core.NewForm(d).SetStruct(&tg.Display). +func (tg *TensorGrid) EditStyle() { //types:add + d := core.NewBody("Tensor grid style") + core.NewForm(d).SetStruct(&tg.GridStyle). OnChange(func(e events.Event) { tg.NeedsRender() }) @@ -253,33 +100,32 @@ func (tg *TensorGrid) MinSize() math32.Vector2 { if tg.Tensor == nil || tg.Tensor.Len() == 0 { return math32.Vector2{} } - if tg.Display.Image { + if tg.GridStyle.Image { return math32.Vec2(float32(tg.Tensor.DimSize(1)), float32(tg.Tensor.DimSize(0))) } - rows, cols, rowEx, colEx := tensor.Projection2DShape(tg.Tensor.Shape(), tg.Display.OddRow) - frw := float32(rows) + float32(rowEx)*tg.Display.DimExtra // extra spacing - fcl := float32(cols) + float32(colEx)*tg.Display.DimExtra // extra spacing + rows, cols, rowEx, colEx := tensor.Projection2DShape(tg.Tensor.Shape(), tg.GridStyle.OddRow) + frw := float32(rows) + float32(rowEx)*tg.GridStyle.DimExtra // extra spacing + fcl := float32(cols) + float32(colEx)*tg.GridStyle.DimExtra // extra spacing mx := float32(max(frw, fcl)) - gsz := tg.Display.TotPrefSize / mx - gsz = max(gsz, tg.Display.GridMinSize) - gsz = min(gsz, tg.Display.GridMaxSize) + gsz := tg.GridStyle.TotalSize / mx + gsz = tg.GridStyle.Size.ClampValue(gsz) gsz = max(gsz, 2) return math32.Vec2(gsz*float32(fcl), gsz*float32(frw)) } // EnsureColorMap makes sure there is a valid color map that matches specified name func (tg *TensorGrid) EnsureColorMap() { - if tg.ColorMap != nil && tg.ColorMap.Name != string(tg.Display.ColorMap) { + if tg.ColorMap != nil && tg.ColorMap.Name != string(tg.GridStyle.ColorMap) { tg.ColorMap = nil } if tg.ColorMap == nil { ok := false - tg.ColorMap, ok = colormap.AvailableMaps[string(tg.Display.ColorMap)] + tg.ColorMap, ok = colormap.AvailableMaps[string(tg.GridStyle.ColorMap)] if !ok { - tg.Display.ColorMap = "" - tg.Display.Defaults() + tg.GridStyle.ColorMap = "" + tg.GridStyle.Defaults() } - tg.ColorMap = colormap.AvailableMaps[string(tg.Display.ColorMap)] + tg.ColorMap = colormap.AvailableMaps[string(tg.GridStyle.ColorMap)] } } @@ -287,22 +133,22 @@ func (tg *TensorGrid) Color(val float64) (norm float64, clr color.Color) { if tg.ColorMap.Indexed { clr = tg.ColorMap.MapIndex(int(val)) } else { - norm = tg.Display.Range.ClipNormValue(val) + norm = tg.GridStyle.Range.ClipNormValue(val) clr = tg.ColorMap.Map(float32(norm)) } return } func (tg *TensorGrid) UpdateRange() { - if !tg.Display.Range.FixMin || !tg.Display.Range.FixMax { - min, max, _, _ := tg.Tensor.Range() - if !tg.Display.Range.FixMin { + if !tg.GridStyle.Range.FixMin || !tg.GridStyle.Range.FixMax { + min, max, _, _ := tensor.Range(tg.Tensor.AsValues()) + if !tg.GridStyle.Range.FixMin { nmin := minmax.NiceRoundNumber(min, true) // true = below # - tg.Display.Range.Min = nmin + tg.GridStyle.Range.Min = nmin } - if !tg.Display.Range.FixMax { + if !tg.GridStyle.Range.FixMax { nmax := minmax.NiceRoundNumber(max, false) // false = above # - tg.Display.Range.Max = nmax + tg.GridStyle.Range.Max = nmax } } } @@ -324,7 +170,7 @@ func (tg *TensorGrid) Render() { tsr := tg.Tensor - if tg.Display.Image { + if tg.GridStyle.Image { ysz := tsr.DimSize(0) xsz := tsr.DimSize(1) nclr := 1 @@ -344,18 +190,18 @@ func (tg *TensorGrid) Render() { for y := 0; y < ysz; y++ { for x := 0; x < xsz; x++ { ey := y - if !tg.Display.TopZero { + if !tg.GridStyle.TopZero { ey = (ysz - 1) - y } switch { case outclr: var r, g, b, a float64 a = 1 - r = tg.Display.Range.ClipNormValue(tsr.Float([]int{0, y, x})) - g = tg.Display.Range.ClipNormValue(tsr.Float([]int{1, y, x})) - b = tg.Display.Range.ClipNormValue(tsr.Float([]int{2, y, x})) + r = tg.GridStyle.Range.ClipNormValue(tsr.Float(0, y, x)) + g = tg.GridStyle.Range.ClipNormValue(tsr.Float(1, y, x)) + b = tg.GridStyle.Range.ClipNormValue(tsr.Float(2, y, x)) if nclr > 3 { - a = tg.Display.Range.ClipNormValue(tsr.Float([]int{3, y, x})) + a = tg.GridStyle.Range.ClipNormValue(tsr.Float(3, y, x)) } cr := math32.Vec2(float32(x), float32(ey)) pr := pos.Add(cr.Mul(gsz)) @@ -364,18 +210,18 @@ func (tg *TensorGrid) Render() { case nclr > 1: var r, g, b, a float64 a = 1 - r = tg.Display.Range.ClipNormValue(tsr.Float([]int{y, x, 0})) - g = tg.Display.Range.ClipNormValue(tsr.Float([]int{y, x, 1})) - b = tg.Display.Range.ClipNormValue(tsr.Float([]int{y, x, 2})) + r = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 0)) + g = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 1)) + b = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 2)) if nclr > 3 { - a = tg.Display.Range.ClipNormValue(tsr.Float([]int{y, x, 3})) + a = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 3)) } cr := math32.Vec2(float32(x), float32(ey)) pr := pos.Add(cr.Mul(gsz)) pc.StrokeStyle.Color = colors.Uniform(colors.FromFloat64(r, g, b, a)) pc.FillBox(pr, gsz, pc.StrokeStyle.Color) default: - val := tg.Display.Range.ClipNormValue(tsr.Float([]int{y, x})) + val := tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x)) cr := math32.Vec2(float32(x), float32(ey)) pr := pos.Add(cr.Mul(gsz)) pc.StrokeStyle.Color = colors.Uniform(colors.FromFloat64(val, val, val, 1)) @@ -385,9 +231,9 @@ func (tg *TensorGrid) Render() { } return } - rows, cols, rowEx, colEx := tensor.Projection2DShape(tsr.Shape(), tg.Display.OddRow) - frw := float32(rows) + float32(rowEx)*tg.Display.DimExtra // extra spacing - fcl := float32(cols) + float32(colEx)*tg.Display.DimExtra // extra spacing + rows, cols, rowEx, colEx := tensor.Projection2DShape(tsr.Shape(), tg.GridStyle.OddRow) + frw := float32(rows) + float32(rowEx)*tg.GridStyle.DimExtra // extra spacing + fcl := float32(cols) + float32(colEx)*tg.GridStyle.DimExtra // extra spacing rowsInner := rows colsInner := cols if rowEx > 0 { @@ -399,16 +245,16 @@ func (tg *TensorGrid) Render() { tsz := math32.Vec2(fcl, frw) gsz := sz.Div(tsz) - ssz := gsz.MulScalar(tg.Display.GridFill) // smaller size with margin + ssz := gsz.MulScalar(tg.GridStyle.GridFill) // smaller size with margin for y := 0; y < rows; y++ { - yex := float32(int(y/rowsInner)) * tg.Display.DimExtra + yex := float32(int(y/rowsInner)) * tg.GridStyle.DimExtra for x := 0; x < cols; x++ { - xex := float32(int(x/colsInner)) * tg.Display.DimExtra + xex := float32(int(x/colsInner)) * tg.GridStyle.DimExtra ey := y - if !tg.Display.TopZero { + if !tg.GridStyle.TopZero { ey = (rows - 1) - y } - val := tensor.Projection2DValue(tsr, tg.Display.OddRow, ey, x) + val := tensor.Projection2DValue(tsr, tg.GridStyle.OddRow, ey, x) cr := math32.Vec2(float32(x)+xex, float32(y)+yex) pr := pos.Add(cr.Mul(gsz)) _, clr := tg.Color(val) diff --git a/tensor/tensorcore/typegen.go b/tensor/tensorcore/typegen.go index bbea16ff82..9ad7727687 100644 --- a/tensor/tensorcore/typegen.go +++ b/tensor/tensorcore/typegen.go @@ -4,46 +4,87 @@ package tensorcore import ( "cogentcore.org/core/colors/colormap" + "cogentcore.org/core/core" + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/simat" "cogentcore.org/core/tensor/table" "cogentcore.org/core/tree" "cogentcore.org/core/types" ) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.SimMatGrid", IDName: "sim-mat-grid", Doc: "SimMatGrid is a widget that displays a similarity / distance matrix\nwith tensor values as a grid of colored squares, and labels for rows and columns.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Embeds: []types.Field{{Name: "TensorGrid"}}, Fields: []types.Field{{Name: "SimMat", Doc: "the similarity / distance matrix"}, {Name: "rowMaxSz"}, {Name: "rowMinBlank"}, {Name: "rowNGps"}, {Name: "colMaxSz"}, {Name: "colMinBlank"}, {Name: "colNGps"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.Layout", IDName: "layout", Doc: "Layout are layout options for displaying tensors.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"--setters"}}}, Fields: []types.Field{{Name: "OddRow", Doc: "OddRow means that even-numbered dimensions are displayed as Y*X rectangles.\nThis determines along which dimension to display any remaining\nodd dimension: OddRow = true = organize vertically along row\ndimension, false = organize horizontally across column dimension."}, {Name: "TopZero", Doc: "TopZero means that the Y=0 coordinate is displayed from the top-down;\notherwise the Y=0 coordinate is displayed from the bottom up,\nwhich is typical for emergent network patterns."}, {Name: "Image", Doc: "Image will display the data as a bitmap image. If a 2D tensor, then it will\nbe a greyscale image. If a 3D tensor with size of either the first\nor last dim = either 3 or 4, then it is a RGB(A) color image."}}}) -// NewSimMatGrid returns a new [SimMatGrid] with the given optional parent: -// SimMatGrid is a widget that displays a similarity / distance matrix -// with tensor values as a grid of colored squares, and labels for rows and columns. -func NewSimMatGrid(parent ...tree.Node) *SimMatGrid { return tree.New[SimMatGrid](parent...) } +// SetOddRow sets the [Layout.OddRow]: +// OddRow means that even-numbered dimensions are displayed as Y*X rectangles. +// This determines along which dimension to display any remaining +// odd dimension: OddRow = true = organize vertically along row +// dimension, false = organize horizontally across column dimension. +func (t *Layout) SetOddRow(v bool) *Layout { t.OddRow = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.Table", IDName: "table", Doc: "Table provides a GUI widget for representing [table.Table] values.", Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Table", Doc: "the idx view of the table that we're a view of"}, {Name: "TensorDisplay", Doc: "overall display options for tensor display"}, {Name: "ColumnTensorDisplay", Doc: "per column tensor display params"}, {Name: "ColumnTensorBlank", Doc: "per column blank tensor values"}, {Name: "NCols", Doc: "number of columns in table (as of last update)"}, {Name: "SortIndex", Doc: "current sort index"}, {Name: "SortDescending", Doc: "whether current sort order is descending"}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields"}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields"}, {Name: "BlankString", Doc: "\tblank values for out-of-range rows"}, {Name: "BlankFloat"}}}) +// SetTopZero sets the [Layout.TopZero]: +// TopZero means that the Y=0 coordinate is displayed from the top-down; +// otherwise the Y=0 coordinate is displayed from the bottom up, +// which is typical for emergent network patterns. +func (t *Layout) SetTopZero(v bool) *Layout { t.TopZero = v; return t } + +// SetImage sets the [Layout.Image]: +// Image will display the data as a bitmap image. If a 2D tensor, then it will +// be a greyscale image. If a 3D tensor with size of either the first +// or last dim = either 3 or 4, then it is a RGB(A) color image. +func (t *Layout) SetImage(v bool) *Layout { t.Image = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.GridStyle", IDName: "grid-style", Doc: "GridStyle are options for displaying tensors", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"--setters"}}}, Embeds: []types.Field{{Name: "Layout"}}, Fields: []types.Field{{Name: "Range", Doc: "Range to plot"}, {Name: "MinMax", Doc: "MinMax has the actual range of data, if not using fixed Range."}, {Name: "ColorMap", Doc: "ColorMap is the name of the color map to use in translating values to colors."}, {Name: "GridFill", Doc: "GridFill sets proportion of grid square filled by the color block:\n1 = all, .5 = half, etc."}, {Name: "DimExtra", Doc: "DimExtra is the amount of extra space to add at dimension boundaries,\nas a proportion of total grid size."}, {Name: "Size", Doc: "Size sets the minimum and maximum size for grid squares."}, {Name: "TotalSize", Doc: "TotalSize sets the total preferred display size along largest dimension.\nGrid squares will be sized to fit within this size,\nsubject to the Size.Min / Max constraints, which have precedence."}, {Name: "FontSize", Doc: "FontSize is the font size in standard point units for labels."}}}) + +// SetRange sets the [GridStyle.Range]: +// Range to plot +func (t *GridStyle) SetRange(v minmax.Range64) *GridStyle { t.Range = v; return t } + +// SetMinMax sets the [GridStyle.MinMax]: +// MinMax has the actual range of data, if not using fixed Range. +func (t *GridStyle) SetMinMax(v minmax.F64) *GridStyle { t.MinMax = v; return t } + +// SetColorMap sets the [GridStyle.ColorMap]: +// ColorMap is the name of the color map to use in translating values to colors. +func (t *GridStyle) SetColorMap(v core.ColorMapName) *GridStyle { t.ColorMap = v; return t } + +// SetGridFill sets the [GridStyle.GridFill]: +// GridFill sets proportion of grid square filled by the color block: +// 1 = all, .5 = half, etc. +func (t *GridStyle) SetGridFill(v float32) *GridStyle { t.GridFill = v; return t } + +// SetDimExtra sets the [GridStyle.DimExtra]: +// DimExtra is the amount of extra space to add at dimension boundaries, +// as a proportion of total grid size. +func (t *GridStyle) SetDimExtra(v float32) *GridStyle { t.DimExtra = v; return t } + +// SetSize sets the [GridStyle.Size]: +// Size sets the minimum and maximum size for grid squares. +func (t *GridStyle) SetSize(v minmax.F32) *GridStyle { t.Size = v; return t } + +// SetTotalSize sets the [GridStyle.TotalSize]: +// TotalSize sets the total preferred display size along largest dimension. +// Grid squares will be sized to fit within this size, +// subject to the Size.Min / Max constraints, which have precedence. +func (t *GridStyle) SetTotalSize(v float32) *GridStyle { t.TotalSize = v; return t } + +// SetFontSize sets the [GridStyle.FontSize]: +// FontSize is the font size in standard point units for labels. +func (t *GridStyle) SetFontSize(v float32) *GridStyle { t.FontSize = v; return t } + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.Table", IDName: "table", Doc: "Table provides a GUI widget for representing [table.Table] values.", Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Table", Doc: "Table is the table that we're a view of."}, {Name: "GridStyle", Doc: "GridStyle has global grid display styles. GridStylers on the Table\nare applied to this on top of defaults."}, {Name: "ColumnGridStyle", Doc: "ColumnGridStyle has per column grid display styles."}, {Name: "SortIndex", Doc: "current sort index."}, {Name: "SortDescending", Doc: "whether current sort order is descending."}, {Name: "nCols", Doc: "number of columns in table (as of last update)."}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields."}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields."}, {Name: "blankString", Doc: "\tblank values for out-of-range rows."}, {Name: "blankFloat"}, {Name: "blankCells", Doc: "blankCells has per column blank tensor cells."}}}) // NewTable returns a new [Table] with the given optional parent: // Table provides a GUI widget for representing [table.Table] values. func NewTable(parent ...tree.Node) *Table { return tree.New[Table](parent...) } -// SetNCols sets the [Table.NCols]: -// number of columns in table (as of last update) -func (t *Table) SetNCols(v int) *Table { t.NCols = v; return t } - // SetSortIndex sets the [Table.SortIndex]: -// current sort index +// current sort index. func (t *Table) SetSortIndex(v int) *Table { t.SortIndex = v; return t } // SetSortDescending sets the [Table.SortDescending]: -// whether current sort order is descending +// whether current sort order is descending. func (t *Table) SetSortDescending(v bool) *Table { t.SortDescending = v; return t } -// SetBlankString sets the [Table.BlankString]: -// -// blank values for out-of-range rows -func (t *Table) SetBlankString(v string) *Table { t.BlankString = v; return t } - -// SetBlankFloat sets the [Table.BlankFloat] -func (t *Table) SetBlankFloat(v float64) *Table { t.BlankFloat = v; return t } - var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.TensorEditor", IDName: "tensor-editor", Doc: "TensorEditor provides a GUI widget for representing [tensor.Tensor] values.", Methods: []types.Method{{Name: "SaveCSV", Doc: "SaveTSV writes a tensor to a tab-separated-values (TSV) file.\nOuter-most dims are rows in the file, and inner-most is column --\nReading just grabs all values and doesn't care about shape.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}, {Name: "OpenCSV", Doc: "OpenTSV reads a tensor from a tab-separated-values (TSV) file.\nusing the Go standard encoding/csv reader conforming\nto the official CSV standard.\nReads all values and assigns as many as fit.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}}, Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "the tensor that we're a view of"}, {Name: "Layout", Doc: "overall layout options for tensor display"}, {Name: "NCols", Doc: "number of columns in table (as of last update)"}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields"}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields"}, {Name: "BlankString", Doc: "\tblank values for out-of-range rows"}, {Name: "BlankFloat"}}}) // NewTensorEditor returns a new [TensorEditor] with the given optional parent: @@ -62,22 +103,18 @@ func (t *TensorEditor) SetBlankString(v string) *TensorEditor { t.BlankString = // SetBlankFloat sets the [TensorEditor.BlankFloat] func (t *TensorEditor) SetBlankFloat(v float64) *TensorEditor { t.BlankFloat = v; return t } -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.TensorLayout", IDName: "tensor-layout", Doc: "TensorLayout are layout options for displaying tensors", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "OddRow", Doc: "even-numbered dimensions are displayed as Y*X rectangles.\nThis determines along which dimension to display any remaining\nodd dimension: OddRow = true = organize vertically along row\ndimension, false = organize horizontally across column dimension."}, {Name: "TopZero", Doc: "if true, then the Y=0 coordinate is displayed from the top-down;\notherwise the Y=0 coordinate is displayed from the bottom up,\nwhich is typical for emergent network patterns."}, {Name: "Image", Doc: "display the data as a bitmap image. if a 2D tensor, then it will\nbe a greyscale image. if a 3D tensor with size of either the first\nor last dim = either 3 or 4, then it is a RGB(A) color image."}}}) - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.TensorDisplay", IDName: "tensor-display", Doc: "TensorDisplay are options for displaying tensors", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Embeds: []types.Field{{Name: "TensorLayout"}}, Fields: []types.Field{{Name: "Range", Doc: "range to plot"}, {Name: "MinMax", Doc: "if not using fixed range, this is the actual range of data"}, {Name: "ColorMap", Doc: "the name of the color map to use in translating values to colors"}, {Name: "GridFill", Doc: "what proportion of grid square should be filled by color block -- 1 = all, .5 = half, etc"}, {Name: "DimExtra", Doc: "amount of extra space to add at dimension boundaries, as a proportion of total grid size"}, {Name: "GridMinSize", Doc: "minimum size for grid squares -- they will never be smaller than this"}, {Name: "GridMaxSize", Doc: "maximum size for grid squares -- they will never be larger than this"}, {Name: "TotPrefSize", Doc: "total preferred display size along largest dimension.\ngrid squares will be sized to fit within this size,\nsubject to harder GridMin / Max size constraints"}, {Name: "FontSize", Doc: "font size in standard point units for labels (e.g., SimMat)"}, {Name: "GridView", Doc: "our gridview, for update method"}}}) - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.TensorGrid", IDName: "tensor-grid", Doc: "TensorGrid is a widget that displays tensor values as a grid of colored squares.", Methods: []types.Method{{Name: "OpenTensorEditor", Doc: "OpenTensorEditor pulls up a TensorEditor of our tensor", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "EditSettings", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "the tensor that we view"}, {Name: "Display", Doc: "display options"}, {Name: "ColorMap", Doc: "the actual colormap"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.TensorGrid", IDName: "tensor-grid", Doc: "TensorGrid is a widget that displays tensor values as a grid of colored squares.", Methods: []types.Method{{Name: "TensorEditor", Doc: "TensorEditor pulls up a TensorEditor of our tensor", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "EditStyle", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor is the tensor that we view."}, {Name: "GridStyle", Doc: "GridStyle has grid display style properties."}, {Name: "ColorMap", Doc: "ColorMap is the colormap displayed (based on)"}}}) // NewTensorGrid returns a new [TensorGrid] with the given optional parent: // TensorGrid is a widget that displays tensor values as a grid of colored squares. func NewTensorGrid(parent ...tree.Node) *TensorGrid { return tree.New[TensorGrid](parent...) } -// SetDisplay sets the [TensorGrid.Display]: -// display options -func (t *TensorGrid) SetDisplay(v TensorDisplay) *TensorGrid { t.Display = v; return t } +// SetGridStyle sets the [TensorGrid.GridStyle]: +// GridStyle has grid display style properties. +func (t *TensorGrid) SetGridStyle(v GridStyle) *TensorGrid { t.GridStyle = v; return t } // SetColorMap sets the [TensorGrid.ColorMap]: -// the actual colormap +// ColorMap is the colormap displayed (based on) func (t *TensorGrid) SetColorMap(v *colormap.Map) *TensorGrid { t.ColorMap = v; return t } var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.TensorButton", IDName: "tensor-button", Doc: "TensorButton represents a Tensor with a button for making a [TensorGrid]\nviewer for an [tensor.Tensor].", Embeds: []types.Field{{Name: "Button"}}, Fields: []types.Field{{Name: "Tensor"}}}) @@ -98,12 +135,3 @@ func NewTableButton(parent ...tree.Node) *TableButton { return tree.New[TableBut // SetTable sets the [TableButton.Table] func (t *TableButton) SetTable(v *table.Table) *TableButton { t.Table = v; return t } - -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor/tensorcore.SimMatButton", IDName: "sim-mat-button", Doc: "SimMatValue presents a button that pulls up the [SimMatGrid] viewer for a [table.Table].", Embeds: []types.Field{{Name: "Button"}}, Fields: []types.Field{{Name: "SimMat"}}}) - -// NewSimMatButton returns a new [SimMatButton] with the given optional parent: -// SimMatValue presents a button that pulls up the [SimMatGrid] viewer for a [table.Table]. -func NewSimMatButton(parent ...tree.Node) *SimMatButton { return tree.New[SimMatButton](parent...) } - -// SetSimMat sets the [SimMatButton.SimMat] -func (t *SimMatButton) SetSimMat(v *simat.SimMat) *SimMatButton { t.SimMat = v; return t } diff --git a/tensor/tensorcore/values.go b/tensor/tensorcore/values.go index 26eee622cd..7f77470b97 100644 --- a/tensor/tensorcore/values.go +++ b/tensor/tensorcore/values.go @@ -5,10 +5,10 @@ package tensorcore import ( + "cogentcore.org/core/base/metadata" "cogentcore.org/core/core" "cogentcore.org/core/icons" "cogentcore.org/core/tensor" - "cogentcore.org/core/tensor/stats/simat" "cogentcore.org/core/tensor/table" ) @@ -20,8 +20,8 @@ func init() { core.AddValueType[tensor.Int32, TensorButton]() core.AddValueType[tensor.Byte, TensorButton]() core.AddValueType[tensor.String, TensorButton]() - core.AddValueType[tensor.Bits, TensorButton]() - core.AddValueType[simat.SimMat, SimMatButton]() + core.AddValueType[tensor.Bool, TensorButton]() + // core.AddValueType[simat.SimMat, SimMatButton]() } // TensorButton represents a Tensor with a button for making a [TensorGrid] @@ -62,7 +62,7 @@ func (tb *TableButton) Init() { tb.Updater(func() { text := "None" if tb.Table != nil { - if nm, has := tb.Table.MetaData["name"]; has { + if nm, err := metadata.Get[string](tb.Table.Meta, "name"); err == nil { text = nm } else { text = "Table" @@ -75,6 +75,7 @@ func (tb *TableButton) Init() { }) } +/* // SimMatValue presents a button that pulls up the [SimMatGrid] viewer for a [table.Table]. type SimMatButton struct { core.Button @@ -101,3 +102,4 @@ func (tb *SimMatButton) Init() { NewSimMatGrid(d).SetSimMat(tb.SimMat) }) } +*/ diff --git a/tensor/tensorfs/README.md b/tensor/tensorfs/README.md new file mode 100644 index 0000000000..e47cfbfd3b --- /dev/null +++ b/tensor/tensorfs/README.md @@ -0,0 +1,82 @@ +# tensorfs: a virtual filesystem for tensor data + +`tensorfs` is a virtual file system that implements the Go `fs` interface, and can be accessed using fs-general tools, including the cogent core `filetree` and the `goal` shell. + +Values are represented using the [tensor] package universal data type: the `tensor.Tensor`, which can represent everything from a single scalar value up to n-dimensional collections of patterns, in a range of data types. + +A given `Node` in the file system is either: +* A _Value_, with a tensor encoding its value. These are terminal "leaves" in the hierarchical data tree, equivalent to "files" in a standard filesystem. +* A _Directory_, with an ordered map of other Node nodes under it. + +Each Node has a name which must be unique within the directory. The nodes in a directory are processed in the order of its ordered map list, which initially reflects the order added, and can be re-ordered as needed. An alphabetical sort is also available with the `Alpha` versions of methods, and is the default sort for standard FS operations. + +The hierarchical structure of a filesystem naturally supports various kinds of functions, such as various time scales of logging, with lower-level data aggregated into upper levels. Or hierarchical splits for a pivot-table effect. + +# Usage + +There are two main APIs, one for direct usage within Go, and another that is used by the `goal` framework for interactive shell-based access, which always operates relative to a current working directory. + +## Go API + +The primary Go access function is the generic `Value`: + +```Go +tsr := tensorfs.Value[float64](dir, "filename", 5, 5) +``` + +This returns a `tensor.Values` for the node `"filename"` in the directory Node `dir` with the tensor shape size of 5x5, and `float64` values. + +If the tensor was previously created, then it is returned, and otherwise it is created. This provides a robust single-function API for access and creation, and it doesn't return any errors, so the return value can used directly, in inline expressions etc. + +For efficiency, _there are no checks_ on the existing value relative to the arguments passed, so if you end up using the same name for two different things, that will cause problems that will hopefully become evident. If you want to ensure that the size is correct, you should use an explicit `tensor.SetShapeSizes` call, which is still quite efficient if the size is the same. You can also have an initial call to `Value` that has no size args, and then set the size later -- that works fine. + +There are also functions for high-frequency types, defined on the `Node`: `Float64`, `Float32`, `Int`, and `StringValue` (`String` is taken by `fmt.Stringer`, `StringValue` is used in `tensor`), e.g.,: + +```Go +tsr := dir.Float64("filename", 5, 5) +``` + +There are also a few other variants of the `Value` functionality: +* `Scalar` calls `Value` with a size of 1. +* `Values` makes multiple tensor values of the same shape, with a final variadic list of names. +* `ValueType` takes a `reflect.Kind` arg for the data type, which can then be a variable. +* `NewForTensor` creates a node for an existing tensor. + +`DirTable` returns a `table.Table` with all the tensors under a given directory node, which can then be used for making plots or doing other forms of data analysis. This works best when each tensor has the same outer-most row dimension. The table is persistent and very efficient, using direct pointers to the underlying tensor values. + +### Directories + +Directories are `Node` elements that have a `nodes` value (ordered map of named nodes) instead of a tensor value. + +The primary way to make / access a subdirectory is the `Dir` method: +```Go +subdir := dir.Dir("subdir") +``` +If the subdirectory doesn't exist yet, it will be made, and otherwise it is returned. Any errors will be logged and a nil returned, likely causing a panic unless you expect it to fail and check for that. + +There are parallel `Node` and `Value` access methods for directory nodes, with the Value ones being: + +* `tsr := dir.Value("name")` returns tensor directly, will panic if not valid +* `tsrs, err := dir.Values("name1", "name2")` returns a slice of tensors and error if any issues +* `tsrs := dir.ValuesFunc()` walks down directories (unless filtered) and returns a flat list of all tensors found. Goes in "directory order" = order nodes were added. +* `tsrs := dir.ValuesAlphaFunc()` is like `ValuesFunc` but traverses in alpha order at each node. + +### Existing items and unique names + +As in a real filesystem, names must be unique within each directory, which creates issues for how to manage conflicts between existing and new items. To make the overall framework maximally robust and eliminate the need for a controlled initialization-then-access ordering, we generally adopt the "Recycle" logic: + +* _Return an existing item of the same name, or make a new one._ + +In addition, if you really need to know if there is an existing item, you can use the `Node` method to check for yourself -- it will return `nil` if no node of that name exists. Furthermore, the global `NewDir` function returns an `fs.ErrExist` error for existing items (e.g., use `errors.Is(fs.ErrExist)`), as used in various `os` package functions. + +## `goal` Command API + +The following shell command style functions always operate relative to the global `CurDir` current directory and `CurRoot` root, and `goal` in math mode exposes these methods directly. Goal operates on tensor valued variables always. + +* `Chdir("subdir")` change current directory to subdir. +* `Mkdir("subdir")` make a new directory. +* `List()` print a list of nodes. +* `tsr := Get("mydata")` get tensor value at "mydata" node. +* `Set("mydata", tsr)` set tensor to "mydata" node. + + diff --git a/tensor/tensorfs/commands.go b/tensor/tensorfs/commands.go new file mode 100644 index 0000000000..a57b75ee31 --- /dev/null +++ b/tensor/tensorfs/commands.go @@ -0,0 +1,156 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "fmt" + "io/fs" + "path" + "strings" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" +) + +var ( + // CurDir is the current working directory. + CurDir *Node + + // CurRoot is the current root tensorfs system. + // A default root tensorfs is created at startup. + CurRoot *Node +) + +func init() { + CurRoot, _ = NewDir("data") + CurDir = CurRoot +} + +// Record saves given tensor to current directory with given name. +func Record(tsr tensor.Tensor, name string) error { + _, err := NewForTensor(CurDir, tsr, name) + return err // todo: could prompt about conficts, or always overwrite existing? +} + +// Chdir changes the current working tensorfs directory to the named directory. +func Chdir(dir string) error { + if CurDir == nil { + CurDir = CurRoot + } + if dir == "" { + CurDir = CurRoot + return nil + } + ndir, err := CurDir.DirAtPath(dir) + if err != nil { + return err + } + CurDir = ndir + return nil +} + +// Mkdir creates a new directory with the specified name in the current directory. +// It returns an existing directory of the same name without error. +func Mkdir(dir string) *Node { + if CurDir == nil { + CurDir = CurRoot + } + if dir == "" { + err := &fs.PathError{Op: "Mkdir", Path: dir, Err: errors.New("path must not be empty")} + errors.Log(err) + return nil + } + return CurDir.Dir(dir) +} + +// List lists files using arguments (options and path) from the current directory. +func List(opts ...string) error { + if CurDir == nil { + CurDir = CurRoot + } + + long := false + recursive := false + if len(opts) > 0 && len(opts[0]) > 0 && opts[0][0] == '-' { + op := opts[0] + if strings.Contains(op, "l") { + long = true + } + if strings.Contains(op, "r") { + recursive = true + } + opts = opts[1:] + } + dir := CurDir + if len(opts) > 0 { + nd, err := CurDir.DirAtPath(opts[0]) + if err == nil { + dir = nd + } + } + ls := dir.List(long, recursive) + fmt.Println(ls) + return nil +} + +// Get returns the tensor value at given path relative to the +// current working directory. +// This is the direct pointer to the node, so changes +// to it will change the node. Clone the tensor to make +// a new copy disconnected from the original. +func Get(name string) tensor.Tensor { + if CurDir == nil { + CurDir = CurRoot + } + if name == "" { + err := &fs.PathError{Op: "Get", Path: name, Err: errors.New("name must not be empty")} + errors.Log(err) + return nil + } + nd, err := CurDir.NodeAtPath(name) + if errors.Log(err) != nil { + return nil + } + if nd.IsDir() { + err := &fs.PathError{Op: "Get", Path: name, Err: errors.New("node is a directory, not a data node")} + errors.Log(err) + return nil + } + return nd.Tensor +} + +// Set sets tensor to given name or path relative to the +// current working directory. +// If the node already exists, its previous tensor is updated to the +// given one; if it doesn't, then a new node is created. +func Set(name string, tsr tensor.Tensor) error { + if CurDir == nil { + CurDir = CurRoot + } + if name == "" { + err := &fs.PathError{Op: "Set", Path: name, Err: errors.New("name must not be empty")} + return errors.Log(err) + } + itm, err := CurDir.NodeAtPath(name) + if err == nil { + if itm.IsDir() { + err := &fs.PathError{Op: "Set", Path: name, Err: errors.New("existing node is a directory, not a data node")} + return errors.Log(err) + } + itm.Tensor = tsr + return nil + } + cd := CurDir + dir, name := path.Split(name) + if dir != "" { + d, err := CurDir.DirAtPath(dir) + if err != nil { + return errors.Log(err) + } + cd = d + } + _, err = NewForTensor(cd, tsr, name) + return errors.Log(err) +} diff --git a/tensor/tensorfs/copy.go b/tensor/tensorfs/copy.go new file mode 100644 index 0000000000..c85ff42631 --- /dev/null +++ b/tensor/tensorfs/copy.go @@ -0,0 +1,101 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "errors" + "io/fs" + "time" + + "cogentcore.org/core/tensor" +) + +const ( + // Preserve is used for Overwrite flag, indicating to not overwrite and preserve existing. + Preserve = false + + // Overwrite is used for Overwrite flag, indicating to overwrite existing. + Overwrite = true +) + +// CopyFromValue copies value from given source node, cloning it. +func (d *Node) CopyFromValue(frd *Node) { + d.modTime = time.Now() + d.Tensor = tensor.Clone(frd.Tensor) +} + +// Clone returns a copy of this node, recursively cloning directory nodes +// if it is a directory. +func (nd *Node) Clone() *Node { + if !nd.IsDir() { + cp, _ := newNode(nil, nd.name) + cp.Tensor = tensor.Clone(nd.Tensor) + return cp + } + nodes, _ := nd.Nodes() + cp, _ := NewDir(nd.name) + for _, it := range nodes { + cp.Add(it.Clone()) + } + return cp +} + +// Copy copies node(s) from given paths to given path or directory. +// if there are multiple from nodes, then to must be a directory. +// must be called on a directory node. +func (dir *Node) Copy(overwrite bool, to string, from ...string) error { + if err := dir.mustDir("Copy", to); err != nil { + return err + } + switch { + case to == "": + return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("to location is empty")} + case len(from) == 0: + return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("no from sources specified")} + } + // todo: check for to conflict first here.. + tod, _ := dir.NodeAtPath(to) + var errs []error + if len(from) > 1 && tod != nil && !tod.IsDir() { + return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("multiple source nodes requires destination to be a directory")} + } + targd := dir + targf := to + if tod != nil && tod.IsDir() { + targd = tod + targf = "" + } + for _, fr := range from { + opstr := fr + " -> " + to + frd, err := dir.NodeAtPath(fr) + if err != nil { + errs = append(errs, err) + continue + } + if targf == "" { + if trg, ok := targd.nodes.AtTry(frd.name); ok { // target exists + switch { + case trg.IsDir() && frd.IsDir(): + // todo: copy all nodes from frd into trg + case trg.IsDir(): // frd is not + errs = append(errs, &fs.PathError{Op: "Copy", Path: opstr, Err: errors.New("cannot copy from Value onto directory of same name")}) + case frd.IsDir(): // trg is not + errs = append(errs, &fs.PathError{Op: "Copy", Path: opstr, Err: errors.New("cannot copy from Directory onto Value of same name")}) + default: // both nodes + if overwrite { // todo: interactive!? + trg.CopyFromValue(frd) + } + } + continue + } + } + nw := frd.Clone() + if targf != "" { + nw.name = targf + } + targd.Add(nw) + } + return errors.Join(errs...) +} diff --git a/tensor/tensorfs/dir.go b/tensor/tensorfs/dir.go new file mode 100644 index 0000000000..7fda4aaa5d --- /dev/null +++ b/tensor/tensorfs/dir.go @@ -0,0 +1,346 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "fmt" + "io/fs" + "path" + "slices" + "sort" + "strings" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/keylist" + "cogentcore.org/core/tensor" +) + +// Nodes is a map of directory entry names to Nodes. +// It retains the order that nodes were added in, which is +// the natural order nodes are processed in. +type Nodes = keylist.List[string, *Node] + +// NewDir returns a new tensorfs directory with the given name. +// If parent != nil and a directory, this dir is added to it. +// If the parent already has an node of that name, it is returned, +// with an [fs.ErrExist] error. +// If the name is empty, then it is set to "root", the root directory. +// Note that "/" is not allowed for the root directory in Go [fs]. +// If no parent (i.e., a new root) and CurRoot is nil, then it is set +// to this. +func NewDir(name string, parent ...*Node) (*Node, error) { + if name == "" { + name = "root" + } + var par *Node + if len(parent) == 1 { + par = parent[0] + } + dir, err := newNode(par, name) + if dir != nil && dir.nodes == nil { + dir.nodes = &Nodes{} + } + return dir, err +} + +// Dir creates a new directory under given dir with the specified name +// if it doesn't already exist, otherwise returns the existing one. +// Path / slash separators can be used to make a path of multiple directories. +// It logs an error and returns nil if this dir node is not a directory. +func (dir *Node) Dir(name string) *Node { + if err := dir.mustDir("Dir", name); errors.Log(err) != nil { + return nil + } + if len(name) == 0 { + return dir + } + path := strings.Split(name, "/") + if cd := dir.nodes.At(path[0]); cd != nil { + if len(path) > 1 { + return cd.Dir(strings.Join(path[1:], "/")) + } + return cd + } + nd, _ := NewDir(path[0], dir) + if len(path) > 1 { + return nd.Dir(strings.Join(path[1:], "/")) + } + return nd +} + +// Node returns a Node in given directory by name. +// This is for fast access and direct usage of known +// nodes, and it will panic if this node is not a directory. +// Returns nil if no node of given name exists. +func (dir *Node) Node(name string) *Node { + return dir.nodes.At(name) +} + +// Value returns the [tensor.Tensor] value for given node +// within this directory. This will panic if node is not +// found, and will return nil if it is not a Value +// (i.e., it is a directory). +func (dir *Node) Value(name string) tensor.Tensor { + return dir.nodes.At(name).Tensor +} + +// Nodes returns a slice of Nodes in given directory by names variadic list. +// If list is empty, then all nodes in the directory are returned. +// returned error reports any nodes not found, or if not a directory. +func (dir *Node) Nodes(names ...string) ([]*Node, error) { + if err := dir.mustDir("Nodes", ""); err != nil { + return nil, err + } + var nds []*Node + if len(names) == 0 { + for _, it := range dir.nodes.Values { + nds = append(nds, it) + } + return nds, nil + } + var errs []error + for _, nm := range names { + dt := dir.nodes.At(nm) + if dt != nil { + nds = append(nds, dt) + } else { + err := fmt.Errorf("tensorfs Dir %q node not found: %q", dir.Path(), nm) + errs = append(errs, err) + } + } + return nds, errors.Join(errs...) +} + +// Values returns a slice of tensor values in the given directory, +// by names variadic list. If list is empty, then all value nodes +// in the directory are returned. +// returned error reports any nodes not found, or if not a directory. +func (dir *Node) Values(names ...string) ([]tensor.Tensor, error) { + if err := dir.mustDir("Values", ""); err != nil { + return nil, err + } + var nds []tensor.Tensor + if len(names) == 0 { + for _, it := range dir.nodes.Values { + if it.Tensor != nil { + nds = append(nds, it.Tensor) + } + } + return nds, nil + } + var errs []error + for _, nm := range names { + it := dir.nodes.At(nm) + if it != nil && it.Tensor != nil { + nds = append(nds, it.Tensor) + } else { + err := fmt.Errorf("tensorfs Dir %q node not found: %q", dir.Path(), nm) + errs = append(errs, err) + } + } + return nds, errors.Join(errs...) +} + +// ValuesFunc returns all tensor Values under given directory, +// filtered by given function, in directory order (e.g., order added), +// recursively descending into directories to return a flat list of +// the entire subtree. The function can filter out directories to prune +// the tree, e.g., using `IsDir` method. +// If func is nil, all Value nodes are returned. +func (dir *Node) ValuesFunc(fun func(nd *Node) bool) []tensor.Tensor { + if err := dir.mustDir("ValuesFunc", ""); err != nil { + return nil + } + var nds []tensor.Tensor + for _, it := range dir.nodes.Values { + if fun != nil && !fun(it) { + continue + } + if it.IsDir() { + subs := it.ValuesFunc(fun) + nds = append(nds, subs...) + } else { + nds = append(nds, it.Tensor) + } + } + return nds +} + +// NodesFunc returns leaf Nodes under given directory, filtered by +// given function, recursively descending into directories +// to return a flat list of the entire subtree, in directory order +// (e.g., order added). +// The function can filter out directories to prune the tree. +// If func is nil, all leaf Nodes are returned. +func (dir *Node) NodesFunc(fun func(nd *Node) bool) []*Node { + if err := dir.mustDir("NodesFunc", ""); err != nil { + return nil + } + var nds []*Node + for _, it := range dir.nodes.Values { + if fun != nil && !fun(it) { + continue + } + if it.IsDir() { + subs := it.NodesFunc(fun) + nds = append(nds, subs...) + } else { + nds = append(nds, it) + } + } + return nds +} + +// ValuesAlphaFunc returns all Value nodes (tensors) in given directory, +// recursively descending into directories to return a flat list of +// the entire subtree, filtered by given function, with nodes at each +// directory level traversed in alphabetical order. +// The function can filter out directories to prune the tree. +// If func is nil, all Values are returned. +func (dir *Node) ValuesAlphaFunc(fun func(nd *Node) bool) []tensor.Tensor { + if err := dir.mustDir("ValuesAlphaFunc", ""); err != nil { + return nil + } + names := dir.dirNamesAlpha() + var nds []tensor.Tensor + for _, nm := range names { + it := dir.nodes.At(nm) + if fun != nil && !fun(it) { + continue + } + if it.IsDir() { + subs := it.ValuesAlphaFunc(fun) + nds = append(nds, subs...) + } else { + nds = append(nds, it.Tensor) + } + } + return nds +} + +// NodesAlphaFunc returns leaf nodes under given directory, filtered +// by given function, with nodes at each directory level +// traversed in alphabetical order, recursively descending into directories +// to return a flat list of the entire subtree, in directory order +// (e.g., order added). +// The function can filter out directories to prune the tree. +// If func is nil, all leaf Nodes are returned. +func (dir *Node) NodesAlphaFunc(fun func(nd *Node) bool) []*Node { + if err := dir.mustDir("NodesAlphaFunc", ""); err != nil { + return nil + } + names := dir.dirNamesAlpha() + var nds []*Node + for _, nm := range names { + it := dir.nodes.At(nm) + if fun != nil && !fun(it) { + continue + } + if it.IsDir() { + subs := it.NodesAlphaFunc(fun) + nds = append(nds, subs...) + } else { + nds = append(nds, it) + } + } + return nds +} + +// todo: these must handle going up the tree using .. + +// DirAtPath returns directory at given relative path +// from this starting dir. +func (dir *Node) DirAtPath(dirPath string) (*Node, error) { + var err error + dirPath = path.Clean(dirPath) + sdf, err := dir.Sub(dirPath) // this ensures that dir is a dir + if err != nil { + return nil, err + } + return sdf.(*Node), nil +} + +// NodeAtPath returns node at given relative path from this starting dir. +func (dir *Node) NodeAtPath(name string) (*Node, error) { + if err := dir.mustDir("NodeAtPath", name); err != nil { + return nil, err + } + if !fs.ValidPath(name) { + return nil, &fs.PathError{Op: "NodeAtPath", Path: name, Err: errors.New("invalid path")} + } + dirPath, file := path.Split(name) + sd, err := dir.DirAtPath(dirPath) + if err != nil { + return nil, err + } + nd, ok := sd.nodes.AtTry(file) + if !ok { + if dirPath == "" && (file == dir.name || file == ".") { + return dir, nil + } + return nil, &fs.PathError{Op: "NodeAtPath", Path: name, Err: errors.New("file not found")} + } + return nd, nil +} + +// Path returns the full path to this data node +func (dir *Node) Path() string { + pt := dir.name + cur := dir.Parent + loops := make(map[*Node]struct{}) + for { + if cur == nil { + return pt + } + if _, ok := loops[cur]; ok { + return pt + } + pt = path.Join(cur.name, pt) + loops[cur] = struct{}{} + cur = cur.Parent + } +} + +// dirNamesAlpha returns the names of nodes in the directory +// sorted alphabetically. Node must be dir by this point. +func (dir *Node) dirNamesAlpha() []string { + names := slices.Clone(dir.nodes.Keys) + sort.Strings(names) + return names +} + +// dirNamesByTime returns the names of nodes in the directory +// sorted by modTime. Node must be dir by this point. +func (dir *Node) dirNamesByTime() []string { + names := slices.Clone(dir.nodes.Keys) + slices.SortFunc(names, func(a, b string) int { + return dir.nodes.At(a).ModTime().Compare(dir.nodes.At(b).ModTime()) + }) + return names +} + +// mustDir returns an error for given operation and path +// if this data node is not a directory. +func (dir *Node) mustDir(op, path string) error { + if !dir.IsDir() { + return &fs.PathError{Op: op, Path: path, Err: errors.New("tensorfs node is not a directory")} + } + return nil +} + +// Add adds an node to this directory data node. +// The only errors are if this node is not a directory, +// or the name already exists, in which case an [fs.ErrExist] is returned. +// Names must be unique within a directory. +func (dir *Node) Add(it *Node) error { + if err := dir.mustDir("Add", it.name); err != nil { + return err + } + err := dir.nodes.Add(it.name, it) + if err != nil { + return fs.ErrExist + } + return nil +} diff --git a/tensor/datafs/file.go b/tensor/tensorfs/file.go similarity index 85% rename from tensor/datafs/file.go rename to tensor/tensorfs/file.go index 74dcfc4a0e..3e3ef0eaca 100644 --- a/tensor/datafs/file.go +++ b/tensor/tensorfs/file.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package datafs +package tensorfs import ( "bytes" @@ -14,17 +14,17 @@ import ( // All io functionality is handled by [bytes.Reader]. type File struct { bytes.Reader - Data *Data + Node *Node dirEntries []fs.DirEntry dirsRead int } func (f *File) Stat() (fs.FileInfo, error) { - return f.Data, nil + return f.Node, nil } func (f *File) Close() error { - f.Reader.Reset(f.Data.Bytes()) + f.Reader.Reset(f.Node.Bytes()) return nil } @@ -36,11 +36,11 @@ type DirFile struct { } func (f *DirFile) ReadDir(n int) ([]fs.DirEntry, error) { - if err := f.Data.mustDir("DirFile:ReadDir", ""); err != nil { + if err := f.Node.mustDir("DirFile:ReadDir", ""); err != nil { return nil, err } if f.dirEntries == nil { - f.dirEntries, _ = f.Data.ReadDir(".") + f.dirEntries, _ = f.Node.ReadDir(".") f.dirsRead = 0 } ne := len(f.dirEntries) diff --git a/tensor/tensorfs/fs.go b/tensor/tensorfs/fs.go new file mode 100644 index 0000000000..da095cd91d --- /dev/null +++ b/tensor/tensorfs/fs.go @@ -0,0 +1,179 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "bytes" + "errors" + "io/fs" + "slices" + "time" + + "cogentcore.org/core/base/fileinfo" + "cogentcore.org/core/base/fsx" +) + +// fs.go contains all the io/fs interface implementations, and other fs functionality. + +// Open opens the given node at given path within this tensorfs filesystem. +func (nd *Node) Open(name string) (fs.File, error) { + itm, err := nd.NodeAtPath(name) + if err != nil { + return nil, err + } + if itm.IsDir() { + return &DirFile{File: File{Reader: *bytes.NewReader(itm.Bytes()), Node: itm}}, nil + } + return &File{Reader: *bytes.NewReader(itm.Bytes()), Node: itm}, nil +} + +// Stat returns a FileInfo describing the file. +// If there is an error, it should be of type *PathError. +func (nd *Node) Stat(name string) (fs.FileInfo, error) { + return nd.NodeAtPath(name) +} + +// Sub returns a data FS corresponding to the subtree rooted at dir. +func (nd *Node) Sub(dir string) (fs.FS, error) { + if err := nd.mustDir("Sub", dir); err != nil { + return nil, err + } + if !fs.ValidPath(dir) { + return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("invalid name")} + } + if dir == "." || dir == "" || dir == nd.name { + return nd, nil + } + cd := dir + cur := nd + root, rest := fsx.SplitRootPathFS(dir) + if root == "." || root == nd.name { + cd = rest + } + for { + if cd == "." || cd == "" { + return cur, nil + } + root, rest := fsx.SplitRootPathFS(cd) + if root == "." && rest == "" { + return cur, nil + } + cd = rest + sd, ok := cur.nodes.AtTry(root) + if !ok { + return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("directory not found")} + } + if !sd.IsDir() { + return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("is not a directory")} + } + cur = sd + } +} + +// ReadDir returns the contents of the given directory within this filesystem. +// Use "." (or "") to refer to the current directory. +func (nd *Node) ReadDir(dir string) ([]fs.DirEntry, error) { + sd, err := nd.DirAtPath(dir) + if err != nil { + return nil, err + } + names := sd.dirNamesAlpha() + ents := make([]fs.DirEntry, len(names)) + for i, nm := range names { + ents[i] = sd.nodes.At(nm) + } + return ents, nil +} + +// ReadFile reads the named file and returns its contents. +// A successful call returns a nil error, not io.EOF. +// (Because ReadFile reads the whole file, the expected EOF +// from the final Read is not treated as an error to be reported.) +// +// The caller is permitted to modify the returned byte slice. +// This method should return a copy of the underlying data. +func (nd *Node) ReadFile(name string) ([]byte, error) { + itm, err := nd.NodeAtPath(name) + if err != nil { + return nil, err + } + if itm.IsDir() { + return nil, &fs.PathError{Op: "ReadFile", Path: name, Err: errors.New("Node is a directory")} + } + return slices.Clone(itm.Bytes()), nil +} + +//////// FileInfo interface: + +func (nd *Node) Name() string { return nd.name } + +// Size returns the size of known data Values, or it uses +// the Sizer interface, otherwise returns 0. +func (nd *Node) Size() int64 { + if nd.Tensor == nil { + return 0 + } + return nd.Tensor.AsValues().Sizeof() +} + +func (nd *Node) IsDir() bool { + return nd.nodes != nil +} + +func (nd *Node) ModTime() time.Time { + return nd.modTime +} + +func (nd *Node) Mode() fs.FileMode { + if nd.IsDir() { + return 0755 | fs.ModeDir + } + return 0444 +} + +// Sys returns the Dir or Value +func (nd *Node) Sys() any { + if nd.Tensor != nil { + return nd.Tensor + } + return nd.nodes +} + +//////// DirEntry interface + +func (nd *Node) Type() fs.FileMode { + return nd.Mode().Type() +} + +func (nd *Node) Info() (fs.FileInfo, error) { + return nd, nil +} + +//////// Misc + +func (nd *Node) KnownFileInfo() fileinfo.Known { + if nd.Tensor == nil { + return fileinfo.Unknown + } + tsr := nd.Tensor + if tsr.Len() > 1 { + return fileinfo.Tensor + } + // scalars by type + if tsr.IsString() { + return fileinfo.String + } + return fileinfo.Number +} + +// Bytes returns the byte-wise representation of the data Value. +// This is the actual underlying data, so make a copy if it can be +// unintentionally modified or retained more than for immediate use. +func (nd *Node) Bytes() []byte { + if nd.Tensor == nil || nd.Tensor.NumDims() == 0 || nd.Tensor.Len() == 0 { + return nil + } + return nd.Tensor.AsValues().Bytes() +} diff --git a/tensor/datafs/fs_test.go b/tensor/tensorfs/fs_test.go similarity index 70% rename from tensor/datafs/fs_test.go rename to tensor/tensorfs/fs_test.go index e44f9a2d27..c640a5d709 100644 --- a/tensor/datafs/fs_test.go +++ b/tensor/tensorfs/fs_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package datafs +package tensorfs import ( "fmt" @@ -12,21 +12,17 @@ import ( "github.com/stretchr/testify/assert" ) -func makeTestData(t *testing.T) *Data { +func makeTestNode(t *testing.T) *Node { dfs, err := NewDir("root") assert.NoError(t, err) - net, err := dfs.Mkdir("network") - assert.NoError(t, err) - NewTensor[float32](net, "units", []int{50, 50}) - log, err := dfs.Mkdir("log") - assert.NoError(t, err) - _, err = NewTable(log, "Trial") - assert.NoError(t, err) + net := dfs.Dir("network") + Value[float32](net, "units", 50, 50) + dfs.Dir("log") return dfs } func TestFS(t *testing.T) { - dfs := makeTestData(t) + dfs := makeTestNode(t) dirs, err := dfs.ReadDir(".") assert.NoError(t, err) for _, d := range dirs { diff --git a/tensor/tensorfs/list.go b/tensor/tensorfs/list.go new file mode 100644 index 0000000000..cb15736dde --- /dev/null +++ b/tensor/tensorfs/list.go @@ -0,0 +1,75 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "strings" + + "cogentcore.org/core/base/indent" +) + +const ( + Short = false + Long = true + + DirOnly = false + Recursive = true +) + +// todo: list options string + +func (nd *Node) String() string { + if !nd.IsDir() { + return nd.Tensor.Label() + } + return nd.List(Short, DirOnly) +} + +// List returns a listing of nodes in the given directory. +// - long = include detailed information about each node, vs just the name. +// - recursive = descend into subdirectories. +func (dir *Node) List(long, recursive bool) string { + if long { + return dir.ListLong(recursive, 0) + } + return dir.ListShort(recursive, 0) +} + +// ListShort returns a name-only listing of given directory. +func (dir *Node) ListShort(recursive bool, ident int) string { + var b strings.Builder + nodes, _ := dir.Nodes() + for _, it := range nodes { + b.WriteString(indent.Tabs(ident)) + if it.IsDir() { + if recursive { + b.WriteString("\n" + it.ListShort(recursive, ident+1)) + } else { + b.WriteString(it.name + "/ ") + } + } else { + b.WriteString(it.name + " ") + } + } + return b.String() +} + +// ListLong returns a detailed listing of given directory. +func (dir *Node) ListLong(recursive bool, ident int) string { + var b strings.Builder + nodes, _ := dir.Nodes() + for _, it := range nodes { + b.WriteString(indent.Tabs(ident)) + if it.IsDir() { + b.WriteString(it.name + "/\n") + if recursive { + b.WriteString(it.ListLong(recursive, ident+1)) + } + } else { + b.WriteString(it.String() + "\n") + } + } + return b.String() +} diff --git a/tensor/tensorfs/metadata.go b/tensor/tensorfs/metadata.go new file mode 100644 index 0000000000..887e0e0234 --- /dev/null +++ b/tensor/tensorfs/metadata.go @@ -0,0 +1,38 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" +) + +// This file provides standardized metadata options for frequent +// use cases, using codified key names to eliminate typos. + +// SetMetaItems sets given metadata for Value items in given directory +// with given names. Returns error for any items not found. +func (d *Node) SetMetaItems(key string, value any, names ...string) error { + tsrs, err := d.Values(names...) + for _, tsr := range tsrs { + tsr.Metadata().Set(key, value) + } + return err +} + +// CalcAll calls function set by [Node.SetCalcFunc] for all items +// in this directory and all of its subdirectories. +// Calls Calc on items from ValuesFunc(nil) +func (d *Node) CalcAll() error { + var errs []error + items := d.ValuesFunc(nil) + for _, it := range items { + err := tensor.Calc(it) + if err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} diff --git a/tensor/tensorfs/node.go b/tensor/tensorfs/node.go new file mode 100644 index 0000000000..03fa1f18b4 --- /dev/null +++ b/tensor/tensorfs/node.go @@ -0,0 +1,219 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensorfs + +import ( + "io/fs" + "reflect" + "time" + + "cogentcore.org/core/base/errors" + "cogentcore.org/core/base/fsx" + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/table" +) + +// Node is the element type for the filesystem, which can represent either +// a [tensor] Value as a "file" equivalent, or a "directory" containing other Nodes. +// The [tensor.Tensor] can represent everything from a single scalar value up to +// n-dimensional collections of patterns, in a range of data types. +// Directories have an ordered map of nodes. +type Node struct { + // Parent is the parent data directory. + Parent *Node + + // name is the name of this node. it is not a path. + name string + + // modTime tracks time added to directory, used for ordering. + modTime time.Time + + // Tensor is the tensor value for a file or leaf Node in the FS, + // represented using the universal [tensor] data type of + // [tensor.Tensor], which can represent anything from a scalar + // to n-dimensional data, in a range of data types. + Tensor tensor.Tensor + + // nodes is for directory nodes, with all the nodes in the directory. + nodes *Nodes + + // DirTable is a summary [table.Table] with columns comprised of Value + // nodes in the directory, which can be used for plotting or other operations. + DirTable *table.Table +} + +// newNode returns a new Node in given directory Node, which can be nil. +// If dir is not a directory, returns nil and an error. +// If an node already exists in dir with that name, that node is returned +// with an [fs.ErrExist] error, and the caller can decide how to proceed. +// The modTime is set to now. The name must be unique within parent. +func newNode(dir *Node, name string) (*Node, error) { + if dir == nil { + return &Node{name: name, modTime: time.Now()}, nil + } + if err := dir.mustDir("newNode", name); err != nil { + return nil, err + } + if ex, ok := dir.nodes.AtTry(name); ok { + return ex, fs.ErrExist + } + d := &Node{Parent: dir, name: name, modTime: time.Now()} + dir.nodes.Add(name, d) + return d, nil +} + +// Value creates / returns a Node with given name as a [tensor.Tensor] +// of given data type and shape sizes, in given directory Node. +// If it already exists, it is returned as-is (no checking against the +// type or sizes provided, for efficiency -- if there is doubt, check!), +// otherwise a new tensor is created. It is fine to not pass any sizes and +// use `SetShapeSizes` method later to set the size. +func Value[T tensor.DataTypes](dir *Node, name string, sizes ...int) tensor.Values { + it := dir.Node(name) + if it != nil { + return it.Tensor.(tensor.Values) + } + tsr := tensor.New[T](sizes...) + metadata.SetName(tsr, name) + nd, err := newNode(dir, name) + if errors.Log(err) != nil { + return nil + } + nd.Tensor = tsr + return tsr +} + +// NewValues makes new tensor Node value(s) (as a [tensor.Tensor]) +// of given data type and shape sizes, in given directory. +// Any existing nodes with the same names are recycled without checking +// or updating the data type or sizes. +// See the [Value] documentation for more info. +func NewValues[T tensor.DataTypes](dir *Node, shape []int, names ...string) { + for _, nm := range names { + Value[T](dir, nm, shape...) + } +} + +// Scalar returns a scalar Node value (as a [tensor.Tensor]) +// of given data type, in given directory and name. +// If it already exists, it is returned without checking against args, +// else a new one is made. See the [Value] documentation for more info. +func Scalar[T tensor.DataTypes](dir *Node, name string) tensor.Values { + return Value[T](dir, name, 1) +} + +// ValueType creates / returns a Node with given name as a [tensor.Tensor] +// of given data type specified as a reflect.Kind, with shape sizes, +// in given directory Node. +// Supported types are string, bool (for [Bool]), float32, float64, int, int32, and byte. +// If it already exists, it is returned as-is (no checking against the +// type or sizes provided, for efficiency -- if there is doubt, check!), +// otherwise a new tensor is created. It is fine to not pass any sizes and +// use `SetShapeSizes` method later to set the size. +func ValueType(dir *Node, name string, typ reflect.Kind, sizes ...int) tensor.Values { + it := dir.Node(name) + if it != nil { + return it.Tensor.(tensor.Values) + } + tsr := tensor.NewOfType(typ, sizes...) + metadata.SetName(tsr, name) + nd, err := newNode(dir, name) + if errors.Log(err) != nil { + return nil + } + nd.Tensor = tsr + return tsr +} + +// NewForTensor creates a new Node node for given existing tensor with given name. +// If the name already exists, that Node is returned with [fs.ErrExists] error. +func NewForTensor(dir *Node, tsr tensor.Tensor, name string) (*Node, error) { + nd, err := newNode(dir, name) + if err != nil { + return nd, err + } + nd.Tensor = tsr + return nd, nil +} + +// DirTable returns a [table.Table] with all of the tensor values under +// the given directory, with columns as the Tensor values elements in the directory +// and any subdirectories, using given filter function. +// This is a convenient mechanism for creating a plot of all the data +// in a given directory. +// If such was previously constructed, it is returned from "DirTable" +// where it is stored for later use. +// Row count is updated to current max row. +// Set DirTable = nil to regenerate. +func DirTable(dir *Node, fun func(node *Node) bool) *table.Table { + nds := dir.NodesFunc(fun) + if dir.DirTable != nil { + if dir.DirTable.NumColumns() == len(nds) { + dir.DirTable.SetNumRowsToMax() + return dir.DirTable + } + } + dt := table.New(fsx.DirAndFile(string(dir.Path()))) + for _, it := range nds { + if it.Tensor == nil || it.Tensor.NumDims() == 0 { + continue + } + tsr := it.Tensor + rows := tsr.DimSize(0) + if dt.Columns.Rows < rows { + dt.Columns.Rows = rows + dt.SetNumRows(dt.Columns.Rows) + } + nm := it.name + if it.Parent != dir { + nm = fsx.DirAndFile(string(it.Path())) + } + dt.AddColumn(nm, tsr.AsValues()) + } + dir.DirTable = dt + return dt +} + +// DirFromTable sets tensor values under given directory node to the +// columns of the given [table.Table]. Also sets the DirTable to this table. +func DirFromTable(dir *Node, dt *table.Table) { + for i, cl := range dt.Columns.Values { + nm := dt.Columns.Keys[i] + nd, err := newNode(dir, nm) + if err == nil || err == fs.ErrExist { + nd.Tensor = cl + } + } + dir.DirTable = dt +} + +// Float64 creates / returns a Node with given name as a [tensor.Float64] +// for given shape sizes, in given directory [Node]. +// See [Values] function for more info. +func (dir *Node) Float64(name string, sizes ...int) *tensor.Float64 { + return Value[float64](dir, name, sizes...).(*tensor.Float64) +} + +// Float32 creates / returns a Node with given name as a [tensor.Float32] +// for given shape sizes, in given directory [Node]. +// See [Values] function for more info. +func (dir *Node) Float32(name string, sizes ...int) *tensor.Float32 { + return Value[float32](dir, name, sizes...).(*tensor.Float32) +} + +// Int creates / returns a Node with given name as a [tensor.Int] +// for given shape sizes, in given directory [Node]. +// See [Values] function for more info. +func (dir *Node) Int(name string, sizes ...int) *tensor.Int { + return Value[int](dir, name, sizes...).(*tensor.Int) +} + +// StringValue creates / returns a Node with given name as a [tensor.String] +// for given shape sizes, in given directory [Node]. +// See [Values] function for more info. +func (dir *Node) StringValue(name string, sizes ...int) *tensor.String { + return Value[string](dir, name, sizes...).(*tensor.String) +} diff --git a/tensor/tensormpi/table.go b/tensor/tensormpi/table.go index da874b1a8a..c21908293d 100644 --- a/tensor/tensormpi/table.go +++ b/tensor/tensormpi/table.go @@ -13,15 +13,15 @@ import ( // dest will have np * src.Rows Rows, filled with each processor's data, in order. // dest must be a clone of src: if not same number of cols, will be configured from src. func GatherTableRows(dest, src *table.Table, comm *mpi.Comm) { - sr := src.Rows + sr := src.NumRows() np := mpi.WorldSize() dr := np * sr - if len(dest.Columns) != len(src.Columns) { + if dest.NumColumns() != src.NumColumns() { *dest = *src.Clone() } dest.SetNumRows(dr) - for ci, st := range src.Columns { - dt := dest.Columns[ci] + for ci, st := range src.Columns.Values { + dt := dest.Columns.Values[ci] GatherTensorRows(dt, st, comm) } } @@ -33,13 +33,13 @@ func GatherTableRows(dest, src *table.Table, comm *mpi.Comm) { // dest will be a clone of src if not the same (cos & rows), // does nothing for strings. func ReduceTable(dest, src *table.Table, comm *mpi.Comm, op mpi.Op) { - sr := src.Rows - if len(dest.Columns) != len(src.Columns) { + sr := src.NumRows() + if dest.NumColumns() != src.NumColumns() { *dest = *src.Clone() } dest.SetNumRows(sr) - for ci, st := range src.Columns { - dt := dest.Columns[ci] + for ci, st := range src.Columns.Values { + dt := dest.Columns.Values[ci] ReduceTensor(dt, st, comm, op) } } diff --git a/tensor/tensormpi/tensor.go b/tensor/tensormpi/tensor.go index f25e42268c..a2e59425a2 100644 --- a/tensor/tensormpi/tensor.go +++ b/tensor/tensormpi/tensor.go @@ -15,13 +15,13 @@ import ( // using a row-based tensor organization (as in an table.Table). // dest will have np * src.Rows Rows, filled with each processor's data, in order. // dest must have same overall shape as src at start, but rows will be enforced. -func GatherTensorRows(dest, src tensor.Tensor, comm *mpi.Comm) error { +func GatherTensorRows(dest, src tensor.Values, comm *mpi.Comm) error { dt := src.DataType() if dt == reflect.String { return GatherTensorRowsString(dest.(*tensor.String), src.(*tensor.String), comm) } - sr, _ := src.RowCellSize() - dr, _ := dest.RowCellSize() + sr, _ := src.Shape().RowCellSize() + dr, _ := dest.Shape().RowCellSize() np := mpi.WorldSize() dl := np * sr if dr != dl { @@ -62,8 +62,8 @@ func GatherTensorRows(dest, src tensor.Tensor, comm *mpi.Comm) error { // dest will have np * src.Rows Rows, filled with each processor's data, in order. // dest must have same overall shape as src at start, but rows will be enforced. func GatherTensorRowsString(dest, src *tensor.String, comm *mpi.Comm) error { - sr, _ := src.RowCellSize() - dr, _ := dest.RowCellSize() + sr, _ := src.Shape().RowCellSize() + dr, _ := dest.Shape().RowCellSize() np := mpi.WorldSize() dl := np * sr if dr != dl { @@ -112,20 +112,20 @@ func GatherTensorRowsString(dest, src *tensor.String, comm *mpi.Comm) error { // IMPORTANT: src and dest must be different slices! // each processor must have the same shape and organization for this to make sense. // does nothing for strings. -func ReduceTensor(dest, src tensor.Tensor, comm *mpi.Comm, op mpi.Op) error { +func ReduceTensor(dest, src tensor.Values, comm *mpi.Comm, op mpi.Op) error { dt := src.DataType() if dt == reflect.String { return nil } slen := src.Len() if slen != dest.Len() { - dest.CopyShapeFrom(src) + tensor.SetShapeFrom(dest, src) } var err error switch dt { case reflect.Bool: - dt := dest.(*tensor.Bits) - st := src.(*tensor.Bits) + dt := dest.(*tensor.Bool) + st := src.(*tensor.Bool) err = comm.AllReduceU8(op, dt.Values, st.Values) case reflect.Uint8: dt := dest.(*tensor.Byte) diff --git a/tensor/tmath/README.md b/tensor/tmath/README.md new file mode 100644 index 0000000000..95cda68594 --- /dev/null +++ b/tensor/tmath/README.md @@ -0,0 +1,13 @@ +# tmath is the Tensor math library + +# math functions + +All the standard library [math](https://pkg.go.dev/math) functions are implemented on `*tensor.Tensor`. + +To properly handle the row-wise indexes, all processing is done using row, cell indexes, with the row indirected through the indexes. + +The output result tensor(s) can be the same as the input for all functions (except where specifically noted), to perform an in-place operation on the same data. + +The standard `Add`, `Sub`, `Mul`, `Div` (`+, -, *, /`) mathematical operators all operate element-wise, with a separate MatMul for matrix multiplication, which operates through gonum routines, for 2D Float64 tensor shapes with no indexes, so that the raw float64 values can be passed directly to gonum. + + diff --git a/tensor/tmath/bool.go b/tensor/tmath/bool.go new file mode 100644 index 0000000000..1e80f49013 --- /dev/null +++ b/tensor/tmath/bool.go @@ -0,0 +1,126 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tmath + +import ( + "cogentcore.org/core/base/errors" + "cogentcore.org/core/tensor" +) + +// Equal stores in the output the bool value a == b. +func Equal(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(EqualOut, a, b) +} + +// EqualOut stores in the output the bool value a == b. +func EqualOut(a, b tensor.Tensor, out *tensor.Bool) error { + if a.IsString() { + return tensor.BoolStringsFuncOut(func(a, b string) bool { return a == b }, a, b, out) + } + return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a == b }, a, b, out) +} + +// Less stores in the output the bool value a < b. +func Less(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(LessOut, a, b) +} + +// LessOut stores in the output the bool value a < b. +func LessOut(a, b tensor.Tensor, out *tensor.Bool) error { + if a.IsString() { + return tensor.BoolStringsFuncOut(func(a, b string) bool { return a < b }, a, b, out) + } + return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a < b }, a, b, out) +} + +// Greater stores in the output the bool value a > b. +func Greater(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(GreaterOut, a, b) +} + +// GreaterOut stores in the output the bool value a > b. +func GreaterOut(a, b tensor.Tensor, out *tensor.Bool) error { + if a.IsString() { + return tensor.BoolStringsFuncOut(func(a, b string) bool { return a > b }, a, b, out) + } + return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a > b }, a, b, out) +} + +// NotEqual stores in the output the bool value a != b. +func NotEqual(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(NotEqualOut, a, b) +} + +// NotEqualOut stores in the output the bool value a != b. +func NotEqualOut(a, b tensor.Tensor, out *tensor.Bool) error { + if a.IsString() { + return tensor.BoolStringsFuncOut(func(a, b string) bool { return a != b }, a, b, out) + } + return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a != b }, a, b, out) +} + +// LessEqual stores in the output the bool value a <= b. +func LessEqual(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(LessEqualOut, a, b) +} + +// LessEqualOut stores in the output the bool value a <= b. +func LessEqualOut(a, b tensor.Tensor, out *tensor.Bool) error { + if a.IsString() { + return tensor.BoolStringsFuncOut(func(a, b string) bool { return a <= b }, a, b, out) + } + return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a <= b }, a, b, out) +} + +// GreaterEqual stores in the output the bool value a >= b. +func GreaterEqual(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(GreaterEqualOut, a, b) +} + +// GreaterEqualOut stores in the output the bool value a >= b. +func GreaterEqualOut(a, b tensor.Tensor, out *tensor.Bool) error { + if a.IsString() { + return tensor.BoolStringsFuncOut(func(a, b string) bool { return a >= b }, a, b, out) + } + return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a >= b }, a, b, out) +} + +// Or stores in the output the bool value a || b. +func Or(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(OrOut, a, b) +} + +// OrOut stores in the output the bool value a || b. +func OrOut(a, b tensor.Tensor, out *tensor.Bool) error { + return tensor.BoolIntsFuncOut(func(a, b int) bool { return a > 0 || b > 0 }, a, b, out) +} + +// And stores in the output the bool value a || b. +func And(a, b tensor.Tensor) *tensor.Bool { + return tensor.CallOut2Bool(AndOut, a, b) +} + +// AndOut stores in the output the bool value a || b. +func AndOut(a, b tensor.Tensor, out *tensor.Bool) error { + return tensor.BoolIntsFuncOut(func(a, b int) bool { return a > 0 && b > 0 }, a, b, out) +} + +// Not stores in the output the bool value !a. +func Not(a tensor.Tensor) *tensor.Bool { + out := tensor.NewBool() + errors.Log(NotOut(a, out)) + return out +} + +// NotOut stores in the output the bool value !a. +func NotOut(a tensor.Tensor, out *tensor.Bool) error { + out.SetShapeSizes(a.Shape().Sizes...) + alen := a.Len() + tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen }, + func(idx int, tsr ...tensor.Tensor) { + out.SetBool1D(tsr[0].Int1D(idx) == 0, idx) + }, a, out) + return nil +} diff --git a/tensor/tmath/bool_test.go b/tensor/tmath/bool_test.go new file mode 100644 index 0000000000..b48e44caca --- /dev/null +++ b/tensor/tmath/bool_test.go @@ -0,0 +1,52 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tmath + +import ( + "testing" + + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" +) + +func TestBoolOps(t *testing.T) { + ar := tensor.NewIntRange(12) + // fmt.Println(v) + bo := tensor.NewBool() + sc := tensor.NewIntScalar(6) + + EqualOut(ar, sc, bo) + for i, v := range ar.Values { + assert.Equal(t, v == 6, bo.Bool1D(i)) + } + + LessOut(ar, sc, bo) + for i, v := range ar.Values { + assert.Equal(t, v < 6, bo.Bool1D(i)) + } + + GreaterOut(ar, sc, bo) + // fmt.Println(bo) + for i, v := range ar.Values { + assert.Equal(t, v > 6, bo.Bool1D(i)) + } + + NotEqualOut(ar, sc, bo) + for i, v := range ar.Values { + assert.Equal(t, v != 6, bo.Bool1D(i)) + } + + LessEqualOut(ar, sc, bo) + for i, v := range ar.Values { + assert.Equal(t, v <= 6, bo.Bool1D(i)) + } + + GreaterEqualOut(ar, sc, bo) + // fmt.Println(bo) + for i, v := range ar.Values { + assert.Equal(t, v >= 6, bo.Bool1D(i)) + } + +} diff --git a/tensor/stats/pca/doc.go b/tensor/tmath/doc.go similarity index 51% rename from tensor/stats/pca/doc.go rename to tensor/tmath/doc.go index f1c43f22f3..4a66e705f8 100644 --- a/tensor/stats/pca/doc.go +++ b/tensor/tmath/doc.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. /* -Package pca performs principal component's analysis and associated covariance -matrix computations, operating on table.Table or tensor.Tensor data. +Package tmath provides basic math operations and functions that operate on tensor.Tensor. */ -package pca +package tmath diff --git a/tensor/tmath/math.go b/tensor/tmath/math.go new file mode 100644 index 0000000000..828b718b21 --- /dev/null +++ b/tensor/tmath/math.go @@ -0,0 +1,411 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tmath + +import ( + "math" + + "cogentcore.org/core/tensor" +) + +func Abs(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AbsOut, in) +} + +func AbsOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Abs(a) }, in, out) +} + +func Acos(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AcosOut, in) +} + +func AcosOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Acos(a) }, in, out) +} + +func Acosh(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AcoshOut, in) +} + +func AcoshOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Acosh(a) }, in, out) +} + +func Asin(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AsinOut, in) +} + +func AsinOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Asin(a) }, in, out) +} + +func Asinh(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AsinhOut, in) +} + +func AsinhOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Asinh(a) }, in, out) +} + +func Atan(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AtanOut, in) +} + +func AtanOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Atan(a) }, in, out) +} + +func Atanh(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(AtanhOut, in) +} + +func AtanhOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Atanh(a) }, in, out) +} + +func Cbrt(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(CbrtOut, in) +} + +func CbrtOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cbrt(a) }, in, out) +} + +func Ceil(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(CeilOut, in) +} + +func CeilOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Ceil(a) }, in, out) +} + +func Cos(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(CosOut, in) +} + +func CosOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cos(a) }, in, out) +} + +func Cosh(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(CoshOut, in) +} + +func CoshOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cosh(a) }, in, out) +} + +func Erf(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(ErfOut, in) +} + +func ErfOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erf(a) }, in, out) +} + +func Erfc(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(ErfcOut, in) +} + +func ErfcOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfc(a) }, in, out) +} + +func Erfcinv(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(ErfcinvOut, in) +} + +func ErfcinvOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfcinv(a) }, in, out) +} + +func Erfinv(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(ErfinvOut, in) +} + +func ErfinvOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfinv(a) }, in, out) +} + +func Exp(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(ExpOut, in) +} + +func ExpOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Exp(a) }, in, out) +} + +func Exp2(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Exp2Out, in) +} + +func Exp2Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Exp2(a) }, in, out) +} + +func Expm1(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Expm1Out, in) +} + +func Expm1Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Expm1(a) }, in, out) +} + +func Floor(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(FloorOut, in) +} + +func FloorOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Floor(a) }, in, out) +} + +func Gamma(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(GammaOut, in) +} + +func GammaOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Gamma(a) }, in, out) +} + +func J0(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(J0Out, in) +} + +func J0Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.J0(a) }, in, out) +} + +func J1(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(J1Out, in) +} + +func J1Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.J1(a) }, in, out) +} + +func Log(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(LogOut, in) +} + +func LogOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log(a) }, in, out) +} + +func Log10(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Log10Out, in) +} + +func Log10Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log10(a) }, in, out) +} + +func Log1p(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Log1pOut, in) +} + +func Log1pOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log1p(a) }, in, out) +} + +func Log2(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Log2Out, in) +} + +func Log2Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log2(a) }, in, out) +} + +func Logb(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(LogbOut, in) +} + +func LogbOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Logb(a) }, in, out) +} + +func Round(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(RoundOut, in) +} + +func RoundOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Round(a) }, in, out) +} + +func RoundToEven(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(RoundToEvenOut, in) +} + +func RoundToEvenOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.RoundToEven(a) }, in, out) +} + +func Sin(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(SinOut, in) +} + +func SinOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sin(a) }, in, out) +} + +func Sinh(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(SinhOut, in) +} + +func SinhOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sinh(a) }, in, out) +} + +func Sqrt(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(SqrtOut, in) +} + +func SqrtOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sqrt(a) }, in, out) +} + +func Tan(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(TanOut, in) +} + +func TanOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Tan(a) }, in, out) +} + +func Tanh(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(TanhOut, in) +} + +func TanhOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Tanh(a) }, in, out) +} + +func Trunc(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(TruncOut, in) +} + +func TruncOut(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Trunc(a) }, in, out) +} + +func Y0(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Y0Out, in) +} + +func Y0Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Y0(a) }, in, out) +} + +func Y1(in tensor.Tensor) tensor.Values { + return tensor.CallOut1Float64(Y1Out, in) +} + +func Y1Out(in tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Y1(a) }, in, out) +} + +//////// Binary + +func Atan2(y, x tensor.Tensor) tensor.Values { + return tensor.CallOut2(Atan2Out, y, x) +} + +func Atan2Out(y, x tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Atan2(a, b) }, y, x, out) +} + +func Copysign(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(CopysignOut, x, y) +} + +func CopysignOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Copysign(a, b) }, x, y, out) +} + +func Dim(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(DimOut, x, y) +} + +func DimOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Dim(a, b) }, x, y, out) +} + +func Hypot(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(HypotOut, x, y) +} + +func HypotOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Hypot(a, b) }, x, y, out) +} + +func Max(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(MaxOut, x, y) +} + +func MaxOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Max(a, b) }, x, y, out) +} + +func Min(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(MinOut, x, y) +} + +func MinOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Min(a, b) }, x, y, out) +} + +func Nextafter(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(NextafterOut, x, y) +} + +func NextafterOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Nextafter(a, b) }, x, y, out) +} + +func Pow(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(PowOut, x, y) +} + +func PowOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Pow(a, b) }, x, y, out) +} + +func Remainder(x, y tensor.Tensor) tensor.Values { + return tensor.CallOut2(RemainderOut, x, y) +} + +func RemainderOut(x, y tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Remainder(a, b) }, x, y, out) +} + +/* +func Nextafter32(x, y float32) (r float32) + +func Inf(sign int) float64 +func IsInf(f float64, sign int) bool +func IsNaN(f float64) (is bool) +func NaN() float64 +func Signbit(x float64) bool + +func Float32bits(f float32) uint32 +func Float32frombits(b uint32) float32 +func Float64bits(f float64) uint64 +func Float64frombits(b uint64) float64 + +func FMA(x, y, z float64) float64 + +func Jn(n int, in tensor.Tensor, out tensor.Values) +func Yn(n int, in tensor.Tensor, out tensor.Values) + +func Ldexp(frac float64, exp int) float64 + +func Ilogb(x float64) int +func Pow10(n int) float64 + +func Frexp(f float64) (frac float64, exp int) +func Modf(f float64) (int float64, frac float64) +func Lgamma(x float64) (lgamma float64, sign int) +func Sincos(x float64) (sin, cos float64) +*/ diff --git a/tensor/tmath/math_test.go b/tensor/tmath/math_test.go new file mode 100644 index 0000000000..0c8a2ba468 --- /dev/null +++ b/tensor/tmath/math_test.go @@ -0,0 +1,116 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tmath + +import ( + "math" + "testing" + + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" +) + +type onef func(x float64) float64 +type tonef func(in tensor.Tensor, out tensor.Values) error + +// testEqual does equal testing taking into account NaN +func testEqual(t *testing.T, trg, val float64) { + if math.IsNaN(trg) { + if !math.IsNaN(val) { + t.Error("target is NaN but actual is not") + } + return + } + assert.InDelta(t, trg, val, 1.0e-4) +} + +func TestMath(t *testing.T) { + scalar := tensor.NewFloat64Scalar(-5.5) + scout := scalar.Clone() + + vals := []float64{-1.507556722888818, -1.2060453783110545, -0.9045340337332908, -0.6030226891555273, -0.3015113445777635, 0, 0.3015113445777635, 0.603022689155527, 0.904534033733291, 1.2060453783110545, 1.507556722888818, .3} + + oned := tensor.NewNumberFromValues(vals...) + oneout := oned.Clone() + + cell2d := tensor.NewFloat32(5, 2, 6) + _, cells := cell2d.Shape().RowCellSize() + assert.Equal(t, cells, 12) + tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) { + ci := idx % cells + cell2d.SetFloat1D(oned.Float1D(ci), idx) + }, cell2d) + cellout := cell2d.Clone() + + mfuncs := []onef{math.Abs, math.Acos, math.Acosh, math.Asin, math.Asinh, math.Atan, math.Atanh, math.Cbrt, math.Ceil, math.Cos, math.Cosh, math.Erf, math.Erfc, math.Erfcinv, math.Erfinv, math.Exp, math.Exp2, math.Expm1, math.Floor, math.Gamma, math.J0, math.J1, math.Log, math.Log10, math.Log1p, math.Log2, math.Logb, math.Round, math.RoundToEven, math.Sin, math.Sinh, math.Sqrt, math.Tan, math.Tanh, math.Trunc, math.Y0, math.Y1} + tfuncs := []tonef{AbsOut, AcosOut, AcoshOut, AsinOut, AsinhOut, AtanOut, AtanhOut, CbrtOut, CeilOut, CosOut, CoshOut, ErfOut, ErfcOut, ErfcinvOut, ErfinvOut, ExpOut, Exp2Out, Expm1Out, FloorOut, GammaOut, J0Out, J1Out, LogOut, Log10Out, Log1pOut, Log2Out, LogbOut, RoundOut, RoundToEvenOut, SinOut, SinhOut, SqrtOut, TanOut, TanhOut, TruncOut, Y0Out, Y1Out} + + for i, fun := range mfuncs { + tf := tfuncs[i] + tf(scalar, scout) + tf(oned, oneout) + tf(cell2d, cellout) + + testEqual(t, fun(scalar.Float1D(0)), scout.Float1D(0)) + for i, v := range vals { + testEqual(t, fun(v), oneout.Float1D(i)) + } + lv := len(vals) + for r := range 5 { + // fmt.Println(r) + si := lv * r + for c, v := range vals { + ov := tensor.AsFloat32(cellout).Values[si+c] + testEqual(t, fun(v), float64(ov)) + } + } + } +} + +type twof func(x, y float64) float64 +type ttwof func(x, y tensor.Tensor, out tensor.Values) error + +func TestMathBinary(t *testing.T) { + scalar := tensor.NewFloat64Scalar(-5.5) + scout := scalar.Clone() + + vals := []float64{-1.507556722888818, -1.2060453783110545, -0.9045340337332908, -0.6030226891555273, -0.3015113445777635, 0, 0.3015113445777635, 0.603022689155527, 0.904534033733291, 1.2060453783110545, 1.507556722888818, .3} + + oned := tensor.NewNumberFromValues(vals...) + oneout := oned.Clone() + + cell2d := tensor.NewFloat32(5, 2, 6) + _, cells := cell2d.Shape().RowCellSize() + assert.Equal(t, cells, 12) + tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) { + ci := idx % cells + cell2d.SetFloat1D(oned.Float1D(ci), idx) + }, cell2d) + cellout := cell2d.Clone() + + mfuncs := []twof{math.Atan2, math.Copysign, math.Dim, math.Hypot, math.Max, math.Min, math.Nextafter, math.Pow, math.Remainder} + tfuncs := []ttwof{Atan2Out, CopysignOut, DimOut, HypotOut, MaxOut, MinOut, NextafterOut, PowOut, RemainderOut} + + for i, fun := range mfuncs { + tf := tfuncs[i] + tf(scalar, scalar, scout) + tf(oned, oned, oneout) + tf(cell2d, cell2d, cellout) + + testEqual(t, fun(scalar.Float1D(0), scalar.Float1D(0)), scout.Float1D(0)) + for i, v := range vals { + testEqual(t, fun(v, v), oneout.Float1D(i)) + } + lv := len(vals) + for r := range 5 { + // fmt.Println(r) + si := lv * r + for c, v := range vals { + ov := tensor.AsFloat32(cellout).Values[si+c] + testEqual(t, fun(v, v), float64(ov)) + } + } + } +} diff --git a/tensor/tmath/ops.go b/tensor/tmath/ops.go new file mode 100644 index 0000000000..9e400d6a94 --- /dev/null +++ b/tensor/tmath/ops.go @@ -0,0 +1,128 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tmath + +import ( + "math" + + "cogentcore.org/core/tensor" +) + +// Assign assigns values from b into a. +func Assign(a, b tensor.Tensor) error { + return tensor.FloatAssignFunc(func(a, b float64) float64 { return b }, a, b) +} + +// AddAssign does += add assign values from b into a. +func AddAssign(a, b tensor.Tensor) error { + if a.IsString() { + return tensor.StringAssignFunc(func(a, b string) string { return a + b }, a, b) + } + return tensor.FloatAssignFunc(func(a, b float64) float64 { return a + b }, a, b) +} + +// SubAssign does -= sub assign values from b into a. +func SubAssign(a, b tensor.Tensor) error { + return tensor.FloatAssignFunc(func(a, b float64) float64 { return a - b }, a, b) +} + +// MulAssign does *= mul assign values from b into a. +func MulAssign(a, b tensor.Tensor) error { + return tensor.FloatAssignFunc(func(a, b float64) float64 { return a * b }, a, b) +} + +// DivAssign does /= divide assign values from b into a. +func DivAssign(a, b tensor.Tensor) error { + return tensor.FloatAssignFunc(func(a, b float64) float64 { return a / b }, a, b) +} + +// ModAssign does %= modulus assign values from b into a. +func ModAssign(a, b tensor.Tensor) error { + return tensor.FloatAssignFunc(func(a, b float64) float64 { return math.Mod(a, b) }, a, b) +} + +// Inc increments values in given tensor by 1. +func Inc(a tensor.Tensor) error { + alen := a.Len() + tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen }, + func(idx int, tsr ...tensor.Tensor) { + tsr[0].SetFloat1D(tsr[0].Float1D(idx)+1.0, idx) + }, a) + return nil +} + +// Dec decrements values in given tensor by 1. +func Dec(a tensor.Tensor) error { + alen := a.Len() + tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen }, + func(idx int, tsr ...tensor.Tensor) { + tsr[0].SetFloat1D(tsr[0].Float1D(idx)-1.0, idx) + }, a) + return nil +} + +// Add adds two tensors into output. +func Add(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(AddOut, a, b) +} + +// AddOut adds two tensors into output. +func AddOut(a, b tensor.Tensor, out tensor.Values) error { + if a.IsString() { + return tensor.StringBinaryFuncOut(func(a, b string) string { return a + b }, a, b, out) + } + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a + b }, a, b, out) +} + +// Sub subtracts tensors into output. +func Sub(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(SubOut, a, b) +} + +// SubOut subtracts two tensors into output. +func SubOut(a, b tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a - b }, a, b, out) +} + +// Mul multiplies tensors into output. +func Mul(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(MulOut, a, b) +} + +// MulOut multiplies two tensors into output. +func MulOut(a, b tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a * b }, a, b, out) +} + +// Div divides tensors into output. always does floating point division, +// even with integer operands. +func Div(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2Float64(DivOut, a, b) +} + +// DivOut divides two tensors into output. +func DivOut(a, b tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a / b }, a, b, out) +} + +// Mod performs modulus a%b on tensors into output. +func Mod(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2(ModOut, a, b) +} + +// ModOut performs modulus a%b on tensors into output. +func ModOut(a, b tensor.Tensor, out tensor.Values) error { + return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Mod(a, b) }, a, b, out) +} + +// Negate stores in the output the bool value -a. +func Negate(a tensor.Tensor) tensor.Values { + return tensor.CallOut1(NegateOut, a) +} + +// NegateOut stores in the output the bool value -a. +func NegateOut(a tensor.Tensor, out tensor.Values) error { + return tensor.FloatFuncOut(1, func(in float64) float64 { return -in }, a, out) +} diff --git a/tensor/tmath/ops_test.go b/tensor/tmath/ops_test.go new file mode 100644 index 0000000000..eb276b59cc --- /dev/null +++ b/tensor/tmath/ops_test.go @@ -0,0 +1,224 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tmath + +import ( + "fmt" + "testing" + + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" +) + +func TestOps(t *testing.T) { + scalar := tensor.NewFloat64Scalar(-5.5) + scb := scalar.Clone() + scb.SetFloat1D(-4.0, 0) + scout := scalar.Clone() + + vals := []float64{-1.507556722888818, -1.2060453783110545, -0.9045340337332908, -0.6030226891555273, -0.3015113445777635, 0.1, 0.3015113445777635, 0.603022689155527, 0.904534033733291, 1.2060453783110545, 1.507556722888818, .3} + + oned := tensor.NewNumberFromValues(vals...) + oneout := oned.Clone() + + cell2d := tensor.NewFloat32(5, 12) + _, cells := cell2d.Shape().RowCellSize() + assert.Equal(t, cells, 12) + tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) { + ci := idx % cells + cell2d.SetFloat1D(oned.Float1D(ci), idx) + }, cell2d) + // cell2d.DeleteRows(3, 1) + cellout := cell2d.Clone() + _ = cellout + + AddOut(scalar, scb, scout) + assert.Equal(t, -5.5+-4, scout.Float1D(0)) + + AddOut(scalar, oned, oneout) + for i, v := range vals { + assert.Equal(t, v+-5.5, oneout.Float1D(i)) + } + + AddOut(oned, oned, oneout) + for i, v := range vals { + assert.Equal(t, v+v, oneout.Float1D(i)) + } + + AddOut(cell2d, oned, cellout) + for ri := range 5 { + for i, v := range vals { + assert.InDelta(t, v+v, cellout.FloatRow(ri, i), 1.0e-6) + } + } + + SubOut(scalar, scb, scout) + assert.Equal(t, -5.5 - -4, scout.Float1D(0)) + + SubOut(scb, scalar, scout) + assert.Equal(t, -4 - -5.5, scout.Float1D(0)) + + SubOut(scalar, oned, oneout) + for i, v := range vals { + assert.Equal(t, -5.5-v, oneout.Float1D(i)) + } + + SubOut(oned, scalar, oneout) + for i, v := range vals { + assert.Equal(t, v - -5.5, oneout.Float1D(i)) + } + + SubOut(oned, oned, oneout) + for i, v := range vals { + assert.Equal(t, v-v, oneout.Float1D(i)) + } + + SubOut(cell2d, oned, cellout) + for ri := range 5 { + for i, v := range vals { + assert.InDelta(t, v-v, cellout.FloatRow(ri, i), 1.0e-6) + } + } + + MulOut(scalar, scb, scout) + assert.Equal(t, -5.5*-4, scout.Float1D(0)) + + MulOut(scalar, oned, oneout) + for i, v := range vals { + assert.Equal(t, v*-5.5, oneout.Float1D(i)) + } + + MulOut(oned, oned, oneout) + for i, v := range vals { + assert.Equal(t, v*v, oneout.Float1D(i)) + } + + MulOut(cell2d, oned, cellout) + for ri := range 5 { + for i, v := range vals { + assert.InDelta(t, v*v, cellout.FloatRow(ri, i), 1.0e-6) + } + } + + DivOut(scalar, scb, scout) + assert.Equal(t, -5.5/-4, scout.Float1D(0)) + + DivOut(scb, scalar, scout) + assert.Equal(t, -4/-5.5, scout.Float1D(0)) + + DivOut(scalar, oned, oneout) + for i, v := range vals { + assert.Equal(t, -5.5/v, oneout.Float1D(i)) + } + + DivOut(oned, scalar, oneout) + for i, v := range vals { + assert.Equal(t, v/-5.5, oneout.Float1D(i)) + } + + DivOut(oned, oned, oneout) + for i, v := range vals { + assert.Equal(t, v/v, oneout.Float1D(i)) + } + + DivOut(cell2d, oned, cellout) + for ri := range 5 { + for i, v := range vals { + assert.InDelta(t, v/v, cellout.FloatRow(ri, i), 1.0e-6) + } + } + + onedc := tensor.Clone(oned) + AddAssign(onedc, scalar) + for i, v := range vals { + assert.Equal(t, v+-5.5, onedc.Float1D(i)) + } + + SubAssign(onedc, scalar) + for i, v := range vals { + assert.InDelta(t, v, onedc.Float1D(i), 1.0e-8) + } + + MulAssign(onedc, scalar) + for i, v := range vals { + assert.InDelta(t, v*-5.5, onedc.Float1D(i), 1.0e-7) + } + + DivAssign(onedc, scalar) + for i, v := range vals { + assert.InDelta(t, v, onedc.Float1D(i), 1.0e-7) + } + + Inc(onedc) + for i, v := range vals { + assert.InDelta(t, v+1, onedc.Float1D(i), 1.0e-7) + } + + Dec(onedc) + for i, v := range vals { + assert.InDelta(t, v, onedc.Float1D(i), 1.0e-7) + } +} + +func runBenchMult(b *testing.B, n int, thread bool) { + if thread { + tensor.ThreadingThreshold = 1 + } else { + tensor.ThreadingThreshold = 100_000_000 + } + av := tensor.NewFloat64(n) + bv := tensor.NewFloat64(n) + ov := tensor.NewFloat64(n) + for i := range n { + av.SetFloat1D(1.0/float64(n), i) + bv.SetFloat1D(1.0/float64(n), i) + } + b.ResetTimer() + for range b.N { + MulOut(av, bv, ov) + } +} + +// to run this benchmark, do: +// go test -bench BenchmarkMult -count 10 >bench.txt +// go install golang.org/x/perf/cmd/benchstat@latest +// benchstat -row /n -col .name bench.txt + +// goos: darwin +// goarch: arm64 +// pkg: cogentcore.org/core/tensor/tmath +// │ MultThreaded │ MultSingle │ +// │ sec/op │ sec/op vs base │ +// 10 3656.5n ± 0% 699.8n ± 0% -80.86% (p=0.000 n=10) +// 100 7.288µ ± 1% 4.510µ ± 0% -38.12% (p=0.000 n=10) +// 200 9.813µ ± 1% 8.761µ ± 0% -10.72% (p=0.000 n=10) +// 300 12.06µ ± 2% 13.04µ ± 0% +8.12% (p=0.000 n=10) +// 400 14.53µ ± 3% 17.13µ ± 1% +17.88% (p=0.000 n=10) +// 500 16.65µ ± 2% 21.35µ ± 1% +28.19% (p=0.000 n=10) +// 600 18.74µ ± 2% 25.68µ ± 0% +37.00% (p=0.000 n=10) +// 700 20.83µ ± 3% 29.94µ ± 0% +43.74% (p=0.000 n=10) +// 800 22.33µ ± 1% 34.11µ ± 0% +52.72% (p=0.000 n=10) +// 900 24.13µ ± 2% 38.23µ ± 0% +58.44% (p=0.000 n=10) +// 1000 26.25µ ± 1% 42.41µ ± 0% +61.55% (p=0.000 n=10) +// 10000 127.2µ ± 1% 424.1µ ± 1% +233.37% (p=0.000 n=10) +// geomean 16.88µ 19.11µ +13.21% + +var ns = []int{10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 10_000} + +func BenchmarkMultThreaded(b *testing.B) { + for _, n := range ns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchMult(b, n, true) + }) + } +} + +func BenchmarkMultSingle(b *testing.B) { + for _, n := range ns { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + runBenchMult(b, n, false) + }) + } +} diff --git a/tensor/typegen.go b/tensor/typegen.go new file mode 100644 index 0000000000..b9d87260fc --- /dev/null +++ b/tensor/typegen.go @@ -0,0 +1,19 @@ +// Code generated by "core generate"; DO NOT EDIT. + +package tensor + +import ( + "cogentcore.org/core/types" +) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor.Indexed", IDName: "indexed", Doc: "Indexed provides an arbitrarily indexed view onto another \"source\" [Tensor]\nwith each index value providing a full n-dimensional index into the source.\nThe shape of this view is determined by the shape of the [Indexed.Indexes]\ntensor up to the final innermost dimension, which holds the index values.\nThus the innermost dimension size of the indexes is equal to the number\nof dimensions in the source tensor. Given the essential role of the\nindexes in this view, it is not usable without the indexes.\nThis view is not memory-contiguous and does not support the [RowMajor]\ninterface or efficient access to inner-dimensional subspaces.\nTo produce a new concrete [Values] that has raw data actually\norganized according to the indexed order (i.e., the copy function\nof numpy), call [Indexed.AsValues].", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor source that we are an indexed view onto."}, {Name: "Indexes", Doc: "Indexes is the list of indexes into the source tensor,\nwith the innermost dimension providing the index values\n(size = number of dimensions in the source tensor), and\nthe remaining outer dimensions determine the shape\nof this [Indexed] tensor view."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor.Masked", IDName: "masked", Doc: "Masked is a filtering wrapper around another \"source\" [Tensor],\nthat provides a bit-masked view onto the Tensor defined by a [Bool] [Values]\ntensor with a matching shape. If the bool mask has a 'false'\nthen the corresponding value cannot be Set, and Float access returns\nNaN indicating missing data (other type access returns the zero value).\nA new Masked view defaults to a full transparent view of the source tensor.\nTo produce a new [Values] tensor with only the 'true' cases,\n(i.e., the copy function of numpy), call [Masked.AsValues].", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor source that we are a masked view onto."}, {Name: "Mask", Doc: "Bool tensor with same shape as source tensor, providing mask."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor.Reshaped", IDName: "reshaped", Doc: "Reshaped is a reshaping wrapper around another \"source\" [Tensor],\nthat provides a length-preserving reshaped view onto the source Tensor.\nReshaping by adding new size=1 dimensions (via [NewAxis] value) is\noften important for properly aligning two tensors in a computationally\ncompatible manner; see the [AlignShapes] function.\n[Reshaped.AsValues] on this view returns a new [Values] with the view\nshape, calling [Clone] on the source tensor to get the values.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor source that we are a masked view onto."}, {Name: "Reshape", Doc: "Reshape is the effective shape we use for access.\nThis must have the same Len() as the source Tensor."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor.Rows", IDName: "rows", Doc: "Rows is a row-indexed wrapper view around a [Values] [Tensor] that allows\narbitrary row-wise ordering and filtering according to the [Rows.Indexes].\nSorting and filtering a tensor along this outermost row dimension only\nrequires updating the indexes while leaving the underlying Tensor alone.\nUnlike the more general [Sliced] view, Rows maintains memory contiguity\nfor the inner dimensions (\"cells\") within each row, and supports the [RowMajor]\ninterface, with the [Set]FloatRow[Cell] methods providing efficient access.\nUse [Rows.AsValues] to obtain a concrete [Values] representation with the\ncurrent row sorting.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Sequential", Doc: "Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "ExcludeMissing", Doc: "ExcludeMissing deletes indexes where the values are missing, as indicated by NaN.\nUses first cell of higher dimensional data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "FilterString", Doc: "FilterString filters the indexes using string values compared to given\nstring. Includes rows with matching values unless the Exclude option is set.\nIf Contains option is set, it only checks if row contains string;\nif IgnoreCase, ignores case, otherwise filtering is case sensitive.\nUses first cell of higher dimensional data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"str", "opts"}}, {Name: "addRowsIndexes", Doc: "addRowsIndexes adds n rows to indexes starting at end of current tensor size", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"n"}}, {Name: "AddRows", Doc: "AddRows adds n rows to end of underlying Tensor, and to the indexes in this view", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"n"}}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor source that we are an indexed view onto.\nNote that this must be a concrete [Values] tensor, to enable efficient\n[RowMajor] access and subspace functions."}, {Name: "Indexes", Doc: "Indexes are the indexes into Tensor rows, with nil = sequential.\nOnly set if order is different from default sequential order.\nUse the [Rows.RowIndex] method for nil-aware logic."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor.FilterOptions", IDName: "filter-options", Doc: "FilterOptions are options to a Filter function\ndetermining how the string filter value is used for matching.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Exclude", Doc: "Exclude means to exclude matches,\nwith the default (false) being to include"}, {Name: "Contains", Doc: "Contains means the string only needs to contain the target string,\nwith the default (false) requiring a complete match to entire string."}, {Name: "IgnoreCase", Doc: "IgnoreCase means that differences in case are ignored in comparing strings,\nwith the default (false) using case."}}}) + +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/tensor.Sliced", IDName: "sliced", Doc: "Sliced provides a re-sliced view onto another \"source\" [Tensor],\ndefined by a set of [Sliced.Indexes] for each dimension (must have\nat least 1 index per dimension to avoid a null view).\nThus, each dimension can be transformed in arbitrary ways relative\nto the original tensor (filtered subsets, reversals, sorting, etc).\nThis view is not memory-contiguous and does not support the [RowMajor]\ninterface or efficient access to inner-dimensional subspaces.\nA new Sliced view defaults to a full transparent view of the source tensor.\nThere is additional cost for every access operation associated with the\nindexed indirection, and access is always via the full n-dimensional indexes.\nSee also [Rows] for a version that only indexes the outermost row dimension,\nwhich is much more efficient for this common use-case, and does support [RowMajor].\nTo produce a new concrete [Values] that has raw data actually organized according\nto the indexed order (i.e., the copy function of numpy), call [Sliced.AsValues].", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Sequential", Doc: "Sequential sets all Indexes to nil, resulting in full sequential access into tensor.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor source that we are an indexed view onto."}, {Name: "Indexes", Doc: "Indexes are the indexes for each dimension, with dimensions as the outer\nslice (enforced to be the same length as the NumDims of the source Tensor),\nand a list of dimension index values (within range of DimSize(d)).\nA nil list of indexes for a dimension automatically provides a full,\nsequential view of that dimension."}}}) diff --git a/tensor/values.go b/tensor/values.go new file mode 100644 index 0000000000..6fd9a41d3a --- /dev/null +++ b/tensor/values.go @@ -0,0 +1,127 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "fmt" + "reflect" + + "cogentcore.org/core/base/metadata" +) + +// Values is an extended [Tensor] interface for raw value tensors. +// This supports direct setting of the shape of the underlying values, +// sub-space access to inner-dimensional subspaces of values, etc. +type Values interface { + RowMajor + + // SetShapeSizes sets the dimension sizes of the tensor, and resizes + // backing storage appropriately, retaining all existing data that fits. + SetShapeSizes(sizes ...int) + + // SetNumRows sets the number of rows (outermost dimension). + // It is safe to set this to 0. For incrementally growing tensors (e.g., a log) + // it is best to first set the anticipated full size, which allocates the + // full amount of memory, and then set to 0 and grow incrementally. + SetNumRows(rows int) + + // Sizeof returns the number of bytes contained in the Values of this tensor. + // For String types, this is just the string pointers, not the string content. + Sizeof() int64 + + // Bytes returns the underlying byte representation of the tensor values. + // This is the actual underlying data, so make a copy if it can be + // unintentionally modified or retained more than for immediate use. + Bytes() []byte + + // SetZeros is a convenience function initialize all values to the + // zero value of the type (empty strings for string type). + // New tensors always start out with zeros. + SetZeros() + + // Clone clones this tensor, creating a duplicate copy of itself with its + // own separate memory representation of all the values. + Clone() Values + + // CopyFrom copies all values from other tensor into this tensor, with an + // optimized implementation if the other tensor is of the same type, and + // otherwise it goes through the appropriate standard type (Float, Int, String). + CopyFrom(from Values) + + // CopyCellsFrom copies given range of values from other tensor into this tensor, + // using flat 1D indexes: to = starting index in this Tensor to start copying into, + // start = starting index on from Tensor to start copying from, and n = number of + // values to copy. Uses an optimized implementation if the other tensor is + // of the same type, and otherwise it goes through appropriate standard type. + CopyCellsFrom(from Values, to, start, n int) + + // AppendFrom appends all values from other tensor into this tensor, with an + // optimized implementation if the other tensor is of the same type, and + // otherwise it goes through the appropriate standard type (Float, Int, String). + AppendFrom(from Values) error +} + +// New returns a new n-dimensional tensor of given value type +// with the given sizes per dimension (shape). +func New[T DataTypes](sizes ...int) Values { + var v T + switch any(v).(type) { + case string: + return NewString(sizes...) + case bool: + return NewBool(sizes...) + case float64: + return NewNumber[float64](sizes...) + case float32: + return NewNumber[float32](sizes...) + case int: + return NewNumber[int](sizes...) + case int32: + return NewNumber[int32](sizes...) + case uint32: + return NewNumber[uint32](sizes...) + case byte: + return NewNumber[byte](sizes...) + default: + panic("tensor.New: unexpected error: type not supported") + } +} + +// NewOfType returns a new n-dimensional tensor of given reflect.Kind type +// with the given sizes per dimension (shape). +// Supported types are in [DataTypes]. +func NewOfType(typ reflect.Kind, sizes ...int) Values { + switch typ { + case reflect.String: + return NewString(sizes...) + case reflect.Bool: + return NewBool(sizes...) + case reflect.Float64: + return NewNumber[float64](sizes...) + case reflect.Float32: + return NewNumber[float32](sizes...) + case reflect.Int: + return NewNumber[int](sizes...) + case reflect.Int32: + return NewNumber[int32](sizes...) + case reflect.Uint8: + return NewNumber[byte](sizes...) + default: + panic(fmt.Sprintf("tensor.NewOfType: type not supported: %v", typ)) + } +} + +// metadata helpers + +// SetShapeNames sets the tensor shape dimension names into given metadata. +func SetShapeNames(md *metadata.Data, names ...string) { + md.Set("ShapeNames", names) +} + +// ShapeNames gets the tensor shape dimension names from given metadata. +func ShapeNames(md *metadata.Data) []string { + names, _ := metadata.Get[[]string](*md, "ShapeNames") + return names +} diff --git a/tensor/vector/README.md b/tensor/vector/README.md new file mode 100644 index 0000000000..e821fce072 --- /dev/null +++ b/tensor/vector/README.md @@ -0,0 +1,5 @@ +# vector + +vector provides standard vector math functions that always operate on 1D views of tensor inputs, regardless of the original tensor shape. + + diff --git a/tensor/vector/vector.go b/tensor/vector/vector.go new file mode 100644 index 0000000000..b6fd70d9df --- /dev/null +++ b/tensor/vector/vector.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package vector provides standard vector math functions that +// always operate on 1D views of tensor inputs regardless of the original +// vector shape. +package vector + +import ( + "math" + + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/tmath" +) + +// Mul multiplies two vectors element-wise, using a 1D vector +// view of the two vectors, returning the output 1D vector. +func Mul(a, b tensor.Tensor) tensor.Values { + return tensor.CallOut2Float64(MulOut, a, b) +} + +// MulOut multiplies two vectors element-wise, using a 1D vector +// view of the two vectors, filling in values in the output 1D vector. +func MulOut(a, b tensor.Tensor, out tensor.Values) error { + return tmath.MulOut(tensor.As1D(a), tensor.As1D(b), out) +} + +// Sum returns the sum of all values in the tensor, as a scalar. +func Sum(a tensor.Tensor) tensor.Values { + n := a.Len() + sum := 0.0 + tensor.Vectorize(func(tsr ...tensor.Tensor) int { return n }, + func(idx int, tsr ...tensor.Tensor) { + sum += tsr[0].Float1D(idx) + }, a) + return tensor.NewFloat64Scalar(sum) +} + +// Dot performs the vector dot product: the [Sum] of the [Mul] product +// of the two tensors, returning a scalar value. Also known as the inner product. +func Dot(a, b tensor.Tensor) tensor.Values { + return Sum(Mul(a, b)) +} + +// L2Norm returns the length of the vector as the L2 Norm: +// square root of the sum of squared values of the vector, as a scalar. +// This is the Sqrt of the [Dot] product of the vector with itself. +func L2Norm(a tensor.Tensor) tensor.Values { + dot := Dot(a, a).Float1D(0) + return tensor.NewFloat64Scalar(math.Sqrt(dot)) +} + +// L1Norm returns the length of the vector as the L1 Norm: +// sum of the absolute values of the tensor, as a scalar. +func L1Norm(a tensor.Tensor) tensor.Values { + n := a.Len() + sum := 0.0 + tensor.Vectorize(func(tsr ...tensor.Tensor) int { return n }, + func(idx int, tsr ...tensor.Tensor) { + sum += math.Abs(tsr[0].Float1D(idx)) + }, a) + return tensor.NewFloat64Scalar(sum) +} diff --git a/tensor/vector/vector_test.go b/tensor/vector/vector_test.go new file mode 100644 index 0000000000..3fde1fa7e1 --- /dev/null +++ b/tensor/vector/vector_test.go @@ -0,0 +1,31 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vector + +import ( + "math" + "testing" + + "cogentcore.org/core/tensor" + "github.com/stretchr/testify/assert" +) + +func TestVector(t *testing.T) { + v := tensor.NewFloat64FromValues(1, 2, 3) + ip := Mul(v, v).(*tensor.Float64) + assert.Equal(t, []float64{1, 4, 9}, ip.Values) + + smv := Sum(ip).(*tensor.Float64) + assert.Equal(t, 14.0, smv.Values[0]) + + dpv := Dot(v, v).(*tensor.Float64) + assert.Equal(t, 14.0, dpv.Values[0]) + + nl2v := L2Norm(v).(*tensor.Float64) + assert.Equal(t, math.Sqrt(14.0), nl2v.Values[0]) + + nl1v := L1Norm(v).(*tensor.Float64) + assert.Equal(t, 6.0, nl1v.Values[0]) +} diff --git a/tensor/vectorize.go b/tensor/vectorize.go new file mode 100644 index 0000000000..f81388d00f --- /dev/null +++ b/tensor/vectorize.go @@ -0,0 +1,142 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tensor + +import ( + "math" + "runtime" + "sync" +) + +var ( + // ThreadingThreshod is the threshold in number of flops (floating point ops), + // computed as tensor N * flops per element, to engage actual parallel processing. + // Heuristically, numbers below this threshold do not result in + // an overall speedup, due to overhead costs. See tmath/ops_test.go for benchmark. + ThreadingThreshold = 300 + + // NumThreads is the number of threads to use for parallel threading. + // The default of 0 causes the [runtime.GOMAXPROCS] to be used. + NumThreads = 0 +) + +// Vectorize applies given function 'fun' to tensor elements indexed +// by given index, with the 'nfun' providing the number of indexes +// to vectorize over, and initializing any output vectors. +// Thus the nfun is often specific to a particular class of functions. +// Both functions are called with the same set +// of Tensors passed as the final argument(s). +// The role of each tensor is function-dependent: there could be multiple +// inputs and outputs, and the output could be effectively scalar, +// as in a sum operation. The interpretation of the index is +// function dependent as well, but often is used to iterate over +// the outermost row dimension of the tensor. +// This version runs purely sequentially on on this go routine. +// See VectorizeThreaded and VectorizeGPU for other versions. +func Vectorize(nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) { + n := nfun(tsr...) + if n <= 0 { + return + } + for idx := range n { + fun(idx, tsr...) + } +} + +// VectorizeThreaded is a version of [Vectorize] that will automatically +// distribute the computation in parallel across multiple "threads" (goroutines) +// if the number of elements to be computed times the given flops +// (floating point operations) for the function exceeds the [ThreadingThreshold]. +// Heuristically, numbers below this threshold do not result +// in an overall speedup, due to overhead costs. +// Each elemental math operation in the function adds a flop. +// See estimates in [tmath] for basic math functions. +func VectorizeThreaded(flops int, nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) { + n := nfun(tsr...) + if n <= 0 { + return + } + if flops < 0 { + flops = 1 + } + if n*flops < ThreadingThreshold { + Vectorize(nfun, fun, tsr...) + return + } + VectorizeOnThreads(0, nfun, fun, tsr...) +} + +// DefaultNumThreads returns the default number of threads to use: +// NumThreads if non-zero, otherwise [runtime.GOMAXPROCS]. +func DefaultNumThreads() int { + if NumThreads > 0 { + return NumThreads + } + return runtime.GOMAXPROCS(0) +} + +// VectorizeOnThreads runs given [Vectorize] function on given number +// of threads. Use [VectorizeThreaded] to only use parallel threads when +// it is likely to be beneficial, in terms of the ThreadingThreshold. +// If threads is 0, then the [DefaultNumThreads] will be used: +// GOMAXPROCS subject to NumThreads constraint if non-zero. +func VectorizeOnThreads(threads int, nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) { + if threads == 0 { + threads = DefaultNumThreads() + } + n := nfun(tsr...) + if n <= 0 { + return + } + nper := int(math.Ceil(float64(n) / float64(threads))) + wait := sync.WaitGroup{} + for start := 0; start < n; start += nper { + end := start + nper + if end > n { + end = n + } + wait.Add(1) // todo: move out of loop + go func() { + for idx := start; idx < end; idx++ { + fun(idx, tsr...) + } + wait.Done() + }() + } + wait.Wait() +} + +// NFirstRows is an N function for Vectorize that returns the number of +// outer-dimension rows (or Indexes) of the first tensor. +func NFirstRows(tsr ...Tensor) int { + if len(tsr) == 0 { + return 0 + } + return tsr[0].DimSize(0) +} + +// NFirstLen is an N function for Vectorize that returns the number of +// elements in the tensor, taking into account the Indexes view. +func NFirstLen(tsr ...Tensor) int { + if len(tsr) == 0 { + return 0 + } + return tsr[0].Len() +} + +// NMinLen is an N function for Vectorize that returns the min number of +// elements across given number of tensors in the list. Use a closure +// to call this with the nt. +func NMinLen(nt int, tsr ...Tensor) int { + nt = min(len(tsr), nt) + if nt == 0 { + return 0 + } + n := tsr[0].Len() + for i := 1; i < nt; i++ { + n = min(n, tsr[0].Len()) + } + return n +} diff --git a/texteditor/editor.go b/texteditor/editor.go index 46b026ebe7..7819bd436a 100644 --- a/texteditor/editor.go +++ b/texteditor/editor.go @@ -248,6 +248,9 @@ func (ed *Editor) Init() { s.MaxBorder.Width.Set(units.Dp(2)) s.Background = colors.Scheme.SurfaceContainerLow + if s.IsReadOnly() { + s.Background = colors.Scheme.SurfaceContainer + } // note: a blank background does NOT work for depth color rendering if s.Is(states.Focused) { s.StateLayer = 0 @@ -317,9 +320,6 @@ func (ed *Editor) resetState() { if ed.Buffer == nil || ed.lastFilename != ed.Buffer.Filename { // don't reset if reopening.. ed.CursorPos = lexer.Pos{} } - if ed.Buffer != nil { - ed.Buffer.SetReadOnly(ed.IsReadOnly()) - } } // SetBuffer sets the [Buffer] that this is an editor of, and interconnects their events. diff --git a/texteditor/events.go b/texteditor/events.go index 90240efbbd..2800ca56de 100644 --- a/texteditor/events.go +++ b/texteditor/events.go @@ -729,5 +729,13 @@ func (ed *Editor) contextMenu(m *core.Scene) { OnClick(func(e events.Event) { ed.Clear() }) + core.NewButton(m).SetText("Editable").SetIcon(icons.Edit). + OnClick(func(e events.Event) { + ed.SetReadOnly(false) + if ed.Buffer != nil { + ed.Buffer.Info.Generated = false // another reason it is !editable + } + ed.Update() + }) } } diff --git a/texteditor/highlighting/defaults.highlighting b/texteditor/highlighting/defaults.highlighting index 3e42a6b9a9..cb7dd0aff8 100644 --- a/texteditor/highlighting/defaults.highlighting +++ b/texteditor/highlighting/defaults.highlighting @@ -6458,10 +6458,10 @@ "A": 0 }, "Background": { - "R": 225, - "G": 225, - "B": 225, - "A": 255 + "R": 0, + "G": 0, + "B": 0, + "A": 0 }, "Border": { "R": 0, diff --git a/texteditor/highlighting/style.go b/texteditor/highlighting/style.go index 634dc76e83..ec38c876ca 100644 --- a/texteditor/highlighting/style.go +++ b/texteditor/highlighting/style.go @@ -46,26 +46,34 @@ func (t Trilean) Prefix(s string) string { // StyleEntry is one value in the map of highlight style values type StyleEntry struct { - // text color + // Color is the text color. Color color.RGBA - // background color + // Background color. + // In general it is not good to use this because it obscures highlighting. Background color.RGBA - // border color? not sure what this is -- not really used + // Border color? not sure what this is -- not really used. Border color.RGBA `display:"-"` - // bold font + // Bold font. Bold Trilean - // italic font + // Italic font. Italic Trilean - // underline + // Underline. Underline Trilean - // don't inherit these settings from sub-category or category levels -- otherwise everything with a Pass is inherited + // NoInherit indicates to not inherit these settings from sub-category or category levels. + // Otherwise everything with a Pass is inherited. NoInherit bool + + // themeColor is the theme-adjusted text color. + themeColor color.RGBA + + // themeBackground is the theme-adjusted background color. + themeBackground color.RGBA } // // FromChroma copies styles from chroma @@ -108,7 +116,7 @@ func (se *StyleEntry) UpdateFromTheme() { if matcolor.SchemeIsDark { ctone = 80 } - se.Color = hc.WithChroma(max(hc.Chroma, 48)).WithTone(ctone).AsRGBA() + se.themeColor = hc.WithChroma(max(hc.Chroma, 48)).WithTone(ctone).AsRGBA() if !colors.IsNil(se.Background) { hb := hct.FromColor(se.Background) @@ -116,7 +124,7 @@ func (se *StyleEntry) UpdateFromTheme() { if matcolor.SchemeIsDark { btone = min(hb.Tone, 17) } - se.Background = hb.WithChroma(max(hb.Chroma, 6)).WithTone(btone).AsRGBA() + se.themeBackground = hb.WithChroma(max(hb.Chroma, 6)).WithTone(btone).AsRGBA() } } @@ -134,11 +142,11 @@ func (se StyleEntry) String() string { if se.NoInherit { out = append(out, "noinherit") } - if !colors.IsNil(se.Color) { - out = append(out, colors.AsString(se.Color)) + if !colors.IsNil(se.themeColor) { + out = append(out, colors.AsString(se.themeColor)) } - if !colors.IsNil(se.Background) { - out = append(out, "bg:"+colors.AsString(se.Background)) + if !colors.IsNil(se.themeBackground) { + out = append(out, "bg:"+colors.AsString(se.themeBackground)) } if !colors.IsNil(se.Border) { out = append(out, "border:"+colors.AsString(se.Border)) @@ -149,11 +157,11 @@ func (se StyleEntry) String() string { // ToCSS converts StyleEntry to CSS attributes. func (se StyleEntry) ToCSS() string { styles := []string{} - if !colors.IsNil(se.Color) { - styles = append(styles, "color: "+colors.AsString(se.Color)) + if !colors.IsNil(se.themeColor) { + styles = append(styles, "color: "+colors.AsString(se.themeColor)) } - if !colors.IsNil(se.Background) { - styles = append(styles, "background-color: "+colors.AsString(se.Background)) + if !colors.IsNil(se.themeBackground) { + styles = append(styles, "background-color: "+colors.AsString(se.themeBackground)) } if se.Bold == Yes { styles = append(styles, "font-weight: bold") @@ -170,11 +178,11 @@ func (se StyleEntry) ToCSS() string { // ToProperties converts the StyleEntry to key-value properties. func (se StyleEntry) ToProperties() map[string]any { pr := map[string]any{} - if !colors.IsNil(se.Color) { - pr["color"] = se.Color + if !colors.IsNil(se.themeColor) { + pr["color"] = se.themeColor } - if !colors.IsNil(se.Background) { - pr["background-color"] = se.Background + if !colors.IsNil(se.themeBackground) { + pr["background-color"] = se.themeBackground } if se.Bold == Yes { pr["font-weight"] = styles.WeightBold @@ -189,25 +197,27 @@ func (se StyleEntry) ToProperties() map[string]any { } // Sub subtracts two style entries, returning an entry with only the differences set -func (s StyleEntry) Sub(e StyleEntry) StyleEntry { +func (se StyleEntry) Sub(e StyleEntry) StyleEntry { out := StyleEntry{} - if e.Color != s.Color { - out.Color = s.Color + if e.Color != se.Color { + out.Color = se.Color + out.themeColor = se.themeColor } - if e.Background != s.Background { - out.Background = s.Background + if e.Background != se.Background { + out.Background = se.Background + out.themeBackground = se.themeBackground } - if e.Border != s.Border { - out.Border = s.Border + if e.Border != se.Border { + out.Border = se.Border } - if e.Bold != s.Bold { - out.Bold = s.Bold + if e.Bold != se.Bold { + out.Bold = se.Bold } - if e.Italic != s.Italic { - out.Italic = s.Italic + if e.Italic != se.Italic { + out.Italic = se.Italic } - if e.Underline != s.Underline { - out.Underline = s.Underline + if e.Underline != se.Underline { + out.Underline = se.Underline } return out } @@ -215,18 +225,20 @@ func (s StyleEntry) Sub(e StyleEntry) StyleEntry { // Inherit styles from ancestors. // // Ancestors should be provided from oldest, furthest away to newest, closest. -func (s StyleEntry) Inherit(ancestors ...StyleEntry) StyleEntry { - out := s +func (se StyleEntry) Inherit(ancestors ...StyleEntry) StyleEntry { + out := se for i := len(ancestors) - 1; i >= 0; i-- { if out.NoInherit { return out } ancestor := ancestors[i] - if colors.IsNil(out.Color) { + if colors.IsNil(out.themeColor) { out.Color = ancestor.Color + out.themeColor = ancestor.themeColor } - if colors.IsNil(out.Background) { + if colors.IsNil(out.themeBackground) { out.Background = ancestor.Background + out.themeBackground = ancestor.themeBackground } if colors.IsNil(out.Border) { out.Border = ancestor.Border @@ -244,9 +256,9 @@ func (s StyleEntry) Inherit(ancestors ...StyleEntry) StyleEntry { return out } -func (s StyleEntry) IsZero() bool { - return colors.IsNil(s.Color) && colors.IsNil(s.Background) && colors.IsNil(s.Border) && s.Bold == Pass && s.Italic == Pass && - s.Underline == Pass && !s.NoInherit +func (se StyleEntry) IsZero() bool { + return colors.IsNil(se.Color) && colors.IsNil(se.Background) && colors.IsNil(se.Border) && se.Bold == Pass && se.Italic == Pass && + se.Underline == Pass && !se.NoInherit } /////////////////////////////////////////////////////////////////////////////////// diff --git a/texteditor/highlighting/typegen.go b/texteditor/highlighting/typegen.go index 2566f3a9a0..ac1180c6c2 100644 --- a/texteditor/highlighting/typegen.go +++ b/texteditor/highlighting/typegen.go @@ -10,7 +10,7 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor/highligh var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor/highlighting.Trilean", IDName: "trilean", Doc: "Trilean value for StyleEntry value inheritance."}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor/highlighting.StyleEntry", IDName: "style-entry", Doc: "StyleEntry is one value in the map of highlight style values", Fields: []types.Field{{Name: "Color", Doc: "text color"}, {Name: "Background", Doc: "background color"}, {Name: "Border", Doc: "border color? not sure what this is -- not really used"}, {Name: "Bold", Doc: "bold font"}, {Name: "Italic", Doc: "italic font"}, {Name: "Underline", Doc: "underline"}, {Name: "NoInherit", Doc: "don't inherit these settings from sub-category or category levels -- otherwise everything with a Pass is inherited"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor/highlighting.StyleEntry", IDName: "style-entry", Doc: "StyleEntry is one value in the map of highlight style values", Fields: []types.Field{{Name: "Color", Doc: "Color is the text color."}, {Name: "Background", Doc: "Background color.\nIn general it is not good to use this because it obscures highlighting."}, {Name: "Border", Doc: "Border color? not sure what this is -- not really used."}, {Name: "Bold", Doc: "Bold font."}, {Name: "Italic", Doc: "Italic font."}, {Name: "Underline", Doc: "Underline."}, {Name: "NoInherit", Doc: "NoInherit indicates to not inherit these settings from sub-category or category levels.\nOtherwise everything with a Pass is inherited."}, {Name: "themeColor", Doc: "themeColor is the theme-adjusted text color."}, {Name: "themeBackground", Doc: "themeBackground is the theme-adjusted background color."}}}) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor/highlighting.Style", IDName: "style", Doc: "Style is a full style map of styles for different token.Tokens tag values"}) diff --git a/texteditor/typegen.go b/texteditor/typegen.go index 22d09b6999..d603903f29 100644 --- a/texteditor/typegen.go +++ b/texteditor/typegen.go @@ -15,7 +15,7 @@ import ( var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor.Buffer", IDName: "buffer", Doc: "Buffer is a buffer of text, which can be viewed by [Editor](s).\nIt holds the raw text lines (in original string and rune formats,\nand marked-up from syntax highlighting), and sends signals for making\nedits to the text and coordinating those edits across multiple views.\nEditors always only view a single buffer, so they directly call methods\non the buffer to drive updates, which are then broadcast.\nIt also has methods for loading and saving buffers to files.\nUnlike GUI widgets, its methods generally send events, without an\nexplicit Event suffix.\nInternally, the buffer represents new lines using \\n = LF, but saving\nand loading can deal with Windows/DOS CRLF format.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "Open", Doc: "Open loads the given file into the buffer.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}, {Name: "Revert", Doc: "Revert re-opens text from the current file,\nif the filename is set; returns false if not.\nIt uses an optimized diff-based update to preserve\nexisting formatting, making it very fast if not very different.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"bool"}}, {Name: "SaveAs", Doc: "SaveAs saves the current text into given file; does an editDone first to save edits\nand checks for an existing file; if it does exist then prompts to overwrite or not.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}}, {Name: "Save", Doc: "Save saves the current text into the current filename associated with this buffer.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"error"}}}, Embeds: []types.Field{{Name: "Lines"}}, Fields: []types.Field{{Name: "Filename", Doc: "Filename is the filename of the file that was last loaded or saved.\nIt is used when highlighting code."}, {Name: "Autosave", Doc: "Autosave specifies whether the file should be automatically\nsaved after changes are made."}, {Name: "Info", Doc: "Info is the full information about the current file."}, {Name: "LineColors", Doc: "LineColors are the colors to use for rendering circles\nnext to the line numbers of certain lines."}, {Name: "editors", Doc: "editors are the editors that are currently viewing this buffer."}, {Name: "posHistory", Doc: "posHistory is the history of cursor positions.\nIt can be used to move back through them."}, {Name: "Complete", Doc: "Complete is the functions and data for text completion."}, {Name: "spell", Doc: "spell is the functions and data for spelling correction."}, {Name: "currentEditor", Doc: "currentEditor is the current text editor, such as the one that initiated the\nComplete or Correct process. The cursor position in this view is updated, and\nit is reset to nil after usage."}, {Name: "listeners", Doc: "listeners is used for sending standard system events.\nChange is sent for BufferDone, BufferInsert, and BufferDelete."}, {Name: "autoSaving", Doc: "autoSaving is used in atomically safe way to protect autosaving"}, {Name: "notSaved", Doc: "notSaved indicates if the text has been changed (edited) relative to the\noriginal, since last Save. This can be true even when changed flag is\nfalse, because changed is cleared on EditDone, e.g., when texteditor\nis being monitored for OnChange and user does Control+Enter.\nUse IsNotSaved() method to query state."}, {Name: "fileModOK", Doc: "fileModOK have already asked about fact that file has changed since being\nopened, user is ok"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor.DiffEditor", IDName: "diff-editor", Doc: "DiffEditor presents two side-by-side [Editor]s showing the differences\nbetween two files (represented as lines of strings).", Methods: []types.Method{{Name: "saveFileA", Doc: "saveFileA saves the current state of file A to given filename", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "saveFileB", Doc: "saveFileB saves the current state of file B to given filename", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "FileA", Doc: "first file name being compared"}, {Name: "FileB", Doc: "second file name being compared"}, {Name: "RevisionA", Doc: "revision for first file, if relevant"}, {Name: "RevisionB", Doc: "revision for second file, if relevant"}, {Name: "bufferA", Doc: "[Buffer] for A showing the aligned edit view"}, {Name: "bufferB", Doc: "[Buffer] for B showing the aligned edit view"}, {Name: "alignD", Doc: "aligned diffs records diff for aligned lines"}, {Name: "diffs", Doc: "diffs applied"}, {Name: "inInputEvent"}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/texteditor.DiffEditor", IDName: "diff-editor", Doc: "DiffEditor presents two side-by-side [Editor]s showing the differences\nbetween two files (represented as lines of strings).", Methods: []types.Method{{Name: "saveFileA", Doc: "saveFileA saves the current state of file A to given filename", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "saveFileB", Doc: "saveFileB saves the current state of file B to given filename", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "FileA", Doc: "first file name being compared"}, {Name: "FileB", Doc: "second file name being compared"}, {Name: "RevisionA", Doc: "revision for first file, if relevant"}, {Name: "RevisionB", Doc: "revision for second file, if relevant"}, {Name: "bufferA", Doc: "[Buffer] for A showing the aligned edit view"}, {Name: "bufferB", Doc: "[Buffer] for B showing the aligned edit view"}, {Name: "alignD", Doc: "aligned diffs records diff for aligned lines"}, {Name: "diffs", Doc: "diffs applied"}, {Name: "inInputEvent"}, {Name: "toolbar"}}}) // NewDiffEditor returns a new [DiffEditor] with the given optional parent: // DiffEditor presents two side-by-side [Editor]s showing the differences diff --git a/types/typegen/generator.go b/types/typegen/generator.go index 4c7e371adf..733c95809b 100644 --- a/types/typegen/generator.go +++ b/types/typegen/generator.go @@ -80,7 +80,7 @@ func (g *Generator) Find() error { return err } g.Types = []*Type{} - err = generate.Inspect(g.Pkg, g.Inspect) + err = generate.Inspect(g.Pkg, g.Inspect, "enumgen.go", "typegen.go") if err != nil { return fmt.Errorf("error while inspecting: %w", err) } diff --git a/types/typegen/testdata/typegen.go b/types/typegen/testdata/typegen.go index 3149353580..14ab9d7984 100644 --- a/types/typegen/testdata/typegen.go +++ b/types/typegen/testdata/typegen.go @@ -1,4 +1,4 @@ -// Code generated by "typegen.test -test.testlogfile=/var/folders/x1/r8shprmj7j71zbw3qvgl9dqc0000gq/T/go-build1829688390/b982/testlog.txt -test.paniconexit0 -test.timeout=20s"; DO NOT EDIT. +// Code generated by "typegen.test -test.paniconexit0 -test.timeout=10m0s -test.v=true"; DO NOT EDIT. package testdata diff --git a/types/typegen/typegen_gen.go b/types/typegen/typegen_gen.go index 1e087d2d29..41753e4872 100644 --- a/types/typegen/typegen_gen.go +++ b/types/typegen/typegen_gen.go @@ -7,5 +7,3 @@ import ( ) var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/types/typegen.Config", IDName: "config", Doc: "Config contains the configuration information\nused by typegen", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Dir", Doc: "the source directory to run typegen on (can be set to multiple through paths like ./...)"}, {Name: "Output", Doc: "the output file location relative to the package on which typegen is being called"}, {Name: "AddTypes", Doc: "whether to add types to typegen by default"}, {Name: "AddMethods", Doc: "whether to add methods to typegen by default"}, {Name: "AddFuncs", Doc: "whether to add functions to typegen by default"}, {Name: "InterfaceConfigs", Doc: "An ordered map of configs keyed by fully qualified interface type names; if a type implements the interface, the config will be applied to it.\nThe configs are applied in sequential ascending order, which means that\nthe last config overrides the other ones, so the most specific\ninterfaces should typically be put last.\nNote: the package typegen is run on must explicitly reference this interface at some point for this to work; adding a simple\n`var _ MyInterface = (*MyType)(nil)` statement to check for interface implementation is an easy way to accomplish that.\nNote: typegen will still succeed if it can not find one of the interfaces specified here in order to allow it to work generically across multiple directories; you can use the -v flag to get log warnings about this if you suspect that it is not finding interfaces when it should."}, {Name: "Setters", Doc: "Whether to generate chaining `Set*` methods for each exported field of each type (eg: \"SetText\" for field \"Text\").\nIf this is set to true, then you can add `set:\"-\"` struct tags to individual fields\nto prevent Set methods being generated for them."}, {Name: "Templates", Doc: "a slice of templates to execute on each type being added; the template data is of the type typegen.Type"}}}) - -var _ = types.AddFunc(&types.Func{Name: "cogentcore.org/core/types/typegen.Generate", Doc: "Generate generates typegen type info, using the\nconfiguration information, loading the packages from the\nconfiguration source directory, and writing the result\nto the configuration output file.\n\nIt is a simple entry point to typegen that does all\nof the steps; for more specific functionality, create\na new [Generator] with [NewGenerator] and call methods on it.", Directives: []types.Directive{{Tool: "cli", Directive: "cmd", Args: []string{"-root"}}, {Tool: "types", Directive: "add"}}, Args: []string{"cfg"}, Returns: []string{"error"}}) diff --git a/xyz/typegen.go b/xyz/typegen.go index 801ba84539..66e8ac870c 100644 --- a/xyz/typegen.go +++ b/xyz/typegen.go @@ -62,8 +62,6 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.LightColors", I var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.Lines", IDName: "lines", Doc: "Lines are lines rendered as long thin boxes defined by points\nand width parameters. The Mesh must be drawn in the XY plane (i.e., use Z = 0\nor a constant unless specifically relevant to have full 3D variation).\nRotate the solid to put into other planes.", Embeds: []types.Field{{Name: "MeshBase"}}, Fields: []types.Field{{Name: "Points", Doc: "line points (must be 2 or more)"}, {Name: "Width", Doc: "line width, Y = height perpendicular to line direction, and X = depth"}, {Name: "Colors", Doc: "optional colors for each point -- actual color interpolates between"}, {Name: "Closed", Doc: "if true, connect the first and last points to form a closed shape"}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.Line", IDName: "line", Doc: "Line is a Solid that is used for line elements.\nType is need to trigger more precise event handling.", Directives: []types.Directive{{Tool: "core", Directive: "no-new"}}, Embeds: []types.Field{{Name: "Solid"}}}) - var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.Material", IDName: "material", Doc: "Material describes the material properties of a surface (colors, shininess, texture)\ni.e., phong lighting parameters.\nMain color is used for both ambient and diffuse color, and alpha component\nis used for opacity. The Emissive color is only for glowing objects.\nThe Specular color is always white (multiplied by light color).\nTextures are stored on the Scene and accessed by name", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Color", Doc: "Color is the main color of surface, used for both ambient and diffuse color in standard Phong model -- alpha component determines transparency -- note that transparent objects require more complex rendering"}, {Name: "Emissive", Doc: "Emissive is the color that surface emits independent of any lighting -- i.e., glow -- can be used for marking lights with an object"}, {Name: "Shiny", Doc: "Shiny is the specular shininess factor -- how focally vs. broad the surface shines back directional light -- this is an exponential factor, with 0 = very broad diffuse reflection, and higher values (typically max of 128 or so but can go higher) having a smaller more focal specular reflection. Also set Reflective factor to change overall shininess effect."}, {Name: "Reflective", Doc: "Reflective is the specular reflectiveness factor -- how much it shines back directional light. The specular reflection color is always white * the incoming light."}, {Name: "Bright", Doc: "Bright is an overall multiplier on final computed color value -- can be used to tune the overall brightness of various surfaces relative to each other for a given set of lighting parameters"}, {Name: "TextureName", Doc: "TextureName is the name of the texture to provide color for the surface."}, {Name: "Tiling", Doc: "Tiling is the texture tiling parameters: repeat and offset."}, {Name: "CullBack", Doc: "CullBack indicates to cull the back-facing surfaces."}, {Name: "CullFront", Doc: "CullFront indicates to cull the front-facing surfaces."}, {Name: "Texture", Doc: "Texture is the cached [Texture] object set based on [Material.TextureName]."}}}) // SetColor sets the [Material.Color]: @@ -124,7 +122,7 @@ var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.GenMesh", IDNam var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.Node", IDName: "node", Doc: "Node is the common interface for all xyz 3D tree nodes.\n[Solid] and [Group] are the two main types of nodes,\nwhich both extend [NodeBase] for the core functionality.", Methods: []types.Method{{Name: "AsNodeBase", Doc: "AsNodeBase returns the [NodeBase] for our node, which gives\naccess to all the base-level data structures and methods\nwithout requiring interface methods.", Returns: []string{"NodeBase"}}, {Name: "IsSolid", Doc: "IsSolid returns true if this is an [Solid] node (otherwise a [Group]).", Returns: []string{"bool"}}, {Name: "AsSolid", Doc: "AsSolid returns the node as a [Solid] (nil if not).", Returns: []string{"Solid"}}, {Name: "Validate", Doc: "Validate checks that scene element is valid.", Returns: []string{"error"}}, {Name: "UpdateWorldMatrix", Doc: "UpdateWorldMatrix updates this node's local and world matrix based on parent's world matrix.", Args: []string{"parWorld"}}, {Name: "UpdateMeshBBox", Doc: "UpdateMeshBBox updates the Mesh-based BBox info for all nodes.\ngroups aggregate over elements. It is called from WalkPost traversal."}, {Name: "IsVisible", Doc: "IsVisible provides the definitive answer as to whether a given node\nis currently visible. It is only entirely valid after a render pass\nfor widgets in a visible window, but it checks the window and viewport\nfor their visibility status as well, which is available always.\nNon-visible nodes are automatically not rendered and not connected to\nwindow events. The Invisible flag is one key element of the IsVisible\ncalculus; it is set by e.g., TabView for invisible tabs, and is also\nset if a widget is entirely out of render range. But again, use\nIsVisible as the main end-user method.\nFor robustness, it recursively calls the parent; this is typically\na short path; propagating the Invisible flag properly can be\nvery challenging without mistakenly overwriting invisibility at various\nlevels.", Returns: []string{"bool"}}, {Name: "IsTransparent", Doc: "IsTransparent returns true if solid has transparent color.", Returns: []string{"bool"}}, {Name: "Config", Doc: "Config configures the node."}, {Name: "RenderClass", Doc: "RenderClass returns the class of rendering for this solid.\nIt is used for organizing the ordering of rendering.", Returns: []string{"RenderClasses"}}, {Name: "PreRender", Doc: "PreRender is called by Scene Render to upload\nall the object data to the Phong renderer."}, {Name: "Render", Doc: "Render is called by Scene Render to actually render.", Args: []string{"rp"}}}}) -var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.NodeBase", IDName: "node-base", Doc: "NodeBase is the basic 3D tree node, which has the full transform information\nrelative to parent, and computed bounding boxes, etc.\nIt implements the [Node] interface and contains the core functionality\ncommon to all 3D nodes.", Embeds: []types.Field{{Name: "NodeBase"}}, Fields: []types.Field{{Name: "Invisible", Doc: "Invisible is whether this node is invisible."}, {Name: "Pose", Doc: "Pose is the complete specification of position and orientation."}, {Name: "Scene", Doc: "Scene is the cached [Scene]."}, {Name: "MeshBBox", Doc: "mesh-based local bounding box (aggregated for groups)"}, {Name: "WorldBBox", Doc: "world coordinates bounding box"}, {Name: "NDCBBox", Doc: "normalized display coordinates bounding box, used for frustrum clipping"}, {Name: "BBox", Doc: "raw original bounding box for the widget within its parent Scene.\nThis is prior to intersecting with Frame bounds."}, {Name: "SceneBBox", Doc: "2D bounding box for region occupied within Scene Frame that we render onto.\nThis is BBox intersected with Frame bounds."}}}) +var _ = types.AddType(&types.Type{Name: "cogentcore.org/core/xyz.NodeBase", IDName: "node-base", Doc: "NodeBase is the basic 3D tree node, which has the full transform information\nrelative to parent, and computed bounding boxes, etc.\nIt implements the [Node] interface and contains the core functionality\ncommon to all 3D nodes.", Embeds: []types.Field{{Name: "NodeBase"}}, Fields: []types.Field{{Name: "Invisible", Doc: "Invisible is whether this node is invisible."}, {Name: "Pose", Doc: "Pose is the complete specification of position and orientation."}, {Name: "Scene", Doc: "Scene is the cached [Scene]."}, {Name: "MeshBBox", Doc: "mesh-based local bounding box (aggregated for groups)"}, {Name: "WorldBBox", Doc: "world coordinates bounding box"}, {Name: "NDCBBox", Doc: "normalized display coordinates bounding box, used for frustrum clipping"}, {Name: "BBox", Doc: "raw original bounding box for the widget within its parent Scene.\nThis is prior to intersecting with Frame bounds."}, {Name: "SceneBBox", Doc: "2D bounding box for region occupied within Scene Frame that we render onto.\nThis is BBox intersected with Frame bounds."}, {Name: "isLinear", Doc: "isLinear indicates that this element contains a line-like shape,\nwhich engages a more selective event processing logic to determine\nif the node was selected based on a mouse click point."}}}) // NewNodeBase returns a new [NodeBase] with the given optional parent: // NodeBase is the basic 3D tree node, which has the full transform information diff --git a/yaegicore/symbols/cogentcore_org-core-base-errors.go b/yaegicore/nogui/cogentcore_org-core-base-errors.go similarity index 98% rename from yaegicore/symbols/cogentcore_org-core-base-errors.go rename to yaegicore/nogui/cogentcore_org-core-base-errors.go index 6e34726b01..fd0777ae5d 100644 --- a/yaegicore/symbols/cogentcore_org-core-base-errors.go +++ b/yaegicore/nogui/cogentcore_org-core-base-errors.go @@ -1,6 +1,6 @@ // Code generated by 'yaegi extract cogentcore.org/core/base/errors'. DO NOT EDIT. -package symbols +package nogui import ( "cogentcore.org/core/base/errors" diff --git a/yaegicore/symbols/cogentcore_org-core-base-fileinfo.go b/yaegicore/nogui/cogentcore_org-core-base-fileinfo.go similarity index 99% rename from yaegicore/symbols/cogentcore_org-core-base-fileinfo.go rename to yaegicore/nogui/cogentcore_org-core-base-fileinfo.go index 0f3a5a442f..f1ea76f300 100644 --- a/yaegicore/symbols/cogentcore_org-core-base-fileinfo.go +++ b/yaegicore/nogui/cogentcore_org-core-base-fileinfo.go @@ -1,6 +1,6 @@ // Code generated by 'yaegi extract cogentcore.org/core/base/fileinfo'. DO NOT EDIT. -package symbols +package nogui import ( "cogentcore.org/core/base/fileinfo" @@ -89,6 +89,7 @@ func init() { "Icons": reflect.ValueOf(&fileinfo.Icons).Elem(), "Image": reflect.ValueOf(fileinfo.Image), "Ini": reflect.ValueOf(fileinfo.Ini), + "IsGeneratedFile": reflect.ValueOf(fileinfo.IsGeneratedFile), "IsMatch": reflect.ValueOf(fileinfo.IsMatch), "IsMatchList": reflect.ValueOf(fileinfo.IsMatchList), "Java": reflect.ValueOf(fileinfo.Java), diff --git a/yaegicore/symbols/cogentcore_org-core-base-fsx.go b/yaegicore/nogui/cogentcore_org-core-base-fsx.go similarity index 92% rename from yaegicore/symbols/cogentcore_org-core-base-fsx.go rename to yaegicore/nogui/cogentcore_org-core-base-fsx.go index 4083d45fd0..7dde22275d 100644 --- a/yaegicore/symbols/cogentcore_org-core-base-fsx.go +++ b/yaegicore/nogui/cogentcore_org-core-base-fsx.go @@ -1,6 +1,6 @@ // Code generated by 'yaegi extract cogentcore.org/core/base/fsx'. DO NOT EDIT. -package symbols +package nogui import ( "cogentcore.org/core/base/fsx" @@ -25,5 +25,8 @@ func init() { "RelativeFilePath": reflect.ValueOf(fsx.RelativeFilePath), "SplitRootPathFS": reflect.ValueOf(fsx.SplitRootPathFS), "Sub": reflect.ValueOf(fsx.Sub), + + // type definitions + "Filename": reflect.ValueOf((*fsx.Filename)(nil)), } } diff --git a/yaegicore/symbols/cogentcore_org-core-base-labels.go b/yaegicore/nogui/cogentcore_org-core-base-labels.go similarity index 99% rename from yaegicore/symbols/cogentcore_org-core-base-labels.go rename to yaegicore/nogui/cogentcore_org-core-base-labels.go index 8e95a74de5..b00adf46fe 100644 --- a/yaegicore/symbols/cogentcore_org-core-base-labels.go +++ b/yaegicore/nogui/cogentcore_org-core-base-labels.go @@ -1,6 +1,6 @@ // Code generated by 'yaegi extract cogentcore.org/core/base/labels'. DO NOT EDIT. -package symbols +package nogui import ( "cogentcore.org/core/base/labels" diff --git a/yaegicore/nogui/cogentcore_org-core-base-num.go b/yaegicore/nogui/cogentcore_org-core-base-num.go new file mode 100644 index 0000000000..7896ff8d08 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-base-num.go @@ -0,0 +1,11 @@ +// Code generated by 'yaegi extract cogentcore.org/core/base/num'. DO NOT EDIT. + +package nogui + +import ( + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/base/num/num"] = map[string]reflect.Value{} +} diff --git a/yaegicore/symbols/cogentcore_org-core-base-reflectx.go b/yaegicore/nogui/cogentcore_org-core-base-reflectx.go similarity index 92% rename from yaegicore/symbols/cogentcore_org-core-base-reflectx.go rename to yaegicore/nogui/cogentcore_org-core-base-reflectx.go index c341675c80..540385eedc 100644 --- a/yaegicore/symbols/cogentcore_org-core-base-reflectx.go +++ b/yaegicore/nogui/cogentcore_org-core-base-reflectx.go @@ -1,6 +1,6 @@ // Code generated by 'yaegi extract cogentcore.org/core/base/reflectx'. DO NOT EDIT. -package symbols +package nogui import ( "cogentcore.org/core/base/reflectx" @@ -12,11 +12,16 @@ func init() { Symbols["cogentcore.org/core/base/reflectx/reflectx"] = map[string]reflect.Value{ // function, constant and variable definitions "CloneToType": reflect.ValueOf(reflectx.CloneToType), + "CopyFields": reflect.ValueOf(reflectx.CopyFields), "CopyMapRobust": reflect.ValueOf(reflectx.CopyMapRobust), "CopySliceRobust": reflect.ValueOf(reflectx.CopySliceRobust), + "FieldAtPath": reflect.ValueOf(reflectx.FieldAtPath), + "FieldValue": reflect.ValueOf(reflectx.FieldValue), "FormatDefault": reflect.ValueOf(reflectx.FormatDefault), "IsNil": reflect.ValueOf(reflectx.IsNil), "KindIsBasic": reflect.ValueOf(reflectx.KindIsBasic), + "KindIsFloat": reflect.ValueOf(reflectx.KindIsFloat), + "KindIsInt": reflect.ValueOf(reflectx.KindIsInt), "KindIsNumber": reflect.ValueOf(reflectx.KindIsNumber), "LongTypeName": reflect.ValueOf(reflectx.LongTypeName), "MapAdd": reflect.ValueOf(reflectx.MapAdd), @@ -33,6 +38,7 @@ func init() { "NumAllFields": reflect.ValueOf(reflectx.NumAllFields), "OnePointerValue": reflect.ValueOf(reflectx.OnePointerValue), "PointerValue": reflect.ValueOf(reflectx.PointerValue), + "SetFieldsFromMap": reflect.ValueOf(reflectx.SetFieldsFromMap), "SetFromDefaultTag": reflect.ValueOf(reflectx.SetFromDefaultTag), "SetFromDefaultTags": reflect.ValueOf(reflectx.SetFromDefaultTags), "SetMapRobust": reflect.ValueOf(reflectx.SetMapRobust), diff --git a/yaegicore/nogui/cogentcore_org-core-goal-goalib.go b/yaegicore/nogui/cogentcore_org-core-goal-goalib.go new file mode 100644 index 0000000000..2333b7c009 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-goal-goalib.go @@ -0,0 +1,21 @@ +// Code generated by 'yaegi extract cogentcore.org/core/goal/goalib'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/goal/goalib" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/goal/goalib/goalib"] = map[string]reflect.Value{ + // function, constant and variable definitions + "AllFiles": reflect.ValueOf(goalib.AllFiles), + "FileExists": reflect.ValueOf(goalib.FileExists), + "ReadFile": reflect.ValueOf(goalib.ReadFile), + "ReplaceInFile": reflect.ValueOf(goalib.ReplaceInFile), + "SplitLines": reflect.ValueOf(goalib.SplitLines), + "StringsToAnys": reflect.ValueOf(goalib.StringsToAnys), + "WriteFile": reflect.ValueOf(goalib.WriteFile), + } +} diff --git a/yaegicore/symbols/cogentcore_org-core-math32.go b/yaegicore/nogui/cogentcore_org-core-math32.go similarity index 99% rename from yaegicore/symbols/cogentcore_org-core-math32.go rename to yaegicore/nogui/cogentcore_org-core-math32.go index df1c0a89e8..93cd2e2bf8 100644 --- a/yaegicore/symbols/cogentcore_org-core-math32.go +++ b/yaegicore/nogui/cogentcore_org-core-math32.go @@ -1,6 +1,6 @@ // Code generated by 'yaegi extract cogentcore.org/core/math32'. DO NOT EDIT. -package symbols +package nogui import ( "cogentcore.org/core/math32" @@ -29,8 +29,6 @@ func init() { "BarycoordFromPoint": reflect.ValueOf(math32.BarycoordFromPoint), "Cbrt": reflect.ValueOf(math32.Cbrt), "Ceil": reflect.ValueOf(math32.Ceil), - "Clamp": reflect.ValueOf(math32.Clamp), - "ClampInt": reflect.ValueOf(math32.ClampInt), "ContainsPoint": reflect.ValueOf(math32.ContainsPoint), "CopyFloat32s": reflect.ValueOf(math32.CopyFloat32s), "CopyFloat64s": reflect.ValueOf(math32.CopyFloat64s), diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-matrix.go b/yaegicore/nogui/cogentcore_org-core-tensor-matrix.go new file mode 100644 index 0000000000..9c333922e5 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-matrix.go @@ -0,0 +1,56 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/matrix'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/matrix" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/matrix/matrix"] = map[string]reflect.Value{ + // function, constant and variable definitions + "CallOut1": reflect.ValueOf(matrix.CallOut1), + "CallOut2": reflect.ValueOf(matrix.CallOut2), + "CopyFromDense": reflect.ValueOf(matrix.CopyFromDense), + "Det": reflect.ValueOf(matrix.Det), + "Diagonal": reflect.ValueOf(matrix.Diagonal), + "DiagonalIndices": reflect.ValueOf(matrix.DiagonalIndices), + "DiagonalN": reflect.ValueOf(matrix.DiagonalN), + "Eig": reflect.ValueOf(matrix.Eig), + "EigOut": reflect.ValueOf(matrix.EigOut), + "EigSym": reflect.ValueOf(matrix.EigSym), + "EigSymOut": reflect.ValueOf(matrix.EigSymOut), + "Identity": reflect.ValueOf(matrix.Identity), + "Inverse": reflect.ValueOf(matrix.Inverse), + "InverseOut": reflect.ValueOf(matrix.InverseOut), + "LogDet": reflect.ValueOf(matrix.LogDet), + "Mul": reflect.ValueOf(matrix.Mul), + "MulOut": reflect.ValueOf(matrix.MulOut), + "NewDense": reflect.ValueOf(matrix.NewDense), + "NewMatrix": reflect.ValueOf(matrix.NewMatrix), + "NewSymmetric": reflect.ValueOf(matrix.NewSymmetric), + "ProjectOnMatrixColumn": reflect.ValueOf(matrix.ProjectOnMatrixColumn), + "ProjectOnMatrixColumnOut": reflect.ValueOf(matrix.ProjectOnMatrixColumnOut), + "SVD": reflect.ValueOf(matrix.SVD), + "SVDOut": reflect.ValueOf(matrix.SVDOut), + "SVDValues": reflect.ValueOf(matrix.SVDValues), + "SVDValuesOut": reflect.ValueOf(matrix.SVDValuesOut), + "StringCheck": reflect.ValueOf(matrix.StringCheck), + "Trace": reflect.ValueOf(matrix.Trace), + "Tri": reflect.ValueOf(matrix.Tri), + "TriL": reflect.ValueOf(matrix.TriL), + "TriLIndicies": reflect.ValueOf(matrix.TriLIndicies), + "TriLNum": reflect.ValueOf(matrix.TriLNum), + "TriLView": reflect.ValueOf(matrix.TriLView), + "TriU": reflect.ValueOf(matrix.TriU), + "TriUIndicies": reflect.ValueOf(matrix.TriUIndicies), + "TriUNum": reflect.ValueOf(matrix.TriUNum), + "TriUView": reflect.ValueOf(matrix.TriUView), + "TriUpper": reflect.ValueOf(matrix.TriUpper), + + // type definitions + "Matrix": reflect.ValueOf((*matrix.Matrix)(nil)), + "Symmetric": reflect.ValueOf((*matrix.Symmetric)(nil)), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-stats-metric.go b/yaegicore/nogui/cogentcore_org-core-tensor-stats-metric.go new file mode 100644 index 0000000000..0071bc5554 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-stats-metric.go @@ -0,0 +1,80 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/stats/metric'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/stats/metric" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/stats/metric/metric"] = map[string]reflect.Value{ + // function, constant and variable definitions + "AsMetricFunc": reflect.ValueOf(metric.AsMetricFunc), + "AsMetricOutFunc": reflect.ValueOf(metric.AsMetricOutFunc), + "ClosestRow": reflect.ValueOf(metric.ClosestRow), + "ClosestRowOut": reflect.ValueOf(metric.ClosestRowOut), + "Correlation": reflect.ValueOf(metric.Correlation), + "CorrelationOut": reflect.ValueOf(metric.CorrelationOut), + "CorrelationOut64": reflect.ValueOf(metric.CorrelationOut64), + "Cosine": reflect.ValueOf(metric.Cosine), + "CosineOut": reflect.ValueOf(metric.CosineOut), + "CosineOut64": reflect.ValueOf(metric.CosineOut64), + "Covariance": reflect.ValueOf(metric.Covariance), + "CovarianceMatrix": reflect.ValueOf(metric.CovarianceMatrix), + "CovarianceMatrixOut": reflect.ValueOf(metric.CovarianceMatrixOut), + "CovarianceOut": reflect.ValueOf(metric.CovarianceOut), + "CrossEntropy": reflect.ValueOf(metric.CrossEntropy), + "CrossEntropyOut": reflect.ValueOf(metric.CrossEntropyOut), + "CrossMatrix": reflect.ValueOf(metric.CrossMatrix), + "CrossMatrixOut": reflect.ValueOf(metric.CrossMatrixOut), + "DotProduct": reflect.ValueOf(metric.DotProduct), + "DotProductOut": reflect.ValueOf(metric.DotProductOut), + "Hamming": reflect.ValueOf(metric.Hamming), + "HammingOut": reflect.ValueOf(metric.HammingOut), + "InvCorrelation": reflect.ValueOf(metric.InvCorrelation), + "InvCorrelationOut": reflect.ValueOf(metric.InvCorrelationOut), + "InvCosine": reflect.ValueOf(metric.InvCosine), + "InvCosineOut": reflect.ValueOf(metric.InvCosineOut), + "L1Norm": reflect.ValueOf(metric.L1Norm), + "L1NormOut": reflect.ValueOf(metric.L1NormOut), + "L2Norm": reflect.ValueOf(metric.L2Norm), + "L2NormBinTol": reflect.ValueOf(metric.L2NormBinTol), + "L2NormBinTolOut": reflect.ValueOf(metric.L2NormBinTolOut), + "L2NormOut": reflect.ValueOf(metric.L2NormOut), + "Matrix": reflect.ValueOf(metric.Matrix), + "MatrixOut": reflect.ValueOf(metric.MatrixOut), + "MetricCorrelation": reflect.ValueOf(metric.MetricCorrelation), + "MetricCosine": reflect.ValueOf(metric.MetricCosine), + "MetricCovariance": reflect.ValueOf(metric.MetricCovariance), + "MetricCrossEntropy": reflect.ValueOf(metric.MetricCrossEntropy), + "MetricDotProduct": reflect.ValueOf(metric.MetricDotProduct), + "MetricHamming": reflect.ValueOf(metric.MetricHamming), + "MetricInvCorrelation": reflect.ValueOf(metric.MetricInvCorrelation), + "MetricInvCosine": reflect.ValueOf(metric.MetricInvCosine), + "MetricL1Norm": reflect.ValueOf(metric.MetricL1Norm), + "MetricL2Norm": reflect.ValueOf(metric.MetricL2Norm), + "MetricL2NormBinTol": reflect.ValueOf(metric.MetricL2NormBinTol), + "MetricSumSquares": reflect.ValueOf(metric.MetricSumSquares), + "MetricSumSquaresBinTol": reflect.ValueOf(metric.MetricSumSquaresBinTol), + "MetricsN": reflect.ValueOf(metric.MetricsN), + "MetricsValues": reflect.ValueOf(metric.MetricsValues), + "SumSquares": reflect.ValueOf(metric.SumSquares), + "SumSquaresBinTol": reflect.ValueOf(metric.SumSquaresBinTol), + "SumSquaresBinTolOut": reflect.ValueOf(metric.SumSquaresBinTolOut), + "SumSquaresBinTolScaleOut64": reflect.ValueOf(metric.SumSquaresBinTolScaleOut64), + "SumSquaresOut": reflect.ValueOf(metric.SumSquaresOut), + "SumSquaresOut64": reflect.ValueOf(metric.SumSquaresOut64), + "SumSquaresScaleOut64": reflect.ValueOf(metric.SumSquaresScaleOut64), + "Vectorize2Out64": reflect.ValueOf(metric.Vectorize2Out64), + "Vectorize3Out64": reflect.ValueOf(metric.Vectorize3Out64), + "VectorizeOut64": reflect.ValueOf(metric.VectorizeOut64), + "VectorizePre3Out64": reflect.ValueOf(metric.VectorizePre3Out64), + "VectorizePreOut64": reflect.ValueOf(metric.VectorizePreOut64), + + // type definitions + "MetricFunc": reflect.ValueOf((*metric.MetricFunc)(nil)), + "MetricOutFunc": reflect.ValueOf((*metric.MetricOutFunc)(nil)), + "Metrics": reflect.ValueOf((*metric.Metrics)(nil)), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-stats-stats.go b/yaegicore/nogui/cogentcore_org-core-tensor-stats-stats.go new file mode 100644 index 0000000000..d30f8ea304 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-stats-stats.go @@ -0,0 +1,119 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/stats/stats'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/stats/stats" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/stats/stats/stats"] = map[string]reflect.Value{ + // function, constant and variable definitions + "AsStatsFunc": reflect.ValueOf(stats.AsStatsFunc), + "Binarize": reflect.ValueOf(stats.Binarize), + "BinarizeOut": reflect.ValueOf(stats.BinarizeOut), + "Clamp": reflect.ValueOf(stats.Clamp), + "ClampOut": reflect.ValueOf(stats.ClampOut), + "Count": reflect.ValueOf(stats.Count), + "CountOut": reflect.ValueOf(stats.CountOut), + "CountOut64": reflect.ValueOf(stats.CountOut64), + "Describe": reflect.ValueOf(stats.Describe), + "DescribeTable": reflect.ValueOf(stats.DescribeTable), + "DescribeTableAll": reflect.ValueOf(stats.DescribeTableAll), + "DescriptiveStats": reflect.ValueOf(&stats.DescriptiveStats).Elem(), + "GroupAll": reflect.ValueOf(stats.GroupAll), + "GroupDescribe": reflect.ValueOf(stats.GroupDescribe), + "GroupStats": reflect.ValueOf(stats.GroupStats), + "GroupStatsAsTable": reflect.ValueOf(stats.GroupStatsAsTable), + "GroupStatsAsTableNoStatName": reflect.ValueOf(stats.GroupStatsAsTableNoStatName), + "Groups": reflect.ValueOf(stats.Groups), + "L1Norm": reflect.ValueOf(stats.L1Norm), + "L1NormOut": reflect.ValueOf(stats.L1NormOut), + "L2Norm": reflect.ValueOf(stats.L2Norm), + "L2NormOut": reflect.ValueOf(stats.L2NormOut), + "L2NormOut64": reflect.ValueOf(stats.L2NormOut64), + "Max": reflect.ValueOf(stats.Max), + "MaxAbs": reflect.ValueOf(stats.MaxAbs), + "MaxAbsOut": reflect.ValueOf(stats.MaxAbsOut), + "MaxOut": reflect.ValueOf(stats.MaxOut), + "Mean": reflect.ValueOf(stats.Mean), + "MeanOut": reflect.ValueOf(stats.MeanOut), + "MeanOut64": reflect.ValueOf(stats.MeanOut64), + "Median": reflect.ValueOf(stats.Median), + "MedianOut": reflect.ValueOf(stats.MedianOut), + "Min": reflect.ValueOf(stats.Min), + "MinAbs": reflect.ValueOf(stats.MinAbs), + "MinAbsOut": reflect.ValueOf(stats.MinAbsOut), + "MinOut": reflect.ValueOf(stats.MinOut), + "Prod": reflect.ValueOf(stats.Prod), + "ProdOut": reflect.ValueOf(stats.ProdOut), + "Q1": reflect.ValueOf(stats.Q1), + "Q1Out": reflect.ValueOf(stats.Q1Out), + "Q3": reflect.ValueOf(stats.Q3), + "Q3Out": reflect.ValueOf(stats.Q3Out), + "Quantiles": reflect.ValueOf(stats.Quantiles), + "QuantilesOut": reflect.ValueOf(stats.QuantilesOut), + "Sem": reflect.ValueOf(stats.Sem), + "SemOut": reflect.ValueOf(stats.SemOut), + "SemPop": reflect.ValueOf(stats.SemPop), + "SemPopOut": reflect.ValueOf(stats.SemPopOut), + "StatCount": reflect.ValueOf(stats.StatCount), + "StatL1Norm": reflect.ValueOf(stats.StatL1Norm), + "StatL2Norm": reflect.ValueOf(stats.StatL2Norm), + "StatMax": reflect.ValueOf(stats.StatMax), + "StatMaxAbs": reflect.ValueOf(stats.StatMaxAbs), + "StatMean": reflect.ValueOf(stats.StatMean), + "StatMedian": reflect.ValueOf(stats.StatMedian), + "StatMin": reflect.ValueOf(stats.StatMin), + "StatMinAbs": reflect.ValueOf(stats.StatMinAbs), + "StatProd": reflect.ValueOf(stats.StatProd), + "StatQ1": reflect.ValueOf(stats.StatQ1), + "StatQ3": reflect.ValueOf(stats.StatQ3), + "StatSem": reflect.ValueOf(stats.StatSem), + "StatSemPop": reflect.ValueOf(stats.StatSemPop), + "StatStd": reflect.ValueOf(stats.StatStd), + "StatStdPop": reflect.ValueOf(stats.StatStdPop), + "StatSum": reflect.ValueOf(stats.StatSum), + "StatSumSq": reflect.ValueOf(stats.StatSumSq), + "StatVar": reflect.ValueOf(stats.StatVar), + "StatVarPop": reflect.ValueOf(stats.StatVarPop), + "StatsN": reflect.ValueOf(stats.StatsN), + "StatsValues": reflect.ValueOf(stats.StatsValues), + "Std": reflect.ValueOf(stats.Std), + "StdOut": reflect.ValueOf(stats.StdOut), + "StdOut64": reflect.ValueOf(stats.StdOut64), + "StdPop": reflect.ValueOf(stats.StdPop), + "StdPopOut": reflect.ValueOf(stats.StdPopOut), + "StripPackage": reflect.ValueOf(stats.StripPackage), + "Sum": reflect.ValueOf(stats.Sum), + "SumOut": reflect.ValueOf(stats.SumOut), + "SumOut64": reflect.ValueOf(stats.SumOut64), + "SumSq": reflect.ValueOf(stats.SumSq), + "SumSqDevOut64": reflect.ValueOf(stats.SumSqDevOut64), + "SumSqOut": reflect.ValueOf(stats.SumSqOut), + "SumSqOut64": reflect.ValueOf(stats.SumSqOut64), + "SumSqScaleOut64": reflect.ValueOf(stats.SumSqScaleOut64), + "TableGroupDescribe": reflect.ValueOf(stats.TableGroupDescribe), + "TableGroupStats": reflect.ValueOf(stats.TableGroupStats), + "TableGroups": reflect.ValueOf(stats.TableGroups), + "UnitNorm": reflect.ValueOf(stats.UnitNorm), + "UnitNormOut": reflect.ValueOf(stats.UnitNormOut), + "Var": reflect.ValueOf(stats.Var), + "VarOut": reflect.ValueOf(stats.VarOut), + "VarOut64": reflect.ValueOf(stats.VarOut64), + "VarPop": reflect.ValueOf(stats.VarPop), + "VarPopOut": reflect.ValueOf(stats.VarPopOut), + "VarPopOut64": reflect.ValueOf(stats.VarPopOut64), + "Vectorize2Out64": reflect.ValueOf(stats.Vectorize2Out64), + "VectorizeOut64": reflect.ValueOf(stats.VectorizeOut64), + "VectorizePreOut64": reflect.ValueOf(stats.VectorizePreOut64), + "ZScore": reflect.ValueOf(stats.ZScore), + "ZScoreOut": reflect.ValueOf(stats.ZScoreOut), + + // type definitions + "Stats": reflect.ValueOf((*stats.Stats)(nil)), + "StatsFunc": reflect.ValueOf((*stats.StatsFunc)(nil)), + "StatsOutFunc": reflect.ValueOf((*stats.StatsOutFunc)(nil)), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-table.go b/yaegicore/nogui/cogentcore_org-core-tensor-table.go new file mode 100644 index 0000000000..6e5fbd455b --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-table.go @@ -0,0 +1,37 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/table'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/table" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/table/table"] = map[string]reflect.Value{ + // function, constant and variable definitions + "CleanCatTSV": reflect.ValueOf(table.CleanCatTSV), + "ConfigFromDataValues": reflect.ValueOf(table.ConfigFromDataValues), + "ConfigFromHeaders": reflect.ValueOf(table.ConfigFromHeaders), + "ConfigFromTableHeaders": reflect.ValueOf(table.ConfigFromTableHeaders), + "DetectTableHeaders": reflect.ValueOf(table.DetectTableHeaders), + "ErrLogNoNewRows": reflect.ValueOf(&table.ErrLogNoNewRows).Elem(), + "Headers": reflect.ValueOf(table.Headers), + "InferDataType": reflect.ValueOf(table.InferDataType), + "New": reflect.ValueOf(table.New), + "NewColumns": reflect.ValueOf(table.NewColumns), + "NewSliceTable": reflect.ValueOf(table.NewSliceTable), + "NewView": reflect.ValueOf(table.NewView), + "NoHeaders": reflect.ValueOf(table.NoHeaders), + "ShapeFromString": reflect.ValueOf(table.ShapeFromString), + "TableColumnType": reflect.ValueOf(table.TableColumnType), + "TableHeaderChar": reflect.ValueOf(table.TableHeaderChar), + "TableHeaderToType": reflect.ValueOf(&table.TableHeaderToType).Elem(), + "UpdateSliceTable": reflect.ValueOf(table.UpdateSliceTable), + + // type definitions + "Columns": reflect.ValueOf((*table.Columns)(nil)), + "FilterFunc": reflect.ValueOf((*table.FilterFunc)(nil)), + "Table": reflect.ValueOf((*table.Table)(nil)), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-tensorfs.go b/yaegicore/nogui/cogentcore_org-core-tensor-tensorfs.go new file mode 100644 index 0000000000..8fe8657caf --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-tensorfs.go @@ -0,0 +1,38 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/tensorfs'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/tensorfs" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/tensorfs/tensorfs"] = map[string]reflect.Value{ + // function, constant and variable definitions + "Chdir": reflect.ValueOf(tensorfs.Chdir), + "CurDir": reflect.ValueOf(&tensorfs.CurDir).Elem(), + "CurRoot": reflect.ValueOf(&tensorfs.CurRoot).Elem(), + "DirFromTable": reflect.ValueOf(tensorfs.DirFromTable), + "DirOnly": reflect.ValueOf(tensorfs.DirOnly), + "DirTable": reflect.ValueOf(tensorfs.DirTable), + "Get": reflect.ValueOf(tensorfs.Get), + "List": reflect.ValueOf(tensorfs.List), + "Long": reflect.ValueOf(tensorfs.Long), + "Mkdir": reflect.ValueOf(tensorfs.Mkdir), + "NewDir": reflect.ValueOf(tensorfs.NewDir), + "NewForTensor": reflect.ValueOf(tensorfs.NewForTensor), + "Overwrite": reflect.ValueOf(tensorfs.Overwrite), + "Preserve": reflect.ValueOf(tensorfs.Preserve), + "Record": reflect.ValueOf(tensorfs.Record), + "Recursive": reflect.ValueOf(tensorfs.Recursive), + "Set": reflect.ValueOf(tensorfs.Set), + "Short": reflect.ValueOf(tensorfs.Short), + "ValueType": reflect.ValueOf(tensorfs.ValueType), + + // type definitions + "DirFile": reflect.ValueOf((*tensorfs.DirFile)(nil)), + "File": reflect.ValueOf((*tensorfs.File)(nil)), + "Node": reflect.ValueOf((*tensorfs.Node)(nil)), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-tmath.go b/yaegicore/nogui/cogentcore_org-core-tensor-tmath.go new file mode 100644 index 0000000000..a6e8242258 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-tmath.go @@ -0,0 +1,144 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/tmath'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/tmath" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/tmath/tmath"] = map[string]reflect.Value{ + // function, constant and variable definitions + "Abs": reflect.ValueOf(tmath.Abs), + "AbsOut": reflect.ValueOf(tmath.AbsOut), + "Acos": reflect.ValueOf(tmath.Acos), + "AcosOut": reflect.ValueOf(tmath.AcosOut), + "Acosh": reflect.ValueOf(tmath.Acosh), + "AcoshOut": reflect.ValueOf(tmath.AcoshOut), + "Add": reflect.ValueOf(tmath.Add), + "AddAssign": reflect.ValueOf(tmath.AddAssign), + "AddOut": reflect.ValueOf(tmath.AddOut), + "And": reflect.ValueOf(tmath.And), + "AndOut": reflect.ValueOf(tmath.AndOut), + "Asin": reflect.ValueOf(tmath.Asin), + "AsinOut": reflect.ValueOf(tmath.AsinOut), + "Asinh": reflect.ValueOf(tmath.Asinh), + "AsinhOut": reflect.ValueOf(tmath.AsinhOut), + "Assign": reflect.ValueOf(tmath.Assign), + "Atan": reflect.ValueOf(tmath.Atan), + "Atan2": reflect.ValueOf(tmath.Atan2), + "Atan2Out": reflect.ValueOf(tmath.Atan2Out), + "AtanOut": reflect.ValueOf(tmath.AtanOut), + "Atanh": reflect.ValueOf(tmath.Atanh), + "AtanhOut": reflect.ValueOf(tmath.AtanhOut), + "Cbrt": reflect.ValueOf(tmath.Cbrt), + "CbrtOut": reflect.ValueOf(tmath.CbrtOut), + "Ceil": reflect.ValueOf(tmath.Ceil), + "CeilOut": reflect.ValueOf(tmath.CeilOut), + "Copysign": reflect.ValueOf(tmath.Copysign), + "CopysignOut": reflect.ValueOf(tmath.CopysignOut), + "Cos": reflect.ValueOf(tmath.Cos), + "CosOut": reflect.ValueOf(tmath.CosOut), + "Cosh": reflect.ValueOf(tmath.Cosh), + "CoshOut": reflect.ValueOf(tmath.CoshOut), + "Dec": reflect.ValueOf(tmath.Dec), + "Dim": reflect.ValueOf(tmath.Dim), + "DimOut": reflect.ValueOf(tmath.DimOut), + "Div": reflect.ValueOf(tmath.Div), + "DivAssign": reflect.ValueOf(tmath.DivAssign), + "DivOut": reflect.ValueOf(tmath.DivOut), + "Equal": reflect.ValueOf(tmath.Equal), + "EqualOut": reflect.ValueOf(tmath.EqualOut), + "Erf": reflect.ValueOf(tmath.Erf), + "ErfOut": reflect.ValueOf(tmath.ErfOut), + "Erfc": reflect.ValueOf(tmath.Erfc), + "ErfcOut": reflect.ValueOf(tmath.ErfcOut), + "Erfcinv": reflect.ValueOf(tmath.Erfcinv), + "ErfcinvOut": reflect.ValueOf(tmath.ErfcinvOut), + "Erfinv": reflect.ValueOf(tmath.Erfinv), + "ErfinvOut": reflect.ValueOf(tmath.ErfinvOut), + "Exp": reflect.ValueOf(tmath.Exp), + "Exp2": reflect.ValueOf(tmath.Exp2), + "Exp2Out": reflect.ValueOf(tmath.Exp2Out), + "ExpOut": reflect.ValueOf(tmath.ExpOut), + "Expm1": reflect.ValueOf(tmath.Expm1), + "Expm1Out": reflect.ValueOf(tmath.Expm1Out), + "Floor": reflect.ValueOf(tmath.Floor), + "FloorOut": reflect.ValueOf(tmath.FloorOut), + "Gamma": reflect.ValueOf(tmath.Gamma), + "GammaOut": reflect.ValueOf(tmath.GammaOut), + "Greater": reflect.ValueOf(tmath.Greater), + "GreaterEqual": reflect.ValueOf(tmath.GreaterEqual), + "GreaterEqualOut": reflect.ValueOf(tmath.GreaterEqualOut), + "GreaterOut": reflect.ValueOf(tmath.GreaterOut), + "Hypot": reflect.ValueOf(tmath.Hypot), + "HypotOut": reflect.ValueOf(tmath.HypotOut), + "Inc": reflect.ValueOf(tmath.Inc), + "J0": reflect.ValueOf(tmath.J0), + "J0Out": reflect.ValueOf(tmath.J0Out), + "J1": reflect.ValueOf(tmath.J1), + "J1Out": reflect.ValueOf(tmath.J1Out), + "Less": reflect.ValueOf(tmath.Less), + "LessEqual": reflect.ValueOf(tmath.LessEqual), + "LessEqualOut": reflect.ValueOf(tmath.LessEqualOut), + "LessOut": reflect.ValueOf(tmath.LessOut), + "Log": reflect.ValueOf(tmath.Log), + "Log10": reflect.ValueOf(tmath.Log10), + "Log10Out": reflect.ValueOf(tmath.Log10Out), + "Log1p": reflect.ValueOf(tmath.Log1p), + "Log1pOut": reflect.ValueOf(tmath.Log1pOut), + "Log2": reflect.ValueOf(tmath.Log2), + "Log2Out": reflect.ValueOf(tmath.Log2Out), + "LogOut": reflect.ValueOf(tmath.LogOut), + "Logb": reflect.ValueOf(tmath.Logb), + "LogbOut": reflect.ValueOf(tmath.LogbOut), + "Max": reflect.ValueOf(tmath.Max), + "MaxOut": reflect.ValueOf(tmath.MaxOut), + "Min": reflect.ValueOf(tmath.Min), + "MinOut": reflect.ValueOf(tmath.MinOut), + "Mod": reflect.ValueOf(tmath.Mod), + "ModAssign": reflect.ValueOf(tmath.ModAssign), + "ModOut": reflect.ValueOf(tmath.ModOut), + "Mul": reflect.ValueOf(tmath.Mul), + "MulAssign": reflect.ValueOf(tmath.MulAssign), + "MulOut": reflect.ValueOf(tmath.MulOut), + "Negate": reflect.ValueOf(tmath.Negate), + "NegateOut": reflect.ValueOf(tmath.NegateOut), + "Nextafter": reflect.ValueOf(tmath.Nextafter), + "NextafterOut": reflect.ValueOf(tmath.NextafterOut), + "Not": reflect.ValueOf(tmath.Not), + "NotEqual": reflect.ValueOf(tmath.NotEqual), + "NotEqualOut": reflect.ValueOf(tmath.NotEqualOut), + "NotOut": reflect.ValueOf(tmath.NotOut), + "Or": reflect.ValueOf(tmath.Or), + "OrOut": reflect.ValueOf(tmath.OrOut), + "Pow": reflect.ValueOf(tmath.Pow), + "PowOut": reflect.ValueOf(tmath.PowOut), + "Remainder": reflect.ValueOf(tmath.Remainder), + "RemainderOut": reflect.ValueOf(tmath.RemainderOut), + "Round": reflect.ValueOf(tmath.Round), + "RoundOut": reflect.ValueOf(tmath.RoundOut), + "RoundToEven": reflect.ValueOf(tmath.RoundToEven), + "RoundToEvenOut": reflect.ValueOf(tmath.RoundToEvenOut), + "Sin": reflect.ValueOf(tmath.Sin), + "SinOut": reflect.ValueOf(tmath.SinOut), + "Sinh": reflect.ValueOf(tmath.Sinh), + "SinhOut": reflect.ValueOf(tmath.SinhOut), + "Sqrt": reflect.ValueOf(tmath.Sqrt), + "SqrtOut": reflect.ValueOf(tmath.SqrtOut), + "Sub": reflect.ValueOf(tmath.Sub), + "SubAssign": reflect.ValueOf(tmath.SubAssign), + "SubOut": reflect.ValueOf(tmath.SubOut), + "Tan": reflect.ValueOf(tmath.Tan), + "TanOut": reflect.ValueOf(tmath.TanOut), + "Tanh": reflect.ValueOf(tmath.Tanh), + "TanhOut": reflect.ValueOf(tmath.TanhOut), + "Trunc": reflect.ValueOf(tmath.Trunc), + "TruncOut": reflect.ValueOf(tmath.TruncOut), + "Y0": reflect.ValueOf(tmath.Y0), + "Y0Out": reflect.ValueOf(tmath.Y0Out), + "Y1": reflect.ValueOf(tmath.Y1), + "Y1Out": reflect.ValueOf(tmath.Y1Out), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor-vector.go b/yaegicore/nogui/cogentcore_org-core-tensor-vector.go new file mode 100644 index 0000000000..bddb196bf5 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor-vector.go @@ -0,0 +1,20 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/vector'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/tensor/vector" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/vector/vector"] = map[string]reflect.Value{ + // function, constant and variable definitions + "Dot": reflect.ValueOf(vector.Dot), + "L1Norm": reflect.ValueOf(vector.L1Norm), + "L2Norm": reflect.ValueOf(vector.L2Norm), + "Mul": reflect.ValueOf(vector.Mul), + "MulOut": reflect.ValueOf(vector.MulOut), + "Sum": reflect.ValueOf(vector.Sum), + } +} diff --git a/yaegicore/nogui/cogentcore_org-core-tensor.go b/yaegicore/nogui/cogentcore_org-core-tensor.go new file mode 100644 index 0000000000..7ee6f8c679 --- /dev/null +++ b/yaegicore/nogui/cogentcore_org-core-tensor.go @@ -0,0 +1,486 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor'. DO NOT EDIT. + +package nogui + +import ( + "cogentcore.org/core/base/metadata" + "cogentcore.org/core/tensor" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/tensor"] = map[string]reflect.Value{ + // function, constant and variable definitions + "AddFunc": reflect.ValueOf(tensor.AddFunc), + "AddShapes": reflect.ValueOf(tensor.AddShapes), + "AlignForAssign": reflect.ValueOf(tensor.AlignForAssign), + "AlignShapes": reflect.ValueOf(tensor.AlignShapes), + "As1D": reflect.ValueOf(tensor.As1D), + "AsFloat32": reflect.ValueOf(tensor.AsFloat32), + "AsFloat64": reflect.ValueOf(tensor.AsFloat64), + "AsFloat64Scalar": reflect.ValueOf(tensor.AsFloat64Scalar), + "AsFloat64Slice": reflect.ValueOf(tensor.AsFloat64Slice), + "AsIndexed": reflect.ValueOf(tensor.AsIndexed), + "AsInt": reflect.ValueOf(tensor.AsInt), + "AsIntScalar": reflect.ValueOf(tensor.AsIntScalar), + "AsIntSlice": reflect.ValueOf(tensor.AsIntSlice), + "AsMasked": reflect.ValueOf(tensor.AsMasked), + "AsReshaped": reflect.ValueOf(tensor.AsReshaped), + "AsRows": reflect.ValueOf(tensor.AsRows), + "AsSliced": reflect.ValueOf(tensor.AsSliced), + "AsString": reflect.ValueOf(tensor.AsString), + "AsStringScalar": reflect.ValueOf(tensor.AsStringScalar), + "AsStringSlice": reflect.ValueOf(tensor.AsStringSlice), + "Ascending": reflect.ValueOf(tensor.Ascending), + "BoolFloatsFunc": reflect.ValueOf(tensor.BoolFloatsFunc), + "BoolFloatsFuncOut": reflect.ValueOf(tensor.BoolFloatsFuncOut), + "BoolIntsFunc": reflect.ValueOf(tensor.BoolIntsFunc), + "BoolIntsFuncOut": reflect.ValueOf(tensor.BoolIntsFuncOut), + "BoolStringsFunc": reflect.ValueOf(tensor.BoolStringsFunc), + "BoolStringsFuncOut": reflect.ValueOf(tensor.BoolStringsFuncOut), + "BoolToFloat64": reflect.ValueOf(tensor.BoolToFloat64), + "BoolToInt": reflect.ValueOf(tensor.BoolToInt), + "Calc": reflect.ValueOf(tensor.Calc), + "CallOut1": reflect.ValueOf(tensor.CallOut1), + "CallOut1Float64": reflect.ValueOf(tensor.CallOut1Float64), + "CallOut2": reflect.ValueOf(tensor.CallOut2), + "CallOut2Bool": reflect.ValueOf(tensor.CallOut2Bool), + "CallOut2Float64": reflect.ValueOf(tensor.CallOut2Float64), + "CallOut3": reflect.ValueOf(tensor.CallOut3), + "Cells1D": reflect.ValueOf(tensor.Cells1D), + "CellsSize": reflect.ValueOf(tensor.CellsSize), + "Clone": reflect.ValueOf(tensor.Clone), + "ColumnMajorStrides": reflect.ValueOf(tensor.ColumnMajorStrides), + "Comma": reflect.ValueOf(tensor.Comma), + "DefaultNumThreads": reflect.ValueOf(tensor.DefaultNumThreads), + "DelimsN": reflect.ValueOf(tensor.DelimsN), + "DelimsValues": reflect.ValueOf(tensor.DelimsValues), + "Descending": reflect.ValueOf(tensor.Descending), + "Detect": reflect.ValueOf(tensor.Detect), + "Ellipsis": reflect.ValueOf(tensor.Ellipsis), + "Flatten": reflect.ValueOf(tensor.Flatten), + "Float64ToBool": reflect.ValueOf(tensor.Float64ToBool), + "Float64ToString": reflect.ValueOf(tensor.Float64ToString), + "FloatAssignFunc": reflect.ValueOf(tensor.FloatAssignFunc), + "FloatBinaryFunc": reflect.ValueOf(tensor.FloatBinaryFunc), + "FloatBinaryFuncOut": reflect.ValueOf(tensor.FloatBinaryFuncOut), + "FloatFunc": reflect.ValueOf(tensor.FloatFunc), + "FloatFuncOut": reflect.ValueOf(tensor.FloatFuncOut), + "FloatPromoteType": reflect.ValueOf(tensor.FloatPromoteType), + "FloatSetFunc": reflect.ValueOf(tensor.FloatSetFunc), + "FullAxis": reflect.ValueOf(tensor.FullAxis), + "FuncByName": reflect.ValueOf(tensor.FuncByName), + "Funcs": reflect.ValueOf(&tensor.Funcs).Elem(), + "IntToBool": reflect.ValueOf(tensor.IntToBool), + "Mask": reflect.ValueOf(tensor.Mask), + "MaxPrintLineWidth": reflect.ValueOf(&tensor.MaxPrintLineWidth).Elem(), + "MaxSprintLength": reflect.ValueOf(&tensor.MaxSprintLength).Elem(), + "MustBeSameShape": reflect.ValueOf(tensor.MustBeSameShape), + "MustBeValues": reflect.ValueOf(tensor.MustBeValues), + "NFirstLen": reflect.ValueOf(tensor.NFirstLen), + "NFirstRows": reflect.ValueOf(tensor.NFirstRows), + "NMinLen": reflect.ValueOf(tensor.NMinLen), + "NegIndex": reflect.ValueOf(tensor.NegIndex), + "NewAxis": reflect.ValueOf(tensor.NewAxis), + "NewBool": reflect.ValueOf(tensor.NewBool), + "NewBoolShape": reflect.ValueOf(tensor.NewBoolShape), + "NewByte": reflect.ValueOf(tensor.NewByte), + "NewFloat32": reflect.ValueOf(tensor.NewFloat32), + "NewFloat32FromValues": reflect.ValueOf(tensor.NewFloat32FromValues), + "NewFloat32Scalar": reflect.ValueOf(tensor.NewFloat32Scalar), + "NewFloat64": reflect.ValueOf(tensor.NewFloat64), + "NewFloat64FromValues": reflect.ValueOf(tensor.NewFloat64FromValues), + "NewFloat64Full": reflect.ValueOf(tensor.NewFloat64Full), + "NewFloat64Ones": reflect.ValueOf(tensor.NewFloat64Ones), + "NewFloat64Rand": reflect.ValueOf(tensor.NewFloat64Rand), + "NewFloat64Scalar": reflect.ValueOf(tensor.NewFloat64Scalar), + "NewFloat64SpacedLinear": reflect.ValueOf(tensor.NewFloat64SpacedLinear), + "NewFunc": reflect.ValueOf(tensor.NewFunc), + "NewIndexed": reflect.ValueOf(tensor.NewIndexed), + "NewInt": reflect.ValueOf(tensor.NewInt), + "NewInt32": reflect.ValueOf(tensor.NewInt32), + "NewIntFromValues": reflect.ValueOf(tensor.NewIntFromValues), + "NewIntFull": reflect.ValueOf(tensor.NewIntFull), + "NewIntRange": reflect.ValueOf(tensor.NewIntRange), + "NewIntScalar": reflect.ValueOf(tensor.NewIntScalar), + "NewMasked": reflect.ValueOf(tensor.NewMasked), + "NewOfType": reflect.ValueOf(tensor.NewOfType), + "NewReshaped": reflect.ValueOf(tensor.NewReshaped), + "NewRowCellsView": reflect.ValueOf(tensor.NewRowCellsView), + "NewRows": reflect.ValueOf(tensor.NewRows), + "NewShape": reflect.ValueOf(tensor.NewShape), + "NewSlice": reflect.ValueOf(tensor.NewSlice), + "NewSliced": reflect.ValueOf(tensor.NewSliced), + "NewString": reflect.ValueOf(tensor.NewString), + "NewStringFromValues": reflect.ValueOf(tensor.NewStringFromValues), + "NewStringFull": reflect.ValueOf(tensor.NewStringFull), + "NewStringScalar": reflect.ValueOf(tensor.NewStringScalar), + "NewStringShape": reflect.ValueOf(tensor.NewStringShape), + "NewUint32": reflect.ValueOf(tensor.NewUint32), + "NumThreads": reflect.ValueOf(&tensor.NumThreads).Elem(), + "OnedColumn": reflect.ValueOf(tensor.OnedColumn), + "OnedRow": reflect.ValueOf(tensor.OnedRow), + "OpenCSV": reflect.ValueOf(tensor.OpenCSV), + "Precision": reflect.ValueOf(tensor.Precision), + "Projection2DCoords": reflect.ValueOf(tensor.Projection2DCoords), + "Projection2DDimShapes": reflect.ValueOf(tensor.Projection2DDimShapes), + "Projection2DIndex": reflect.ValueOf(tensor.Projection2DIndex), + "Projection2DSet": reflect.ValueOf(tensor.Projection2DSet), + "Projection2DSetString": reflect.ValueOf(tensor.Projection2DSetString), + "Projection2DShape": reflect.ValueOf(tensor.Projection2DShape), + "Projection2DString": reflect.ValueOf(tensor.Projection2DString), + "Projection2DValue": reflect.ValueOf(tensor.Projection2DValue), + "Range": reflect.ValueOf(tensor.Range), + "ReadCSV": reflect.ValueOf(tensor.ReadCSV), + "Reshape": reflect.ValueOf(tensor.Reshape), + "Reslice": reflect.ValueOf(tensor.Reslice), + "RowMajorStrides": reflect.ValueOf(tensor.RowMajorStrides), + "SaveCSV": reflect.ValueOf(tensor.SaveCSV), + "SetAllFloat64": reflect.ValueOf(tensor.SetAllFloat64), + "SetAllInt": reflect.ValueOf(tensor.SetAllInt), + "SetAllString": reflect.ValueOf(tensor.SetAllString), + "SetCalcFunc": reflect.ValueOf(tensor.SetCalcFunc), + "SetPrecision": reflect.ValueOf(tensor.SetPrecision), + "SetShape": reflect.ValueOf(tensor.SetShape), + "SetShapeFrom": reflect.ValueOf(tensor.SetShapeFrom), + "SetShapeNames": reflect.ValueOf(tensor.SetShapeNames), + "SetShapeSizesFromTensor": reflect.ValueOf(tensor.SetShapeSizesFromTensor), + "ShapeNames": reflect.ValueOf(tensor.ShapeNames), + "SlicesMagicN": reflect.ValueOf(tensor.SlicesMagicN), + "SlicesMagicValues": reflect.ValueOf(tensor.SlicesMagicValues), + "Space": reflect.ValueOf(tensor.Space), + "SplitAtInnerDims": reflect.ValueOf(tensor.SplitAtInnerDims), + "Sprintf": reflect.ValueOf(tensor.Sprintf), + "Squeeze": reflect.ValueOf(tensor.Squeeze), + "StableSort": reflect.ValueOf(tensor.StableSort), + "StringAssignFunc": reflect.ValueOf(tensor.StringAssignFunc), + "StringBinaryFunc": reflect.ValueOf(tensor.StringBinaryFunc), + "StringBinaryFuncOut": reflect.ValueOf(tensor.StringBinaryFuncOut), + "StringToFloat64": reflect.ValueOf(tensor.StringToFloat64), + "Tab": reflect.ValueOf(tensor.Tab), + "ThreadingThreshold": reflect.ValueOf(&tensor.ThreadingThreshold).Elem(), + "Transpose": reflect.ValueOf(tensor.Transpose), + "UnstableSort": reflect.ValueOf(tensor.UnstableSort), + "Vectorize": reflect.ValueOf(tensor.Vectorize), + "VectorizeOnThreads": reflect.ValueOf(tensor.VectorizeOnThreads), + "VectorizeThreaded": reflect.ValueOf(tensor.VectorizeThreaded), + "WrapIndex1D": reflect.ValueOf(tensor.WrapIndex1D), + "WriteCSV": reflect.ValueOf(tensor.WriteCSV), + + // type definitions + "Arg": reflect.ValueOf((*tensor.Arg)(nil)), + "Bool": reflect.ValueOf((*tensor.Bool)(nil)), + "Delims": reflect.ValueOf((*tensor.Delims)(nil)), + "FilterFunc": reflect.ValueOf((*tensor.FilterFunc)(nil)), + "FilterOptions": reflect.ValueOf((*tensor.FilterOptions)(nil)), + "Func": reflect.ValueOf((*tensor.Func)(nil)), + "Indexed": reflect.ValueOf((*tensor.Indexed)(nil)), + "Masked": reflect.ValueOf((*tensor.Masked)(nil)), + "Reshaped": reflect.ValueOf((*tensor.Reshaped)(nil)), + "RowMajor": reflect.ValueOf((*tensor.RowMajor)(nil)), + "Rows": reflect.ValueOf((*tensor.Rows)(nil)), + "Shape": reflect.ValueOf((*tensor.Shape)(nil)), + "Slice": reflect.ValueOf((*tensor.Slice)(nil)), + "Sliced": reflect.ValueOf((*tensor.Sliced)(nil)), + "SlicesMagic": reflect.ValueOf((*tensor.SlicesMagic)(nil)), + "String": reflect.ValueOf((*tensor.String)(nil)), + "Tensor": reflect.ValueOf((*tensor.Tensor)(nil)), + "Values": reflect.ValueOf((*tensor.Values)(nil)), + + // interface wrapper definitions + "_RowMajor": reflect.ValueOf((*_cogentcore_org_core_tensor_RowMajor)(nil)), + "_Tensor": reflect.ValueOf((*_cogentcore_org_core_tensor_Tensor)(nil)), + "_Values": reflect.ValueOf((*_cogentcore_org_core_tensor_Values)(nil)), + } +} + +// _cogentcore_org_core_tensor_RowMajor is an interface wrapper for RowMajor type +type _cogentcore_org_core_tensor_RowMajor struct { + IValue interface{} + WAppendRow func(val tensor.Values) + WAppendRowFloat func(val ...float64) + WAppendRowInt func(val ...int) + WAppendRowString func(val ...string) + WAsValues func() tensor.Values + WDataType func() reflect.Kind + WDimSize func(dim int) int + WFloat func(i ...int) float64 + WFloat1D func(i int) float64 + WFloatRow func(row int, cell int) float64 + WInt func(i ...int) int + WInt1D func(i int) int + WIntRow func(row int, cell int) int + WIsString func() bool + WLabel func() string + WLen func() int + WMetadata func() *metadata.Data + WNumDims func() int + WRowTensor func(row int) tensor.Values + WSetFloat func(val float64, i ...int) + WSetFloat1D func(val float64, i int) + WSetFloatRow func(val float64, row int, cell int) + WSetInt func(val int, i ...int) + WSetInt1D func(val int, i int) + WSetIntRow func(val int, row int, cell int) + WSetRowTensor func(val tensor.Values, row int) + WSetString func(val string, i ...int) + WSetString1D func(val string, i int) + WSetStringRow func(val string, row int, cell int) + WShape func() *tensor.Shape + WShapeSizes func() []int + WString func() string + WString1D func(i int) string + WStringRow func(row int, cell int) string + WStringValue func(i ...int) string + WSubSpace func(offs ...int) tensor.Values +} + +func (W _cogentcore_org_core_tensor_RowMajor) AppendRow(val tensor.Values) { W.WAppendRow(val) } +func (W _cogentcore_org_core_tensor_RowMajor) AppendRowFloat(val ...float64) { + W.WAppendRowFloat(val...) +} +func (W _cogentcore_org_core_tensor_RowMajor) AppendRowInt(val ...int) { W.WAppendRowInt(val...) } +func (W _cogentcore_org_core_tensor_RowMajor) AppendRowString(val ...string) { + W.WAppendRowString(val...) +} +func (W _cogentcore_org_core_tensor_RowMajor) AsValues() tensor.Values { return W.WAsValues() } +func (W _cogentcore_org_core_tensor_RowMajor) DataType() reflect.Kind { return W.WDataType() } +func (W _cogentcore_org_core_tensor_RowMajor) DimSize(dim int) int { return W.WDimSize(dim) } +func (W _cogentcore_org_core_tensor_RowMajor) Float(i ...int) float64 { return W.WFloat(i...) } +func (W _cogentcore_org_core_tensor_RowMajor) Float1D(i int) float64 { return W.WFloat1D(i) } +func (W _cogentcore_org_core_tensor_RowMajor) FloatRow(row int, cell int) float64 { + return W.WFloatRow(row, cell) +} +func (W _cogentcore_org_core_tensor_RowMajor) Int(i ...int) int { return W.WInt(i...) } +func (W _cogentcore_org_core_tensor_RowMajor) Int1D(i int) int { return W.WInt1D(i) } +func (W _cogentcore_org_core_tensor_RowMajor) IntRow(row int, cell int) int { + return W.WIntRow(row, cell) +} +func (W _cogentcore_org_core_tensor_RowMajor) IsString() bool { return W.WIsString() } +func (W _cogentcore_org_core_tensor_RowMajor) Label() string { return W.WLabel() } +func (W _cogentcore_org_core_tensor_RowMajor) Len() int { return W.WLen() } +func (W _cogentcore_org_core_tensor_RowMajor) Metadata() *metadata.Data { return W.WMetadata() } +func (W _cogentcore_org_core_tensor_RowMajor) NumDims() int { return W.WNumDims() } +func (W _cogentcore_org_core_tensor_RowMajor) RowTensor(row int) tensor.Values { + return W.WRowTensor(row) +} +func (W _cogentcore_org_core_tensor_RowMajor) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) } +func (W _cogentcore_org_core_tensor_RowMajor) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) } +func (W _cogentcore_org_core_tensor_RowMajor) SetFloatRow(val float64, row int, cell int) { + W.WSetFloatRow(val, row, cell) +} +func (W _cogentcore_org_core_tensor_RowMajor) SetInt(val int, i ...int) { W.WSetInt(val, i...) } +func (W _cogentcore_org_core_tensor_RowMajor) SetInt1D(val int, i int) { W.WSetInt1D(val, i) } +func (W _cogentcore_org_core_tensor_RowMajor) SetIntRow(val int, row int, cell int) { + W.WSetIntRow(val, row, cell) +} +func (W _cogentcore_org_core_tensor_RowMajor) SetRowTensor(val tensor.Values, row int) { + W.WSetRowTensor(val, row) +} +func (W _cogentcore_org_core_tensor_RowMajor) SetString(val string, i ...int) { + W.WSetString(val, i...) +} +func (W _cogentcore_org_core_tensor_RowMajor) SetString1D(val string, i int) { W.WSetString1D(val, i) } +func (W _cogentcore_org_core_tensor_RowMajor) SetStringRow(val string, row int, cell int) { + W.WSetStringRow(val, row, cell) +} +func (W _cogentcore_org_core_tensor_RowMajor) Shape() *tensor.Shape { return W.WShape() } +func (W _cogentcore_org_core_tensor_RowMajor) ShapeSizes() []int { return W.WShapeSizes() } +func (W _cogentcore_org_core_tensor_RowMajor) String() string { + if W.WString == nil { + return "" + } + return W.WString() +} +func (W _cogentcore_org_core_tensor_RowMajor) String1D(i int) string { return W.WString1D(i) } +func (W _cogentcore_org_core_tensor_RowMajor) StringRow(row int, cell int) string { + return W.WStringRow(row, cell) +} +func (W _cogentcore_org_core_tensor_RowMajor) StringValue(i ...int) string { + return W.WStringValue(i...) +} +func (W _cogentcore_org_core_tensor_RowMajor) SubSpace(offs ...int) tensor.Values { + return W.WSubSpace(offs...) +} + +// _cogentcore_org_core_tensor_Tensor is an interface wrapper for Tensor type +type _cogentcore_org_core_tensor_Tensor struct { + IValue interface{} + WAsValues func() tensor.Values + WDataType func() reflect.Kind + WDimSize func(dim int) int + WFloat func(i ...int) float64 + WFloat1D func(i int) float64 + WInt func(i ...int) int + WInt1D func(i int) int + WIsString func() bool + WLabel func() string + WLen func() int + WMetadata func() *metadata.Data + WNumDims func() int + WSetFloat func(val float64, i ...int) + WSetFloat1D func(val float64, i int) + WSetInt func(val int, i ...int) + WSetInt1D func(val int, i int) + WSetString func(val string, i ...int) + WSetString1D func(val string, i int) + WShape func() *tensor.Shape + WShapeSizes func() []int + WString func() string + WString1D func(i int) string + WStringValue func(i ...int) string +} + +func (W _cogentcore_org_core_tensor_Tensor) AsValues() tensor.Values { return W.WAsValues() } +func (W _cogentcore_org_core_tensor_Tensor) DataType() reflect.Kind { return W.WDataType() } +func (W _cogentcore_org_core_tensor_Tensor) DimSize(dim int) int { return W.WDimSize(dim) } +func (W _cogentcore_org_core_tensor_Tensor) Float(i ...int) float64 { return W.WFloat(i...) } +func (W _cogentcore_org_core_tensor_Tensor) Float1D(i int) float64 { return W.WFloat1D(i) } +func (W _cogentcore_org_core_tensor_Tensor) Int(i ...int) int { return W.WInt(i...) } +func (W _cogentcore_org_core_tensor_Tensor) Int1D(i int) int { return W.WInt1D(i) } +func (W _cogentcore_org_core_tensor_Tensor) IsString() bool { return W.WIsString() } +func (W _cogentcore_org_core_tensor_Tensor) Label() string { return W.WLabel() } +func (W _cogentcore_org_core_tensor_Tensor) Len() int { return W.WLen() } +func (W _cogentcore_org_core_tensor_Tensor) Metadata() *metadata.Data { return W.WMetadata() } +func (W _cogentcore_org_core_tensor_Tensor) NumDims() int { return W.WNumDims() } +func (W _cogentcore_org_core_tensor_Tensor) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) } +func (W _cogentcore_org_core_tensor_Tensor) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) } +func (W _cogentcore_org_core_tensor_Tensor) SetInt(val int, i ...int) { W.WSetInt(val, i...) } +func (W _cogentcore_org_core_tensor_Tensor) SetInt1D(val int, i int) { W.WSetInt1D(val, i) } +func (W _cogentcore_org_core_tensor_Tensor) SetString(val string, i ...int) { W.WSetString(val, i...) } +func (W _cogentcore_org_core_tensor_Tensor) SetString1D(val string, i int) { W.WSetString1D(val, i) } +func (W _cogentcore_org_core_tensor_Tensor) Shape() *tensor.Shape { return W.WShape() } +func (W _cogentcore_org_core_tensor_Tensor) ShapeSizes() []int { return W.WShapeSizes() } +func (W _cogentcore_org_core_tensor_Tensor) String() string { + if W.WString == nil { + return "" + } + return W.WString() +} +func (W _cogentcore_org_core_tensor_Tensor) String1D(i int) string { return W.WString1D(i) } +func (W _cogentcore_org_core_tensor_Tensor) StringValue(i ...int) string { return W.WStringValue(i...) } + +// _cogentcore_org_core_tensor_Values is an interface wrapper for Values type +type _cogentcore_org_core_tensor_Values struct { + IValue interface{} + WAppendFrom func(from tensor.Values) error + WAppendRow func(val tensor.Values) + WAppendRowFloat func(val ...float64) + WAppendRowInt func(val ...int) + WAppendRowString func(val ...string) + WAsValues func() tensor.Values + WBytes func() []byte + WClone func() tensor.Values + WCopyCellsFrom func(from tensor.Values, to int, start int, n int) + WCopyFrom func(from tensor.Values) + WDataType func() reflect.Kind + WDimSize func(dim int) int + WFloat func(i ...int) float64 + WFloat1D func(i int) float64 + WFloatRow func(row int, cell int) float64 + WInt func(i ...int) int + WInt1D func(i int) int + WIntRow func(row int, cell int) int + WIsString func() bool + WLabel func() string + WLen func() int + WMetadata func() *metadata.Data + WNumDims func() int + WRowTensor func(row int) tensor.Values + WSetFloat func(val float64, i ...int) + WSetFloat1D func(val float64, i int) + WSetFloatRow func(val float64, row int, cell int) + WSetInt func(val int, i ...int) + WSetInt1D func(val int, i int) + WSetIntRow func(val int, row int, cell int) + WSetNumRows func(rows int) + WSetRowTensor func(val tensor.Values, row int) + WSetShapeSizes func(sizes ...int) + WSetString func(val string, i ...int) + WSetString1D func(val string, i int) + WSetStringRow func(val string, row int, cell int) + WSetZeros func() + WShape func() *tensor.Shape + WShapeSizes func() []int + WSizeof func() int64 + WString func() string + WString1D func(i int) string + WStringRow func(row int, cell int) string + WStringValue func(i ...int) string + WSubSpace func(offs ...int) tensor.Values +} + +func (W _cogentcore_org_core_tensor_Values) AppendFrom(from tensor.Values) error { + return W.WAppendFrom(from) +} +func (W _cogentcore_org_core_tensor_Values) AppendRow(val tensor.Values) { W.WAppendRow(val) } +func (W _cogentcore_org_core_tensor_Values) AppendRowFloat(val ...float64) { W.WAppendRowFloat(val...) } +func (W _cogentcore_org_core_tensor_Values) AppendRowInt(val ...int) { W.WAppendRowInt(val...) } +func (W _cogentcore_org_core_tensor_Values) AppendRowString(val ...string) { + W.WAppendRowString(val...) +} +func (W _cogentcore_org_core_tensor_Values) AsValues() tensor.Values { return W.WAsValues() } +func (W _cogentcore_org_core_tensor_Values) Bytes() []byte { return W.WBytes() } +func (W _cogentcore_org_core_tensor_Values) Clone() tensor.Values { return W.WClone() } +func (W _cogentcore_org_core_tensor_Values) CopyCellsFrom(from tensor.Values, to int, start int, n int) { + W.WCopyCellsFrom(from, to, start, n) +} +func (W _cogentcore_org_core_tensor_Values) CopyFrom(from tensor.Values) { W.WCopyFrom(from) } +func (W _cogentcore_org_core_tensor_Values) DataType() reflect.Kind { return W.WDataType() } +func (W _cogentcore_org_core_tensor_Values) DimSize(dim int) int { return W.WDimSize(dim) } +func (W _cogentcore_org_core_tensor_Values) Float(i ...int) float64 { return W.WFloat(i...) } +func (W _cogentcore_org_core_tensor_Values) Float1D(i int) float64 { return W.WFloat1D(i) } +func (W _cogentcore_org_core_tensor_Values) FloatRow(row int, cell int) float64 { + return W.WFloatRow(row, cell) +} +func (W _cogentcore_org_core_tensor_Values) Int(i ...int) int { return W.WInt(i...) } +func (W _cogentcore_org_core_tensor_Values) Int1D(i int) int { return W.WInt1D(i) } +func (W _cogentcore_org_core_tensor_Values) IntRow(row int, cell int) int { + return W.WIntRow(row, cell) +} +func (W _cogentcore_org_core_tensor_Values) IsString() bool { return W.WIsString() } +func (W _cogentcore_org_core_tensor_Values) Label() string { return W.WLabel() } +func (W _cogentcore_org_core_tensor_Values) Len() int { return W.WLen() } +func (W _cogentcore_org_core_tensor_Values) Metadata() *metadata.Data { return W.WMetadata() } +func (W _cogentcore_org_core_tensor_Values) NumDims() int { return W.WNumDims() } +func (W _cogentcore_org_core_tensor_Values) RowTensor(row int) tensor.Values { + return W.WRowTensor(row) +} +func (W _cogentcore_org_core_tensor_Values) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) } +func (W _cogentcore_org_core_tensor_Values) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) } +func (W _cogentcore_org_core_tensor_Values) SetFloatRow(val float64, row int, cell int) { + W.WSetFloatRow(val, row, cell) +} +func (W _cogentcore_org_core_tensor_Values) SetInt(val int, i ...int) { W.WSetInt(val, i...) } +func (W _cogentcore_org_core_tensor_Values) SetInt1D(val int, i int) { W.WSetInt1D(val, i) } +func (W _cogentcore_org_core_tensor_Values) SetIntRow(val int, row int, cell int) { + W.WSetIntRow(val, row, cell) +} +func (W _cogentcore_org_core_tensor_Values) SetNumRows(rows int) { W.WSetNumRows(rows) } +func (W _cogentcore_org_core_tensor_Values) SetRowTensor(val tensor.Values, row int) { + W.WSetRowTensor(val, row) +} +func (W _cogentcore_org_core_tensor_Values) SetShapeSizes(sizes ...int) { W.WSetShapeSizes(sizes...) } +func (W _cogentcore_org_core_tensor_Values) SetString(val string, i ...int) { W.WSetString(val, i...) } +func (W _cogentcore_org_core_tensor_Values) SetString1D(val string, i int) { W.WSetString1D(val, i) } +func (W _cogentcore_org_core_tensor_Values) SetStringRow(val string, row int, cell int) { + W.WSetStringRow(val, row, cell) +} +func (W _cogentcore_org_core_tensor_Values) SetZeros() { W.WSetZeros() } +func (W _cogentcore_org_core_tensor_Values) Shape() *tensor.Shape { return W.WShape() } +func (W _cogentcore_org_core_tensor_Values) ShapeSizes() []int { return W.WShapeSizes() } +func (W _cogentcore_org_core_tensor_Values) Sizeof() int64 { return W.WSizeof() } +func (W _cogentcore_org_core_tensor_Values) String() string { + if W.WString == nil { + return "" + } + return W.WString() +} +func (W _cogentcore_org_core_tensor_Values) String1D(i int) string { return W.WString1D(i) } +func (W _cogentcore_org_core_tensor_Values) StringRow(row int, cell int) string { + return W.WStringRow(row, cell) +} +func (W _cogentcore_org_core_tensor_Values) StringValue(i ...int) string { return W.WStringValue(i...) } +func (W _cogentcore_org_core_tensor_Values) SubSpace(offs ...int) tensor.Values { + return W.WSubSpace(offs...) +} diff --git a/yaegicore/symbols/fmt.go b/yaegicore/nogui/fmt.go similarity index 99% rename from yaegicore/symbols/fmt.go rename to yaegicore/nogui/fmt.go index 4a2e90c6a0..b9f4ee7322 100644 --- a/yaegicore/symbols/fmt.go +++ b/yaegicore/nogui/fmt.go @@ -3,7 +3,7 @@ //go:build go1.22 // +build go1.22 -package symbols +package nogui import ( "fmt" diff --git a/yaegicore/symbols/log-slog.go b/yaegicore/nogui/log-slog.go similarity index 99% rename from yaegicore/symbols/log-slog.go rename to yaegicore/nogui/log-slog.go index 267e584759..6c7c118253 100644 --- a/yaegicore/symbols/log-slog.go +++ b/yaegicore/nogui/log-slog.go @@ -3,7 +3,7 @@ //go:build go1.22 // +build go1.22 -package symbols +package nogui import ( "context" diff --git a/yaegicore/nogui/make b/yaegicore/nogui/make new file mode 100755 index 0000000000..3046666e79 --- /dev/null +++ b/yaegicore/nogui/make @@ -0,0 +1,12 @@ +#!/usr/bin/env cosh + +command extract { + for _, pkg := range args { + yaegi extract {"cogentcore.org/core/"+pkg} + } +} + +yaegi extract fmt strconv strings math time log/slog reflect + +extract math32 tensor tensor/tmath tensor/table tensor/vector tensor/matrix tensor/stats/stats tensor/stats/metric tensor/tensorfs base/errors base/fsx base/reflectx base/labels base/fileinfo base/num goal/goalib + diff --git a/yaegicore/nogui/math.go b/yaegicore/nogui/math.go new file mode 100644 index 0000000000..33a942261e --- /dev/null +++ b/yaegicore/nogui/math.go @@ -0,0 +1,116 @@ +// Code generated by 'yaegi extract math'. DO NOT EDIT. + +//go:build go1.22 +// +build go1.22 + +package nogui + +import ( + "go/constant" + "go/token" + "math" + "reflect" +) + +func init() { + Symbols["math/math"] = map[string]reflect.Value{ + // function, constant and variable definitions + "Abs": reflect.ValueOf(math.Abs), + "Acos": reflect.ValueOf(math.Acos), + "Acosh": reflect.ValueOf(math.Acosh), + "Asin": reflect.ValueOf(math.Asin), + "Asinh": reflect.ValueOf(math.Asinh), + "Atan": reflect.ValueOf(math.Atan), + "Atan2": reflect.ValueOf(math.Atan2), + "Atanh": reflect.ValueOf(math.Atanh), + "Cbrt": reflect.ValueOf(math.Cbrt), + "Ceil": reflect.ValueOf(math.Ceil), + "Copysign": reflect.ValueOf(math.Copysign), + "Cos": reflect.ValueOf(math.Cos), + "Cosh": reflect.ValueOf(math.Cosh), + "Dim": reflect.ValueOf(math.Dim), + "E": reflect.ValueOf(constant.MakeFromLiteral("2.71828182845904523536028747135266249775724709369995957496696762566337824315673231520670375558666729784504486779277967997696994772644702281675346915668215131895555530285035761295375777990557253360748291015625", token.FLOAT, 0)), + "Erf": reflect.ValueOf(math.Erf), + "Erfc": reflect.ValueOf(math.Erfc), + "Erfcinv": reflect.ValueOf(math.Erfcinv), + "Erfinv": reflect.ValueOf(math.Erfinv), + "Exp": reflect.ValueOf(math.Exp), + "Exp2": reflect.ValueOf(math.Exp2), + "Expm1": reflect.ValueOf(math.Expm1), + "FMA": reflect.ValueOf(math.FMA), + "Float32bits": reflect.ValueOf(math.Float32bits), + "Float32frombits": reflect.ValueOf(math.Float32frombits), + "Float64bits": reflect.ValueOf(math.Float64bits), + "Float64frombits": reflect.ValueOf(math.Float64frombits), + "Floor": reflect.ValueOf(math.Floor), + "Frexp": reflect.ValueOf(math.Frexp), + "Gamma": reflect.ValueOf(math.Gamma), + "Hypot": reflect.ValueOf(math.Hypot), + "Ilogb": reflect.ValueOf(math.Ilogb), + "Inf": reflect.ValueOf(math.Inf), + "IsInf": reflect.ValueOf(math.IsInf), + "IsNaN": reflect.ValueOf(math.IsNaN), + "J0": reflect.ValueOf(math.J0), + "J1": reflect.ValueOf(math.J1), + "Jn": reflect.ValueOf(math.Jn), + "Ldexp": reflect.ValueOf(math.Ldexp), + "Lgamma": reflect.ValueOf(math.Lgamma), + "Ln10": reflect.ValueOf(constant.MakeFromLiteral("2.30258509299404568401799145468436420760110148862877297603332784146804725494827975466552490443295866962642372461496758838959542646932914211937012833592062802600362869664962772731087170541286468505859375", token.FLOAT, 0)), + "Ln2": reflect.ValueOf(constant.MakeFromLiteral("0.6931471805599453094172321214581765680755001343602552541206800092715999496201383079363438206637927920954189307729314303884387720696314608777673678644642390655170150035209453154294578780536539852619171142578125", token.FLOAT, 0)), + "Log": reflect.ValueOf(math.Log), + "Log10": reflect.ValueOf(math.Log10), + "Log10E": reflect.ValueOf(constant.MakeFromLiteral("0.43429448190325182765112891891660508229439700580366656611445378416636798190620320263064286300825210972160277489744884502676719847561509639618196799746596688688378591625127711495224502868950366973876953125", token.FLOAT, 0)), + "Log1p": reflect.ValueOf(math.Log1p), + "Log2": reflect.ValueOf(math.Log2), + "Log2E": reflect.ValueOf(constant.MakeFromLiteral("1.44269504088896340735992468100189213742664595415298593413544940772066427768997545329060870636212628972710992130324953463427359402479619301286929040235571747101382214539290471666532766903401352465152740478515625", token.FLOAT, 0)), + "Logb": reflect.ValueOf(math.Logb), + "Max": reflect.ValueOf(math.Max), + "MaxFloat32": reflect.ValueOf(constant.MakeFromLiteral("340282346638528859811704183484516925440", token.FLOAT, 0)), + "MaxFloat64": reflect.ValueOf(constant.MakeFromLiteral("179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368", token.FLOAT, 0)), + "MaxInt": reflect.ValueOf(constant.MakeFromLiteral("9223372036854775807", token.INT, 0)), + "MaxInt16": reflect.ValueOf(constant.MakeFromLiteral("32767", token.INT, 0)), + "MaxInt32": reflect.ValueOf(constant.MakeFromLiteral("2147483647", token.INT, 0)), + "MaxInt64": reflect.ValueOf(constant.MakeFromLiteral("9223372036854775807", token.INT, 0)), + "MaxInt8": reflect.ValueOf(constant.MakeFromLiteral("127", token.INT, 0)), + "MaxUint": reflect.ValueOf(constant.MakeFromLiteral("18446744073709551615", token.INT, 0)), + "MaxUint16": reflect.ValueOf(constant.MakeFromLiteral("65535", token.INT, 0)), + "MaxUint32": reflect.ValueOf(constant.MakeFromLiteral("4294967295", token.INT, 0)), + "MaxUint64": reflect.ValueOf(constant.MakeFromLiteral("18446744073709551615", token.INT, 0)), + "MaxUint8": reflect.ValueOf(constant.MakeFromLiteral("255", token.INT, 0)), + "Min": reflect.ValueOf(math.Min), + "MinInt": reflect.ValueOf(constant.MakeFromLiteral("-9223372036854775808", token.INT, 0)), + "MinInt16": reflect.ValueOf(constant.MakeFromLiteral("-32768", token.INT, 0)), + "MinInt32": reflect.ValueOf(constant.MakeFromLiteral("-2147483648", token.INT, 0)), + "MinInt64": reflect.ValueOf(constant.MakeFromLiteral("-9223372036854775808", token.INT, 0)), + "MinInt8": reflect.ValueOf(constant.MakeFromLiteral("-128", token.INT, 0)), + "Mod": reflect.ValueOf(math.Mod), + "Modf": reflect.ValueOf(math.Modf), + "NaN": reflect.ValueOf(math.NaN), + "Nextafter": reflect.ValueOf(math.Nextafter), + "Nextafter32": reflect.ValueOf(math.Nextafter32), + "Phi": reflect.ValueOf(constant.MakeFromLiteral("1.6180339887498948482045868343656381177203091798057628621354486119746080982153796619881086049305501566952211682590824739205931370737029882996587050475921915678674035433959321750307935872115194797515869140625", token.FLOAT, 0)), + "Pi": reflect.ValueOf(constant.MakeFromLiteral("3.141592653589793238462643383279502884197169399375105820974944594789982923695635954704435713335896673485663389728754819466702315787113662862838515639906529162340867271374644786874341662041842937469482421875", token.FLOAT, 0)), + "Pow": reflect.ValueOf(math.Pow), + "Pow10": reflect.ValueOf(math.Pow10), + "Remainder": reflect.ValueOf(math.Remainder), + "Round": reflect.ValueOf(math.Round), + "RoundToEven": reflect.ValueOf(math.RoundToEven), + "Signbit": reflect.ValueOf(math.Signbit), + "Sin": reflect.ValueOf(math.Sin), + "Sincos": reflect.ValueOf(math.Sincos), + "Sinh": reflect.ValueOf(math.Sinh), + "SmallestNonzeroFloat32": reflect.ValueOf(constant.MakeFromLiteral("1.40129846432481707092372958328991613128026194187651577175706828388979108268586060148663818836212158203125e-45", token.FLOAT, 0)), + "SmallestNonzeroFloat64": reflect.ValueOf(constant.MakeFromLiteral("4.940656458412465441765687928682213723650598026143247644255856825006755072702087518652998363616359923797965646954457177309266567103559397963987747960107818781263007131903114045278458171678489821036887186360569987307230500063874091535649843873124733972731696151400317153853980741262385655911710266585566867681870395603106249319452715914924553293054565444011274801297099995419319894090804165633245247571478690147267801593552386115501348035264934720193790268107107491703332226844753335720832431936092382893458368060106011506169809753078342277318329247904982524730776375927247874656084778203734469699533647017972677717585125660551199131504891101451037862738167250955837389733598993664809941164205702637090279242767544565229087538682506419718265533447265625e-324", token.FLOAT, 0)), + "Sqrt": reflect.ValueOf(math.Sqrt), + "Sqrt2": reflect.ValueOf(constant.MakeFromLiteral("1.414213562373095048801688724209698078569671875376948073176679739576083351575381440094441524123797447886801949755143139115339040409162552642832693297721230919563348109313505318596071447245776653289794921875", token.FLOAT, 0)), + "SqrtE": reflect.ValueOf(constant.MakeFromLiteral("1.64872127070012814684865078781416357165377610071014801157507931167328763229187870850146925823776361770041160388013884200789716007979526823569827080974091691342077871211546646890155898290686309337615966796875", token.FLOAT, 0)), + "SqrtPhi": reflect.ValueOf(constant.MakeFromLiteral("1.2720196495140689642524224617374914917156080418400962486166403754616080542166459302584536396369727769747312116100875915825863540562126478288118732191412003988041797518382391984914647764526307582855224609375", token.FLOAT, 0)), + "SqrtPi": reflect.ValueOf(constant.MakeFromLiteral("1.772453850905516027298167483341145182797549456122387128213807789740599698370237052541269446184448945647349951047154197675245574635259260134350885938555625028620527962319730619356050738133490085601806640625", token.FLOAT, 0)), + "Tan": reflect.ValueOf(math.Tan), + "Tanh": reflect.ValueOf(math.Tanh), + "Trunc": reflect.ValueOf(math.Trunc), + "Y0": reflect.ValueOf(math.Y0), + "Y1": reflect.ValueOf(math.Y1), + "Yn": reflect.ValueOf(math.Yn), + } +} diff --git a/yaegicore/symbols/reflect.go b/yaegicore/nogui/reflect.go similarity index 99% rename from yaegicore/symbols/reflect.go rename to yaegicore/nogui/reflect.go index d88982d619..d1868b1649 100644 --- a/yaegicore/symbols/reflect.go +++ b/yaegicore/nogui/reflect.go @@ -3,7 +3,7 @@ //go:build go1.22 // +build go1.22 -package symbols +package nogui import ( "reflect" diff --git a/yaegicore/symbols/strconv.go b/yaegicore/nogui/strconv.go similarity index 99% rename from yaegicore/symbols/strconv.go rename to yaegicore/nogui/strconv.go index 99a1e3a547..3c1471e8ea 100644 --- a/yaegicore/symbols/strconv.go +++ b/yaegicore/nogui/strconv.go @@ -3,7 +3,7 @@ //go:build go1.22 // +build go1.22 -package symbols +package nogui import ( "go/constant" diff --git a/yaegicore/symbols/strings.go b/yaegicore/nogui/strings.go similarity index 99% rename from yaegicore/symbols/strings.go rename to yaegicore/nogui/strings.go index 6d0eba0d58..17c6738310 100644 --- a/yaegicore/symbols/strings.go +++ b/yaegicore/nogui/strings.go @@ -3,7 +3,7 @@ //go:build go1.22 // +build go1.22 -package symbols +package nogui import ( "reflect" diff --git a/yaegicore/nogui/symbols.go b/yaegicore/nogui/symbols.go new file mode 100644 index 0000000000..5a47f34e8e --- /dev/null +++ b/yaegicore/nogui/symbols.go @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Cogent Core. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package symbols contains yaegi symbols for core packages. +package nogui + +//go:generate ./make + +import "reflect" + +var Symbols = map[string]map[string]reflect.Value{} diff --git a/yaegicore/symbols/time.go b/yaegicore/nogui/time.go similarity index 99% rename from yaegicore/symbols/time.go rename to yaegicore/nogui/time.go index 9bba72045c..7c50adc202 100644 --- a/yaegicore/symbols/time.go +++ b/yaegicore/nogui/time.go @@ -3,7 +3,7 @@ //go:build go1.22 // +build go1.22 -package symbols +package nogui import ( "go/constant" diff --git a/yaegicore/symbols/cogentcore_org-core-core.go b/yaegicore/symbols/cogentcore_org-core-core.go index c025a127fb..55ac3ba110 100644 --- a/yaegicore/symbols/cogentcore_org-core-core.go +++ b/yaegicore/symbols/cogentcore_org-core-core.go @@ -305,6 +305,7 @@ func init() { "SystemSettingsData": reflect.ValueOf((*core.SystemSettingsData)(nil)), "Tab": reflect.ValueOf((*core.Tab)(nil)), "TabTypes": reflect.ValueOf((*core.TabTypes)(nil)), + "Tabber": reflect.ValueOf((*core.Tabber)(nil)), "Table": reflect.ValueOf((*core.Table)(nil)), "TableStyler": reflect.ValueOf((*core.TableStyler)(nil)), "Tabs": reflect.ValueOf((*core.Tabs)(nil)), @@ -340,6 +341,7 @@ func init() { "_SettingsOpener": reflect.ValueOf((*_cogentcore_org_core_core_SettingsOpener)(nil)), "_SettingsSaver": reflect.ValueOf((*_cogentcore_org_core_core_SettingsSaver)(nil)), "_ShouldDisplayer": reflect.ValueOf((*_cogentcore_org_core_core_ShouldDisplayer)(nil)), + "_Tabber": reflect.ValueOf((*_cogentcore_org_core_core_Tabber)(nil)), "_TextFieldEmbedder": reflect.ValueOf((*_cogentcore_org_core_core_TextFieldEmbedder)(nil)), "_ToolbarMaker": reflect.ValueOf((*_cogentcore_org_core_core_ToolbarMaker)(nil)), "_Treer": reflect.ValueOf((*_cogentcore_org_core_core_Treer)(nil)), @@ -599,6 +601,32 @@ func (W _cogentcore_org_core_core_ShouldDisplayer) ShouldDisplay(field string) b return W.WShouldDisplay(field) } +// _cogentcore_org_core_core_Tabber is an interface wrapper for Tabber type +type _cogentcore_org_core_core_Tabber struct { + IValue interface{} + WAsCoreTabs func() *core.Tabs + WCurrentTab func() (core.Widget, int) + WRecycleTab func(name string) *core.Frame + WSelectTabByName func(name string) *core.Frame + WSelectTabIndex func(idx int) *core.Frame + WTabByName func(name string) *core.Frame +} + +func (W _cogentcore_org_core_core_Tabber) AsCoreTabs() *core.Tabs { return W.WAsCoreTabs() } +func (W _cogentcore_org_core_core_Tabber) CurrentTab() (core.Widget, int) { return W.WCurrentTab() } +func (W _cogentcore_org_core_core_Tabber) RecycleTab(name string) *core.Frame { + return W.WRecycleTab(name) +} +func (W _cogentcore_org_core_core_Tabber) SelectTabByName(name string) *core.Frame { + return W.WSelectTabByName(name) +} +func (W _cogentcore_org_core_core_Tabber) SelectTabIndex(idx int) *core.Frame { + return W.WSelectTabIndex(idx) +} +func (W _cogentcore_org_core_core_Tabber) TabByName(name string) *core.Frame { + return W.WTabByName(name) +} + // _cogentcore_org_core_core_TextFieldEmbedder is an interface wrapper for TextFieldEmbedder type type _cogentcore_org_core_core_TextFieldEmbedder struct { IValue interface{} diff --git a/yaegicore/symbols/cogentcore_org-core-filetree.go b/yaegicore/symbols/cogentcore_org-core-filetree.go index 442548dcc5..fb6b391bf1 100644 --- a/yaegicore/symbols/cogentcore_org-core-filetree.go +++ b/yaegicore/symbols/cogentcore_org-core-filetree.go @@ -18,6 +18,7 @@ func init() { Symbols["cogentcore.org/core/filetree/filetree"] = map[string]reflect.Value{ // function, constant and variable definitions "AsNode": reflect.ValueOf(filetree.AsNode), + "AsTree": reflect.ValueOf(filetree.AsTree), "FindLocationAll": reflect.ValueOf(filetree.FindLocationAll), "FindLocationDir": reflect.ValueOf(filetree.FindLocationDir), "FindLocationFile": reflect.ValueOf(filetree.FindLocationFile), @@ -41,11 +42,13 @@ func init() { "NodeNameCount": reflect.ValueOf((*filetree.NodeNameCount)(nil)), "SearchResults": reflect.ValueOf((*filetree.SearchResults)(nil)), "Tree": reflect.ValueOf((*filetree.Tree)(nil)), + "Treer": reflect.ValueOf((*filetree.Treer)(nil)), "VCSLog": reflect.ValueOf((*filetree.VCSLog)(nil)), // interface wrapper definitions "_Filer": reflect.ValueOf((*_cogentcore_org_core_filetree_Filer)(nil)), "_NodeEmbedder": reflect.ValueOf((*_cogentcore_org_core_filetree_NodeEmbedder)(nil)), + "_Treer": reflect.ValueOf((*_cogentcore_org_core_filetree_Treer)(nil)), } } @@ -142,3 +145,11 @@ type _cogentcore_org_core_filetree_NodeEmbedder struct { } func (W _cogentcore_org_core_filetree_NodeEmbedder) AsNode() *filetree.Node { return W.WAsNode() } + +// _cogentcore_org_core_filetree_Treer is an interface wrapper for Treer type +type _cogentcore_org_core_filetree_Treer struct { + IValue interface{} + WAsFileTree func() *filetree.Tree +} + +func (W _cogentcore_org_core_filetree_Treer) AsFileTree() *filetree.Tree { return W.WAsFileTree() } diff --git a/yaegicore/symbols/cogentcore_org-core-plot-plotcore.go b/yaegicore/symbols/cogentcore_org-core-plot-plotcore.go index 41c461e613..462d579dd8 100644 --- a/yaegicore/symbols/cogentcore_org-core-plot-plotcore.go +++ b/yaegicore/symbols/cogentcore_org-core-plot-plotcore.go @@ -10,25 +10,14 @@ import ( func init() { Symbols["cogentcore.org/core/plot/plotcore/plotcore"] = map[string]reflect.Value{ // function, constant and variable definitions - "Bar": reflect.ValueOf(plotcore.Bar), - "FixMax": reflect.ValueOf(plotcore.FixMax), - "FixMin": reflect.ValueOf(plotcore.FixMin), - "FloatMax": reflect.ValueOf(plotcore.FloatMax), - "FloatMin": reflect.ValueOf(plotcore.FloatMin), - "NewPlot": reflect.ValueOf(plotcore.NewPlot), - "NewPlotEditor": reflect.ValueOf(plotcore.NewPlotEditor), - "NewSubPlot": reflect.ValueOf(plotcore.NewSubPlot), - "Off": reflect.ValueOf(plotcore.Off), - "On": reflect.ValueOf(plotcore.On), - "PlotTypesN": reflect.ValueOf(plotcore.PlotTypesN), - "PlotTypesValues": reflect.ValueOf(plotcore.PlotTypesValues), - "XY": reflect.ValueOf(plotcore.XY), + "NewPlot": reflect.ValueOf(plotcore.NewPlot), + "NewPlotEditor": reflect.ValueOf(plotcore.NewPlotEditor), + "NewPlotterChooser": reflect.ValueOf(plotcore.NewPlotterChooser), + "NewSubPlot": reflect.ValueOf(plotcore.NewSubPlot), // type definitions - "ColumnOptions": reflect.ValueOf((*plotcore.ColumnOptions)(nil)), - "Plot": reflect.ValueOf((*plotcore.Plot)(nil)), - "PlotEditor": reflect.ValueOf((*plotcore.PlotEditor)(nil)), - "PlotOptions": reflect.ValueOf((*plotcore.PlotOptions)(nil)), - "PlotTypes": reflect.ValueOf((*plotcore.PlotTypes)(nil)), + "Plot": reflect.ValueOf((*plotcore.Plot)(nil)), + "PlotEditor": reflect.ValueOf((*plotcore.PlotEditor)(nil)), + "PlotterChooser": reflect.ValueOf((*plotcore.PlotterChooser)(nil)), } } diff --git a/yaegicore/symbols/cogentcore_org-core-plot-plots.go b/yaegicore/symbols/cogentcore_org-core-plot-plots.go index 956dce326e..32f2179e08 100644 --- a/yaegicore/symbols/cogentcore_org-core-plot-plots.go +++ b/yaegicore/symbols/cogentcore_org-core-plot-plots.go @@ -4,120 +4,32 @@ package symbols import ( "cogentcore.org/core/plot/plots" + "go/constant" + "go/token" "reflect" ) func init() { Symbols["cogentcore.org/core/plot/plots/plots"] = map[string]reflect.Value{ // function, constant and variable definitions - "AddTableLine": reflect.ValueOf(plots.AddTableLine), - "AddTableLinePoints": reflect.ValueOf(plots.AddTableLinePoints), - "Box": reflect.ValueOf(plots.Box), - "Circle": reflect.ValueOf(plots.Circle), - "Cross": reflect.ValueOf(plots.Cross), - "DrawBox": reflect.ValueOf(plots.DrawBox), - "DrawCircle": reflect.ValueOf(plots.DrawCircle), - "DrawCross": reflect.ValueOf(plots.DrawCross), - "DrawPlus": reflect.ValueOf(plots.DrawPlus), - "DrawPyramid": reflect.ValueOf(plots.DrawPyramid), - "DrawRing": reflect.ValueOf(plots.DrawRing), - "DrawShape": reflect.ValueOf(plots.DrawShape), - "DrawSquare": reflect.ValueOf(plots.DrawSquare), - "DrawTriangle": reflect.ValueOf(plots.DrawTriangle), - "MidStep": reflect.ValueOf(plots.MidStep), - "NewBarChart": reflect.ValueOf(plots.NewBarChart), - "NewLabels": reflect.ValueOf(plots.NewLabels), - "NewLine": reflect.ValueOf(plots.NewLine), - "NewLinePoints": reflect.ValueOf(plots.NewLinePoints), - "NewScatter": reflect.ValueOf(plots.NewScatter), - "NewTableXYer": reflect.ValueOf(plots.NewTableXYer), - "NewXErrorBars": reflect.ValueOf(plots.NewXErrorBars), - "NewYErrorBars": reflect.ValueOf(plots.NewYErrorBars), - "NoStep": reflect.ValueOf(plots.NoStep), - "Plus": reflect.ValueOf(plots.Plus), - "PostStep": reflect.ValueOf(plots.PostStep), - "PreStep": reflect.ValueOf(plots.PreStep), - "Pyramid": reflect.ValueOf(plots.Pyramid), - "Ring": reflect.ValueOf(plots.Ring), - "ShapesN": reflect.ValueOf(plots.ShapesN), - "ShapesValues": reflect.ValueOf(plots.ShapesValues), - "Square": reflect.ValueOf(plots.Square), - "StepKindN": reflect.ValueOf(plots.StepKindN), - "StepKindValues": reflect.ValueOf(plots.StepKindValues), - "TableColumnIndex": reflect.ValueOf(plots.TableColumnIndex), - "Triangle": reflect.ValueOf(plots.Triangle), + "BarType": reflect.ValueOf(constant.MakeFromLiteral("\"Bar\"", token.STRING, 0)), + "LabelsType": reflect.ValueOf(constant.MakeFromLiteral("\"Labels\"", token.STRING, 0)), + "NewBar": reflect.ValueOf(plots.NewBar), + "NewLabels": reflect.ValueOf(plots.NewLabels), + "NewLine": reflect.ValueOf(plots.NewLine), + "NewScatter": reflect.ValueOf(plots.NewScatter), + "NewXErrorBars": reflect.ValueOf(plots.NewXErrorBars), + "NewXY": reflect.ValueOf(plots.NewXY), + "NewYErrorBars": reflect.ValueOf(plots.NewYErrorBars), + "XErrorBarsType": reflect.ValueOf(constant.MakeFromLiteral("\"XErrorBars\"", token.STRING, 0)), + "XYType": reflect.ValueOf(constant.MakeFromLiteral("\"XY\"", token.STRING, 0)), + "YErrorBarsType": reflect.ValueOf(constant.MakeFromLiteral("\"YErrorBars\"", token.STRING, 0)), // type definitions - "BarChart": reflect.ValueOf((*plots.BarChart)(nil)), - "Errors": reflect.ValueOf((*plots.Errors)(nil)), + "Bar": reflect.ValueOf((*plots.Bar)(nil)), "Labels": reflect.ValueOf((*plots.Labels)(nil)), - "Line": reflect.ValueOf((*plots.Line)(nil)), - "Scatter": reflect.ValueOf((*plots.Scatter)(nil)), - "Shapes": reflect.ValueOf((*plots.Shapes)(nil)), - "StepKind": reflect.ValueOf((*plots.StepKind)(nil)), - "Table": reflect.ValueOf((*plots.Table)(nil)), - "TableXYer": reflect.ValueOf((*plots.TableXYer)(nil)), "XErrorBars": reflect.ValueOf((*plots.XErrorBars)(nil)), - "XErrorer": reflect.ValueOf((*plots.XErrorer)(nil)), - "XErrors": reflect.ValueOf((*plots.XErrors)(nil)), - "XYLabeler": reflect.ValueOf((*plots.XYLabeler)(nil)), - "XYLabels": reflect.ValueOf((*plots.XYLabels)(nil)), + "XY": reflect.ValueOf((*plots.XY)(nil)), "YErrorBars": reflect.ValueOf((*plots.YErrorBars)(nil)), - "YErrorer": reflect.ValueOf((*plots.YErrorer)(nil)), - "YErrors": reflect.ValueOf((*plots.YErrors)(nil)), - - // interface wrapper definitions - "_Table": reflect.ValueOf((*_cogentcore_org_core_plot_plots_Table)(nil)), - "_XErrorer": reflect.ValueOf((*_cogentcore_org_core_plot_plots_XErrorer)(nil)), - "_XYLabeler": reflect.ValueOf((*_cogentcore_org_core_plot_plots_XYLabeler)(nil)), - "_YErrorer": reflect.ValueOf((*_cogentcore_org_core_plot_plots_YErrorer)(nil)), } } - -// _cogentcore_org_core_plot_plots_Table is an interface wrapper for Table type -type _cogentcore_org_core_plot_plots_Table struct { - IValue interface{} - WColumnName func(i int) string - WNumColumns func() int - WNumRows func() int - WPlotData func(column int, row int) float32 -} - -func (W _cogentcore_org_core_plot_plots_Table) ColumnName(i int) string { return W.WColumnName(i) } -func (W _cogentcore_org_core_plot_plots_Table) NumColumns() int { return W.WNumColumns() } -func (W _cogentcore_org_core_plot_plots_Table) NumRows() int { return W.WNumRows() } -func (W _cogentcore_org_core_plot_plots_Table) PlotData(column int, row int) float32 { - return W.WPlotData(column, row) -} - -// _cogentcore_org_core_plot_plots_XErrorer is an interface wrapper for XErrorer type -type _cogentcore_org_core_plot_plots_XErrorer struct { - IValue interface{} - WXError func(i int) (low float32, high float32) -} - -func (W _cogentcore_org_core_plot_plots_XErrorer) XError(i int) (low float32, high float32) { - return W.WXError(i) -} - -// _cogentcore_org_core_plot_plots_XYLabeler is an interface wrapper for XYLabeler type -type _cogentcore_org_core_plot_plots_XYLabeler struct { - IValue interface{} - WLabel func(i int) string - WLen func() int - WXY func(i int) (x float32, y float32) -} - -func (W _cogentcore_org_core_plot_plots_XYLabeler) Label(i int) string { return W.WLabel(i) } -func (W _cogentcore_org_core_plot_plots_XYLabeler) Len() int { return W.WLen() } -func (W _cogentcore_org_core_plot_plots_XYLabeler) XY(i int) (x float32, y float32) { return W.WXY(i) } - -// _cogentcore_org_core_plot_plots_YErrorer is an interface wrapper for YErrorer type -type _cogentcore_org_core_plot_plots_YErrorer struct { - IValue interface{} - WYError func(i int) (float32, float32) -} - -func (W _cogentcore_org_core_plot_plots_YErrorer) YError(i int) (float32, float32) { - return W.WYError(i) -} diff --git a/yaegicore/symbols/cogentcore_org-core-plot.go b/yaegicore/symbols/cogentcore_org-core-plot.go index 6e2b8b230f..84bbda9d0e 100644 --- a/yaegicore/symbols/cogentcore_org-core-plot.go +++ b/yaegicore/symbols/cogentcore_org-core-plot.go @@ -3,6 +3,7 @@ package symbols import ( + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/plot" "reflect" ) @@ -10,107 +11,163 @@ import ( func init() { Symbols["cogentcore.org/core/plot/plot"] = map[string]reflect.Value{ // function, constant and variable definitions - "CheckFloats": reflect.ValueOf(plot.CheckFloats), - "CheckNaNs": reflect.ValueOf(plot.CheckNaNs), - "CopyValues": reflect.ValueOf(plot.CopyValues), - "CopyXYZs": reflect.ValueOf(plot.CopyXYZs), - "CopyXYs": reflect.ValueOf(plot.CopyXYs), - "DefaultFontFamily": reflect.ValueOf(&plot.DefaultFontFamily).Elem(), - "ErrInfinity": reflect.ValueOf(&plot.ErrInfinity).Elem(), - "ErrNoData": reflect.ValueOf(&plot.ErrNoData).Elem(), - "New": reflect.ValueOf(plot.New), - "PlotXYs": reflect.ValueOf(plot.PlotXYs), - "Range": reflect.ValueOf(plot.Range), - "UTCUnixTime": reflect.ValueOf(&plot.UTCUnixTime).Elem(), - "UnixTimeIn": reflect.ValueOf(plot.UnixTimeIn), - "XYRange": reflect.ValueOf(plot.XYRange), + "AddStylerTo": reflect.ValueOf(plot.AddStylerTo), + "AxisScalesN": reflect.ValueOf(plot.AxisScalesN), + "AxisScalesValues": reflect.ValueOf(plot.AxisScalesValues), + "Box": reflect.ValueOf(plot.Box), + "CheckFloats": reflect.ValueOf(plot.CheckFloats), + "CheckNaNs": reflect.ValueOf(plot.CheckNaNs), + "Circle": reflect.ValueOf(plot.Circle), + "Color": reflect.ValueOf(plot.Color), + "CopyRole": reflect.ValueOf(plot.CopyRole), + "CopyValues": reflect.ValueOf(plot.CopyValues), + "Cross": reflect.ValueOf(plot.Cross), + "Default": reflect.ValueOf(plot.Default), + "DefaultFontFamily": reflect.ValueOf(&plot.DefaultFontFamily).Elem(), + "DefaultOffOnN": reflect.ValueOf(plot.DefaultOffOnN), + "DefaultOffOnValues": reflect.ValueOf(plot.DefaultOffOnValues), + "DrawBox": reflect.ValueOf(plot.DrawBox), + "DrawCircle": reflect.ValueOf(plot.DrawCircle), + "DrawCross": reflect.ValueOf(plot.DrawCross), + "DrawPlus": reflect.ValueOf(plot.DrawPlus), + "DrawPyramid": reflect.ValueOf(plot.DrawPyramid), + "DrawRing": reflect.ValueOf(plot.DrawRing), + "DrawSquare": reflect.ValueOf(plot.DrawSquare), + "DrawTriangle": reflect.ValueOf(plot.DrawTriangle), + "ErrInfinity": reflect.ValueOf(&plot.ErrInfinity).Elem(), + "ErrNoData": reflect.ValueOf(&plot.ErrNoData).Elem(), + "GetStylersFrom": reflect.ValueOf(plot.GetStylersFrom), + "GetStylersFromData": reflect.ValueOf(plot.GetStylersFromData), + "High": reflect.ValueOf(plot.High), + "InverseLinear": reflect.ValueOf(plot.InverseLinear), + "InverseLog": reflect.ValueOf(plot.InverseLog), + "Label": reflect.ValueOf(plot.Label), + "Linear": reflect.ValueOf(plot.Linear), + "Log": reflect.ValueOf(plot.Log), + "Low": reflect.ValueOf(plot.Low), + "MidStep": reflect.ValueOf(plot.MidStep), + "MustCopyRole": reflect.ValueOf(plot.MustCopyRole), + "New": reflect.ValueOf(plot.New), + "NewPlotter": reflect.ValueOf(plot.NewPlotter), + "NewStyle": reflect.ValueOf(plot.NewStyle), + "NewTablePlot": reflect.ValueOf(plot.NewTablePlot), + "NoRole": reflect.ValueOf(plot.NoRole), + "NoStep": reflect.ValueOf(plot.NoStep), + "Off": reflect.ValueOf(plot.Off), + "On": reflect.ValueOf(plot.On), + "PlotX": reflect.ValueOf(plot.PlotX), + "PlotY": reflect.ValueOf(plot.PlotY), + "PlotterByType": reflect.ValueOf(plot.PlotterByType), + "Plotters": reflect.ValueOf(&plot.Plotters).Elem(), + "Plus": reflect.ValueOf(plot.Plus), + "PostStep": reflect.ValueOf(plot.PostStep), + "PreStep": reflect.ValueOf(plot.PreStep), + "Pyramid": reflect.ValueOf(plot.Pyramid), + "Range": reflect.ValueOf(plot.Range), + "RangeClamp": reflect.ValueOf(plot.RangeClamp), + "RegisterPlotter": reflect.ValueOf(plot.RegisterPlotter), + "Ring": reflect.ValueOf(plot.Ring), + "RolesN": reflect.ValueOf(plot.RolesN), + "RolesValues": reflect.ValueOf(plot.RolesValues), + "SetStylerTo": reflect.ValueOf(plot.SetStylerTo), + "SetStylersTo": reflect.ValueOf(plot.SetStylersTo), + "ShapesN": reflect.ValueOf(plot.ShapesN), + "ShapesValues": reflect.ValueOf(plot.ShapesValues), + "Size": reflect.ValueOf(plot.Size), + "Square": reflect.ValueOf(plot.Square), + "StepKindN": reflect.ValueOf(plot.StepKindN), + "StepKindValues": reflect.ValueOf(plot.StepKindValues), + "Triangle": reflect.ValueOf(plot.Triangle), + "U": reflect.ValueOf(plot.U), + "UTCUnixTime": reflect.ValueOf(&plot.UTCUnixTime).Elem(), + "UnixTimeIn": reflect.ValueOf(plot.UnixTimeIn), + "V": reflect.ValueOf(plot.V), + "W": reflect.ValueOf(plot.W), + "X": reflect.ValueOf(plot.X), + "Y": reflect.ValueOf(plot.Y), + "Z": reflect.ValueOf(plot.Z), // type definitions "Axis": reflect.ValueOf((*plot.Axis)(nil)), + "AxisScales": reflect.ValueOf((*plot.AxisScales)(nil)), + "AxisStyle": reflect.ValueOf((*plot.AxisStyle)(nil)), "ConstantTicks": reflect.ValueOf((*plot.ConstantTicks)(nil)), - "DataRanger": reflect.ValueOf((*plot.DataRanger)(nil)), + "Data": reflect.ValueOf((*plot.Data)(nil)), + "DefaultOffOn": reflect.ValueOf((*plot.DefaultOffOn)(nil)), "DefaultTicks": reflect.ValueOf((*plot.DefaultTicks)(nil)), "InvertedScale": reflect.ValueOf((*plot.InvertedScale)(nil)), - "Labeler": reflect.ValueOf((*plot.Labeler)(nil)), + "Labels": reflect.ValueOf((*plot.Labels)(nil)), "Legend": reflect.ValueOf((*plot.Legend)(nil)), "LegendEntry": reflect.ValueOf((*plot.LegendEntry)(nil)), "LegendPosition": reflect.ValueOf((*plot.LegendPosition)(nil)), + "LegendStyle": reflect.ValueOf((*plot.LegendStyle)(nil)), "LineStyle": reflect.ValueOf((*plot.LineStyle)(nil)), "LinearScale": reflect.ValueOf((*plot.LinearScale)(nil)), "LogScale": reflect.ValueOf((*plot.LogScale)(nil)), "LogTicks": reflect.ValueOf((*plot.LogTicks)(nil)), "Normalizer": reflect.ValueOf((*plot.Normalizer)(nil)), + "PanZoom": reflect.ValueOf((*plot.PanZoom)(nil)), "Plot": reflect.ValueOf((*plot.Plot)(nil)), + "PlotStyle": reflect.ValueOf((*plot.PlotStyle)(nil)), "Plotter": reflect.ValueOf((*plot.Plotter)(nil)), + "PlotterName": reflect.ValueOf((*plot.PlotterName)(nil)), + "PlotterType": reflect.ValueOf((*plot.PlotterType)(nil)), + "PointStyle": reflect.ValueOf((*plot.PointStyle)(nil)), + "Roles": reflect.ValueOf((*plot.Roles)(nil)), + "Shapes": reflect.ValueOf((*plot.Shapes)(nil)), + "StepKind": reflect.ValueOf((*plot.StepKind)(nil)), + "Style": reflect.ValueOf((*plot.Style)(nil)), + "Stylers": reflect.ValueOf((*plot.Stylers)(nil)), "Text": reflect.ValueOf((*plot.Text)(nil)), "TextStyle": reflect.ValueOf((*plot.TextStyle)(nil)), "Thumbnailer": reflect.ValueOf((*plot.Thumbnailer)(nil)), "Tick": reflect.ValueOf((*plot.Tick)(nil)), "Ticker": reflect.ValueOf((*plot.Ticker)(nil)), - "TickerFunc": reflect.ValueOf((*plot.TickerFunc)(nil)), "TimeTicks": reflect.ValueOf((*plot.TimeTicks)(nil)), "Valuer": reflect.ValueOf((*plot.Valuer)(nil)), "Values": reflect.ValueOf((*plot.Values)(nil)), - "XValues": reflect.ValueOf((*plot.XValues)(nil)), - "XYValues": reflect.ValueOf((*plot.XYValues)(nil)), - "XYZ": reflect.ValueOf((*plot.XYZ)(nil)), - "XYZer": reflect.ValueOf((*plot.XYZer)(nil)), - "XYZs": reflect.ValueOf((*plot.XYZs)(nil)), - "XYer": reflect.ValueOf((*plot.XYer)(nil)), - "XYs": reflect.ValueOf((*plot.XYs)(nil)), - "YValues": reflect.ValueOf((*plot.YValues)(nil)), + "WidthStyle": reflect.ValueOf((*plot.WidthStyle)(nil)), + "XAxisStyle": reflect.ValueOf((*plot.XAxisStyle)(nil)), // interface wrapper definitions - "_DataRanger": reflect.ValueOf((*_cogentcore_org_core_plot_DataRanger)(nil)), - "_Labeler": reflect.ValueOf((*_cogentcore_org_core_plot_Labeler)(nil)), "_Normalizer": reflect.ValueOf((*_cogentcore_org_core_plot_Normalizer)(nil)), "_Plotter": reflect.ValueOf((*_cogentcore_org_core_plot_Plotter)(nil)), "_Thumbnailer": reflect.ValueOf((*_cogentcore_org_core_plot_Thumbnailer)(nil)), "_Ticker": reflect.ValueOf((*_cogentcore_org_core_plot_Ticker)(nil)), "_Valuer": reflect.ValueOf((*_cogentcore_org_core_plot_Valuer)(nil)), - "_XYZer": reflect.ValueOf((*_cogentcore_org_core_plot_XYZer)(nil)), - "_XYer": reflect.ValueOf((*_cogentcore_org_core_plot_XYer)(nil)), } } -// _cogentcore_org_core_plot_DataRanger is an interface wrapper for DataRanger type -type _cogentcore_org_core_plot_DataRanger struct { - IValue interface{} - WDataRange func(pt *plot.Plot) (xmin float32, xmax float32, ymin float32, ymax float32) -} - -func (W _cogentcore_org_core_plot_DataRanger) DataRange(pt *plot.Plot) (xmin float32, xmax float32, ymin float32, ymax float32) { - return W.WDataRange(pt) -} - -// _cogentcore_org_core_plot_Labeler is an interface wrapper for Labeler type -type _cogentcore_org_core_plot_Labeler struct { - IValue interface{} - WLabel func(i int) string -} - -func (W _cogentcore_org_core_plot_Labeler) Label(i int) string { return W.WLabel(i) } - // _cogentcore_org_core_plot_Normalizer is an interface wrapper for Normalizer type type _cogentcore_org_core_plot_Normalizer struct { IValue interface{} - WNormalize func(min float32, max float32, x float32) float32 + WNormalize func(min float64, max float64, x float64) float64 } -func (W _cogentcore_org_core_plot_Normalizer) Normalize(min float32, max float32, x float32) float32 { +func (W _cogentcore_org_core_plot_Normalizer) Normalize(min float64, max float64, x float64) float64 { return W.WNormalize(min, max, x) } // _cogentcore_org_core_plot_Plotter is an interface wrapper for Plotter type type _cogentcore_org_core_plot_Plotter struct { - IValue interface{} - WPlot func(pt *plot.Plot) - WXYData func() (data plot.XYer, pixels plot.XYer) + IValue interface{} + WApplyStyle func(plotStyle *plot.PlotStyle) + WData func() (data plot.Data, pixX []float32, pixY []float32) + WPlot func(pt *plot.Plot) + WStylers func() *plot.Stylers + WUpdateRange func(plt *plot.Plot, xr *minmax.F64, yr *minmax.F64, zr *minmax.F64) } -func (W _cogentcore_org_core_plot_Plotter) Plot(pt *plot.Plot) { W.WPlot(pt) } -func (W _cogentcore_org_core_plot_Plotter) XYData() (data plot.XYer, pixels plot.XYer) { - return W.WXYData() +func (W _cogentcore_org_core_plot_Plotter) ApplyStyle(plotStyle *plot.PlotStyle) { + W.WApplyStyle(plotStyle) +} +func (W _cogentcore_org_core_plot_Plotter) Data() (data plot.Data, pixX []float32, pixY []float32) { + return W.WData() +} +func (W _cogentcore_org_core_plot_Plotter) Plot(pt *plot.Plot) { W.WPlot(pt) } +func (W _cogentcore_org_core_plot_Plotter) Stylers() *plot.Stylers { return W.WStylers() } +func (W _cogentcore_org_core_plot_Plotter) UpdateRange(plt *plot.Plot, xr *minmax.F64, yr *minmax.F64, zr *minmax.F64) { + W.WUpdateRange(plt, xr, yr, zr) } // _cogentcore_org_core_plot_Thumbnailer is an interface wrapper for Thumbnailer type @@ -124,41 +181,21 @@ func (W _cogentcore_org_core_plot_Thumbnailer) Thumbnail(pt *plot.Plot) { W.WThu // _cogentcore_org_core_plot_Ticker is an interface wrapper for Ticker type type _cogentcore_org_core_plot_Ticker struct { IValue interface{} - WTicks func(min float32, max float32) []plot.Tick + WTicks func(min float64, max float64, nticks int) []plot.Tick } -func (W _cogentcore_org_core_plot_Ticker) Ticks(min float32, max float32) []plot.Tick { - return W.WTicks(min, max) +func (W _cogentcore_org_core_plot_Ticker) Ticks(min float64, max float64, nticks int) []plot.Tick { + return W.WTicks(min, max, nticks) } // _cogentcore_org_core_plot_Valuer is an interface wrapper for Valuer type type _cogentcore_org_core_plot_Valuer struct { - IValue interface{} - WLen func() int - WValue func(i int) float32 -} - -func (W _cogentcore_org_core_plot_Valuer) Len() int { return W.WLen() } -func (W _cogentcore_org_core_plot_Valuer) Value(i int) float32 { return W.WValue(i) } - -// _cogentcore_org_core_plot_XYZer is an interface wrapper for XYZer type -type _cogentcore_org_core_plot_XYZer struct { - IValue interface{} - WLen func() int - WXY func(i int) (float32, float32) - WXYZ func(i int) (float32, float32, float32) -} - -func (W _cogentcore_org_core_plot_XYZer) Len() int { return W.WLen() } -func (W _cogentcore_org_core_plot_XYZer) XY(i int) (float32, float32) { return W.WXY(i) } -func (W _cogentcore_org_core_plot_XYZer) XYZ(i int) (float32, float32, float32) { return W.WXYZ(i) } - -// _cogentcore_org_core_plot_XYer is an interface wrapper for XYer type -type _cogentcore_org_core_plot_XYer struct { - IValue interface{} - WLen func() int - WXY func(i int) (x float32, y float32) + IValue interface{} + WFloat1D func(i int) float64 + WLen func() int + WString1D func(i int) string } -func (W _cogentcore_org_core_plot_XYer) Len() int { return W.WLen() } -func (W _cogentcore_org_core_plot_XYer) XY(i int) (x float32, y float32) { return W.WXY(i) } +func (W _cogentcore_org_core_plot_Valuer) Float1D(i int) float64 { return W.WFloat1D(i) } +func (W _cogentcore_org_core_plot_Valuer) Len() int { return W.WLen() } +func (W _cogentcore_org_core_plot_Valuer) String1D(i int) string { return W.WString1D(i) } diff --git a/yaegicore/symbols/cogentcore_org-core-tensor-databrowser.go b/yaegicore/symbols/cogentcore_org-core-tensor-databrowser.go new file mode 100644 index 0000000000..839edd22d4 --- /dev/null +++ b/yaegicore/symbols/cogentcore_org-core-tensor-databrowser.go @@ -0,0 +1,134 @@ +// Code generated by 'yaegi extract cogentcore.org/core/tensor/databrowser'. DO NOT EDIT. + +package symbols + +import ( + "cogentcore.org/core/core" + "cogentcore.org/core/plot/plotcore" + "cogentcore.org/core/tensor" + "cogentcore.org/core/tensor/databrowser" + "cogentcore.org/core/tensor/table" + "cogentcore.org/core/tensor/tensorcore" + "cogentcore.org/core/tensor/tensorfs" + "cogentcore.org/core/texteditor" + "reflect" +) + +func init() { + Symbols["cogentcore.org/core/tensor/databrowser/databrowser"] = map[string]reflect.Value{ + // function, constant and variable definitions + "AsDataTree": reflect.ValueOf(databrowser.AsDataTree), + "CurTabber": reflect.ValueOf(&databrowser.CurTabber).Elem(), + "FirstComment": reflect.ValueOf(databrowser.FirstComment), + "IsTableFile": reflect.ValueOf(databrowser.IsTableFile), + "NewBasic": reflect.ValueOf(databrowser.NewBasic), + "NewBasicWindow": reflect.ValueOf(databrowser.NewBasicWindow), + "NewDataTree": reflect.ValueOf(databrowser.NewDataTree), + "NewDiffBrowserDirs": reflect.ValueOf(databrowser.NewDiffBrowserDirs), + "NewFileNode": reflect.ValueOf(databrowser.NewFileNode), + "NewTabs": reflect.ValueOf(databrowser.NewTabs), + "PromptOKCancel": reflect.ValueOf(databrowser.PromptOKCancel), + "PromptString": reflect.ValueOf(databrowser.PromptString), + "PromptStruct": reflect.ValueOf(databrowser.PromptStruct), + "TensorFS": reflect.ValueOf(databrowser.TensorFS), + "TheBrowser": reflect.ValueOf(&databrowser.TheBrowser).Elem(), + "TrimOrderPrefix": reflect.ValueOf(databrowser.TrimOrderPrefix), + + // type definitions + "Basic": reflect.ValueOf((*databrowser.Basic)(nil)), + "Browser": reflect.ValueOf((*databrowser.Browser)(nil)), + "DataTree": reflect.ValueOf((*databrowser.DataTree)(nil)), + "FileNode": reflect.ValueOf((*databrowser.FileNode)(nil)), + "Tabber": reflect.ValueOf((*databrowser.Tabber)(nil)), + "Tabs": reflect.ValueOf((*databrowser.Tabs)(nil)), + "Treer": reflect.ValueOf((*databrowser.Treer)(nil)), + + // interface wrapper definitions + "_Tabber": reflect.ValueOf((*_cogentcore_org_core_tensor_databrowser_Tabber)(nil)), + "_Treer": reflect.ValueOf((*_cogentcore_org_core_tensor_databrowser_Treer)(nil)), + } +} + +// _cogentcore_org_core_tensor_databrowser_Tabber is an interface wrapper for Tabber type +type _cogentcore_org_core_tensor_databrowser_Tabber struct { + IValue interface{} + WAsCoreTabs func() *core.Tabs + WAsDataTabs func() *databrowser.Tabs + WCurrentTab func() (core.Widget, int) + WEditorFile func(label string, filename string) *texteditor.Editor + WEditorString func(label string, content string) *texteditor.Editor + WGoUpdatePlot func(label string) *plotcore.PlotEditor + WPlotTable func(label string, dt *table.Table) *plotcore.PlotEditor + WPlotTensorFS func(dfs *tensorfs.Node) *plotcore.PlotEditor + WRecycleTab func(name string) *core.Frame + WSelectTabByName func(name string) *core.Frame + WSelectTabIndex func(idx int) *core.Frame + WSliceTable func(label string, slc any) *core.Table + WTabByName func(name string) *core.Frame + WTensorEditor func(label string, tsr tensor.Tensor) *tensorcore.TensorEditor + WTensorGrid func(label string, tsr tensor.Tensor) *tensorcore.TensorGrid + WTensorTable func(label string, dt *table.Table) *tensorcore.Table + WUpdatePlot func(label string) *plotcore.PlotEditor +} + +func (W _cogentcore_org_core_tensor_databrowser_Tabber) AsCoreTabs() *core.Tabs { + return W.WAsCoreTabs() +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) AsDataTabs() *databrowser.Tabs { + return W.WAsDataTabs() +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) CurrentTab() (core.Widget, int) { + return W.WCurrentTab() +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) EditorFile(label string, filename string) *texteditor.Editor { + return W.WEditorFile(label, filename) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) EditorString(label string, content string) *texteditor.Editor { + return W.WEditorString(label, content) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) GoUpdatePlot(label string) *plotcore.PlotEditor { + return W.WGoUpdatePlot(label) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) PlotTable(label string, dt *table.Table) *plotcore.PlotEditor { + return W.WPlotTable(label, dt) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) PlotTensorFS(dfs *tensorfs.Node) *plotcore.PlotEditor { + return W.WPlotTensorFS(dfs) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) RecycleTab(name string) *core.Frame { + return W.WRecycleTab(name) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) SelectTabByName(name string) *core.Frame { + return W.WSelectTabByName(name) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) SelectTabIndex(idx int) *core.Frame { + return W.WSelectTabIndex(idx) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) SliceTable(label string, slc any) *core.Table { + return W.WSliceTable(label, slc) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) TabByName(name string) *core.Frame { + return W.WTabByName(name) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) TensorEditor(label string, tsr tensor.Tensor) *tensorcore.TensorEditor { + return W.WTensorEditor(label, tsr) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) TensorGrid(label string, tsr tensor.Tensor) *tensorcore.TensorGrid { + return W.WTensorGrid(label, tsr) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) TensorTable(label string, dt *table.Table) *tensorcore.Table { + return W.WTensorTable(label, dt) +} +func (W _cogentcore_org_core_tensor_databrowser_Tabber) UpdatePlot(label string) *plotcore.PlotEditor { + return W.WUpdatePlot(label) +} + +// _cogentcore_org_core_tensor_databrowser_Treer is an interface wrapper for Treer type +type _cogentcore_org_core_tensor_databrowser_Treer struct { + IValue interface{} + WAsDataTree func() *databrowser.DataTree +} + +func (W _cogentcore_org_core_tensor_databrowser_Treer) AsDataTree() *databrowser.DataTree { + return W.WAsDataTree() +} diff --git a/yaegicore/symbols/cogentcore_org-core-tensor-table.go b/yaegicore/symbols/cogentcore_org-core-tensor-table.go deleted file mode 100644 index 1fd38dbb65..0000000000 --- a/yaegicore/symbols/cogentcore_org-core-tensor-table.go +++ /dev/null @@ -1,53 +0,0 @@ -// Code generated by 'yaegi extract cogentcore.org/core/tensor/table'. DO NOT EDIT. - -package symbols - -import ( - "cogentcore.org/core/tensor/table" - "reflect" -) - -func init() { - Symbols["cogentcore.org/core/tensor/table/table"] = map[string]reflect.Value{ - // function, constant and variable definitions - "AddAggName": reflect.ValueOf(table.AddAggName), - "Ascending": reflect.ValueOf(table.Ascending), - "ColumnNameOnly": reflect.ValueOf(table.ColumnNameOnly), - "Comma": reflect.ValueOf(table.Comma), - "ConfigFromDataValues": reflect.ValueOf(table.ConfigFromDataValues), - "ConfigFromHeaders": reflect.ValueOf(table.ConfigFromHeaders), - "ConfigFromTableHeaders": reflect.ValueOf(table.ConfigFromTableHeaders), - "Contains": reflect.ValueOf(table.Contains), - "DelimsN": reflect.ValueOf(table.DelimsN), - "DelimsValues": reflect.ValueOf(table.DelimsValues), - "Descending": reflect.ValueOf(table.Descending), - "Detect": reflect.ValueOf(table.Detect), - "DetectTableHeaders": reflect.ValueOf(table.DetectTableHeaders), - "Equals": reflect.ValueOf(table.Equals), - "Headers": reflect.ValueOf(table.Headers), - "IgnoreCase": reflect.ValueOf(table.IgnoreCase), - "InferDataType": reflect.ValueOf(table.InferDataType), - "NewIndexView": reflect.ValueOf(table.NewIndexView), - "NewSliceTable": reflect.ValueOf(table.NewSliceTable), - "NewTable": reflect.ValueOf(table.NewTable), - "NoHeaders": reflect.ValueOf(table.NoHeaders), - "ShapeFromString": reflect.ValueOf(table.ShapeFromString), - "Space": reflect.ValueOf(table.Space), - "Tab": reflect.ValueOf(table.Tab), - "TableColumnType": reflect.ValueOf(table.TableColumnType), - "TableHeaderChar": reflect.ValueOf(table.TableHeaderChar), - "TableHeaderToType": reflect.ValueOf(&table.TableHeaderToType).Elem(), - "UpdateSliceTable": reflect.ValueOf(table.UpdateSliceTable), - "UseCase": reflect.ValueOf(table.UseCase), - - // type definitions - "Delims": reflect.ValueOf((*table.Delims)(nil)), - "Filterer": reflect.ValueOf((*table.Filterer)(nil)), - "IndexView": reflect.ValueOf((*table.IndexView)(nil)), - "LessFunc": reflect.ValueOf((*table.LessFunc)(nil)), - "SplitAgg": reflect.ValueOf((*table.SplitAgg)(nil)), - "Splits": reflect.ValueOf((*table.Splits)(nil)), - "SplitsLessFunc": reflect.ValueOf((*table.SplitsLessFunc)(nil)), - "Table": reflect.ValueOf((*table.Table)(nil)), - } -} diff --git a/yaegicore/symbols/make b/yaegicore/symbols/make index 67bc088dec..ab362cdac4 100755 --- a/yaegicore/symbols/make +++ b/yaegicore/symbols/make @@ -1,11 +1,12 @@ #!/usr/bin/env cosh -yaegi extract fmt strconv strings image image/color image/draw time log/slog reflect - command extract { for _, pkg := range args { yaegi extract {"cogentcore.org/core/"+pkg} } } -extract core icons events styles styles/states styles/abilities styles/units tree keymap colors colors/gradient filetree texteditor htmlcore pages paint math32 plot plot/plots plot/plotcore tensor/table base/errors base/fsx base/reflectx base/labels base/fileinfo +yaegi extract image image/color image/draw + +extract core icons events styles styles/states styles/abilities styles/units tree keymap colors colors/gradient filetree texteditor htmlcore pages paint plot plot/plots plot/plotcore tensor/databrowser + diff --git a/yaegicore/yaegicore.go b/yaegicore/yaegicore.go index 24178dbba9..154c806d7d 100644 --- a/yaegicore/yaegicore.go +++ b/yaegicore/yaegicore.go @@ -17,6 +17,7 @@ import ( "cogentcore.org/core/events" "cogentcore.org/core/htmlcore" "cogentcore.org/core/texteditor" + "cogentcore.org/core/yaegicore/nogui" "cogentcore.org/core/yaegicore/symbols" "github.com/cogentcore/yaegi/interp" ) @@ -26,6 +27,7 @@ var autoPlanNameCounter uint64 func init() { htmlcore.BindTextEditor = BindTextEditor symbols.Symbols["."] = map[string]reflect.Value{} // make "." available for use + nogui.Symbols["."] = map[string]reflect.Value{} // make "." available for use } // BindTextEditor binds the given text editor to a yaegi interpreter