Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow master configuration for ssh key crypto system #10072

Merged
merged 17 commits into from
Oct 24, 2024
5 changes: 5 additions & 0 deletions docs/reference/deploy/master-config-reference.rst
Original file line number Diff line number Diff line change
@@ -1524,6 +1524,11 @@ Specifies configuration settings for SSH.

Number of bits to use when generating RSA keys for SSH for tasks. Maximum size is 16384.

``key_type``
============

Specifies the crypto system for SSH. Currently accepts ``RSA``, ``ECDSA`` or ``ED25519``.

``authz``
=========

8 changes: 8 additions & 0 deletions docs/release-notes/ssh-crypto-system.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
:orphan:

**Improvements**

- Master Configuration: Add support for crypto system configuration for ssh connection.
``security.key_type`` now accepts ``RSA``, ``ECDSA`` or ``ED25519``. Default key type is changed
from ``1024-bit RSA`` to ``ED25519``, since ``ED25519`` keys are faster and more secure than the
old default, and ``ED25519`` is also the default key type for ``ssh-keygen``.
11 changes: 0 additions & 11 deletions harness/determined/cli/shell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import contextlib
import functools
import getpass
import os
import pathlib
import platform
@@ -22,9 +21,6 @@

def start_shell(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
data = {}
if args.passphrase:
data["passphrase"] = getpass.getpass("Enter new passphrase: ")
config = ntsc.parse_config(args.config_file, None, args.config, args.volume)
workspace_id = cli.workspace.get_workspace_id_from_args(args)

@@ -35,7 +31,6 @@ def start_shell(args: argparse.Namespace) -> None:
args.template,
context_path=args.context,
includes=args.include,
data=data,
workspace_id=workspace_id,
)
shell = bindings.v1LaunchShellResponse.from_json(resp).shell
@@ -280,12 +275,6 @@ def _open_shell(
help=ntsc.INCLUDE_DESC,
),
cli.Arg("--config", action="append", default=[], help=ntsc.CONFIG_DESC),
cli.Arg(
"-p",
"--passphrase",
action="store_true",
help="passphrase to encrypt the shell private key",
),
cli.Arg(
"--template",
type=str,
16 changes: 1 addition & 15 deletions master/internal/api_shell.go
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@ package internal
import (
"archive/tar"
"context"
"encoding/json"
"fmt"
"strconv"

@@ -253,20 +252,7 @@ func (a *apiServer) LaunchShell(
}
maps.Copy(launchReq.Spec.Base.ExtraEnvVars, oidcPachydermEnvVars)

var passphrase *string
if len(req.Data) > 0 {
var data map[string]interface{}
if err = json.Unmarshal(req.Data, &data); err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse data %s: %s", req.Data, err)
}
if pwd, ok := data["passphrase"]; ok {
if typed, typedOK := pwd.(string); typedOK {
passphrase = &typed
}
}
}

keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHRsaSize, passphrase)
keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHConfig)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
2 changes: 1 addition & 1 deletion master/internal/api_user_intg_test.go
Original file line number Diff line number Diff line change
@@ -118,7 +118,7 @@ func setupAPITest(t *testing.T, pgdb *db.PgDB,
TaskContainerDefaults: model.TaskContainerDefaultsConfig{},
ResourceConfig: *config.DefaultResourceConfig(),
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}},
allRms: map[string]rm.ResourceManager{config.DefaultClusterName: mockRM},
},
}
27 changes: 21 additions & 6 deletions master/internal/config/config.go
Original file line number Diff line number Diff line change
@@ -42,6 +42,15 @@ const (
preemptionScheduler = "preemption"
)

const (
// KeyTypeRSA uses RSA.
KeyTypeRSA = "RSA"
// KeyTypeECDSA uses ECDSA.
KeyTypeECDSA = "ECDSA"
// KeyTypeED25519 uses ED25519.
KeyTypeED25519 = "ED25519"
)

type (
// ExperimentConfigPatch is the updatedble fields for patching an experiment.
ExperimentConfigPatch struct {
@@ -108,7 +117,7 @@ func DefaultConfig() *Config {
Group: "root",
},
SSH: SSHConfig{
RsaKeySize: 1024,
KeyType: KeyTypeED25519,
},
AuthZ: *DefaultAuthZConfig(),
},
@@ -452,7 +461,8 @@ type SecurityConfig struct {

// SSHConfig is the configuration setting for SSH.
type SSHConfig struct {
RsaKeySize int `json:"rsa_key_size"`
RsaKeySize int `json:"rsa_key_size"`
KeyType string `json:"key_type"`
}

// TLSConfig is the configuration for setting up serving over TLS.
@@ -475,10 +485,15 @@ func (t *TLSConfig) Validate() []error {
// Validate implements the check.Validatable interface.
func (t *SSHConfig) Validate() []error {
var errs []error
if t.RsaKeySize < 1 {
errs = append(errs, errors.New("RSA Key size must be greater than 0"))
} else if t.RsaKeySize > 16384 {
errs = append(errs, errors.New("RSA Key size must be less than 16,384"))
if t.KeyType != KeyTypeRSA && t.KeyType != KeyTypeECDSA && t.KeyType != KeyTypeED25519 {
errs = append(errs, errors.New("Crypto system must be one of 'RSA', 'ECDSA' or 'ED25519'"))
}
if t.KeyType == KeyTypeRSA {
if t.RsaKeySize < 1 {
errs = append(errs, errors.New("RSA Key size must be greater than 0"))
} else if t.RsaKeySize > 16384 {
errs = append(errs, errors.New("RSA Key size must be less than 16,384"))
}
}
return errs
}
2 changes: 1 addition & 1 deletion master/internal/core.go
Original file line number Diff line number Diff line change
@@ -1242,7 +1242,7 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
HarnessPath: filepath.Join(m.config.Root, "wheels"),
TaskContainerDefaults: m.config.TaskContainerDefaults,
MasterCert: config.GetCertPEM(cert),
SSHRsaSize: m.config.Security.SSH.RsaKeySize,
SSHConfig: m.config.Security.SSH,
SegmentEnabled: m.config.Telemetry.Enabled && m.config.Telemetry.SegmentMasterKey != "",
SegmentAPIKey: m.config.Telemetry.SegmentMasterKey,
LogRetentionDays: m.config.RetentionPolicy.LogRetentionDays,
2 changes: 1 addition & 1 deletion master/internal/core_intg_test.go
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ func TestRun(t *testing.T) {
DefaultLoggingConfig: &model.DefaultLoggingConfig{},
},
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}},
}
require.NoError(t, m.config.Resolve())
m.config.DB = config.DBConfig{
2 changes: 1 addition & 1 deletion master/internal/experiment.go
Original file line number Diff line number Diff line change
@@ -158,7 +158,7 @@ func newExperiment(

taskSpec.AgentUserGroup = agentUserGroup

generatedKeys, err := ssh.GenerateKey(taskSpec.SSHRsaSize, nil)
generatedKeys, err := ssh.GenerateKey(taskSpec.SSHConfig)
if err != nil {
return nil, nil, errors.Wrap(err, "generating ssh keys for trials")
}
1 change: 0 additions & 1 deletion master/internal/trial_intg_test.go
Original file line number Diff line number Diff line change
@@ -172,7 +172,6 @@ func setup(t *testing.T) (
&model.Checkpoint{},
&tasks.TaskSpec{
AgentUserGroup: &model.AgentUserGroup{},
SSHRsaSize: 1024,
Workspace: model.DefaultWorkspaceName,
},
ssh.PrivateAndPublicKeys{},
89 changes: 74 additions & 15 deletions master/pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
package ssh

import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"

"github.com/pkg/errors"
sshlib "golang.org/x/crypto/ssh"

"github.com/determined-ai/determined/master/internal/config"
)

const (
trialPEMBlockType = "RSA PRIVATE KEY"
rsaPEMBlockType = "RSA PRIVATE KEY"
ecdsaPEMBlockType = "EC PRIVATE KEY"
)

// PrivateAndPublicKeys contains a private and public key.
@@ -21,40 +27,93 @@ type PrivateAndPublicKeys struct {
}

// GenerateKey returns a private and public SSH key.
func GenerateKey(rsaKeySize int, passphrase *string) (PrivateAndPublicKeys, error) {
func GenerateKey(conf config.SSHConfig) (PrivateAndPublicKeys, error) {
var generatedKeys PrivateAndPublicKeys
switch conf.KeyType {
case config.KeyTypeRSA:
return generateRSAKey(conf.RsaKeySize)
case config.KeyTypeECDSA:
return generateECDSAKey()
case config.KeyTypeED25519:
return generateED25519Key()
default:
return generatedKeys, errors.New("Invalid crypto system")
}
}

func generateRSAKey(rsaKeySize int) (PrivateAndPublicKeys, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to generate private key")
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA private key")
}

if err = privateKey.Validate(); err != nil {
return generatedKeys, err
return PrivateAndPublicKeys{}, err
}

block := &pem.Block{
Type: trialPEMBlockType,
Type: rsaPEMBlockType,
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}

if passphrase != nil {
// TODO: Replace usage of deprecated x509.EncryptPEMBlock.
block, err = x509.EncryptPEMBlock( //nolint: staticcheck
rand.Reader, block.Type, block.Bytes, []byte(*passphrase), x509.PEMCipherAES256)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to encrypt private key")
}
publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA public key")
}

return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}

func generateECDSAKey() (PrivateAndPublicKeys, error) {
// Curve size currently not configurable, using the NIST recommendation.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a note: I don't think there were any requests to make the curve size configurable, and P256 is the NIST recommendation, so this is fine. But one day we may have to update it or make configurable.

if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA private key")
}

privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ECDSA private key")
}

block := &pem.Block{
Type: ecdsaPEMBlockType,
Bytes: privateKeyBytes,
}

publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to generate public key")
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA public key")
}

generatedKeys = PrivateAndPublicKeys{
return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}

func generateED25519Key() (PrivateAndPublicKeys, error) {
ed25519PublicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 private key")
}

return generatedKeys, nil
// Before OpenSSH 9.6, for ED25519 keys, only the OpenSSH private key format was supported.
block, err := sshlib.MarshalPrivateKey(privateKey, "")
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ED25519 private key")
}

publicKey, err := sshlib.NewPublicKey(ed25519PublicKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 public key")
}

return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}
39 changes: 39 additions & 0 deletions master/pkg/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ssh

import (
"testing"

"golang.org/x/crypto/ssh"
"gotest.tools/assert"

"github.com/determined-ai/determined/master/internal/config"
)

func verifyKeys(t *testing.T, keys PrivateAndPublicKeys) {
privateKey, err := ssh.ParsePrivateKey(keys.PrivateKey)
assert.NilError(t, err)

publickKey, _, _, _, err := ssh.ParseAuthorizedKey(keys.PublicKey) //nolint:dogsled
assert.NilError(t, err)
assert.Equal(t, string(publickKey.Marshal()), string(privateKey.PublicKey().Marshal()))
}

func TestSSHKeyGenerate(t *testing.T) {
t.Run("generate RSA key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeRSA, RsaKeySize: 512})
assert.NilError(t, err)
verifyKeys(t, keys)
})

t.Run("generate ECDSA key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeECDSA})
assert.NilError(t, err)
verifyKeys(t, keys)
})

t.Run("generate ED25519 key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeED25519})
assert.NilError(t, err)
verifyKeys(t, keys)
})
}
3 changes: 2 additions & 1 deletion master/pkg/tasks/task.go
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ import (
"github.com/docker/docker/api/types/mount"
"github.com/jinzhu/copier"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/pkg/archive"
"github.com/determined-ai/determined/master/pkg/cproto"
"github.com/determined-ai/determined/master/pkg/device"
@@ -70,7 +71,7 @@ type TaskSpec struct {
ClusterID string
HarnessPath string
MasterCert []byte
SSHRsaSize int
SSHConfig config.SSHConfig

SegmentEnabled bool
SegmentAPIKey string
2 changes: 1 addition & 1 deletion proto/pkg/apiv1/shell.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion proto/src/determined/api/v1/shell.proto
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ message LaunchShellRequest {
string template_name = 2;
// The files to run with the command.
repeated determined.util.v1.File files = 3;
// Additional data.
// Deprecated: Do not use.
bytes data = 4;
// Workspace ID. Defaults to 'Uncategorized' workspace if not specified.
int32 workspace_id = 5;
Loading