Skip to content

Commit

Permalink
Merge pull request #60 from saltosystems/feature/replace-cgo-with-ker…
Browse files Browse the repository at this point in the history
…nel32-syscall

delegates: replace malloc and free with HeapAlloc and HeapFree
  • Loading branch information
jagobagascon authored Sep 21, 2023
2 parents a39229b + a79d299 commit 2ab5b7d
Show file tree
Hide file tree
Showing 11 changed files with 501 additions and 374 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/peterbourgon/ff/v3 v3.1.2
github.com/stretchr/testify v1.7.5
github.com/tdakkota/win32metadata v0.1.0
golang.org/x/sys v0.0.0-20220624220833-87e55d714810
golang.org/x/tools v0.1.11
)

Expand All @@ -17,6 +18,5 @@ require (
github.com/go-logfmt/logfmt v0.5.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/sys v0.0.0-20220624220833-87e55d714810 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
3 changes: 0 additions & 3 deletions internal/codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,6 @@ func (g *generator) loadCodeGenData(typeDef *winmd.TypeDef) error {
return err
}
f.Data.Delegates = append(f.Data.Delegates, delegate)

exportsFile := g.addFile(typeDef, "_exports")
exportsFile.Data.DelegateExports = append(f.Data.DelegateExports, delegate)
default:
_ = level.Info(g.logger).Log("msg", "generating class", "class", typeDef.TypeNamespace+"."+typeDef.TypeName)

Expand Down
15 changes: 7 additions & 8 deletions internal/codegen/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ type genDataFile struct {
}

type genData struct {
Package string
Imports []string
Classes []*genClass
Enums []*genEnum
Interfaces []*genInterface
Structs []*genStruct
Delegates []*genDelegate
DelegateExports []*genDelegate
Package string
Imports []string
Classes []*genClass
Enums []*genEnum
Interfaces []*genInterface
Structs []*genStruct
Delegates []*genDelegate
}

func (g *genData) ComputeImports(typeDef *winmd.TypeDef) {
Expand Down
185 changes: 144 additions & 41 deletions internal/codegen/templates/delegate.tmpl
Original file line number Diff line number Diff line change
@@ -1,36 +1,3 @@
/*
#include <stdint.h>

// Note: these functions have a different signature but because they are only
// used as function pointers (and never called) and because they use C name
// mangling, the signature doesn't really matter.
void winrt_{{.Name}}_Invoke(void);
void winrt_{{.Name}}_QueryInterface(void);
uint64_t winrt_{{.Name}}_AddRef(void);
uint64_t winrt_{{.Name}}_Release(void);

// The Vtable structure for WinRT {{.Name}} interfaces.
typedef struct {
void *QueryInterface;
void *AddRef;
void *Release;
void *Invoke;
} {{.Name}}Vtbl_t;

// The Vtable itself. It can be kept constant.
static const {{.Name}}Vtbl_t winrt_{{.Name}}Vtbl = {
(void*)winrt_{{.Name}}_QueryInterface,
(void*)winrt_{{.Name}}_AddRef,
(void*)winrt_{{.Name}}_Release,
(void*)winrt_{{.Name}}_Invoke,
};

// A small helper function to get the Vtable.
const {{.Name}}Vtbl_t * winrt_get{{.Name}}Vtbl(void) {
return &winrt_{{.Name}}Vtbl;
}
*/
import "C"

const GUID{{.Name}} string = "{{.GUID}}"
const Signature{{.Name}} string = "{{.Signature}}"
Expand All @@ -42,25 +9,47 @@ type {{.Name}} struct {
IID ole.GUID
}

type {{.Name}}Vtbl struct {
ole.IUnknownVtbl
Invoke uintptr
}

type {{.Name}}Callback func(instance *{{.Name}},{{- range .InParams -}}
{{.GoVarName}} {{template "variabletype.tmpl" . }},
{{- end -}})

var callbacks{{.Name}} = &{{.Name | toLower}}CallbacksMap {
var callbacks{{.Name}} = &{{.Name | toLower}}Callbacks {
mu: &sync.Mutex{},
callbacks: make(map[unsafe.Pointer]{{.Name}}Callback),
}

var releaseChannels{{.Name}} = &{{.Name | toLower}}ReleaseChannels {
mu: &sync.Mutex{},
chans: make(map[unsafe.Pointer]chan struct{}),
}

func New{{.Name}}(iid *ole.GUID, callback {{.Name}}Callback) *{{.Name}} {
inst := (*{{.Name}})(C.malloc(C.size_t(unsafe.Sizeof({{.Name}}{}))))
// Override all properties: the malloc may contain garbage
inst.RawVTable = (*interface{})((unsafe.Pointer)(C.winrt_get{{.Name}}Vtbl()))
size := unsafe.Sizeof(*(*{{.Name}})(nil))
instPtr := kernel32.Malloc(size)
inst := (*{{.Name}})(instPtr)
// Initialize all properties: the malloc may contain garbage
inst.RawVTable = (*interface{})(unsafe.Pointer(&{{.Name}}Vtbl{
IUnknownVtbl: ole.IUnknownVtbl{
QueryInterface: syscall.NewCallback(inst.QueryInterface),
AddRef: syscall.NewCallback(inst.AddRef),
Release: syscall.NewCallback(inst.Release),
},
Invoke: syscall.NewCallback(inst.Invoke),
}))
inst.IID = *iid // copy contents
inst.Mutex = sync.Mutex{}
inst.refs = 0

callbacks{{.Name}}.add(unsafe.Pointer(inst), callback)

// See the docs in the releaseChannels{{.Name}} struct
releaseChannels{{.Name}}.acquire(unsafe.Pointer(inst))

inst.addRef()
return inst
}
Expand All @@ -85,29 +74,143 @@ func (r *{{.Name}}) removeRef() uint64 {
return r.refs
}

type {{.Name | toLower}}CallbacksMap struct {
func (instance *{{.Name}}) QueryInterface(_, iidPtr unsafe.Pointer, ppvObject *unsafe.Pointer) uintptr {
// Checkout these sources for more information about the QueryInterface method.
// - https://docs.microsoft.com/en-us/cpp/atl/queryinterface
// - https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void)

if ppvObject == nil {
// If ppvObject (the address) is nullptr, then this method returns E_POINTER.
return ole.E_POINTER
}

// This function must adhere to the QueryInterface defined here:
// https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nn-unknwn-iunknown
iid := (*ole.GUID)(iidPtr)
if ole.IsEqualGUID(iid, &instance.IID) || ole.IsEqualGUID(iid, ole.IID_IUnknown) || ole.IsEqualGUID(iid, ole.IID_IInspectable){
*ppvObject = unsafe.Pointer(instance)
} else {
*ppvObject = nil
// Return E_NOINTERFACE if the interface is not supported
return ole.E_NOINTERFACE
}

// If the COM object implements the interface, then it returns
// a pointer to that interface after calling IUnknown::AddRef on it.
(*ole.IUnknown)(*ppvObject).AddRef()

// Return S_OK if the interface is supported
return ole.S_OK
}

func (instance *{{.Name}}) Invoke(instancePtr unsafe.Pointer {{range .InParams -}}
,
{{- if .Type.IsEnum -}}
{{.GoVarName}}Raw {{.Type.UnderlyingEnumType}}
{{- else -}}
{{.GoVarName}}Ptr unsafe.Pointer
{{- end}}
{{- end -}}) uintptr {
// See the quote above.
{{range .InParams -}}
{{if .Type.IsEnum -}}
{{.GoVarName}} := ({{template "variabletype.tmpl" . }})({{.GoVarName}}Raw)
{{else -}}
{{.GoVarName}} := ({{template "variabletype.tmpl" . }})({{.GoVarName}}Ptr)
{{end -}}
{{end -}}
if callback, ok := callbacks{{.Name}}.get(instancePtr); ok {
callback(instance, {{range .InParams}}{{.GoVarName}},{{end}})
}
return ole.S_OK
}

func (instance *{{.Name}}) AddRef() uint64 {
return instance.addRef()
}

func (instance *{{.Name}}) Release() uint64 {
rem := instance.removeRef()
if rem == 0 {
// We're done.
instancePtr := unsafe.Pointer(instance)
callbacks{{.Name}}.delete(instancePtr)

// stop release channels used to avoid
// https://github.com/golang/go/issues/55015
releaseChannels{{.Name}}.release(instancePtr)

kernel32.Free(instancePtr)
}
return rem
}

type {{.Name | toLower}}Callbacks struct {
mu *sync.Mutex
callbacks map[unsafe.Pointer]{{.Name}}Callback
}

func (m *{{.Name | toLower}}CallbacksMap) add(p unsafe.Pointer, v {{.Name}}Callback) {
func (m *{{.Name | toLower}}Callbacks) add(p unsafe.Pointer, v {{.Name}}Callback) {
m.mu.Lock()
defer m.mu.Unlock()

m.callbacks[p] = v
}

func (m *{{.Name | toLower}}CallbacksMap) get(p unsafe.Pointer) ({{.Name}}Callback, bool) {
func (m *{{.Name | toLower}}Callbacks) get(p unsafe.Pointer) ({{.Name}}Callback, bool) {
m.mu.Lock()
defer m.mu.Unlock()

v, ok := m.callbacks[p]
return v, ok
}

func (m *{{.Name | toLower}}CallbacksMap) delete(p unsafe.Pointer) {
func (m *{{.Name | toLower}}Callbacks) delete(p unsafe.Pointer) {
m.mu.Lock()
defer m.mu.Unlock()

delete(m.callbacks, p)
}

// typedEventHandlerReleaseChannels keeps a map with channels
// used to keep a goroutine alive during the lifecycle of this object.
// This is required to avoid causing a deadlock error.
// See this: https://github.com/golang/go/issues/55015
type {{.Name | toLower}}ReleaseChannels struct {
mu *sync.Mutex
chans map[unsafe.Pointer]chan struct{}
}

func (m *{{.Name | toLower}}ReleaseChannels) acquire(p unsafe.Pointer) {
m.mu.Lock()
defer m.mu.Unlock()

c := make(chan struct{})
m.chans[p] = c

go func() {
// we need a timer to trick the go runtime into
// thinking there's still something going on here
// but we are only really interested in <-c
t := time.NewTimer(time.Minute)
for {
select {
case <-t.C:
t.Reset(time.Minute)
case <-c:
t.Stop()
return
}
}
}()
}

func (m *{{.Name | toLower}}ReleaseChannels) release(p unsafe.Pointer) {
m.mu.Lock()
defer m.mu.Unlock()

if c, ok := m.chans[p]; ok {
close(c)
delete(m.chans, p)
}
}
78 changes: 0 additions & 78 deletions internal/codegen/templates/delegate_exports.tmpl

This file was deleted.

4 changes: 1 addition & 3 deletions internal/codegen/templates/file.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"unsafe"
"github.com/go-ole/go-ole"
"github.com/saltosystems/winrt-go"
"github.com/saltosystems/winrt-go/internal/kernel32"
{{range .Imports}}"{{.}}"
{{end}}
)
Expand All @@ -33,6 +34,3 @@ import (
{{range .Delegates}}
{{template "delegate.tmpl" .}}
{{end}}
{{range .DelegateExports}}
{{template "delegate_exports.tmpl" .}}
{{end}}
Loading

0 comments on commit 2ab5b7d

Please sign in to comment.