diff --git a/cmd/main.go b/cmd/main.go index 489b3b5..f62143a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -11,7 +11,6 @@ import ( "time" "github.com/prometheus/client_golang/prometheus/promhttp" - "sigs.k8s.io/knftables" "sigs.k8s.io/kube-network-policies/pkg/networkpolicy" npaclient "sigs.k8s.io/network-policy-api/pkg/client/clientset/versioned" npainformers "sigs.k8s.io/network-policy-api/pkg/client/informers/externalversions" @@ -58,11 +57,6 @@ func main() { klog.Infof("flags: %v", flag.Args()) - nft, err := knftables.New(knftables.InetFamily, "kube-network-policies") - if err != nil { - klog.Fatalf("Error initializing nftables: %v", err) - } - if _, _, err := net.SplitHostPort(metricsBindAddress); err != nil { klog.Fatalf("error parsing metrics bind address %s : %v", metricsBindAddress, err) } @@ -138,9 +132,8 @@ func main() { utilruntime.HandleError(err) }() - networkPolicyController := networkpolicy.NewController( + networkPolicyController, err := networkpolicy.NewController( clientset, - nft, informersFactory.Networking().V1().NetworkPolicies(), informersFactory.Core().V1().Namespaces(), informersFactory.Core().V1().Pods(), @@ -150,6 +143,9 @@ func main() { banpInformer, cfg, ) + if err != nil { + klog.Fatalf("Can not start network policy controller: %v", err) + } go func() { err := networkPolicyController.Run(ctx) utilruntime.HandleError(err) diff --git a/pkg/networkpolicy/controller.go b/pkg/networkpolicy/controller.go index ab0d1f2..10211b0 100644 --- a/pkg/networkpolicy/controller.go +++ b/pkg/networkpolicy/controller.go @@ -70,6 +70,36 @@ type Config struct { // NewController returns a new *Controller. func NewController(client clientset.Interface, + networkpolicyInformer networkinginformers.NetworkPolicyInformer, + namespaceInformer coreinformers.NamespaceInformer, + podInformer coreinformers.PodInformer, + nodeInformer coreinformers.NodeInformer, + npaClient npaclient.Interface, + adminNetworkPolicyInformer policyinformers.AdminNetworkPolicyInformer, + baselineAdminNetworkPolicyInformer policyinformers.BaselineAdminNetworkPolicyInformer, + config Config, +) (*Controller, error) { + klog.V(2).Info("Initializing nftables") + nft, err := knftables.New(knftables.InetFamily, "kube-network-policies") + if err != nil { + return nil, err + } + + return newController( + client, + nft, + networkpolicyInformer, + namespaceInformer, + podInformer, + nodeInformer, + npaClient, + adminNetworkPolicyInformer, + baselineAdminNetworkPolicyInformer, + config, + ) +} + +func newController(client clientset.Interface, nft knftables.Interface, networkpolicyInformer networkinginformers.NetworkPolicyInformer, namespaceInformer coreinformers.NamespaceInformer, @@ -79,7 +109,7 @@ func NewController(client clientset.Interface, adminNetworkPolicyInformer policyinformers.AdminNetworkPolicyInformer, baselineAdminNetworkPolicyInformer policyinformers.BaselineAdminNetworkPolicyInformer, config Config, -) *Controller { +) (*Controller, error) { klog.V(2).Info("Creating event broadcaster") broadcaster := record.NewBroadcaster() broadcaster.StartStructuredLogging(0) @@ -113,7 +143,7 @@ func NewController(client clientset.Interface, }, }) if err != nil { - panic(err) + return nil, err } podIndexer := podInformer.Informer().GetIndexer() @@ -150,7 +180,7 @@ func NewController(client clientset.Interface, } err = podInformer.Informer().SetTransform(trim) if err != nil { - utilruntime.HandleError(err) + return nil, err } // process only local Pods that are affected by network policices @@ -257,7 +287,7 @@ func NewController(client clientset.Interface, c.eventBroadcaster = broadcaster c.eventRecorder = recorder - return c + return c, nil } // Controller manages selector-based networkpolicy endpoints. diff --git a/pkg/networkpolicy/controller_test.go b/pkg/networkpolicy/controller_test.go index b1f448c..42f48b1 100644 --- a/pkg/networkpolicy/controller_test.go +++ b/pkg/networkpolicy/controller_test.go @@ -85,15 +85,16 @@ type networkpolicyController struct { nodeStore cache.Store } -func newController() *networkpolicyController { +func newTestController() *networkpolicyController { client := fake.NewSimpleClientset() informersFactory := informers.NewSharedInformerFactory(client, 0) npaClient := npaclientfake.NewSimpleClientset() npaInformerFactory := npainformers.NewSharedInformerFactory(npaClient, 0) - controller := NewController(client, - &knftables.Fake{}, + controller, err := newController( + client, + knftables.NewFake(knftables.InetFamily, "kube-network-policies"), informersFactory.Networking().V1().NetworkPolicies(), informersFactory.Core().V1().Namespaces(), informersFactory.Core().V1().Pods(), @@ -106,6 +107,9 @@ func newController() *networkpolicyController { BaselineAdminNetworkPolicy: true, }, ) + if err != nil { + panic(err) + } controller.networkpoliciesSynced = alwaysReady controller.namespacesSynced = alwaysReady controller.podsSynced = alwaysReady diff --git a/pkg/networkpolicy/networkpolicy_test.go b/pkg/networkpolicy/networkpolicy_test.go index 7d14393..ae55630 100644 --- a/pkg/networkpolicy/networkpolicy_test.go +++ b/pkg/networkpolicy/networkpolicy_test.go @@ -418,7 +418,7 @@ func TestSyncPacket(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - controller := newController() + controller := newTestController() // Add objects to the Store for _, n := range tt.networkpolicy { err := controller.networkpolicyStore.Add(n) @@ -464,7 +464,7 @@ func TestController_evaluateSelectors(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := newController() + c := newTestController() // Add objects to the Store for _, n := range tt.networkpolicies { err := c.networkpolicyStore.Add(n) @@ -529,7 +529,7 @@ func TestController_evaluateIPBlocks(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := newController() + c := newTestController() if got := c.evaluateIPBlocks(tt.ipBlock, tt.ip); got != tt.want { t.Errorf("Controller.evaluateIPBlocks() = %v, want %v", got, tt.want) } @@ -625,7 +625,7 @@ func TestController_evaluatePorts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := newController() + c := newTestController() if got := c.evaluatePorts(tt.networkPolicyPorts, tt.pod, tt.port, tt.protocol); got != tt.want { t.Errorf("Controller.evaluatePorts() = %v, want %v", got, tt.want) } diff --git a/pkg/networkpolicy/networkpolicyapi_test.go b/pkg/networkpolicy/networkpolicyapi_test.go index 02253c2..acb449e 100644 --- a/pkg/networkpolicy/networkpolicyapi_test.go +++ b/pkg/networkpolicy/networkpolicyapi_test.go @@ -455,7 +455,7 @@ func Test_adminNetworkPolicyAction(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - controller := newController() + controller := newTestController() // Add objects to the Store for _, n := range tt.networkpolicy { err := controller.adminNetworkpolicyStore.Add(n) @@ -720,7 +720,7 @@ func TestController_getAdminNetworkPoliciesForPod(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - controller := newController() + controller := newTestController() // Add objects to the Store err := controller.adminNetworkpolicyStore.Add(tt.networkpolicy) if err != nil {