Skip to content

Commit

Permalink
Merge pull request #6 from unreality/listener-checks
Browse files Browse the repository at this point in the history
Various fixes
  • Loading branch information
unreality authored Nov 10, 2022
2 parents c100fc8 + d2780b5 commit 9ccc53d
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 52 deletions.
195 changes: 158 additions & 37 deletions keyman/keymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,16 @@ func NewKeyManager(configPath string) (*KeyManager, error) {
}

km := KeyManager{
Keys: make(map[string]*Key),
providerHandles: make(map[string]uintptr),
configPath: configPath,
config: &kmc,
hwnd: 0,
publicKeysDir: publicKeysDir,
Keys: make(map[string]*Key),
providerHandles: make(map[string]uintptr),
configPath: configPath,
config: &kmc,
hwnd: 0,
publicKeysDir: publicKeysDir,
cygwinListener: nil,
pageantListener: nil,
namedPipeListener: nil,
vSockListener: nil,
}
km.providerHandles = make(map[string]uintptr)
km.configPath = configPath
Expand All @@ -534,33 +538,48 @@ func NewKeyManager(configPath string) (*KeyManager, error) {
mu: sync.Mutex{},
}

km.namedPipeListener = new(listeners.NamedPipe)
km.pageantListener = new(listeners.Pageant)
km.cygwinListener = new(listeners.Cygwin)
km.cygwinListener.Sockfile = filepath.Join(filepath.Dir(km.configPath), "cygwin-agent.sock")
km.vSockListener = new(listeners.VSock)
return &km, nil
}

func (km *KeyManager) Start() error {
saveConfig := false
km.lctx, km.cancel = context.WithCancel(context.Background())

km.lwg = new(sync.WaitGroup)

if km.config.CygwinEnabled {
km.StartListener(km.cygwinListener)
_, disableListenerInConfig, _ := km.StartListener(listeners.TYPE_CYGWIN)
if disableListenerInConfig {
km.config.CygwinEnabled = false
saveConfig = true
}
}

if km.config.VSockEnabled {
km.StartListener(km.vSockListener)
_, disableListenerInConfig, _ := km.StartListener(listeners.TYPE_VSOCK)
if disableListenerInConfig {
km.config.VSockEnabled = false
saveConfig = true
}
}

if km.config.NamedPipeEnabled {
km.StartListener(km.namedPipeListener)
_, disableListenerInConfig, _ := km.StartListener(listeners.TYPE_NAMED_PIPE)
if disableListenerInConfig {
km.config.NamedPipeEnabled = false
saveConfig = true
}
}

if km.config.PageantEnabled {
km.StartListener(km.pageantListener)
_, disableListenerInConfig, _ := km.StartListener(listeners.TYPE_PAGEANT)
if disableListenerInConfig {
km.config.PageantEnabled = false
saveConfig = true
}
}

for _, k := range kmc.Keys {
for _, k := range km.config.Keys {
log.Printf("Loading key %s\n", k.Name)
var err error

Expand All @@ -571,7 +590,7 @@ func NewKeyManager(configPath string) (*KeyManager, error) {

_, err = km.getProviderHandle(k.ProviderName)
if err != nil {
return nil, fmt.Errorf("unable to open provider %s for %s: %w", k.ProviderName, k.Name, err)
return fmt.Errorf("unable to open provider %s for %s: %w", k.ProviderName, k.Name, err)
}

_, err = km.LoadNCryptKey(k)
Expand Down Expand Up @@ -603,10 +622,63 @@ func NewKeyManager(configPath string) (*KeyManager, error) {
}
}

return &km, nil
if saveConfig {
km.SaveConfig()
}

return nil
}

func (km *KeyManager) StartListener(listener listeners.Listener) {
// StartListener attempts top start a listenerType, returning (success, disableListenerInConfig, error)
// If disableListenerInConfig is true, the caller should disable the listener in the config and save
func (km *KeyManager) StartListener(listenerType string) (bool, bool, error) {
var err error
var listener listeners.Listener

switch listenerType {
case listeners.TYPE_VSOCK:
if km.vSockListener != nil {
km.vSockListener.Stop()
}

km.vSockListener, err = listeners.NewVSockListener()

if err != nil {
var listenerErr *listeners.ListenerError
if errors.As(err, &listenerErr) {
switch listenerErr.Code() {
case listeners.ERR_DISABLE:
log.Printf("disabled vSock listener: %s", err)
return false, true, nil
case listeners.ERR_ABORTED:
log.Printf("disabled vSock listener, but config wont be saved: %s", err)
return false, false, nil
}
} else {
log.Printf("could not create vsock listener: %s", err)
return false, false, err
}
}

if km.vSockListener == nil {
return false, false, nil
}

listener = km.vSockListener
case listeners.TYPE_CYGWIN:
km.cygwinListener = new(listeners.Cygwin)
km.cygwinListener.Sockfile = filepath.Join(filepath.Dir(km.configPath), "cygwin-agent.sock")
listener = km.cygwinListener
case listeners.TYPE_NAMED_PIPE:
km.namedPipeListener = new(listeners.NamedPipe)
listener = km.namedPipeListener
case listeners.TYPE_PAGEANT:
km.pageantListener = new(listeners.Pageant)
listener = km.pageantListener
default:
return false, false, fmt.Errorf("invalid listener type %s", listenerType)
}

km.lwg.Add(1)
go func(l listeners.Listener) {
log.Printf("Starting listener %T\n", l)
Expand All @@ -617,20 +689,8 @@ func (km *KeyManager) StartListener(listener listeners.Listener) {
}
km.lwg.Done()
}(listener)
}

func (km *KeyManager) EnsureListenerIs(listener listeners.Listener, enabled bool) {
if listener.Running() == enabled {
return
}

if listener.Running() == false && enabled == true {
km.StartListener(listener)
}

if listener.Running() == true && enabled == false {
listener.Stop()
}
return true, false, nil
}

func (km *KeyManager) LoadNCryptKey(kc *KeyConfig) (*Key, error) {
Expand Down Expand Up @@ -1186,18 +1246,75 @@ func (km *KeyManager) GetPinTimeout() int {
}

func (km *KeyManager) EnableListener(listenerType string, enabled bool) {

running := false

switch listenerType {
case listeners.TYPE_PAGEANT:
if km.pageantListener != nil {
running = km.pageantListener.Running()
}
case listeners.TYPE_CYGWIN:
if km.cygwinListener != nil {
running = km.cygwinListener.Running()
}
case listeners.TYPE_VSOCK:
if km.vSockListener != nil {
running = km.vSockListener.Running()
}
case listeners.TYPE_NAMED_PIPE:
if km.namedPipeListener != nil {
running = km.namedPipeListener.Running()
}
default:
return
}

if running == enabled {
return
}

if running == true && enabled == false {
switch listenerType {
case listeners.TYPE_PAGEANT:
if km.pageantListener != nil {
km.pageantListener.Stop()
}
case listeners.TYPE_CYGWIN:
if km.cygwinListener != nil {
km.cygwinListener.Stop()
}
case listeners.TYPE_VSOCK:
if km.vSockListener != nil {
km.vSockListener.Stop()
}
case listeners.TYPE_NAMED_PIPE:
if km.namedPipeListener != nil {
km.namedPipeListener.Stop()
}
default:
return
}
}

if running == false && enabled == true {
listenerStarted, disableListenerInConfig, _ := km.StartListener(listenerType)

if disableListenerInConfig {
enabled = false
}

enabled = listenerStarted
}

switch listenerType {
case listeners.TYPE_PAGEANT:
km.EnsureListenerIs(km.pageantListener, enabled)
km.config.PageantEnabled = enabled
case listeners.TYPE_CYGWIN:
km.EnsureListenerIs(km.cygwinListener, enabled)
km.config.CygwinEnabled = enabled
case listeners.TYPE_VSOCK:
km.EnsureListenerIs(km.vSockListener, enabled)
km.config.VSockEnabled = enabled
case listeners.TYPE_NAMED_PIPE:
km.EnsureListenerIs(km.namedPipeListener, enabled)
km.config.NamedPipeEnabled = enabled
}
}
Expand Down Expand Up @@ -1228,7 +1345,11 @@ func (km *KeyManager) Notify(n NotifyMsg) {
}

func (km *KeyManager) CygwinSocketLocation() string {
return km.cygwinListener.Sockfile
if km.cygwinListener != nil {
return km.cygwinListener.Sockfile
}

return ""
}

func (km *KeyManager) LoadWebAuthNKey(kc *KeyConfig) (*Key, error) {
Expand Down
88 changes: 88 additions & 0 deletions keyman/listeners/genericdialog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package listeners

import (
"github.com/lxn/walk"
)

type GenericDlg struct {
*walk.Dialog
MessageLabel *walk.LinkLabel
ButtonOne *walk.PushButton
ButtonTwo *walk.PushButton
}

//func RunGenericDialog(owner walk.Form, keyName string) (bool, bool) {
// dlg, err := NewGenericDialog(owner, keyName)
// if showError(err, owner) {
// return false, false
// }
//
// if dlg.Run() == walk.DlgCmdOK {
// return true, dlg.doDeleteFromKeystore
// }
//
// return false, false
//}

func NewGenericDialog(owner walk.Form, title string, message string, buttonOneText string, buttonTwoText string) (*GenericDlg, error) {
var err error
var disposables walk.Disposables
defer disposables.Treat()

dlg := new(GenericDlg)

layout := walk.NewGridLayout()
layout.SetSpacing(6)
layout.SetMargins(walk.Margins{10, 10, 10, 10})
layout.SetColumnStretchFactor(1, 3)

if dlg.Dialog, err = walk.NewDialog(owner); err != nil {
return nil, err
}
disposables.Add(dlg)
// dlg.SetIcon(owner.Icon())
dlg.SetTitle(title)
dlg.SetLayout(layout)
dlg.SetMinMaxSize(walk.Size{500, 200}, walk.Size{0, 0})
//if icon, err := ui.loadSystemIcon("imageres", 109, 32); err == nil {
// dlg.SetIcon(icon)
//}

dlg.MessageLabel, err = walk.NewLinkLabel(dlg)
if err != nil {
return nil, err
}
layout.SetRange(dlg.MessageLabel, walk.Rectangle{0, 0, 2, 2})
//dlg.MessageLabel.SetTextAlignment(walk.AlignHNearVCenter)
dlg.MessageLabel.SetText(message)
dlg.MessageLabel.SetAlignment(walk.AlignHNearVNear)

dlg.MessageLabel.SetVisible(true)

buttonsContainer, err := walk.NewComposite(dlg)
if err != nil {
return nil, err
}
layout.SetRange(buttonsContainer, walk.Rectangle{0, 3, 2, 1})
buttonsContainer.SetLayout(walk.NewHBoxLayout())
buttonsContainer.Layout().SetMargins(walk.Margins{})

walk.NewHSpacer(buttonsContainer)
if dlg.ButtonOne, err = walk.NewPushButton(buttonsContainer); err != nil {
return nil, err
}
dlg.ButtonOne.SetText(buttonOneText)

dlg.ButtonTwo, err = walk.NewPushButton(buttonsContainer)
if err != nil {
return nil, err
}
dlg.ButtonTwo.SetText(buttonTwoText)

dlg.SetCancelButton(dlg.ButtonTwo)
dlg.SetDefaultButton(dlg.ButtonOne)

disposables.Spare()

return dlg, nil
}
11 changes: 11 additions & 0 deletions keyman/listeners/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
const (
STATUS_OK = "ok"
STATUS_STOPPED = "stopped"

ERR_DISABLE = 0
ERR_ABORTED = 1
)

type Listener interface {
Expand All @@ -17,3 +20,11 @@ type Listener interface {
LastError() error
Running() bool
}

type ListenerError struct {
msg string // description of error
code int
}

func (e *ListenerError) Error() string { return e.msg }
func (e *ListenerError) Code() int { return e.code }
Loading

0 comments on commit 9ccc53d

Please sign in to comment.