Skip to content

Commit

Permalink
Add first-class Radius login support (#1609)
Browse files Browse the repository at this point in the history
  • Loading branch information
benashz authored Sep 20, 2022
1 parent 2366b12 commit e3be294
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 0 deletions.
8 changes: 8 additions & 0 deletions internal/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const (
FieldAuthLoginCert = "auth_login_cert"
FieldAuthLoginGCP = "auth_login_gcp"
FieldAuthLoginKerberos = "auth_login_kerberos"
FieldAuthLoginRadius = "auth_login_radius"
FieldIdentity = "identity"
FieldSignature = "signature"
FieldPKCS7 = "pkcs7"
Expand Down Expand Up @@ -125,6 +126,11 @@ const (
// EnvVarKRBKeytab path the keytab file.
EnvVarKRBKeytab = "KRB_KEYTAB"

// EnvVarRadiusUsername for the Radius auth login
EnvVarRadiusUsername = "RADIUS_USERNAME"
// EnvVarRadiusPassword for the Radius auth login
EnvVarRadiusPassword = "RADIUS_PASSWORD"

/*
common mount types
*/
Expand All @@ -139,6 +145,7 @@ const (
MountTypeCert = "cert"
MountTypeGCP = "gcp"
MountTypeKerberos = "kerberos"
MountTypeRadius = "radius"

/*
Vault version constants
Expand All @@ -155,6 +162,7 @@ const (
AuthMethodCert = "cert"
AuthMethodGCP = "gcp"
AuthMethodKerberos = "kerberos"
AuthMethodRadius = "radius"

/*
misc. path related constants
Expand Down
17 changes: 17 additions & 0 deletions internal/provider/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ func (l *AuthLoginCommon) init(d *schema.ResourceData) (string, map[string]inter
return path, params, nil
}

func (l *AuthLoginCommon) checkRequiredFields(d *schema.ResourceData, required ...string) error {
var missing []string
for _, f := range required {
if _, ok := l.getOk(d, f); !ok {
missing = append(missing, f)
}
}

if len(missing) > 0 {
return fmt.Errorf("required fields are unset: %v", missing)
}

return nil
}

func (l *AuthLoginCommon) getOk(d *schema.ResourceData, field string) (interface{}, bool) {
return d.GetOk(fmt.Sprintf("%s.0.%s", l.authField, field))
}
Expand All @@ -153,6 +168,8 @@ func GetAuthLogin(r *schema.ResourceData) (AuthLogin, error) {
l = &AuthLoginGCP{}
case consts.FieldAuthLoginKerberos:
l = &AuthLoginKerberos{}
case consts.FieldAuthLoginRadius:
l = &AuthLoginRadius{}
default:
return nil, nil
}
Expand Down
82 changes: 82 additions & 0 deletions internal/provider/auth_radius.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package provider

import (
"fmt"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/vault/api"

"github.com/hashicorp/terraform-provider-vault/internal/consts"
)

// GetRadiusLoginSchema for the radius authentication engine.
func GetRadiusLoginSchema(authField string) *schema.Schema {
return getLoginSchema(
authField,
"Login to vault using the radius method",
GetRadiusLoginSchemaResource,
)
}

// GetRadiusLoginSchemaResource for the radius authentication engine.
func GetRadiusLoginSchemaResource(_ string) *schema.Resource {
return mustAddLoginSchema(&schema.Resource{
Schema: map[string]*schema.Schema{
consts.FieldUsername: {
Type: schema.TypeString,
Description: "The Radius username.",
Required: true,
DefaultFunc: schema.EnvDefaultFunc(consts.EnvVarRadiusUsername, nil),
},
consts.FieldPassword: {
Type: schema.TypeString,
Required: true,
Description: "The Radius password for username.",
DefaultFunc: schema.EnvDefaultFunc(consts.EnvVarRadiusPassword, nil),
},
},
}, consts.MountTypeRadius)
}

type AuthLoginRadius struct {
AuthLoginCommon
}

// MountPath for the radius authentication engine.
func (l *AuthLoginRadius) MountPath() string {
if l.mount == "" {
return l.Method()
}
return l.mount
}

// LoginPath for the radius authentication engine.
func (l *AuthLoginRadius) LoginPath() string {
return fmt.Sprintf("auth/%s/login", l.MountPath())
}

func (l *AuthLoginRadius) Init(d *schema.ResourceData, authField string) error {
if err := l.AuthLoginCommon.Init(d, authField); err != nil {
return err
}

if err := l.checkRequiredFields(d, consts.FieldUsername, consts.FieldPassword); err != nil {
return err
}

return nil
}

// Method name for the radius authentication engine.
func (l *AuthLoginRadius) Method() string {
return consts.AuthMethodRadius
}

// Login using the radius authentication engine.
func (l *AuthLoginRadius) Login(client *api.Client) (*api.Secret, error) {
if !l.initialized {
return nil, fmt.Errorf("auth login not initialized")
}

return l.login(client, l.LoginPath(), l.copyParams(consts.FieldNamespace, consts.FieldMount))
}
201 changes: 201 additions & 0 deletions internal/provider/auth_radius_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package provider

import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"testing"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/vault/api"

"github.com/hashicorp/terraform-provider-vault/internal/consts"
)

func TestAuthLoginRadius_Init(t *testing.T) {
tests := []struct {
name string
authField string
raw map[string]interface{}
wantErr bool
expectParams map[string]interface{}
expectErr error
}{
{
name: "basic",
authField: consts.FieldAuthLoginRadius,
raw: map[string]interface{}{
consts.FieldAuthLoginRadius: []interface{}{
map[string]interface{}{
consts.FieldNamespace: "ns1",
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
},
},
expectParams: map[string]interface{}{
consts.FieldNamespace: "ns1",
consts.FieldMount: consts.MountTypeRadius,
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
wantErr: false,
},
{
name: "error-missing-resource",
authField: consts.FieldAuthLoginRadius,
expectParams: nil,
wantErr: true,
expectErr: fmt.Errorf("resource data missing field %q", consts.FieldAuthLoginRadius),
},
{
name: "error-missing-required",
authField: consts.FieldAuthLoginRadius,
raw: map[string]interface{}{
consts.FieldAuthLoginRadius: []interface{}{
map[string]interface{}{
consts.FieldUsername: "alice",
},
},
},
expectParams: nil,
wantErr: true,
expectErr: fmt.Errorf("required fields are unset: %v", []string{
consts.FieldPassword,
}),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := map[string]*schema.Schema{
tt.authField: GetRadiusLoginSchema(tt.authField),
}

d := schema.TestResourceDataRaw(t, s, tt.raw)
l := &AuthLoginRadius{}
err := l.Init(d, tt.authField)
if (err != nil) != tt.wantErr {
t.Fatalf("Init() error = %v, wantErr %v", err, tt.wantErr)
}

if err != nil {
if tt.expectErr != nil {
if !reflect.DeepEqual(tt.expectErr, err) {
t.Errorf("Init() expected error %#v, actual %#v", tt.expectErr, err)
}
}
} else {
if !reflect.DeepEqual(tt.expectParams, l.params) {
t.Errorf("Init() expected params %#v, actual %#v", tt.expectParams, l.params)
}
}
})
}
}

func TestAuthLoginRadius_LoginPath(t *testing.T) {
type fields struct {
AuthLoginCommon AuthLoginCommon
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "default",
fields: fields{
AuthLoginCommon: AuthLoginCommon{
params: map[string]interface{}{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
},
},
want: "auth/radius/login",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &AuthLoginRadius{
AuthLoginCommon: tt.fields.AuthLoginCommon,
}
if got := l.LoginPath(); got != tt.want {
t.Errorf("LoginPath() = %v, want %v", got, tt.want)
}
})
}
}

func TestAuthLoginRadius_Login(t *testing.T) {
handlerFunc := func(t *testLoginHandler, w http.ResponseWriter, req *http.Request) {
m, err := json.Marshal(
&api.Secret{},
)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

w.WriteHeader(http.StatusOK)
if _, err := w.Write(m); err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}

tests := []authLoginTest{
{
name: "basic",
authLogin: &AuthLoginRadius{
AuthLoginCommon: AuthLoginCommon{
authField: consts.FieldAuthLoginRadius,
params: map[string]interface{}{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
initialized: true,
},
},
handler: &testLoginHandler{
handlerFunc: handlerFunc,
},
expectReqCount: 1,
expectReqPaths: []string{"/v1/auth/radius/login"},
expectReqParams: []map[string]interface{}{
{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
},
want: &api.Secret{},
wantErr: false,
},
{
name: "error-uninitialized",
authLogin: &AuthLoginRadius{
AuthLoginCommon: AuthLoginCommon{
authField: consts.FieldAuthLoginRadius,
params: map[string]interface{}{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
initialized: false,
},
},
handler: &testLoginHandler{
handlerFunc: handlerFunc,
},
expectReqCount: 0,
want: nil,
wantErr: true,
expectErr: fmt.Errorf("auth login not initialized"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testAuthLogin(t, tt)
})
}
}
7 changes: 7 additions & 0 deletions internal/provider/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type authLoginTest struct {
expectReqParams []map[string]interface{}
expectReqPaths []string
wantErr bool
expectErr error
skipFunc func(t *testing.T)
}

Expand Down Expand Up @@ -88,6 +89,12 @@ func testAuthLogin(t *testing.T, tt authLoginTest) {
return
}

if err != nil && tt.expectErr != nil {
if !reflect.DeepEqual(tt.expectErr, err) {
t.Errorf("Login() expected error %#v, actual %#v", tt.expectErr, err)
}
}

if tt.expectReqCount != tt.handler.requestCount {
t.Errorf("Login() expected %d requests, actual %d", tt.expectReqCount, tt.handler.requestCount)
}
Expand Down
2 changes: 2 additions & 0 deletions vault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ func Provider() *schema.Provider {
f = provider.GetGCPLoginSchema
case consts.FieldAuthLoginKerberos:
f = provider.GetKerberosLoginSchema
case consts.FieldAuthLoginRadius:
f = provider.GetRadiusLoginSchema
default:
continue
}
Expand Down
Loading

0 comments on commit e3be294

Please sign in to comment.