Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first-class Radius login support #1609

Merged
merged 1 commit into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const (
FieldAuthLoginCert = "auth_login_cert"
FieldAuthLoginGCP = "auth_login_gcp"
FieldAuthLoginKerberos = "auth_login_kerberos"
FieldAuthLoginRadius = "auth_login_radius"
FieldIdentity = "identity"
FieldSignature = "signature"
FieldPKCS7 = "pkcs7"
Expand Down Expand Up @@ -125,6 +126,11 @@ const (
// EnvVarKRBKeytab path the keytab file.
EnvVarKRBKeytab = "KRB_KEYTAB"

// EnvVarRadiusUsername for the Radius auth login
EnvVarRadiusUsername = "RADIUS_USERNAME"
// EnvVarRadiusPassword for the Radius auth login
EnvVarRadiusPassword = "RADIUS_PASSWORD"

/*
common mount types
*/
Expand All @@ -139,6 +145,7 @@ const (
MountTypeCert = "cert"
MountTypeGCP = "gcp"
MountTypeKerberos = "kerberos"
MountTypeRadius = "radius"

/*
Vault version constants
Expand All @@ -155,6 +162,7 @@ const (
AuthMethodCert = "cert"
AuthMethodGCP = "gcp"
AuthMethodKerberos = "kerberos"
AuthMethodRadius = "radius"

/*
misc. path related constants
Expand Down
17 changes: 17 additions & 0 deletions internal/provider/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ func (l *AuthLoginCommon) init(d *schema.ResourceData) (string, map[string]inter
return path, params, nil
}

func (l *AuthLoginCommon) checkRequiredFields(d *schema.ResourceData, required ...string) error {
var missing []string
for _, f := range required {
if _, ok := l.getOk(d, f); !ok {
missing = append(missing, f)
}
}

if len(missing) > 0 {
return fmt.Errorf("required fields are unset: %v", missing)
}

return nil
}

func (l *AuthLoginCommon) getOk(d *schema.ResourceData, field string) (interface{}, bool) {
return d.GetOk(fmt.Sprintf("%s.0.%s", l.authField, field))
}
Expand All @@ -153,6 +168,8 @@ func GetAuthLogin(r *schema.ResourceData) (AuthLogin, error) {
l = &AuthLoginGCP{}
case consts.FieldAuthLoginKerberos:
l = &AuthLoginKerberos{}
case consts.FieldAuthLoginRadius:
l = &AuthLoginRadius{}
default:
return nil, nil
}
Expand Down
82 changes: 82 additions & 0 deletions internal/provider/auth_radius.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package provider

import (
"fmt"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/vault/api"

"github.com/hashicorp/terraform-provider-vault/internal/consts"
)

// GetRadiusLoginSchema for the radius authentication engine.
func GetRadiusLoginSchema(authField string) *schema.Schema {
return getLoginSchema(
authField,
"Login to vault using the radius method",
GetRadiusLoginSchemaResource,
)
}

// GetRadiusLoginSchemaResource for the radius authentication engine.
func GetRadiusLoginSchemaResource(_ string) *schema.Resource {
return mustAddLoginSchema(&schema.Resource{
Schema: map[string]*schema.Schema{
consts.FieldUsername: {
Type: schema.TypeString,
Description: "The Radius username.",
Required: true,
DefaultFunc: schema.EnvDefaultFunc(consts.EnvVarRadiusUsername, nil),
},
consts.FieldPassword: {
Type: schema.TypeString,
Required: true,
Description: "The Radius password for username.",
DefaultFunc: schema.EnvDefaultFunc(consts.EnvVarRadiusPassword, nil),
},
},
}, consts.MountTypeRadius)
}

type AuthLoginRadius struct {
AuthLoginCommon
}

// MountPath for the radius authentication engine.
func (l *AuthLoginRadius) MountPath() string {
if l.mount == "" {
return l.Method()
}
return l.mount
}

// LoginPath for the radius authentication engine.
func (l *AuthLoginRadius) LoginPath() string {
return fmt.Sprintf("auth/%s/login", l.MountPath())
}

func (l *AuthLoginRadius) Init(d *schema.ResourceData, authField string) error {
if err := l.AuthLoginCommon.Init(d, authField); err != nil {
return err
}

if err := l.checkRequiredFields(d, consts.FieldUsername, consts.FieldPassword); err != nil {
return err
}

return nil
}

// Method name for the radius authentication engine.
func (l *AuthLoginRadius) Method() string {
return consts.AuthMethodRadius
}

// Login using the radius authentication engine.
func (l *AuthLoginRadius) Login(client *api.Client) (*api.Secret, error) {
if !l.initialized {
return nil, fmt.Errorf("auth login not initialized")
}

return l.login(client, l.LoginPath(), l.copyParams(consts.FieldNamespace, consts.FieldMount))
}
201 changes: 201 additions & 0 deletions internal/provider/auth_radius_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package provider

import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"testing"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/vault/api"

"github.com/hashicorp/terraform-provider-vault/internal/consts"
)

func TestAuthLoginRadius_Init(t *testing.T) {
tests := []struct {
name string
authField string
raw map[string]interface{}
wantErr bool
expectParams map[string]interface{}
expectErr error
}{
{
name: "basic",
authField: consts.FieldAuthLoginRadius,
raw: map[string]interface{}{
consts.FieldAuthLoginRadius: []interface{}{
map[string]interface{}{
consts.FieldNamespace: "ns1",
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
},
},
expectParams: map[string]interface{}{
consts.FieldNamespace: "ns1",
consts.FieldMount: consts.MountTypeRadius,
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
wantErr: false,
},
{
name: "error-missing-resource",
authField: consts.FieldAuthLoginRadius,
expectParams: nil,
wantErr: true,
expectErr: fmt.Errorf("resource data missing field %q", consts.FieldAuthLoginRadius),
},
{
name: "error-missing-required",
authField: consts.FieldAuthLoginRadius,
raw: map[string]interface{}{
consts.FieldAuthLoginRadius: []interface{}{
map[string]interface{}{
consts.FieldUsername: "alice",
},
},
},
expectParams: nil,
wantErr: true,
expectErr: fmt.Errorf("required fields are unset: %v", []string{
consts.FieldPassword,
}),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := map[string]*schema.Schema{
tt.authField: GetRadiusLoginSchema(tt.authField),
}

d := schema.TestResourceDataRaw(t, s, tt.raw)
l := &AuthLoginRadius{}
err := l.Init(d, tt.authField)
if (err != nil) != tt.wantErr {
t.Fatalf("Init() error = %v, wantErr %v", err, tt.wantErr)
}

if err != nil {
if tt.expectErr != nil {
if !reflect.DeepEqual(tt.expectErr, err) {
t.Errorf("Init() expected error %#v, actual %#v", tt.expectErr, err)
}
}
} else {
if !reflect.DeepEqual(tt.expectParams, l.params) {
t.Errorf("Init() expected params %#v, actual %#v", tt.expectParams, l.params)
}
}
})
}
}

func TestAuthLoginRadius_LoginPath(t *testing.T) {
type fields struct {
AuthLoginCommon AuthLoginCommon
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "default",
fields: fields{
AuthLoginCommon: AuthLoginCommon{
params: map[string]interface{}{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
},
},
want: "auth/radius/login",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &AuthLoginRadius{
AuthLoginCommon: tt.fields.AuthLoginCommon,
}
if got := l.LoginPath(); got != tt.want {
t.Errorf("LoginPath() = %v, want %v", got, tt.want)
}
})
}
}

func TestAuthLoginRadius_Login(t *testing.T) {
handlerFunc := func(t *testLoginHandler, w http.ResponseWriter, req *http.Request) {
m, err := json.Marshal(
&api.Secret{},
)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

w.WriteHeader(http.StatusOK)
if _, err := w.Write(m); err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}

tests := []authLoginTest{
{
name: "basic",
authLogin: &AuthLoginRadius{
AuthLoginCommon: AuthLoginCommon{
authField: consts.FieldAuthLoginRadius,
params: map[string]interface{}{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
initialized: true,
},
},
handler: &testLoginHandler{
handlerFunc: handlerFunc,
},
expectReqCount: 1,
expectReqPaths: []string{"/v1/auth/radius/login"},
expectReqParams: []map[string]interface{}{
{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
},
want: &api.Secret{},
wantErr: false,
},
{
name: "error-uninitialized",
authLogin: &AuthLoginRadius{
AuthLoginCommon: AuthLoginCommon{
authField: consts.FieldAuthLoginRadius,
params: map[string]interface{}{
consts.FieldUsername: "alice",
consts.FieldPassword: "password1",
},
initialized: false,
},
},
handler: &testLoginHandler{
handlerFunc: handlerFunc,
},
expectReqCount: 0,
want: nil,
wantErr: true,
expectErr: fmt.Errorf("auth login not initialized"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testAuthLogin(t, tt)
})
}
}
7 changes: 7 additions & 0 deletions internal/provider/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type authLoginTest struct {
expectReqParams []map[string]interface{}
expectReqPaths []string
wantErr bool
expectErr error
skipFunc func(t *testing.T)
}

Expand Down Expand Up @@ -88,6 +89,12 @@ func testAuthLogin(t *testing.T, tt authLoginTest) {
return
}

if err != nil && tt.expectErr != nil {
if !reflect.DeepEqual(tt.expectErr, err) {
t.Errorf("Login() expected error %#v, actual %#v", tt.expectErr, err)
}
}

if tt.expectReqCount != tt.handler.requestCount {
t.Errorf("Login() expected %d requests, actual %d", tt.expectReqCount, tt.handler.requestCount)
}
Expand Down
2 changes: 2 additions & 0 deletions vault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ func Provider() *schema.Provider {
f = provider.GetGCPLoginSchema
case consts.FieldAuthLoginKerberos:
f = provider.GetKerberosLoginSchema
case consts.FieldAuthLoginRadius:
f = provider.GetRadiusLoginSchema
default:
continue
}
Expand Down
Loading