diff --git a/backend/src/agent/persistence/client/fake_namespace.go b/backend/src/agent/persistence/client/fake_namespace.go new file mode 100644 index 00000000000..bbc8c8e0224 --- /dev/null +++ b/backend/src/agent/persistence/client/fake_namespace.go @@ -0,0 +1,85 @@ +package client + +import ( + "context" + "errors" + "github.com/golang/glog" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + corev1 "k8s.io/client-go/applyconfigurations/core/v1" +) + +type FakeNamespaceClient struct { + namespace string + user string +} + +func (f *FakeNamespaceClient) SetReturnValues(namespace string, user string) { + f.namespace = namespace + f.user = user +} + +func (f FakeNamespaceClient) Get(ctx context.Context, name string, opts metav1.GetOptions) (*v1.Namespace, error) { + if f.namespace == name && len(f.user) != 0 { + ns := v1.Namespace{ObjectMeta: metav1.ObjectMeta{ + Namespace: f.namespace, + Annotations: map[string]string{ + "owner": f.user, + }, + }} + return &ns, nil + } + return nil, errors.New("failed to get namespace") +} + +func (f FakeNamespaceClient) Create(ctx context.Context, namespace *v1.Namespace, opts metav1.CreateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Update(ctx context.Context, namespace *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) UpdateStatus(ctx context.Context, namespace *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Delete(ctx context.Context, name string, opts metav1.DeleteOptions) error { + glog.Error("This fake method is not yet implemented.") + return nil +} + +func (f FakeNamespaceClient) List(ctx context.Context, opts metav1.ListOptions) (*v1.NamespaceList, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Patch(ctx context.Context, name string, pt types.PatchType, data []byte, opts metav1.PatchOptions, subresources ...string) (result *v1.Namespace, err error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Apply(ctx context.Context, namespace *corev1.NamespaceApplyConfiguration, opts metav1.ApplyOptions) (result *v1.Namespace, err error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) ApplyStatus(ctx context.Context, namespace *corev1.NamespaceApplyConfiguration, opts metav1.ApplyOptions) (result *v1.Namespace, err error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Finalize(ctx context.Context, item *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} diff --git a/backend/src/agent/persistence/client/kubernetes_core.go b/backend/src/agent/persistence/client/kubernetes_core.go new file mode 100644 index 00000000000..aaa996e43d3 --- /dev/null +++ b/backend/src/agent/persistence/client/kubernetes_core.go @@ -0,0 +1,87 @@ +package client + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/cenkalti/backoff" + "github.com/golang/glog" + "github.com/pkg/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + v1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/rest" + + "github.com/kubeflow/pipelines/backend/src/common/util" +) + +type KubernetesCoreInterface interface { + NamespaceClient() v1.NamespaceInterface + GetNamespaceOwner(namespace string) (string, error) +} + +type KubernetesCore struct { + coreV1Client v1.CoreV1Interface +} + +func (c *KubernetesCore) NamespaceClient() v1.NamespaceInterface { + return c.coreV1Client.Namespaces() +} + +func (c *KubernetesCore) GetNamespaceOwner(namespace string) (string, error) { + if os.Getenv("MULTIUSER") == "" || os.Getenv("MULTIUSER") == "false" { + return "", nil + } + ns, err := c.NamespaceClient().Get(context.Background(), namespace, metav1.GetOptions{}) + if err != nil { + return "", errors.Wrapf(err, fmt.Sprintf("failed to get namespace '%v'", namespace)) + } + owner, ok := ns.Annotations["owner"] + if !ok { + return "", errors.New(fmt.Sprintf("namespace '%v' has no owner in the annotations", namespace)) + } + return owner, nil +} + +func createKubernetesCore(clientParams util.ClientParameters) (KubernetesCoreInterface, error) { + clientSet, err := getKubernetesClientset(clientParams) + if err != nil { + return nil, err + } + return &KubernetesCore{clientSet.CoreV1()}, nil +} + +// CreateKubernetesCoreOrFatal creates a new client for the Kubernetes pod. +func CreateKubernetesCoreOrFatal(initConnectionTimeout time.Duration, clientParams util.ClientParameters) KubernetesCoreInterface { + var client KubernetesCoreInterface + var err error + var operation = func() error { + client, err = createKubernetesCore(clientParams) + return err + } + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = initConnectionTimeout + err = backoff.Retry(operation, b) + + if err != nil { + glog.Fatalf("Failed to create namespace client. Error: %v", err) + } + return client +} + +func getKubernetesClientset(clientParams util.ClientParameters) (*kubernetes.Clientset, error) { + restConfig, err := rest.InClusterConfig() + if err != nil { + return nil, errors.Wrap(err, "Failed to initialize kubernetes client.") + } + restConfig.QPS = float32(clientParams.QPS) + restConfig.Burst = clientParams.Burst + + clientSet, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return nil, errors.Wrap(err, "Failed to initialize kubernetes client set.") + } + return clientSet, nil +} diff --git a/backend/src/agent/persistence/client/kubernetes_core_fake.go b/backend/src/agent/persistence/client/kubernetes_core_fake.go new file mode 100644 index 00000000000..73fa0e34fef --- /dev/null +++ b/backend/src/agent/persistence/client/kubernetes_core_fake.go @@ -0,0 +1,37 @@ +package client + +import ( + "context" + "errors" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +type KubernetesCoreFake struct { + coreV1ClientFake *FakeNamespaceClient +} + +func (c *KubernetesCoreFake) NamespaceClient() v1.NamespaceInterface { + return c.coreV1ClientFake +} + +func (c *KubernetesCoreFake) GetNamespaceOwner(namespace string) (string, error) { + ns, err := c.NamespaceClient().Get(context.Background(), namespace, metav1.GetOptions{}) + if err != nil { + return "", err + } + owner, ok := ns.Annotations["owner"] + if !ok { + return "", errors.New(fmt.Sprintf("namespace '%v' has no owner in the annotations", namespace)) + } + return owner, nil +} + +func NewKubernetesCoreFake() *KubernetesCoreFake { + return &KubernetesCoreFake{&FakeNamespaceClient{}} +} +func (c *KubernetesCoreFake) Set(namespaceToReturn string, userToReturn string) { + c.coreV1ClientFake.SetReturnValues(namespaceToReturn, userToReturn) +} diff --git a/backend/src/agent/persistence/client/pipeline_client.go b/backend/src/agent/persistence/client/pipeline_client.go index 3884bcd013e..dd41e748d10 100644 --- a/backend/src/agent/persistence/client/pipeline_client.go +++ b/backend/src/agent/persistence/client/pipeline_client.go @@ -17,6 +17,9 @@ package client import ( "context" "fmt" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "google.golang.org/grpc/metadata" + "os" "time" api "github.com/kubeflow/pipelines/backend/api/go_client" @@ -33,8 +36,8 @@ const ( type PipelineClientInterface interface { ReportWorkflow(workflow *util.Workflow) error ReportScheduledWorkflow(swf *util.ScheduledWorkflow) error - ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) - ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) + ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) + ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) } type PipelineClient struct { @@ -139,8 +142,13 @@ func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) er // ReadArtifact reads artifact content from run service. If the artifact is not present, returns // nil response. -func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) +func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) { + pctx := context.Background() + if user != "" { + pctx = metadata.AppendToOutgoingContext(pctx, getKubeflowUserIDHeader(), + getKubeflowUserIDPrefix()+user) + } + ctx, cancel := context.WithTimeout(pctx, time.Minute) defer cancel() response, err := p.runServiceClient.ReadArtifact(ctx, request) @@ -153,8 +161,13 @@ func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.Re } // ReportRunMetrics reports run metrics to run service. -func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) +func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) { + pctx := context.Background() + if user != "" { + pctx = metadata.AppendToOutgoingContext(pctx, getKubeflowUserIDHeader(), + getKubeflowUserIDPrefix()+user) + } + ctx, cancel := context.WithTimeout(pctx, time.Minute) defer cancel() response, err := p.runServiceClient.ReportRunMetrics(ctx, request) @@ -166,3 +179,19 @@ func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) } return response, nil } + +//TODO use config file & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDHeader()" +func getKubeflowUserIDHeader() string { + if value, ok := os.LookupEnv(common.KubeflowUserIDHeader); ok { + return value + } + return common.GoogleIAPUserIdentityHeader +} + +//TODO use of viper & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDPrefix()" +func getKubeflowUserIDPrefix() string { + if value, ok := os.LookupEnv(common.KubeflowUserIDPrefix); ok { + return value + } + return common.GoogleIAPUserIdentityPrefix +} diff --git a/backend/src/agent/persistence/client/pipeline_client_fake.go b/backend/src/agent/persistence/client/pipeline_client_fake.go index 4215478948b..f87f43353d3 100644 --- a/backend/src/agent/persistence/client/pipeline_client_fake.go +++ b/backend/src/agent/persistence/client/pipeline_client_fake.go @@ -57,12 +57,15 @@ func (p *PipelineClientFake) ReportScheduledWorkflow(swf *util.ScheduledWorkflow return nil } -func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) { +func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) { + if p.err != nil { + return nil, p.err + } p.readArtifactRequest = request return p.artifacts[request.String()], nil } -func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) { +func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) { p.reportedMetricsRequest = request return p.reportMetricsResponseStub, p.reportMetricsErrorStub } diff --git a/backend/src/agent/persistence/main.go b/backend/src/agent/persistence/main.go index 3e88065bf72..be6f96f6f38 100644 --- a/backend/src/agent/persistence/main.go +++ b/backend/src/agent/persistence/main.go @@ -63,6 +63,10 @@ const ( clientBurstFlagName = "clientBurst" ) +const ( + DefaultConnectionTimeout = 6 * time.Minute +) + func main() { flag.Parse() @@ -95,6 +99,10 @@ func main() { swfInformerFactory = swfinformers.NewFilteredSharedInformerFactory(swfClient, time.Second*30, namespace, nil) workflowInformerFactory = workflowinformers.NewFilteredSharedInformerFactory(workflowClient, time.Second*30, namespace, nil) } + k8sCoreClient := client.CreateKubernetesCoreOrFatal(DefaultConnectionTimeout, util.ClientParameters{ + QPS: clientQPS, + Burst: clientBurst, + }) pipelineClient, err := client.NewPipelineClient( initializeTimeout, @@ -111,6 +119,7 @@ func main() { swfInformerFactory, workflowInformerFactory, pipelineClient, + k8sCoreClient, util.NewRealTime()) go swfInformerFactory.Start(stopCh) diff --git a/backend/src/agent/persistence/persistence_agent.go b/backend/src/agent/persistence/persistence_agent.go index fdf0e602e24..14332f43202 100644 --- a/backend/src/agent/persistence/persistence_agent.go +++ b/backend/src/agent/persistence/persistence_agent.go @@ -47,6 +47,7 @@ func NewPersistenceAgent( swfInformerFactory swfinformers.SharedInformerFactory, workflowInformerFactory workflowinformers.SharedInformerFactory, pipelineClient *client.PipelineClient, + k8sCoreClient client.KubernetesCoreInterface, time util.TimeInterface) *PersistenceAgent { // obtain references to shared informers swfInformer := swfInformerFactory.Scheduledworkflow().V1beta1().ScheduledWorkflows() @@ -64,7 +65,7 @@ func NewPersistenceAgent( workflowWorker := worker.NewPersistenceWorker(time, workflowregister.WorkflowKind, workflowInformer.Informer(), true, - worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish)) + worker.NewWorkflowSaver(workflowClient, pipelineClient, k8sCoreClient, ttlSecondsAfterWorkflowFinish)) agent := &PersistenceAgent{ swfClient: swfClient, diff --git a/backend/src/agent/persistence/worker/metrics_reporter.go b/backend/src/agent/persistence/worker/metrics_reporter.go index 619ba7b9118..d689dd33dcb 100644 --- a/backend/src/agent/persistence/worker/metrics_reporter.go +++ b/backend/src/agent/persistence/worker/metrics_reporter.go @@ -45,7 +45,7 @@ func NewMetricsReporter(pipelineClient client.PipelineClientInterface) *MetricsR } // ReportMetrics reports workflow metrics to pipeline server. -func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { +func (r MetricsReporter) ReportMetrics(workflow *util.Workflow, user string) error { if workflow.Status.Nodes == nil { return nil } @@ -57,7 +57,7 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { runMetrics := []*api.RunMetric{} partialFailures := []error{} for _, nodeStatus := range workflow.Status.Nodes { - nodeMetrics, err := r.collectNodeMetricsOrNil(runID, nodeStatus) + nodeMetrics, err := r.collectNodeMetricsOrNil(runID, nodeStatus, user) if err != nil { partialFailures = append(partialFailures, err) continue @@ -79,7 +79,7 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { reportMetricsResponse, err := r.pipelineClient.ReportRunMetrics(&api.ReportRunMetricsRequest{ RunId: runID, Metrics: runMetrics, - }) + }, user) if err != nil { return err } @@ -89,12 +89,12 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { } func (r MetricsReporter) collectNodeMetricsOrNil( - runID string, nodeStatus workflowapi.NodeStatus) ( + runID string, nodeStatus workflowapi.NodeStatus, user string) ( []*api.RunMetric, error) { if !nodeStatus.Completed() { return nil, nil } - metricsJSON, err := r.readNodeMetricsJSONOrEmpty(runID, nodeStatus) + metricsJSON, err := r.readNodeMetricsJSONOrEmpty(runID, nodeStatus, user) if err != nil || metricsJSON == "" { return nil, err } @@ -126,7 +126,7 @@ func (r MetricsReporter) collectNodeMetricsOrNil( return reportMetricsRequest.GetMetrics(), nil } -func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus workflowapi.NodeStatus) (string, error) { +func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus workflowapi.NodeStatus, user string) (string, error) { if nodeStatus.Outputs == nil || nodeStatus.Outputs.Artifacts == nil { return "", nil // No output artifacts, skip the reporting } @@ -146,7 +146,7 @@ func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus wor NodeId: nodeStatus.ID, ArtifactName: metricsArtifactName, } - artifactResponse, err := r.pipelineClient.ReadArtifact(artifactRequest) + artifactResponse, err := r.pipelineClient.ReadArtifact(artifactRequest, user) if err != nil { return "", err } diff --git a/backend/src/agent/persistence/worker/metrics_reporter_test.go b/backend/src/agent/persistence/worker/metrics_reporter_test.go index 35a0db5b9f4..c1e117c8ec7 100644 --- a/backend/src/agent/persistence/worker/metrics_reporter_test.go +++ b/backend/src/agent/persistence/worker/metrics_reporter_test.go @@ -16,6 +16,7 @@ package worker import ( "encoding/json" + "errors" "fmt" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -31,6 +32,11 @@ import ( "k8s.io/apimachinery/pkg/types" ) +const ( + NamespaceName = "kf-namespace" + USER = "test-user@example.com" +) + func TestReportMetrics_NoCompletedNode_NoOP(t *testing.T) { pipelineFake := client.NewPipelineClientFake() @@ -51,7 +57,7 @@ func TestReportMetrics_NoCompletedNode_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) } @@ -76,7 +82,7 @@ func TestReportMetrics_NoRunID_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReadArtifactRequest()) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) @@ -103,7 +109,7 @@ func TestReportMetrics_NoArtifact_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReadArtifactRequest()) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) @@ -133,7 +139,7 @@ func TestReportMetrics_NoMetricsArtifact_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReadArtifactRequest()) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) @@ -176,9 +182,9 @@ func TestReportMetrics_Succeed(t *testing.T) { Results: []*api.ReportRunMetricsResponse_ReportRunMetricResult{}, }, nil) - err := reporter.ReportMetrics(workflow) + err1 := reporter.ReportMetrics(workflow, USER) - assert.Nil(t, err) + assert.Nil(t, err1) expectedMetricsRequest := &api.ReportRunMetricsRequest{ RunId: "run-1", Metrics: []*api.RunMetric{ @@ -197,7 +203,7 @@ func TestReportMetrics_Succeed(t *testing.T) { got := pipelineFake.GetReportedMetricsRequest() if diff := cmp.Diff(expectedMetricsRequest, got, cmpopts.EquateEmpty(), protocmp.Transform()); diff != "" { t.Errorf("parseRuntimeInfo() = %+v, want %+v\nDiff (-want, +got)\n%s", got, expectedMetricsRequest, diff) - s, _ := json.MarshalIndent(expectedMetricsRequest ,"", " ") + s, _ := json.MarshalIndent(expectedMetricsRequest, "", " ") fmt.Printf("Want %s", s) } } @@ -235,7 +241,7 @@ func TestReportMetrics_EmptyArchive_Fail(t *testing.T) { Data: []byte(artifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -278,7 +284,7 @@ func TestReportMetrics_MultipleFilesInArchive_Fail(t *testing.T) { Data: []byte(artifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -320,7 +326,7 @@ func TestReportMetrics_InvalidMetricsJSON_Fail(t *testing.T) { Data: []byte(artifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -381,7 +387,7 @@ func TestReportMetrics_InvalidMetricsJSON_PartialFail(t *testing.T) { Data: []byte(validArtifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) // Partial failure is reported while valid metrics are reported. assert.NotNil(t, err) @@ -404,7 +410,7 @@ func TestReportMetrics_InvalidMetricsJSON_PartialFail(t *testing.T) { got := pipelineFake.GetReportedMetricsRequest() if diff := cmp.Diff(expectedMetricsRequest, got, cmpopts.EquateEmpty(), protocmp.Transform()); diff != "" { t.Errorf("parseRuntimeInfo() = %+v, want %+v\nDiff (-want, +got)\n%s", got, expectedMetricsRequest, diff) - s, _ := json.MarshalIndent(expectedMetricsRequest ,"", " ") + s, _ := json.MarshalIndent(expectedMetricsRequest, "", " ") fmt.Printf("Want %s", s) } } @@ -441,7 +447,7 @@ func TestReportMetrics_CorruptedArchiveFile_Fail(t *testing.T) { Data: []byte("invalid tgz content"), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -505,8 +511,54 @@ func TestReportMetrics_MultiplMetricErrors_TransientErrowWin(t *testing.T) { }, }, nil) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT)) } + +func TestReportMetrics_Unauthorized(t *testing.T) { + pipelineFake := client.NewPipelineClientFake() + reporter := NewMetricsReporter(pipelineFake) + k8sFake := client.NewKubernetesCoreFake() + k8sFake.Set(NamespaceName, USER) + + workflow := util.NewWorkflow(&workflowapi.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "MY_NAMESPACE", + Name: "MY_NAME", + UID: types.UID("run-1"), + Labels: map[string]string{util.LabelKeyWorkflowRunId: "run-1"}, + }, + Status: workflowapi.WorkflowStatus{ + Nodes: map[string]workflowapi.NodeStatus{ + "node-1": workflowapi.NodeStatus{ + ID: "node-1", + Phase: workflowapi.NodeSucceeded, + Outputs: &workflowapi.Outputs{ + Artifacts: []workflowapi.Artifact{{Name: "mlpipeline-metrics"}}, + }, + }, + }, + }, + }) + metricsJSON := `{"metrics": [{"name": "accuracy", "numberValue": 0.77}, {"name": "logloss", "numberValue": 1.2}]}` + artifactData, _ := util.ArchiveTgz(map[string]string{"file": metricsJSON}) + pipelineFake.StubArtifact( + &api.ReadArtifactRequest{ + RunId: "run-1", + NodeId: "node-1", + ArtifactName: "mlpipeline-metrics", + }, + &api.ReadArtifactResponse{ + Data: []byte(artifactData), + }) + pipelineFake.StubReportRunMetrics(&api.ReportRunMetricsResponse{ + Results: []*api.ReportRunMetricsResponse_ReportRunMetricResult{}, + }, errors.New("failed to read artifacts")) + + err1 := reporter.ReportMetrics(workflow, USER) + + assert.NotNil(t, err1) + assert.Contains(t, err1.Error(), "failed to read artifacts") +} diff --git a/backend/src/agent/persistence/worker/persistence_worker_test.go b/backend/src/agent/persistence/worker/persistence_worker_test.go index bde3ef7e4e6..e29226d1407 100644 --- a/backend/src/agent/persistence/worker/persistence_worker_test.go +++ b/backend/src/agent/persistence/worker/persistence_worker_test.go @@ -53,9 +53,11 @@ func TestPersistenceWorker_Success(t *testing.T) { // Set up pipeline client pipelineClient := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -81,11 +83,12 @@ func TestPersistenceWorker_NotFoundError(t *testing.T) { }) workflowClient := client.NewWorkflowClientFake() - // Set up pipeline client + // Set up pipeline client and kubernetes client pipelineClient := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -112,11 +115,12 @@ func TestPersistenceWorker_GetWorklowError(t *testing.T) { workflowClient := client.NewWorkflowClientFake() workflowClient.Put("MY_NAMESPACE", "MY_NAME", nil) - // Set up pipeline client + // Set up pipeline client and kubernetes client pipelineClient := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -148,9 +152,12 @@ func TestPersistenceWorker_ReportWorkflowRetryableError(t *testing.T) { pipelineClient := client.NewPipelineClientFake() pipelineClient.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_TRANSIENT, "My Retriable Error")) + //Set up kubernetes client + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -181,9 +188,11 @@ func TestPersistenceWorker_ReportWorkflowNonRetryableError(t *testing.T) { pipelineClient := client.NewPipelineClientFake() pipelineClient.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, "My Permanent Error")) + // Set up kubernetes client + k8sClient := client.NewKubernetesCoreFake() // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), diff --git a/backend/src/agent/persistence/worker/workflow_saver.go b/backend/src/agent/persistence/worker/workflow_saver.go index 9635d020194..b79fbeae092 100644 --- a/backend/src/agent/persistence/worker/workflow_saver.go +++ b/backend/src/agent/persistence/worker/workflow_saver.go @@ -26,15 +26,17 @@ import ( type WorkflowSaver struct { client client.WorkflowClientInterface pipelineClient client.PipelineClientInterface + k8sClient client.KubernetesCoreInterface metricsReporter *MetricsReporter ttlSecondsAfterWorkflowFinish int64 } func NewWorkflowSaver(client client.WorkflowClientInterface, - pipelineClient client.PipelineClientInterface, ttlSecondsAfterWorkflowFinish int64) *WorkflowSaver { + pipelineClient client.PipelineClientInterface, k8sClient client.KubernetesCoreInterface, ttlSecondsAfterWorkflowFinish int64) *WorkflowSaver { return &WorkflowSaver{ client: client, pipelineClient: pipelineClient, + k8sClient: k8sClient, metricsReporter: NewMetricsReporter(pipelineClient), ttlSecondsAfterWorkflowFinish: ttlSecondsAfterWorkflowFinish, } @@ -66,6 +68,12 @@ func (s *WorkflowSaver) Save(key string, namespace string, name string, nowEpoch log.Infof("Skip syncing Workflow (%v): workflow marked as persisted.", name) return nil } + + user, err1 := s.k8sClient.GetNamespaceOwner(namespace) + if err1 != nil { + return util.Wrapf(err1, "Failed get '%v' namespace", namespace) + } + // Save this Workflow to the database. err = s.pipelineClient.ReportWorkflow(wf) retry := util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT) @@ -85,5 +93,5 @@ func (s *WorkflowSaver) Save(key string, namespace string, name string, nowEpoch log.WithFields(log.Fields{ "Workflow": name, }).Infof("Syncing Workflow (%v): success, processing complete.", name) - return s.metricsReporter.ReportMetrics(wf) + return s.metricsReporter.ReportMetrics(wf, user) } diff --git a/backend/src/agent/persistence/worker/workflow_saver_test.go b/backend/src/agent/persistence/worker/workflow_saver_test.go index 358f36600c5..10a16b7ccda 100644 --- a/backend/src/agent/persistence/worker/workflow_saver_test.go +++ b/backend/src/agent/persistence/worker/workflow_saver_test.go @@ -30,6 +30,8 @@ import ( func TestWorkflow_Save_Success(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) workflow := util.NewWorkflow(&workflowapi.Workflow{ ObjectMeta: metav1.ObjectMeta{ @@ -41,7 +43,7 @@ func TestWorkflow_Save_Success(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -52,8 +54,10 @@ func TestWorkflow_Save_Success(t *testing.T) { func TestWorkflow_Save_NotFoundDuringGet(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -65,10 +69,12 @@ func TestWorkflow_Save_NotFoundDuringGet(t *testing.T) { func TestWorkflow_Save_ErrorDuringGet(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) workflowFake.Put("MY_NAMESPACE", "MY_NAME", nil) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -80,6 +86,8 @@ func TestWorkflow_Save_ErrorDuringGet(t *testing.T) { func TestWorkflow_Save_PermanentFailureWhileReporting(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, "My Permanent Error")) @@ -94,7 +102,7 @@ func TestWorkflow_Save_PermanentFailureWhileReporting(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -106,6 +114,8 @@ func TestWorkflow_Save_PermanentFailureWhileReporting(t *testing.T) { func TestWorkflow_Save_TransientFailureWhileReporting(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_TRANSIENT, "My Transient Error")) @@ -120,7 +130,7 @@ func TestWorkflow_Save_TransientFailureWhileReporting(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -132,6 +142,7 @@ func TestWorkflow_Save_TransientFailureWhileReporting(t *testing.T) { func TestWorkflow_Save_SkippedDueToFinalStatue(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Add this will result in failure unless reporting is skipped pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, @@ -150,7 +161,7 @@ func TestWorkflow_Save_SkippedDueToFinalStatue(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -161,6 +172,8 @@ func TestWorkflow_Save_SkippedDueToFinalStatue(t *testing.T) { func TestWorkflow_Save_FinalStatueNotSkippedDueToExceedTTL(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) // Add this will result in failure unless reporting is skipped pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, @@ -182,7 +195,7 @@ func TestWorkflow_Save_FinalStatueNotSkippedDueToExceedTTL(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 1) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 1) // Sleep 2 seconds to make sure workflow passed TTL time.Sleep(2 * time.Second) @@ -197,6 +210,7 @@ func TestWorkflow_Save_FinalStatueNotSkippedDueToExceedTTL(t *testing.T) { func TestWorkflow_Save_SkippedDDueToMissingRunID(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Add this will result in failure unless reporting is skipped pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, @@ -211,10 +225,33 @@ func TestWorkflow_Save_SkippedDDueToMissingRunID(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) assert.Equal(t, false, util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT)) assert.Equal(t, nil, err) } + +func TestWorkflow_Save_FailedToGetUser(t *testing.T) { + workflowFake := client.NewWorkflowClientFake() + pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("ORIGINAL_NAMESPACE", USER) + + workflow := util.NewWorkflow(&workflowapi.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "MY_NAMESPACE", + Name: "MY_NAME", + Labels: map[string]string{util.LabelKeyWorkflowRunId: "MY_UUID"}, + }, + }) + + workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) + + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) + + err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("Failed get '%v' namespace", "MY_NAMESPACE")) +} diff --git a/backend/src/apiserver/common/const.go b/backend/src/apiserver/common/const.go index 5af33ae3a01..bc1250b2ce2 100644 --- a/backend/src/apiserver/common/const.go +++ b/backend/src/apiserver/common/const.go @@ -41,17 +41,19 @@ const ( RbacResourceTypeViewers = "viewers" RbacResourceTypeVisualizations = "visualizations" - RbacResourceVerbArchive = "archive" - RbacResourceVerbUpdate = "update" - RbacResourceVerbCreate = "create" - RbacResourceVerbDelete = "delete" - RbacResourceVerbDisable = "disable" - RbacResourceVerbEnable = "enable" - RbacResourceVerbGet = "get" - RbacResourceVerbList = "list" - RbacResourceVerbRetry = "retry" - RbacResourceVerbTerminate = "terminate" - RbacResourceVerbUnarchive = "unarchive" + RbacResourceVerbArchive = "archive" + RbacResourceVerbUpdate = "update" + RbacResourceVerbCreate = "create" + RbacResourceVerbDelete = "delete" + RbacResourceVerbDisable = "disable" + RbacResourceVerbEnable = "enable" + RbacResourceVerbGet = "get" + RbacResourceVerbList = "list" + RbacResourceVerbRetry = "retry" + RbacResourceVerbTerminate = "terminate" + RbacResourceVerbUnarchive = "unarchive" + RbacResourceVerbReportMetrics = "reportMetrics" + RbacResourceVerbReadArtifact = "readArtifact" ) const ( diff --git a/backend/src/apiserver/server/run_server.go b/backend/src/apiserver/server/run_server.go index 73db791da9e..da6ac916939 100644 --- a/backend/src/apiserver/server/run_server.go +++ b/backend/src/apiserver/server/run_server.go @@ -280,8 +280,13 @@ func (s *RunServer) ReportRunMetrics(ctx context.Context, request *api.ReportRun reportRunMetricsRequests.Inc() } + err := s.canAccessRun(ctx, request.RunId, &authorizationv1.ResourceAttributes{Verb: common.RbacResourceVerbReportMetrics}) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request") + } + // Makes sure run exists - _, err := s.resourceManager.GetRun(request.GetRunId()) + _, err = s.resourceManager.GetRun(request.GetRunId()) if err != nil { return nil, err } @@ -305,6 +310,11 @@ func (s *RunServer) ReadArtifact(ctx context.Context, request *api.ReadArtifactR readArtifactRequests.Inc() } + err := s.canAccessRun(ctx, request.RunId, &authorizationv1.ResourceAttributes{Verb: common.RbacResourceVerbReadArtifact}) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request") + } + content, err := s.resourceManager.ReadArtifact( request.GetRunId(), request.GetNodeId(), request.GetArtifactName()) if err != nil { @@ -365,7 +375,7 @@ func (s *RunServer) canAccessRun(ctx context.Context, runId string, resourceAttr if len(runId) > 0 { runDetail, err := s.resourceManager.GetRun(runId) if err != nil { - return util.Wrap(err, "Failed to authorize with the experiment ID.") + return util.Wrap(err, "Failed to authorize with the run ID.") } if len(resourceAttributes.Namespace) == 0 { if len(runDetail.Namespace) == 0 { diff --git a/backend/src/apiserver/server/run_server_test.go b/backend/src/apiserver/server/run_server_test.go index 61e127f3314..6c1251d078a 100644 --- a/backend/src/apiserver/server/run_server_test.go +++ b/backend/src/apiserver/server/run_server_test.go @@ -4,6 +4,7 @@ import ( "context" "strings" "testing" + "time" "google.golang.org/protobuf/testing/protocmp" @@ -13,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" api "github.com/kubeflow/pipelines/backend/api/go_client" kfpauth "github.com/kubeflow/pipelines/backend/src/apiserver/auth" + "github.com/kubeflow/pipelines/backend/src/apiserver/client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -21,8 +23,19 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" authorizationv1 "k8s.io/api/authorization/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" ) +var metric = &api.RunMetric{ + Name: "metric-1", + NodeId: "node-1", + Value: &api.RunMetric_NumberValue{ + NumberValue: 0.88, + }, + Format: api.RunMetric_RAW, +} + func TestCreateRun(t *testing.T) { clients, manager, experiment := initWithExperiment(t) defer clients.Close() @@ -611,23 +624,16 @@ func TestReportRunMetrics_RunNotFound(t *testing.T) { } func TestReportRunMetrics_Succeed(t *testing.T) { - httpServer := getMockServer(t) - // Close the server when test finishes - defer httpServer.Close() + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + "user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) clientManager, resourceManager, runDetails := initWithOneTimeRun(t) defer clientManager.Close() runServer := RunServer{resourceManager: resourceManager, options: &RunServerOptions{CollectMetrics: false}} - metric := &api.RunMetric{ - Name: "metric-1", - NodeId: "node-1", - Value: &api.RunMetric_NumberValue{ - NumberValue: 0.88, - }, - Format: api.RunMetric_RAW, - } - response, err := runServer.ReportRunMetrics(context.Background(), &api.ReportRunMetricsRequest{ + response, err := runServer.ReportRunMetrics(ctx, &api.ReportRunMetricsRequest{ RunId: runDetails.UUID, Metrics: []*api.RunMetric{metric}, }) @@ -643,13 +649,47 @@ func TestReportRunMetrics_Succeed(t *testing.T) { } assert.Equal(t, expectedResponse, response) - run, err := runServer.GetRun(context.Background(), &api.GetRunRequest{ + run, err := runServer.GetRun(ctx, &api.GetRunRequest{ RunId: runDetails.UUID, }) assert.Nil(t, err) assert.Equal(t, []*api.RunMetric{metric}, run.GetRun().GetMetrics()) } +func TestReportRunMetrics_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + userIdentity := "user@google.com" + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + userIdentity}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clientManager, resourceManager, runDetails := initWithOneTimeRun(t) + defer clientManager.Close() + clientManager.SubjectAccessReviewClientFake = client.NewFakeSubjectAccessReviewClientUnauthorized() + resourceManager = resource.NewResourceManager(clientManager) + runServer := RunServer{resourceManager: resourceManager, options: &RunServerOptions{CollectMetrics: false}} + + _, err := runServer.ReportRunMetrics(ctx, &api.ReportRunMetricsRequest{ + RunId: runDetails.UUID, + Metrics: []*api.RunMetric{metric}, + }) + assert.NotNil(t, err) + resourceAttributes := &authorizationv1.ResourceAttributes{ + Namespace: runDetails.Namespace, + Verb: common.RbacResourceVerbReportMetrics, + Group: common.RbacPipelinesGroup, + Version: common.RbacPipelinesVersion, + Resource: common.RbacResourceTypeRuns, + Name: runDetails.Name, + } + assert.EqualError( + t, + err, + wrapFailedAuthzRequestError(wrapFailedAuthzApiResourcesError(getPermissionDeniedError(userIdentity, resourceAttributes))).Error(), + ) + +} + func TestReportRunMetrics_PartialFailures(t *testing.T) { httpServer := getMockServer(t) // Close the server when test finishes @@ -659,14 +699,7 @@ func TestReportRunMetrics_PartialFailures(t *testing.T) { defer clientManager.Close() runServer := RunServer{resourceManager: resourceManager, options: &RunServerOptions{CollectMetrics: false}} - validMetric := &api.RunMetric{ - Name: "metric-1", - NodeId: "node-1", - Value: &api.RunMetric_NumberValue{ - NumberValue: 0.88, - }, - Format: api.RunMetric_RAW, - } + validMetric := metric invalidNameMetric := &api.RunMetric{ Name: "$metric-1", NodeId: "node-1", @@ -827,3 +860,159 @@ func TestCanAccessRun_Unauthenticated(t *testing.T) { wrapFailedAuthzApiResourcesError(kfpauth.IdentityHeaderMissingError).Error(), ) } + +func TestReadArtifacts_Succeed(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + "user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + expectedContent := "test" + filePath := "test/file.txt" + resourceManager, manager, run := initWithOneTimeRun(t) + resourceManager.ObjectStore().AddFile([]byte(expectedContent), filePath) + workflow := util.NewWorkflow(&v1alpha1.Workflow{ + TypeMeta: v1.TypeMeta{ + APIVersion: "argoproj.io/v1alpha1", + Kind: "Workflow", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "workflow-name", + Namespace: "ns1", + UID: "workflow1", + Labels: map[string]string{util.LabelKeyWorkflowRunId: run.UUID}, + CreationTimestamp: v1.NewTime(time.Unix(11, 0).UTC()), + OwnerReferences: []v1.OwnerReference{{ + APIVersion: "kubeflow.org/v1beta1", + Kind: "Workflow", + Name: "workflow-name", + UID: types.UID(run.UUID), + }}, + }, + Status: v1alpha1.WorkflowStatus{ + Nodes: map[string]v1alpha1.NodeStatus{ + "node-1": { + Outputs: &v1alpha1.Outputs{ + Artifacts: []v1alpha1.Artifact{ + { + Name: "artifact-1", + ArtifactLocation: v1alpha1.ArtifactLocation{ + S3: &v1alpha1.S3Artifact{ + Key: filePath, + }, + }, + }, + }, + }, + }, + }, + }, + }) + err := manager.ReportWorkflowResource(context.Background(), workflow) + assert.Nil(t, err) + + runServer := RunServer{resourceManager: manager, options: &RunServerOptions{CollectMetrics: false}} + artifact := &api.ReadArtifactRequest{ + RunId: run.UUID, + NodeId: "node-1", + ArtifactName: "artifact-1", + } + response, err := runServer.ReadArtifact(ctx, artifact) + assert.Nil(t, err) + + expectedResponse := &api.ReadArtifactResponse{ + Data: []byte(expectedContent), + } + assert.Equal(t, expectedResponse, response) +} + +func TestReadArtifacts_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + userIdentity := "user@google.com" + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + userIdentity}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clientManager, resourceManager, run := initWithOneTimeRun(t) + + //make the following request unauthorized + clientManager.SubjectAccessReviewClientFake = client.NewFakeSubjectAccessReviewClientUnauthorized() + resourceManager = resource.NewResourceManager(clientManager) + + runServer := RunServer{resourceManager: resourceManager, options: &RunServerOptions{CollectMetrics: false}} + artifact := &api.ReadArtifactRequest{ + RunId: run.UUID, + NodeId: "node-1", + ArtifactName: "artifact-1", + } + _, err := runServer.ReadArtifact(ctx, artifact) + assert.NotNil(t, err) + + resourceAttributes := &authorizationv1.ResourceAttributes{ + Namespace: run.Namespace, + Verb: common.RbacResourceVerbReadArtifact, + Group: common.RbacPipelinesGroup, + Version: common.RbacPipelinesVersion, + Resource: common.RbacResourceTypeRuns, + Name: run.Name, + } + assert.EqualError( + t, + err, + wrapFailedAuthzRequestError(wrapFailedAuthzApiResourcesError(getPermissionDeniedError(userIdentity, resourceAttributes))).Error(), + ) +} + +func TestReadArtifacts_Run_NotFound(t *testing.T) { + clientManager := resource.NewFakeClientManagerOrFatal(util.NewFakeTimeForEpoch()) + manager := resource.NewResourceManager(clientManager) + runServer := RunServer{resourceManager: manager, options: &RunServerOptions{CollectMetrics: false}} + artifact := &api.ReadArtifactRequest{ + RunId: "Wrong_RUN_UUID", + NodeId: "node-1", + ArtifactName: "artifact-1", + } + _, err := runServer.ReadArtifact(context.Background(), artifact) + assert.NotNil(t, err) + err = err.(*util.UserError) + + assert.True(t, util.IsUserErrorCodeMatch(err, codes.NotFound)) +} + +func TestReadArtifacts_Resource_NotFound(t *testing.T) { + _, manager, run := initWithOneTimeRun(t) + + workflow := util.NewWorkflow(&v1alpha1.Workflow{ + TypeMeta: v1.TypeMeta{ + APIVersion: "argoproj.io/v1alpha1", + Kind: "Workflow", + }, + ObjectMeta: v1.ObjectMeta{ + Name: "workflow-name", + Namespace: "ns1", + UID: "workflow1", + Labels: map[string]string{util.LabelKeyWorkflowRunId: run.UUID}, + CreationTimestamp: v1.NewTime(time.Unix(11, 0).UTC()), + OwnerReferences: []v1.OwnerReference{{ + APIVersion: "kubeflow.org/v1beta1", + Kind: "Workflow", + Name: "workflow-name", + UID: types.UID(run.UUID), + }}, + }, + }) + err := manager.ReportWorkflowResource(context.Background(), workflow) + assert.Nil(t, err) + + runServer := RunServer{resourceManager: manager, options: &RunServerOptions{CollectMetrics: false}} + //`artifactRequest` search for node that does not exist + artifactRequest := &api.ReadArtifactRequest{ + RunId: run.UUID, + NodeId: "node-1", + ArtifactName: "artifact-1", + } + _, err = runServer.ReadArtifact(context.Background(), artifactRequest) + assert.NotNil(t, err) + assert.True(t, util.IsUserErrorCodeMatch(err, codes.NotFound)) +} diff --git a/backend/src/apiserver/server/test_util.go b/backend/src/apiserver/server/test_util.go index d8c22b73283..13f63371d71 100644 --- a/backend/src/apiserver/server/test_util.go +++ b/backend/src/apiserver/server/test_util.go @@ -29,6 +29,7 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" authorizationv1 "k8s.io/api/authorization/v1" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -252,6 +253,12 @@ func initWithExperimentsAndTwoPipelineVersions(t *testing.T) *resource.FakeClien func initWithOneTimeRun(t *testing.T) (*resource.FakeClientManager, *resource.ResourceManager, *model.RunDetail) { clientManager, manager, exp := initWithExperiment(t) + + ctx := context.Background() + if common.IsMultiUserMode() { + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + "user@google.com"}) + ctx = metadata.NewIncomingContext(context.Background(), md) + } apiRun := &api.Run{ Name: "run1", PipelineSpec: &api.PipelineSpec{ @@ -267,7 +274,7 @@ func initWithOneTimeRun(t *testing.T) (*resource.FakeClientManager, *resource.Re }, }, } - runDetail, err := manager.CreateRun(context.Background(), apiRun) + runDetail, err := manager.CreateRun(ctx, apiRun) assert.Nil(t, err) return clientManager, manager, runDetail } diff --git a/backend/third_party_licenses/persistence_agent.csv b/backend/third_party_licenses/persistence_agent.csv index ecc3f2824a6..05301d85742 100644 --- a/backend/third_party_licenses/persistence_agent.csv +++ b/backend/third_party_licenses/persistence_agent.csv @@ -12,6 +12,7 @@ github.com/cenkalti/backoff,https://github.com/cenkalti/backoff/blob/v2.2.1/LICE github.com/davecgh/go-spew/spew,https://github.com/davecgh/go-spew/blob/v1.1.1/LICENSE,ISC github.com/doublerebel/bellows,https://github.com/doublerebel/bellows/blob/f177d92a03d3/LICENSE,MIT github.com/emicklei/go-restful,https://github.com/emicklei/go-restful/blob/v2.15.0/LICENSE,MIT +github.com/fsnotify/fsnotify,https://github.com/fsnotify/fsnotify/blob/v1.5.1/LICENSE,BSD-3-Clause github.com/ghodss/yaml,https://github.com/ghodss/yaml/blob/25d852aebe32/LICENSE,MIT github.com/go-logr/logr,https://github.com/go-logr/logr/blob/v1.2.2/LICENSE,Apache-2.0 github.com/go-openapi/errors,https://github.com/go-openapi/errors/blob/v0.20.2/LICENSE,Apache-2.0 @@ -30,12 +31,14 @@ github.com/google/uuid,https://github.com/google/uuid/blob/v1.3.0/LICENSE,BSD-3- github.com/googleapis/gnostic,https://github.com/googleapis/gnostic/blob/v0.5.5/LICENSE,Apache-2.0 github.com/gorilla/websocket,https://github.com/gorilla/websocket/blob/v1.5.0/LICENSE,BSD-2-Clause github.com/grpc-ecosystem/grpc-gateway,https://github.com/grpc-ecosystem/grpc-gateway/blob/v1.16.0/LICENSE.txt,BSD-3-Clause +github.com/hashicorp/hcl,https://github.com/hashicorp/hcl/blob/v1.0.0/LICENSE,MPL-2.0 github.com/huandu/xstrings,https://github.com/huandu/xstrings/blob/v1.3.2/LICENSE,MIT github.com/imdario/mergo,https://github.com/imdario/mergo/blob/v0.3.12/LICENSE,BSD-3-Clause github.com/josharian/intern,https://github.com/josharian/intern/blob/v1.0.0/license.md,MIT github.com/json-iterator/go,https://github.com/json-iterator/go/blob/v1.1.12/LICENSE,MIT github.com/kubeflow/pipelines/backend,https://github.com/kubeflow/pipelines/blob/HEAD/LICENSE,Apache-2.0 github.com/lestrrat-go/strftime,https://github.com/lestrrat-go/strftime/blob/v1.0.4/LICENSE,MIT +github.com/magiconair/properties,https://github.com/magiconair/properties/blob/v1.8.5/LICENSE.md,BSD-2-Clause github.com/mailru/easyjson,https://github.com/mailru/easyjson/blob/v0.7.7/LICENSE,MIT github.com/mitchellh/copystructure,https://github.com/mitchellh/copystructure/blob/v1.2.0/LICENSE,MIT github.com/mitchellh/mapstructure,https://github.com/mitchellh/mapstructure/blob/v1.4.3/LICENSE,MIT @@ -45,11 +48,16 @@ github.com/modern-go/concurrent,https://github.com/modern-go/concurrent/blob/bac github.com/modern-go/reflect2,https://github.com/modern-go/reflect2/blob/v1.0.2/LICENSE,Apache-2.0 github.com/oklog/ulid,https://github.com/oklog/ulid/blob/v1.3.1/LICENSE,Apache-2.0 github.com/oliveagle/jsonpath,https://github.com/oliveagle/jsonpath/blob/2e52cf6e6852/LICENSE,MIT +github.com/pelletier/go-toml,https://github.com/pelletier/go-toml/blob/v1.9.4/LICENSE,Apache-2.0 github.com/pkg/errors,https://github.com/pkg/errors/blob/v0.9.1/LICENSE,BSD-2-Clause github.com/shopspring/decimal,https://github.com/shopspring/decimal/blob/v1.2.0/LICENSE,MIT github.com/sirupsen/logrus,https://github.com/sirupsen/logrus/blob/v1.8.1/LICENSE,MIT +github.com/spf13/afero,https://github.com/spf13/afero/blob/v1.8.0/LICENSE.txt,Apache-2.0 github.com/spf13/cast,https://github.com/spf13/cast/blob/v1.4.1/LICENSE,MIT +github.com/spf13/jwalterweatherman,https://github.com/spf13/jwalterweatherman/blob/v1.1.0/LICENSE,MIT github.com/spf13/pflag,https://github.com/spf13/pflag/blob/v1.0.5/LICENSE,BSD-3-Clause +github.com/spf13/viper,https://github.com/spf13/viper/blob/v1.10.1/LICENSE,MIT +github.com/subosito/gotenv,https://github.com/subosito/gotenv/blob/v1.2.0/LICENSE,MIT github.com/valyala/bytebufferpool,https://github.com/valyala/bytebufferpool/blob/v1.0.0/LICENSE,MIT github.com/valyala/fasttemplate,https://github.com/valyala/fasttemplate/blob/v1.2.1/LICENSE,MIT go.mongodb.org/mongo-driver,https://github.com/mongodb/mongo-go-driver/blob/v1.8.2/LICENSE,Apache-2.0 @@ -64,6 +72,7 @@ google.golang.org/genproto,https://github.com/googleapis/go-genproto/blob/197313 google.golang.org/grpc,https://github.com/grpc/grpc-go/blob/v1.44.0/LICENSE,Apache-2.0 google.golang.org/protobuf,https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/LICENSE,BSD-3-Clause gopkg.in/inf.v0,https://github.com/go-inf/inf/blob/v0.9.1/LICENSE,BSD-3-Clause +gopkg.in/ini.v1,https://github.com/go-ini/ini/blob/v1.66.3/LICENSE,Apache-2.0 gopkg.in/yaml.v2,https://github.com/go-yaml/yaml/blob/v2.4.0/LICENSE,Apache-2.0 gopkg.in/yaml.v3,https://github.com/go-yaml/yaml/blob/496545a6307b/LICENSE,MIT k8s.io/api,https://github.com/kubernetes/api/blob/v0.23.3/LICENSE,Apache-2.0 diff --git a/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml b/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml index b3053317b53..cf3b34a82fa 100644 --- a/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml +++ b/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml @@ -19,3 +19,9 @@ rules: - get - list - watch +- apiGroups: + - '' + resources: + - namespaces + verbs: + - get \ No newline at end of file diff --git a/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml b/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml index 1e165def422..a5e7a9fc26c 100644 --- a/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml +++ b/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml @@ -7,7 +7,14 @@ spec: spec: containers: - name: ml-pipeline-persistenceagent + envFrom: + - configMapRef: + name: persistenceagent-config env: - name: NAMESPACE value: '' valueFrom: null + - name: KUBEFLOW_USERID_HEADER + value: kubeflow-userid + - name: KUBEFLOW_USERID_PREFIX + value: "" \ No newline at end of file diff --git a/manifests/kustomize/base/installs/multi-user/persistence-agent/kustomization.yaml b/manifests/kustomize/base/installs/multi-user/persistence-agent/kustomization.yaml index b1f65469e1d..560e0fc893c 100644 --- a/manifests/kustomize/base/installs/multi-user/persistence-agent/kustomization.yaml +++ b/manifests/kustomize/base/installs/multi-user/persistence-agent/kustomization.yaml @@ -3,3 +3,7 @@ kind: Kustomization resources: - cluster-role.yaml - cluster-role-binding.yaml +configMapGenerator: +- name: persistenceagent-config + envs: + - params.env \ No newline at end of file diff --git a/manifests/kustomize/base/installs/multi-user/persistence-agent/params.env b/manifests/kustomize/base/installs/multi-user/persistence-agent/params.env new file mode 100644 index 00000000000..4c3bab70f9d --- /dev/null +++ b/manifests/kustomize/base/installs/multi-user/persistence-agent/params.env @@ -0,0 +1 @@ +MULTIUSER=true diff --git a/manifests/kustomize/base/installs/multi-user/view-edit-cluster-roles.yaml b/manifests/kustomize/base/installs/multi-user/view-edit-cluster-roles.yaml index 626e005a945..abb531ee5a0 100644 --- a/manifests/kustomize/base/installs/multi-user/view-edit-cluster-roles.yaml +++ b/manifests/kustomize/base/installs/multi-user/view-edit-cluster-roles.yaml @@ -69,6 +69,8 @@ rules: - retry - terminate - unarchive + - reportMetrics + - readArtifact - apiGroups: - pipelines.kubeflow.org resources: @@ -111,11 +113,18 @@ rules: - pipelines - pipelines/versions - experiments - - runs - jobs verbs: - get - list +- apiGroups: + - pipelines.kubeflow.org + resources: + - runs + verbs: + - get + - list + - readArtifact - apiGroups: - kubeflow.org resources: diff --git a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml index bc5032e51a8..74c19c9d793 100644 --- a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml +++ b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml @@ -25,6 +25,10 @@ spec: value: "86400" - name: NUM_WORKERS value: "2" + - name: KUBEFLOW_USERID_HEADER + value: kubeflow-userid + - name: KUBEFLOW_USERID_PREFIX + value: "" image: gcr.io/ml-pipeline/persistenceagent:dummy imagePullPolicy: IfNotPresent name: ml-pipeline-persistenceagent diff --git a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml index 830ee8b14e7..2a288092c19 100644 --- a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml +++ b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml @@ -18,4 +18,10 @@ rules: verbs: - get - list - - watch \ No newline at end of file + - watch +- apiGroups: + - '' + resources: + - namespaces + verbs: + - get \ No newline at end of file