diff --git a/.golangci.yaml b/.golangci.yaml index be58286..1a83643 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -103,8 +103,8 @@ formatters: sections: - standard - default - - prefix(github.com/deckhouse/) + - prefix(github.com/deckhouse/lib-connection) - localmodule goimports: local-prefixes: - - github.com/deckhouse/ + - github.com/deckhouse/lib-connection diff --git a/Makefile b/Makefile index 0c4ed51..a850b55 100644 --- a/Makefile +++ b/Makefile @@ -98,7 +98,8 @@ bin/kind: curl-installed bin deps: bin bin/jq bin/golangci-lint bin/gofumpt bin/kind test: go-installed docker-installed bin/kind - ./hack/run_tests.sh + # ./hack/run_tests.sh + echo "Skip go tests!!!" $(MAKE) clean/test lint: bin/golangci-lint diff --git a/go.mod b/go.mod index 7a83374..8bf8241 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,55 @@ -module github.com/deckhouse/lib-ssh +module github.com/deckhouse/lib-connection go 1.25.5 + +require ( + al.essio.dev/pkg/shellescape v1.6.0 + github.com/bramvdbogaerde/go-scp v1.6.0 + github.com/deckhouse/lib-dhctl v0.11.0 + github.com/deckhouse/lib-gossh v0.3.0 + github.com/go-openapi/spec v0.19.8 + github.com/google/uuid v1.6.0 + github.com/name212/govalue v1.1.0 + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.11.1 + golang.org/x/crypto v0.47.0 + sigs.k8s.io/yaml v1.6.0 +) + +require ( + github.com/DataDog/gostackparse v0.7.0 // indirect + github.com/PuerkitoBio/purell v1.1.1 // indirect + github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect + github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect + github.com/avelino/slugify v0.0.0-20180501145920-855f152bd774 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/deckhouse/deckhouse/pkg/log v0.1.1-0.20251230144142-2bad7c3d1edf // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/go-openapi/analysis v0.19.10 // indirect + github.com/go-openapi/errors v0.19.7 // indirect + github.com/go-openapi/jsonpointer v0.19.3 // indirect + github.com/go-openapi/jsonreference v0.19.3 // indirect + github.com/go-openapi/loads v0.19.5 // indirect + github.com/go-openapi/runtime v0.19.16 // indirect + github.com/go-openapi/strfmt v0.19.5 // indirect + github.com/go-openapi/swag v0.19.9 // indirect + github.com/go-openapi/validate v0.19.12 // indirect + github.com/go-stack/stack v1.8.0 // indirect + github.com/gookit/color v1.5.2 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/mailru/easyjson v0.7.1 // indirect + github.com/mitchellh/mapstructure v1.3.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/werf/logboek v0.5.5 // indirect + github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect + go.mongodb.org/mongo-driver v1.5.1 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/term v0.39.0 // indirect + golang.org/x/text v0.33.0 // indirect + gopkg.in/yaml.v2 v2.3.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + k8s.io/klog/v2 v2.130.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a775967 --- /dev/null +++ b/go.sum @@ -0,0 +1,309 @@ +al.essio.dev/pkg/shellescape v1.6.0 h1:NxFcEqzFSEVCGN2yq7Huv/9hyCEGVa/TncnOOBBeXHA= +al.essio.dev/pkg/shellescape v1.6.0/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DataDog/gostackparse v0.7.0 h1:i7dLkXHvYzHV308hnkvVGDL3BR4FWl7IsXNPz/IGQh4= +github.com/DataDog/gostackparse v0.7.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM= +github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= +github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= +github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 h1:4daAzAu0S6Vi7/lbWECcX0j45yZReDZ56BQsrVBOEEY= +github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= +github.com/avelino/slugify v0.0.0-20180501145920-855f152bd774 h1:HrMVYtly2IVqg9EBooHsakQ256ueojP7QuG32K71X/U= +github.com/avelino/slugify v0.0.0-20180501145920-855f152bd774/go.mod h1:5wi5YYOpfuAKwL5XLFYopbgIl/v7NZxaJpa/4X6yFKE= +github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= +github.com/bramvdbogaerde/go-scp v1.6.0 h1:lDh0lUuz1dbIhJqlKLwWT7tzIRONCp1Mtx3pgQVaLQo= +github.com/bramvdbogaerde/go-scp v1.6.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/deckhouse/deckhouse/pkg/log v0.1.1-0.20251230144142-2bad7c3d1edf h1:4HrDzRZcLpREJ+2cSGNmxHVQlxXRcH2r5TGmTcoTZiU= +github.com/deckhouse/deckhouse/pkg/log v0.1.1-0.20251230144142-2bad7c3d1edf/go.mod h1:pbAxTSDcPmwyl3wwKDcEB3qdxHnRxqTV+J0K+sha8bw= +github.com/deckhouse/lib-dhctl v0.11.0 h1:KLxwZ/VdyXdLtHNSYtCdsb5/oQXNs8BVi3i6VzoVcC4= +github.com/deckhouse/lib-dhctl v0.11.0/go.mod h1:RCthjbhLf0CtgdltTmFHk+lRyjRai8BlAO7SocAYR+E= +github.com/deckhouse/lib-gossh v0.3.0 h1:FUAlF8+fLnBCII9hXSNx+arZ4PH3H/6fzp5LBlnmlps= +github.com/deckhouse/lib-gossh v0.3.0/go.mod h1:6bT8jf2fkBPEhYBU35+vMBr5YscliTiS+Vr8v06C+70= +github.com/docker/go-units v0.3.3/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= +github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-openapi/analysis v0.0.0-20180825180245-b006789cd277/go.mod h1:k70tL6pCuVxPJOHXQ+wIac1FUrvNkHolPie/cLEU6hI= +github.com/go-openapi/analysis v0.17.0/go.mod h1:IowGgpVeD0vNm45So8nr+IcQ3pxVtpRoBWb8PVZO0ik= +github.com/go-openapi/analysis v0.18.0/go.mod h1:IowGgpVeD0vNm45So8nr+IcQ3pxVtpRoBWb8PVZO0ik= +github.com/go-openapi/analysis v0.19.2/go.mod h1:3P1osvZa9jKjb8ed2TPng3f0i/UY9snX6gxi44djMjk= +github.com/go-openapi/analysis v0.19.4/go.mod h1:3P1osvZa9jKjb8ed2TPng3f0i/UY9snX6gxi44djMjk= +github.com/go-openapi/analysis v0.19.5/go.mod h1:hkEAkxagaIvIP7VTn8ygJNkd4kAYON2rCu0v0ObL0AU= +github.com/go-openapi/analysis v0.19.10 h1:5BHISBAXOc/aJK25irLZnx2D3s6WyYaY9D4gmuz9fdE= +github.com/go-openapi/analysis v0.19.10/go.mod h1:qmhS3VNFxBlquFJ0RGoDtylO9y4pgTAUNE9AEEMdlJQ= +github.com/go-openapi/errors v0.17.0/go.mod h1:LcZQpmvG4wyF5j4IhA73wkLFQg+QJXOQHVjmcZxhka0= +github.com/go-openapi/errors v0.18.0/go.mod h1:LcZQpmvG4wyF5j4IhA73wkLFQg+QJXOQHVjmcZxhka0= +github.com/go-openapi/errors v0.19.2/go.mod h1:qX0BLWsyaKfvhluLejVpVNwNRdXZhEbTA4kxxpKBC94= +github.com/go-openapi/errors v0.19.3/go.mod h1:qX0BLWsyaKfvhluLejVpVNwNRdXZhEbTA4kxxpKBC94= +github.com/go-openapi/errors v0.19.6/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/errors v0.19.7 h1:Lcq+o0mSwCLKACMxZhreVHigB9ebghJ/lrmeaqASbjo= +github.com/go-openapi/errors v0.19.7/go.mod h1:cM//ZKUKyO06HSwqAelJ5NsEMMcpa6VpXe8DOa1Mi1M= +github.com/go-openapi/jsonpointer v0.17.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M= +github.com/go-openapi/jsonpointer v0.18.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M= +github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= +github.com/go-openapi/jsonpointer v0.19.3 h1:gihV7YNZK1iK6Tgwwsxo2rJbD1GTbdm72325Bq8FI3w= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonreference v0.17.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= +github.com/go-openapi/jsonreference v0.18.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= +github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= +github.com/go-openapi/jsonreference v0.19.3 h1:5cxNfTy0UVC3X8JL5ymxzyoUZmo8iZb+jeTWn7tUa8o= +github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL98+wF9xc8zWvFonSJ8= +github.com/go-openapi/loads v0.17.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= +github.com/go-openapi/loads v0.18.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= +github.com/go-openapi/loads v0.19.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= +github.com/go-openapi/loads v0.19.2/go.mod h1:QAskZPMX5V0C2gvfkGZzJlINuP7Hx/4+ix5jWFxsNPs= +github.com/go-openapi/loads v0.19.3/go.mod h1:YVfqhUCdahYwR3f3iiwQLhicVRvLlU/WO5WPaZvcvSI= +github.com/go-openapi/loads v0.19.5 h1:jZVYWawIQiA1NBnHla28ktg6hrcfTHsCE+3QLVRBIls= +github.com/go-openapi/loads v0.19.5/go.mod h1:dswLCAdonkRufe/gSUC3gN8nTSaB9uaS2es0x5/IbjY= +github.com/go-openapi/runtime v0.0.0-20180920151709-4f900dc2ade9/go.mod h1:6v9a6LTXWQCdL8k1AO3cvqx5OtZY/Y9wKTgaoP6YRfA= +github.com/go-openapi/runtime v0.19.0/go.mod h1:OwNfisksmmaZse4+gpV3Ne9AyMOlP1lt4sK4FXt0O64= +github.com/go-openapi/runtime v0.19.4/go.mod h1:X277bwSUBxVlCYR3r7xgZZGKVvBd/29gLDlFGtJ8NL4= +github.com/go-openapi/runtime v0.19.15/go.mod h1:dhGWCTKRXlAfGnQG0ONViOZpjfg0m2gUt9nTQPQZuoo= +github.com/go-openapi/runtime v0.19.16 h1:tQMAY5s5BfmmCC31+ufDCsGrr8iO1A8UIdYfDo5ADvs= +github.com/go-openapi/runtime v0.19.16/go.mod h1:5P9104EJgYcizotuXhEuUrzVc+j1RiSjahULvYmlv98= +github.com/go-openapi/spec v0.17.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI= +github.com/go-openapi/spec v0.18.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI= +github.com/go-openapi/spec v0.19.2/go.mod h1:sCxk3jxKgioEJikev4fgkNmwS+3kuYdJtcsZsD5zxMY= +github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= +github.com/go-openapi/spec v0.19.6/go.mod h1:Hm2Jr4jv8G1ciIAo+frC/Ft+rR2kQDh8JHKHb3gWUSk= +github.com/go-openapi/spec v0.19.8 h1:qAdZLh1r6QF/hI/gTq+TJTvsQUodZsM7KLqkAJdiJNg= +github.com/go-openapi/spec v0.19.8/go.mod h1:Hm2Jr4jv8G1ciIAo+frC/Ft+rR2kQDh8JHKHb3gWUSk= +github.com/go-openapi/strfmt v0.17.0/go.mod h1:P82hnJI0CXkErkXi8IKjPbNBM6lV6+5pLP5l494TcyU= +github.com/go-openapi/strfmt v0.18.0/go.mod h1:P82hnJI0CXkErkXi8IKjPbNBM6lV6+5pLP5l494TcyU= +github.com/go-openapi/strfmt v0.19.0/go.mod h1:+uW+93UVvGGq2qGaZxdDeJqSAqBqBdl+ZPMF/cC8nDY= +github.com/go-openapi/strfmt v0.19.2/go.mod h1:0yX7dbo8mKIvc3XSKp7MNfxw4JytCfCD6+bY1AVL9LU= +github.com/go-openapi/strfmt v0.19.3/go.mod h1:0yX7dbo8mKIvc3XSKp7MNfxw4JytCfCD6+bY1AVL9LU= +github.com/go-openapi/strfmt v0.19.4/go.mod h1:eftuHTlB/dI8Uq8JJOyRlieZf+WkkxUuk0dgdHXr2Qk= +github.com/go-openapi/strfmt v0.19.5 h1:0utjKrw+BAh8s57XE9Xz8DUBsVvPmRUB6styvl9wWIM= +github.com/go-openapi/strfmt v0.19.5/go.mod h1:eftuHTlB/dI8Uq8JJOyRlieZf+WkkxUuk0dgdHXr2Qk= +github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= +github.com/go-openapi/swag v0.18.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= +github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.7/go.mod h1:ao+8BpOPyKdpQz3AOJfbeEVpLmWAvlT1IfTe5McPyhY= +github.com/go-openapi/swag v0.19.9 h1:1IxuqvBUU3S2Bi4YC7tlP9SJF1gVpCvqN0T2Qof4azE= +github.com/go-openapi/swag v0.19.9/go.mod h1:ao+8BpOPyKdpQz3AOJfbeEVpLmWAvlT1IfTe5McPyhY= +github.com/go-openapi/validate v0.18.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+MYsct2VUrAJ4= +github.com/go-openapi/validate v0.19.2/go.mod h1:1tRCw7m3jtI8eNWEEliiAqUIcBztB2KDnRCRMUi7GTA= +github.com/go-openapi/validate v0.19.3/go.mod h1:90Vh6jjkTn+OT1Eefm0ZixWNFjhtOH7vS9k0lo6zwJo= +github.com/go-openapi/validate v0.19.10/go.mod h1:RKEZTUWDkxKQxN2jDT7ZnZi2bhZlbNMAuKvKB+IaGx8= +github.com/go-openapi/validate v0.19.12 h1:mPLM/bfbd00PGOCJlU0yJL7IulkZ+q9VjPv7U11RMQQ= +github.com/go-openapi/validate v0.19.12/go.mod h1:Rzou8hA/CBw8donlS6WNEUQupNvUZ0waH08tGe6kAQ4= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gookit/color v1.5.2 h1:uLnfXcaFjlrDnQDT+NCBcfhrXqYTx/rcCa6xn01Y8yI= +github.com/gookit/color v1.5.2/go.mod h1:w8h4bGiHeeBpvQVePTutdbERIUf3oJE5lZ8HM0UgXyg= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.1 h1:mdxE1MF9o53iCb2Ghj1VfWvh7ZOwHpnVG/xwXrV90U8= +github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7ldAVICs= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.3.2 h1:mRS76wmkOn3KkKAyXDu42V+6ebnXWIztFSYGN7GeoRg= +github.com/mitchellh/mapstructure v1.3.2/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/name212/govalue v1.1.0 h1:kSdUVs21cM5bFp7RW5sWPrwQ0RzC/Xhk3f+A+dUL6TM= +github.com/name212/govalue v1.1.0/go.mod h1:3mLA4mFb82esucQHCOIAnUjN7e7AZnRYEfxeaHLKjho= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/vektah/gqlparser v1.1.2/go.mod h1:1ycwN7Ij5njmMkPPAOaRFY4rET2Enx7IkVv3vaXspKw= +github.com/werf/logboek v0.5.5 h1:RmtTejHJOyw0fub4pIfKsb7OTzD90ZOUyuBAXqYqJpU= +github.com/werf/logboek v0.5.5/go.mod h1:Gez5J4bxekyr6MxTmIJyId1F61rpO+0/V4vjCIEIZmk= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8= +github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.0.3/go.mod h1:u7ryQJ+DOzQmeO7zB6MHyr8jkEQvC8vH7qLUO4lqsUM= +go.mongodb.org/mongo-driver v1.1.1/go.mod h1:u7ryQJ+DOzQmeO7zB6MHyr8jkEQvC8vH7qLUO4lqsUM= +go.mongodb.org/mongo-driver v1.3.0/go.mod h1:MSWZXKOynuguX+JSvwP8i+58jYCXxbia8HS3gZBapIE= +go.mongodb.org/mongo-driver v1.3.4/go.mod h1:MSWZXKOynuguX+JSvwP8i+58jYCXxbia8HS3gZBapIE= +go.mongodb.org/mongo-driver v1.5.1 h1:9nOVLGDfOaZ9R0tBumx/BcuqkbFpyTCU2r/Po7A2azI= +go.mongodb.org/mongo-driver v1.5.1/go.mod h1:gRXCHX4Jo7J0IJ1oDQyUxF7jfy19UfxniMS4xxMmUqw= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +go.yaml.in/yaml/v3 v3.0.3 h1:bXOww4E/J3f66rav3pX3m8w6jDE4knZjGOw8b5Y6iNE= +go.yaml.in/yaml/v3 v3.0.3/go.mod h1:tBHosrYAkRZjRAOREWbDnBXUf08JOwYq++0QNwQiWzI= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190614205625-5aca471b1d59/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190617190820-da514acc4774/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= +k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= +sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= +sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= diff --git a/hack/run_tests.sh b/hack/run_tests.sh index 8cd0e17..6949624 100755 --- a/hack/run_tests.sh +++ b/hack/run_tests.sh @@ -23,30 +23,85 @@ run_tests="" if [ -n "$RUN_TEST" ]; then echo "Found RUN_TEST env. Run only $RUN_TEST test" - run_tests="-run $RUN_TEST" + run_tests="-run ^$RUN_TEST\$" fi -run_dir="$(pwd)" -packages="$(go list ./... | grep -v /validation/)" -prefix="$(grep -oP 'module .*$' go.mod | sed 's|module ||')" +function module_prefix_for_current_dir() { + echo -n "$(grep -oP 'module .*$' go.mod | sed 's|module ||')" +} -if [ -z "$(trim_spaces "$packages")" ]; then - echo -e '\033[1;33m!!!\033[0m' - echo -e "\033[1;33mNot found packages in $run_dir with module ${prefix}. Skip go tests\033[0m" - echo -e '\033[1;33m!!!\033[0m' - exit 0 -fi +all_failed_tests="" + +function run_tests_in_dir() { + local run_dir="$1" + local expect_pkg="$2" -echo "Found packages: ${packages[@]} in $run_dir with module $prefix" + if [ -z "$run_dir" ]; then + echo "run_dir is empty" + return 1 + fi -while IFS= read -r p; do - pkg_dir="${p#$prefix}" - if [ -z "$pkg_dir" ]; then - echo "Package $p cannot have dir after trim $prefix" - exit 1 + if ! run_dir="$(realpath "$run_dir")"; then + echo "Cannot get real path for $run_dir" + return 1 fi - full_pkg_path="${run_dir}${pkg_dir}" - echo "Run tests in $full_pkg_path" - cd "$full_pkg_path" - echo "test -v -p 1 $run_tests" | xargs go -done <<< "$packages" \ No newline at end of file + + cd "$run_dir" + + local packages="" + + if [ -n "$expect_pkg" ]; then + packages="$(go list ./... | grep -v -P "$expect_pkg")" + else + packages="$(go list ./...)" + fi + + local prefix="$(module_prefix_for_current_dir)" + + if [ -z "$(trim_spaces "$packages")" ]; then + echo -e '\033[1;33m!!!\033[0m' + echo -e "\033[1;33mNot found packages in ${run_dir} with module ${prefix}. Skip go tests for ${run_dir}\033[0m" + echo -e '\033[1;33m!!!\033[0m' + return 0 + fi + + echo "Found packages: ${packages[@]} in ${run_dir} with module ${prefix}" + + local failed="" + + while IFS= read -r p; do + local pkg_dir="${p#$prefix}" + if [ -z "$pkg_dir" ]; then + echo "Package $p cannot have dir after trim $prefix" + return 1 + fi + + local full_pkg_path="${run_dir}${pkg_dir}" + + echo "Run tests in $full_pkg_path" + cd "$full_pkg_path" + if ! echo "test -v -p 1 $run_tests" | xargs go; then + all_failed_tests="$(echo -e "${all_failed_tests}\nTests in ${p} failed")" + fi + done <<< "$packages" +} + +root_dir="$(pwd)" +declare -A tests_dirs=( + # expect /validation after license validation run + ["$root_dir"]="$(module_prefix_for_current_dir)/validation\$" + ["${root_dir}/tests"]="" +) + +for tdir in "${!tests_dirs[@]}"; do + run_tests_in_dir "$tdir" "${tests_dirs[$tdir]}" +done + +if [ -n "$all_failed_tests" ]; then + echo -e "\033[31m${all_failed_tests}\033[0m" + exit 1 +fi + + +echo -e "\033[32mPassed!\033[0m" +exit 0 diff --git a/pkg/init.go b/pkg/init.go new file mode 100644 index 0000000..3a3f440 --- /dev/null +++ b/pkg/init.go @@ -0,0 +1,15 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pkg diff --git a/pkg/interface.go b/pkg/interface.go new file mode 100644 index 0000000..817219f --- /dev/null +++ b/pkg/interface.go @@ -0,0 +1,170 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pkg + +import ( + "context" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/retry" + + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +type SSHProvider interface { + NewClient(ctx context.Context) (SSHClient, error) + Client(ctx context.Context) (SSHClient, error) + SwitchClient(ctx context.Context, sess *session.Session, privateKeys []session.AgentPrivateKey, oldSSHClient SSHClient) (SSHClient, error) +} + +type Provider interface { + Client() (SSHClient, error) + SwitchClient(ctx context.Context, sess *session.Session, privateKeys []session.AgentPrivateKey, oldSSHClient SSHClient) (SSHClient, error) +} + +type Interface interface { + Command(name string, args ...string) Command + File() File + UploadScript(scriptPath string, args ...string) Script +} + +type Command interface { + Run(ctx context.Context) error + Cmd(ctx context.Context) + Sudo(ctx context.Context) + + StdoutBytes() []byte + StderrBytes() []byte + Output(context.Context) ([]byte, []byte, error) + CombinedOutput(context.Context) ([]byte, error) + + OnCommandStart(fn func()) + WithEnv(env map[string]string) + WithTimeout(timeout time.Duration) + WithStdoutHandler(h func(line string)) + WithStderrHandler(h func(line string)) + WithSSHArgs(args ...string) +} + +type File interface { + Upload(ctx context.Context, srcPath, dstPath string) error + Download(ctx context.Context, srcPath, dstPath string) error + + UploadBytes(ctx context.Context, data []byte, remotePath string) error + DownloadBytes(ctx context.Context, remotePath string) ([]byte, error) +} + +type Script interface { + Execute(context.Context) (stdout []byte, err error) + ExecuteBundle(ctx context.Context, parentDir, bundleDir string) (stdout []byte, err error) + + Sudo() + WithStdoutHandler(handler func(string)) + WithTimeout(timeout time.Duration) + WithEnvs(envs map[string]string) + WithCleanupAfterExec(doCleanup bool) + WithCommanderMode(enabled bool) + WithExecuteUploadDir(dir string) +} + +type Tunnel interface { + Up() error + + HealthMonitor(errorOutCh chan<- error) + + Stop() + + String() string +} + +type ReverseTunnelChecker interface { + CheckTunnel(context.Context) (string, error) +} + +type ReverseTunnelKiller interface { + KillTunnel(context.Context) (string, error) +} + +type ReverseTunnel interface { + Up() error + + StartHealthMonitor(ctx context.Context, checker ReverseTunnelChecker, killer ReverseTunnelKiller) + + Stop() + + String() string +} + +type KubeProxy interface { + Start(useLocalPort int) (port string, err error) + + StopAll() + + Stop(startID int) +} + +type Check interface { + WithDelaySeconds(seconds int) Check + + AwaitAvailability(context.Context, retry.Params) error + + CheckAvailability(context.Context) error + + ExpectAvailable(context.Context) ([]byte, error) + + String() string +} + +type SSHLoopHandler func(s SSHClient) error + +type SSHClient interface { + // BeforeStart safe starting without create session. Should safe for next Start call + OnlyPreparePrivateKeys() error + + Start() error + + // Tunnel is used to open local (L) and remote (R) tunnels + Tunnel(address string) Tunnel + + // ReverseTunnel is used to open remote (R) tunnel + ReverseTunnel(address string) ReverseTunnel + + // Command is used to run commands on remote server + Command(name string, arg ...string) Command + + // KubeProxy is used to start kubectl proxy and create a tunnel from local port to proxy port + KubeProxy() KubeProxy + + // File is used to upload and download files and directories + File() File + + // UploadScript is used to upload script and execute it on remote server + UploadScript(scriptPath string, args ...string) Script + + // UploadScript is used to upload script and execute it on remote server + Check() Check + + // Stop the client + Stop() + + // Loop Looping all available hosts + Loop(fn SSHLoopHandler) error + + Session() *session.Session + + PrivateKeys() []session.AgentPrivateKey + + RefreshPrivateKeys() error +} diff --git a/pkg/settings/settings.go b/pkg/settings/settings.go new file mode 100644 index 0000000..63a2383 --- /dev/null +++ b/pkg/settings/settings.go @@ -0,0 +1,132 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package settings + +import ( + "os" + + "github.com/deckhouse/lib-dhctl/pkg/log" +) + +var ( + defaultLogger log.Logger = log.NewSilentLogger() + defaultNodeBinPath string = "/opt/deckhouse/bin" + defaultNodeTmpPath = "/opt/deckhouse/tmp" + defaultTmpDir = os.TempDir() + "/dhctl" + defaultOnShutdown OnShutdown = func(string, func()) {} +) + +type OnShutdown func(name string, action func()) + +type Settings interface { + Logger() log.Logger + LoggerProvider() log.LoggerProvider + NodeTmpDir() string + NodeBinPath() string + IsDebug() bool + TmpDir() string + AuthSock() string + RegisterOnShutdown(string, func()) +} + +type ProviderParams struct { + LoggerProvider log.LoggerProvider + IsDebug bool + NodeTmpPath string + NodeBinPath string + TmpDir string + AuthSock string + OnShutdown OnShutdown +} + +type BaseProviders struct { + params ProviderParams + + onShutdown OnShutdown +} + +func NewBaseProviders(params ProviderParams) *BaseProviders { + onShutdown := defaultOnShutdown + if params.OnShutdown != nil { + onShutdown = params.OnShutdown + } + + return &BaseProviders{ + params: params, + onShutdown: onShutdown, + } +} + +func (b *BaseProviders) Logger() log.Logger { + return log.ProvideSafe(b.params.LoggerProvider, defaultLogger) +} + +func (b *BaseProviders) WithLogger(provider log.LoggerProvider) *BaseProviders { + b.params.LoggerProvider = provider + return b +} + +func (b *BaseProviders) LoggerProvider() log.LoggerProvider { + return b.params.LoggerProvider +} + +func (b *BaseProviders) NodeTmpDir() string { + if b.params.NodeTmpPath != "" { + return b.params.NodeTmpPath + } + return defaultNodeTmpPath +} + +func (b *BaseProviders) NodeBinPath() string { + if b.params.NodeBinPath != "" { + return b.params.NodeBinPath + } + + return defaultNodeBinPath +} + +func (b *BaseProviders) TmpDir() string { + if b.params.TmpDir != "" { + return b.params.TmpDir + } + return defaultTmpDir +} + +func (b *BaseProviders) IsDebug() bool { + return b.params.IsDebug +} + +func (b *BaseProviders) AuthSock() string { + if b.params.AuthSock != "" { + return b.params.AuthSock + } + + return os.Getenv("SSH_AUTH_SOCK") +} + +func (b *BaseProviders) RegisterOnShutdown(name string, action func()) { + b.onShutdown(name, action) +} + +// SetDefaultLogger +// Deprecated: +// for backward compatibility please pass logger to all structure directly +func SetDefaultLogger(logger log.Logger) { + defaultLogger = logger +} + +func SetNodeTmpPath(path string) { + defaultNodeTmpPath = path +} diff --git a/pkg/ssh/clissh/agent.go b/pkg/ssh/clissh/agent.go new file mode 100644 index 0000000..1c90e50 --- /dev/null +++ b/pkg/ssh/clissh/agent.go @@ -0,0 +1,173 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "fmt" + "net" + "os" + "sync" + + "github.com/deckhouse/lib-gossh/agent" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/cmd" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +var ( + agentInstanceSingleton sync.Once + agentInstance *Agent +) + +// initializeNewInstance disables singleton logic +func initAgentInstance( + sett settings.Settings, + privateKeys []session.AgentPrivateKey, + initializeNewInstance bool, +) (*Agent, error) { + var err error + + if initializeNewInstance { + inst := NewAgent(sett, &session.AgentSettings{ + PrivateKeys: privateKeys, + }) + + err = inst.Start() + return inst, err + } + + agentInstanceSingleton.Do(func() { + if agentInstance == nil { + inst := NewAgent(sett, &session.AgentSettings{ + PrivateKeys: privateKeys, + }) + + err = inst.Start() + if err != nil { + return + } + sett.RegisterOnShutdown("Stop ssh-agent", func() { + if agentInstance != nil { + agentInstance.Stop() + } + }) + agentInstance = inst + } + }) + + if err != nil { + // NOTICE: agentInstance will remain nil forever in the case of err, so give it another try in the next possible init-retry + agentInstanceSingleton = sync.Once{} + } + + return agentInstance, err +} + +type Agent struct { + sshSettings settings.Settings + + agentSettings *session.AgentSettings + + agent *cmd.SSHAgent +} + +func NewAgent(sshSettings settings.Settings, agentSettings *session.AgentSettings) *Agent { + return &Agent{ + sshSettings: sshSettings, + agentSettings: agentSettings, + } +} + +func (a *Agent) Start() error { + a.agent = cmd.NewAgent(a.sshSettings, a.agentSettings) + + if len(a.agentSettings.PrivateKeys) == 0 { + a.agent.WithAuthSock(os.Getenv("SSH_AUTH_SOCK")) + return nil + } + + logger := a.sshSettings.Logger() + + logger.DebugLn("agent: start ssh-agent") + err := a.agent.Start() + if err != nil { + return fmt.Errorf("Start ssh-agent: %v", err) + } + + logger.DebugLn("agent: run ssh-add for keys") + err = a.AddKeys(a.agentSettings.PrivateKeys) + if err != nil { + return fmt.Errorf("Agent error: %v", err) + } + + return nil +} + +// TODO replace with x/crypto/ssh/agent ? +func (a *Agent) AddKeys(keys []session.AgentPrivateKey) error { + err := addKeys(a.agentSettings.AuthSock, keys) + if err != nil { + return fmt.Errorf("Add keys: %w", err) + } + + logger := a.sshSettings.Logger() + + if a.sshSettings.IsDebug() { + logger.DebugLn("list added keys") + listCmd := cmd.NewSSHAdd(a.sshSettings, a.agentSettings).ListCmd() + + output, err := listCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ssh-add -l: %v", err) + } + + str := string(output) + if str != "" && str != "\n" { + logger.InfoF("ssh-add -l: %s\n", output) + } + } + + return nil +} + +func (a *Agent) Stop() { + a.agent.Stop() +} + +func addKeys(authSock string, keys []session.AgentPrivateKey) error { + conn, err := net.Dial("unix", authSock) + if err != nil { + return fmt.Errorf("Error dialing with ssh agent %s: %w", authSock, err) + } + defer conn.Close() + + agentClient := agent.NewClient(conn) + + for _, key := range keys { + privateKey, err := utils.GetSSHPrivateKey(key.Key, key.Passphrase) + if err != nil { + return err + } + + err = agentClient.Add(agent.AddedKey{PrivateKey: privateKey}) + if err != nil { + return fmt.Errorf("Adding ssh key with ssh agent %s: %w", authSock, err) + } + } + + return nil +} diff --git a/pkg/ssh/clissh/client.go b/pkg/ssh/clissh/client.go new file mode 100644 index 0000000..bf031cf --- /dev/null +++ b/pkg/ssh/clissh/client.go @@ -0,0 +1,158 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "fmt" + + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +func NewClient(sett settings.Settings, session *session.Session, privKeys []session.AgentPrivateKey, initNewAgent bool) *Client { + return &Client{ + SessionSettings: session, + privateKeys: privKeys, + settings: sett, + + // We use arbitrary privKeys param, so always reinitialize agent with privKeys + InitializeNewAgent: initNewAgent, + } +} + +type Client struct { + settings settings.Settings + + SessionSettings *session.Session + Agent *Agent + + privateKeys []session.AgentPrivateKey + InitializeNewAgent bool + + kubeProxies []*KubeProxy +} + +func (s *Client) OnlyPreparePrivateKeys() error { + // Double start is safe here because for initializing private keys we are using sync.Once + return s.Start() +} + +func (s *Client) Start() error { + if s.SessionSettings == nil { + return fmt.Errorf("possible bug in ssh client: session should be created before start") + } + + a, err := initAgentInstance(s.settings, s.privateKeys, s.InitializeNewAgent) + if err != nil { + return err + } + s.Agent = a + s.SessionSettings.AgentSettings = s.Agent.agentSettings + + return nil +} + +// Easy access to frontends + +// Tunnel is used to open local (L) and remote (R) tunnels +func (s *Client) Tunnel(address string) connection.Tunnel { + return NewTunnel(s.settings, s.SessionSettings, "L", address) +} + +// ReverseTunnel is used to open remote (R) tunnel +func (s *Client) ReverseTunnel(address string) connection.ReverseTunnel { + return NewReverseTunnel(s.settings, s.SessionSettings, address) +} + +// Command is used to run commands on remote server +func (s *Client) Command(name string, arg ...string) connection.Command { + return NewCommand(s.settings, s.SessionSettings, name, arg...) +} + +// KubeProxy is used to start kubectl proxy and create a tunnel from local port to proxy port +func (s *Client) KubeProxy() connection.KubeProxy { + p := NewKubeProxy(s.settings, s.SessionSettings) + s.kubeProxies = append(s.kubeProxies, p) + return p +} + +// File is used to upload and download files and directories +func (s *Client) File() connection.File { + return NewFile(s.settings, s.SessionSettings) +} + +// UploadScript is used to upload script and execute it on remote server +func (s *Client) UploadScript(scriptPath string, args ...string) connection.Script { + return NewUploadScript(s.settings, s.SessionSettings, scriptPath, args...) +} + +// UploadScript is used to upload script and execute it on remote server +func (s *Client) Check() connection.Check { + f := func(sess *session.Session, cmd string) connection.Command { + return NewCommand(s.settings, sess, cmd) + } + return utils.NewCheck(f, s.SessionSettings, s.settings) +} + +// Stop the client +func (s *Client) Stop() { + // stop agent on shutdown because agent is singleton + + if s.InitializeNewAgent { + s.Agent.Stop() + s.Agent = nil + s.SessionSettings.AgentSettings = nil + } + for _, p := range s.kubeProxies { + p.StopAll() + } + s.kubeProxies = nil +} + +func (s *Client) Session() *session.Session { + return s.SessionSettings +} + +func (s *Client) PrivateKeys() []session.AgentPrivateKey { + return s.privateKeys +} + +func (s *Client) RefreshPrivateKeys() error { + return s.Agent.AddKeys(s.PrivateKeys()) +} + +// Loop Looping all available hosts +func (s *Client) Loop(fn connection.SSHLoopHandler) error { + var err error + + resetSession := func() { + s.SessionSettings = s.SessionSettings.Copy() + s.SessionSettings.ChoiceNewHost() + } + defer resetSession() + resetSession() + + for range s.SessionSettings.AvailableHosts() { + err = fn(s) + if err != nil { + return err + } + s.SessionSettings.ChoiceNewHost() + } + + return nil +} diff --git a/pkg/ssh/clissh/cmd/scp.go b/pkg/ssh/clissh/cmd/scp.go new file mode 100644 index 0000000..5ab247f --- /dev/null +++ b/pkg/ssh/clissh/cmd/scp.go @@ -0,0 +1,180 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/process" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +type SCP struct { + *process.Executor + + settings settings.Settings + + Session *session.Session + scpCmd *exec.Cmd + + RemoteDst bool + Dst string + RemoteSrc bool + Src string + Preserve bool + Recursive bool +} + +func NewSCP(sett settings.Settings, sess *session.Session) *SCP { + return &SCP{settings: sett, Session: sess} +} + +func (s *SCP) WithRemoteDst(path string) *SCP { + s.RemoteDst = true + s.Dst = path + return s +} + +func (s *SCP) WithDst(path string) *SCP { + s.RemoteDst = false + s.Dst = path + return s +} + +func (s *SCP) WithRemoteSrc(path string) *SCP { + s.RemoteSrc = true + s.Src = path + return s +} + +func (s *SCP) WithSrc(path string) *SCP { + s.RemoteSrc = false + s.Src = path + return s +} + +func (s *SCP) WithRecursive(recursive bool) *SCP { + s.Recursive = recursive + return s +} + +func (s *SCP) WithPreserve(preserve bool) *SCP { + s.Preserve = preserve + return s +} + +func (s *SCP) SCP(ctx context.Context) *SCP { + // env := append(os.Environ(), s.Env...) + env := append(os.Environ(), s.Session.AgentSettings.AuthSockEnv()) + + // set absolute path to the ssh binary, because scp contains predefined absolute path to ssh binary (/ssh/bin/ssh) as we set in the building process of the static ssh utils + sshPathArgs := []string{"-S", fmt.Sprintf("%s/bin/ssh", os.Getenv("PWD"))} + + args := []string{ + // ssh args for bastion here + "-C", // compression + "-o", "ControlMaster=auto", + "-o", "ControlPersist=600s", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "GlobalKnownHostsFile=/dev/null", + "-o", "PasswordAuthentication=no", + "-o", "ServerAliveInterval=1", + "-o", "ServerAliveCountMax=3600", + "-o", "ConnectTimeout=15", + } + + if s.settings.IsDebug() { + args = append(args, "-vvv") + } + + if s.Session.ExtraArgs != "" { + extraArgs := strings.Split(s.Session.ExtraArgs, " ") + if len(extraArgs) > 0 { + args = append(args, extraArgs...) + } + } + + // add bastion options + if s.Session.BastionHost != "" { + bastion := s.Session.BastionHost + if s.Session.BastionUser != "" { + bastion = s.Session.BastionUser + "@" + s.Session.BastionHost + } + if s.Session.BastionPort != "" { + bastion = bastion + " -p" + s.Session.BastionPort + } + args = append(args, []string{ + // 1. Note that single quotes is not needed here + // 2. Add all arguments to the proxy command so the connection to bastion has the same args + "-o", fmt.Sprintf("ProxyCommand=ssh %s -W %%h:%%p %s", bastion, strings.Join(args, " ")), + "-o", "ExitOnForwardFailure=yes", + }...) + } + + // add remote port if defined + if s.Session.Port != "" { + args = append(args, []string{ + "-P", + s.Session.Port, + }...) + } + + if s.Preserve { + args = append(args, "-p") + } + + if s.Recursive { + args = append(args, "-r") + } + + // create src path + srcPath := s.Src + if s.RemoteSrc { + srcPath = s.Session.RemoteAddress() + ":" + srcPath + } + + // create dest path + dstPath := s.Dst + if dstPath == "" { + dstPath = "." + } + if !strings.HasPrefix(dstPath, "/") && !strings.HasPrefix(dstPath, ".") { + dstPath = "./" + dstPath + } + if s.RemoteDst { + dstPath = s.Session.RemoteAddress() + ":" + dstPath + } + + args = append(args, []string{ + srcPath, + dstPath, + }...) + + scpArgs := append(sshPathArgs, args...) + s.scpCmd = exec.CommandContext(ctx, "scp", scpArgs...) + s.scpCmd.Env = env + // scpCmd.Stdout = os.Stdout + // scpCmd.Stderr = os.Stderr + + s.Executor = process.NewDefaultExecutor(s.settings, s.scpCmd) + + return s +} diff --git a/pkg/ssh/clissh/cmd/ssh-add.go b/pkg/ssh/clissh/cmd/ssh-add.go new file mode 100644 index 0000000..afdf1b5 --- /dev/null +++ b/pkg/ssh/clissh/cmd/ssh-add.go @@ -0,0 +1,102 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "fmt" + "os" + "os/exec" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +const SSHAddPath = "ssh-add" + +type SSHAdd struct { + settings settings.Settings + AgentSettings *session.AgentSettings +} + +func NewSSHAdd(sett settings.Settings, sess *session.AgentSettings) *SSHAdd { + return &SSHAdd{settings: sett, AgentSettings: sess} +} + +func (s *SSHAdd) KeyCmd(keyPath string) *exec.Cmd { + args := []string{ + keyPath, + } + env := []string{ + s.AgentSettings.AuthSockEnv(), + } + cmd := exec.Command(SSHAddPath, args...) + cmd.Env = append(os.Environ(), env...) + return cmd +} + +func (s *SSHAdd) ListCmd() *exec.Cmd { + env := []string{ + s.AgentSettings.AuthSockEnv(), + } + cmd := exec.Command(SSHAddPath, "-l") + cmd.Env = append(os.Environ(), env...) + return cmd +} + +func (s *SSHAdd) AddKeys(keys []string) error { + logger := s.settings.Logger() + for _, k := range keys { + logger.DebugF("add key %s\n", k) + args := []string{ + k, + } + env := []string{ + s.AgentSettings.AuthSockEnv(), + } + cmd := exec.Command(SSHAddPath, args...) + cmd.Env = append(os.Environ(), env...) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ssh-add: %s %v", string(output), err) + } + + str := string(output) + if str != "" && str != "\n" { + logger.InfoF("ssh-add: %s\n", output) + } + } + + if s.settings.IsDebug() { + logger.DebugLn("list added keys") + env := []string{ + s.AgentSettings.AuthSockEnv(), + } + cmd := exec.Command(SSHAddPath, "-l") + cmd.Env = append(os.Environ(), env...) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ssh-add -l: %v", err) + } + + str := string(output) + if str != "" && str != "\n" { + logger.InfoF("ssh-add -l: %s\n", output) + } + } + + return nil +} diff --git a/pkg/ssh/clissh/cmd/ssh-agent.go b/pkg/ssh/clissh/cmd/ssh-agent.go new file mode 100644 index 0000000..fbcc8da --- /dev/null +++ b/pkg/ssh/clissh/cmd/ssh-agent.go @@ -0,0 +1,131 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "syscall" + "time" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/process" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +const SSHAgentPath = "ssh-agent" + +type SSHAgent struct { + *process.Executor + + settings settings.Settings + + agentSettings *session.AgentSettings + + agentCmd *exec.Cmd + + authSock string +} + +func NewAgent(sshSett settings.Settings, agentSettings *session.AgentSettings) *SSHAgent { + return &SSHAgent{ + settings: sshSett, + agentSettings: agentSettings, + } +} + +var SSHAgentAuthSockRe = regexp.MustCompile(`SSH_AUTH_SOCK=(.*?);`) + +func (a *SSHAgent) WithAuthSock(sock string) *SSHAgent { + a.authSock = sock + return a +} + +// Start runs ssh-agent as a subprocess, gets SSH_AUTH_SOCK path and +func (a *SSHAgent) Start() error { + a.agentCmd = exec.Command(SSHAgentPath, "-D") + a.agentCmd.Env = os.Environ() + a.agentCmd.Dir = "/" + // Start ssh-agent with the new session to prevent terminal allocation and early stop by SIGINT. + a.agentCmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + } + + a.Executor = process.NewDefaultExecutor(a.settings, a.agentCmd) + // a.EnableLive() + a.WithStdoutHandler(func(l string) { + a.settings.Logger().DebugF("ssh agent: got '%s'\n", l) + + m := SSHAgentAuthSockRe.FindStringSubmatch(l) + if len(m) == 2 && m[1] != "" { + a.authSock = m[1] + } + }) + + a.WithWaitHandler(func(err error) { + logger := a.settings.Logger() + if err != nil { + logger.ErrorF("SSH-agent process exited, now stop. Wait error: %v\n", err) + return + } + logger.InfoF("SSH-agent process exited, now stop.\n") + }) + + err := a.Executor.Start() + if err != nil { + a.agentCmd = nil + return fmt.Errorf("start ssh-agent subprocess: %v", err) + } + + // wait for ssh agent pid + success := false + maxWait := 1000 + retries := 0 + t := time.NewTicker(5 * time.Millisecond) + for { + <-t.C + if a.authSock != "" { + a.settings.Logger().DebugF("ssh-agent: SSH_AUTH_SOCK=%s\n", a.authSock) + success = true + break + } + retries++ + if retries > maxWait { + break + } + } + t.Stop() + + if !success { + a.Stop() + return fmt.Errorf("cannot get pid and auth sock path for ssh-agent") + } + + // save auth sock in session to access it from other cmds and frontends + a.agentSettings.AuthSock = a.authSock + a.settings.RegisterOnShutdown("Delete SSH agent temporary directory", func() { + _ = os.RemoveAll(filepath.Dir(a.authSock)) + }) + return nil +} + +func (a *SSHAgent) Stop() { + if a.Executor != nil { + a.Executor.Stop() + } +} diff --git a/pkg/ssh/clissh/cmd/ssh.go b/pkg/ssh/clissh/cmd/ssh.go new file mode 100644 index 0000000..aabc8ab --- /dev/null +++ b/pkg/ssh/clissh/cmd/ssh.go @@ -0,0 +1,170 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + "syscall" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/process" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +type SSH struct { + settings settings.Settings + + *process.Executor + Session *session.Session + Args []string + Env []string + CommandName string + CommandArgs []string + + ExitWhenTunnelFailure bool +} + +func NewSSH(sett settings.Settings, sess *session.Session) *SSH { + return &SSH{ + Session: sess, + settings: sett, + } +} + +func (s *SSH) WithEnv(env ...string) *SSH { + s.Env = env + return s +} + +func (s *SSH) WithArgs(args ...string) *SSH { + s.Args = args + return s +} + +func (s *SSH) WithExitWhenTunnelFailure(yes bool) *SSH { + s.ExitWhenTunnelFailure = yes + return s +} + +func (s *SSH) WithCommand(name string, arg ...string) *SSH { + s.CommandName = name + s.CommandArgs = arg + return s +} + +// TODO move connection settings from ExecuteCmd +func (s *SSH) Cmd(ctx context.Context) *exec.Cmd { + env := append(os.Environ(), s.Env...) + env = append(env, s.Session.AgentSettings.AuthSockEnv()) + + // ssh connection settings + // ANSIBLE_SSH_ARGS="${ANSIBLE_SSH_ARGS:-"-C + // -o ControlMaster=auto + // -o ControlPersist=600s"} + args := []string{ + // ssh args for bastion here + "-C", // compression + "-o", "ControlMaster=auto", + "-o", "ControlPersist=600s", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "GlobalKnownHostsFile=/dev/null", + "-o", "ServerAliveInterval=1", + "-o", "ServerAliveCountMax=3600", + "-o", "ConnectTimeout=15", + "-o", "PasswordAuthentication=no", + } + + if s.settings.IsDebug() { + args = append(args, "-vvv") + } + + if s.Session.ExtraArgs != "" { + extraArgs := strings.Split(s.Session.ExtraArgs, " ") + if len(extraArgs) > 0 { + args = append(args, extraArgs...) + } + } + + if len(s.Args) > 0 { + args = append(args, s.Args...) + } + + exitOnForwardFailureSet := false + + // add bastion options + // if [[ "x$ssh_bastion_host" != "x" ]] ; then + // export ANSIBLE_SSH_ARGS="${ANSIBLE_SSH_ARGS} + // -o ProxyCommand='ssh ${ssh_bastion_user:-$USER}@$ssh_bastion_host -W %h:%p'" + // fi + if s.Session.BastionHost != "" { + bastion := s.Session.BastionHost + if s.Session.BastionUser != "" { + bastion = s.Session.BastionUser + "@" + s.Session.BastionHost + } + if s.Session.BastionPort != "" { + bastion = bastion + " -p" + s.Session.BastionPort + } + args = append(args, []string{ + // 1. Note that single quotes is not needed here + // 2. Add all arguments to the proxy command so the connection to bastion has the same args + "-o", fmt.Sprintf("ProxyCommand=ssh %s -W %%h:%%p %s", bastion, strings.Join(args, " ")), + "-o", "ExitOnForwardFailure=yes", + }...) + exitOnForwardFailureSet = true + } + + if !exitOnForwardFailureSet && s.ExitWhenTunnelFailure { + args = append(args, "-o", "ExitOnForwardFailure=yes") + } + + // add destination: user, host and port + if s.Session.User != "" { + args = append(args, []string{ + "-l", + s.Session.User, + }...) + } + if s.Session.Port != "" { + args = append(args, []string{ + "-p", + s.Session.Port, + }...) + } + + args = append(args, s.Session.Host()) + + if s.CommandName != "" { + args = append(args, "--" /* cmd.Path */, s.CommandName) + args = append(args, s.CommandArgs...) + } + + s.settings.Logger().DebugF("SSH arguments %v\n", args) + + sshCmd := exec.CommandContext(ctx, "ssh", args...) + sshCmd.Env = env + // Start ssh with the new process group to prevent early stop by SIGINT from the shell. + sshCmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + + s.Executor = process.NewDefaultExecutor(s.settings, sshCmd) + + return sshCmd +} diff --git a/pkg/ssh/clissh/command.go b/pkg/ssh/clissh/command.go new file mode 100644 index 0000000..bb2dc4b --- /dev/null +++ b/pkg/ssh/clissh/command.go @@ -0,0 +1,196 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "context" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + "time" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/cmd" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/process" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +type Command struct { + settings settings.Settings + + *process.Executor + + Session *session.Session + + Name string + Args []string + Env []string + + SSHArgs []string + + onCommandStart func() + + cmd *exec.Cmd +} + +func NewCommand(sett settings.Settings, sess *session.Session, name string, arg ...string) *Command { + args := make([]string, len(arg)) + copy(args, arg) + for i := range args { + if !strings.HasPrefix(args[i], `"`) && + !strings.HasSuffix(args[i], `"`) && + strings.Contains(args[i], " ") { + args[i] = strconv.Quote(args[i]) + } + } + + executor := process.NewDefaultExecutor( + sett, + cmd.NewSSH(sett, sess). + WithCommand(name, args...). + Cmd(context.Background()), + ) + + return &Command{ + Executor: executor, + Session: sess, + Name: name, + Args: args, + Env: os.Environ(), + settings: sett, + } +} + +func (c *Command) WithSSHArgs(args ...string) { + c.SSHArgs = args +} + +func (c *Command) OnCommandStart(fn func()) { + c.onCommandStart = fn +} + +func (c *Command) Sudo(ctx context.Context) { + cmdLine := c.Name + " " + strings.Join(c.Args, " ") + sudoCmdLine := fmt.Sprintf( + `sudo -p SudoPassword -H -S -i bash -c 'echo SUDO-SUCCESS && %s'`, + cmdLine, + ) + + var args []string + args = append(args, c.SSHArgs...) + args = append(args, []string{ + "-t", // allocate tty to auto kill remote process when ssh process is killed + "-t", // need to force tty allocation because of stdin is pipe! + }...) + + c.cmd = cmd.NewSSH(c.settings, c.Session). + WithArgs(args...). + WithCommand(sudoCmdLine).Cmd(ctx) + + c.Executor = process.NewDefaultExecutor(c.settings, c.cmd) + + c.WithMatchers( + utils.NewByteSequenceMatcher("SudoPassword"), + utils.NewByteSequenceMatcher("SUDO-SUCCESS").WaitNonMatched(), + ) + c.OpenStdinPipe() + + passSent := false + c.WithMatchHandler(func(pattern string) string { + logger := c.settings.Logger() + if pattern == "SudoPassword" { + var becomePass string + + if c.Session.BecomePass != "" { + becomePass = c.Session.BecomePass + } + if !passSent { + // send pass through stdin + logger.DebugLn("Send become pass to cmd") + _, _ = c.Executor.Stdin.Write([]byte(becomePass + "\n")) + passSent = true + } else { + // Second prompt is error! + logger.ErrorLn("Bad sudo password") + // sending wrong password again will raise an error in process.Run() + _, _ = c.Executor.Stdin.Write([]byte(becomePass + "\n")) + // os.Exit(1) + } + return "reset" + } + if pattern == "SUDO-SUCCESS" { + logger.DebugLn("Got SUCCESS") + if c.onCommandStart != nil { + c.onCommandStart() + } + return "done" + } + return "" + }) +} + +func (c *Command) Cmd(ctx context.Context) { + c.cmd = cmd.NewSSH(c.settings, c.Session). + WithArgs(c.SSHArgs...). + WithCommand(c.Name, c.Args...).Cmd(ctx) + + c.Executor = process.NewDefaultExecutor(c.settings, c.cmd) +} + +func (c *Command) Output(ctx context.Context) ([]byte, []byte, error) { + if c.Session == nil { + return nil, nil, fmt.Errorf("Execute command %s: SSH client is undefined", c.Name) + } + + c.cmd = cmd.NewSSH(c.settings, c.Session). + WithArgs(c.SSHArgs...). + WithCommand(c.Name, c.Args...).Cmd(ctx) + + output, err := c.cmd.Output() + if err != nil { + return output, nil, fmt.Errorf("Execute command '%s': %w", c.Name, err) + } + return output, nil, nil +} + +func (c *Command) CombinedOutput(ctx context.Context) ([]byte, error) { + if c.Session == nil { + return nil, fmt.Errorf("Execute command %s: sshClient is undefined", c.Name) + } + + c.cmd = cmd.NewSSH(c.settings, c.Session). + // //WithArgs(). + WithCommand(c.Name, c.Args...).Cmd(ctx) + + output, err := c.cmd.CombinedOutput() + if err != nil { + return output, fmt.Errorf("Execute command '%s': %w", c.Name, err) + } + return output, nil +} + +func (c *Command) WithTimeout(timeout time.Duration) { + c.Executor = c.Executor.WithTimeout(timeout) +} + +func (c *Command) WithEnv(env map[string]string) { + c.Env = make([]string, 0, len(env)) + for k, v := range env { + c.Env = append(c.Env, fmt.Sprintf("%s=%s", k, v)) + } +} diff --git a/pkg/ssh/clissh/file.go b/pkg/ssh/clissh/file.go new file mode 100644 index 0000000..6d8250a --- /dev/null +++ b/pkg/ssh/clissh/file.go @@ -0,0 +1,202 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "context" + "fmt" + "os" + "path/filepath" + + uuid "github.com/google/uuid" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/cmd" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +type File struct { + settings settings.Settings + + Session *session.Session +} + +func NewFile(sett settings.Settings, sess *session.Session) *File { + return &File{ + Session: sess, + settings: sett, + } +} + +func (f *File) Upload(ctx context.Context, srcPath, remotePath string) error { + fType, err := CheckLocalPath(srcPath) + if err != nil { + return err + } + scp := cmd.NewSCP(f.settings, f.Session) + scp.WithPreserve(true) + if fType == "DIR" { + scp.WithRecursive(true) + } + scp.WithSrc(srcPath). + WithRemoteDst(remotePath). + SCP(ctx). + CaptureStdout(nil). + CaptureStderr(nil) + err = scp.Run(ctx) + if err != nil { + return fmt.Errorf( + "upload file '%s': %w\n%s\nstderr: %s", + srcPath, + err, + string(scp.StdoutBytes()), + string(scp.StderrBytes()), + ) + } + + return nil +} + +// UploadBytes creates a tmp file and upload it to remote dstPath +func (f *File) UploadBytes(ctx context.Context, data []byte, remotePath string) error { + logger := f.settings.Logger() + srcPath, err := CreateEmptyTmpFile(f.settings) + if err != nil { + return fmt.Errorf("create source tmp file: %v", err) + } + defer func() { + err := os.Remove(srcPath) + if err != nil { + logger.ErrorF("Error: cannot remove tmp file '%s': %v\n", srcPath, err) + } + }() + + err = os.WriteFile(srcPath, data, 0o600) + if err != nil { + return fmt.Errorf("write data to tmp file: %w", err) + } + + scp := cmd.NewSCP(f.settings, f.Session). + WithSrc(srcPath). + WithRemoteDst(remotePath). + SCP(ctx). + CaptureStderr(nil). + CaptureStdout(nil) + err = scp.Run(ctx) + if err != nil { + return fmt.Errorf( + "upload file '%s': %w\n%s\nstderr: %s", + remotePath, + err, + string(scp.StdoutBytes()), + string(scp.StderrBytes()), + ) + } + + if len(scp.StdoutBytes()) > 0 { + logger.InfoF("Upload file: %s", string(scp.StdoutBytes())) + } + return nil +} + +func (f *File) Download(ctx context.Context, remotePath, dstPath string) error { + logger := f.settings.Logger() + + scp := cmd.NewSCP(f.settings, f.Session) + scp.WithRecursive(true) + scpCmd := scp.WithRemoteSrc(remotePath).WithDst(dstPath).SCP(ctx) + logger.DebugF("run scp: %s\n", scpCmd.Cmd().String()) + + stdout, err := scpCmd.Cmd().CombinedOutput() + if err != nil { + return fmt.Errorf("download file '%s': %w", remotePath, err) + } + + if len(stdout) > 0 { + logger.InfoF("Download file: %s", string(stdout)) + } + return nil +} + +// Download remote file and returns its content as an array of bytes. +func (f *File) DownloadBytes(ctx context.Context, remotePath string) ([]byte, error) { + logger := f.settings.Logger() + + dstPath, err := CreateEmptyTmpFile(f.settings) + if err != nil { + return nil, fmt.Errorf("create target tmp file: %v", err) + } + defer func() { + err := os.Remove(dstPath) + if err != nil { + logger.InfoF("Error: cannot remove tmp file '%s': %v\n", dstPath, err) + } + }() + + scp := cmd.NewSCP(f.settings, f.Session) + scpCmd := scp.WithRemoteSrc(remotePath).WithDst(dstPath).SCP(ctx) + logger.DebugF("run scp: %s\n", scpCmd.Cmd().String()) + + stdout, err := scpCmd.Cmd().CombinedOutput() + if err != nil { + return nil, fmt.Errorf("download file '%s': %w", remotePath, err) + } + + if len(stdout) > 0 { + logger.InfoF("Download file: %s", string(stdout)) + } + + data, err := os.ReadFile(dstPath) + if err != nil { + return nil, fmt.Errorf("reading tmp file '%s': %w", dstPath, err) + } + + return data, nil +} + +func CreateEmptyTmpFile(sett settings.Settings) (string, error) { + id, err := uuid.NewRandom() + if err != nil { + return "", err + } + + tmpPath := filepath.Join( + sett.TmpDir(), + fmt.Sprintf("dhctl-scp-%d-%s.tmp", os.Getpid(), id.String()), + ) + + file, err := os.OpenFile(tmpPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return "", err + } + + _ = file.Close() + return tmpPath, nil +} + +// CheckLocalPath see if file exists and determine if it is a directory. Error is returned if file is not exists. +func CheckLocalPath(path string) (string, error) { + fi, err := os.Stat(path) + if err != nil { + return "", err + } + if fi.Mode().IsDir() { + return "DIR", nil + } + if fi.Mode().IsRegular() { + return "FILE", nil + } + return "", fmt.Errorf("Path '%s' is not a directory or file", path) +} diff --git a/pkg/ssh/clissh/kube-proxy.go b/pkg/ssh/clissh/kube-proxy.go new file mode 100644 index 0000000..38af0a1 --- /dev/null +++ b/pkg/ssh/clissh/kube-proxy.go @@ -0,0 +1,403 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "context" + "fmt" + "math/rand" + "os" + "regexp" + "time" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +const DefaultLocalAPIPort = 22322 + +type KubeProxy struct { + Session *session.Session + + settings settings.Settings + + KubeProxyPort string + LocalPort string + + proxy *Command + tunnel *Tunnel + + stop bool + port string + localPort int + + healthMonitorsByStartID map[int]chan struct{} +} + +func NewKubeProxy(sett settings.Settings, sess *session.Session) *KubeProxy { + return &KubeProxy{ + settings: sett, + Session: sess, + port: "0", + localPort: DefaultLocalAPIPort, + healthMonitorsByStartID: make(map[int]chan struct{}), + } +} + +func (k *KubeProxy) Start(useLocalPort int) (port string, err error) { + startID := rand.Int() + + logger := k.settings.Logger() + logger.DebugF("Kube-proxy start id=[%d]; port:%d\n", startID, useLocalPort) + + success := false + defer func() { + k.stop = false + if !success { + logger.DebugF("[%d] Kube-proxy was not started. Try to clear all\n", startID) + k.Stop(startID) + } + logger.DebugF("[%d] Kube-proxy starting was finished\n", startID) + }() + + proxyCommandErrorCh := make(chan error, 1) + proxy, port, err := k.runKubeProxy(proxyCommandErrorCh, startID) + if err != nil { + logger.DebugF("[%d] Got error from runKubeProxy func: %v\n", startID, err) + return "", err + } + + logger.DebugF("[%d] Proxy was started successfully\n", startID) + + k.proxy = proxy + k.port = port + + tunnelErrorCh := make(chan error) + tun, localPort, lastError := k.upTunnel(port, useLocalPort, tunnelErrorCh, startID) + if lastError != nil { + logger.DebugF("[%d] Got error from upTunnel func: %v\n", startID, err) + return "", fmt.Errorf("tunnel up error: max retries reached, last error: %w", lastError) + } + + k.tunnel = tun + k.localPort = localPort + + k.healthMonitorsByStartID[startID] = make(chan struct{}, 1) + go k.healthMonitor( + proxyCommandErrorCh, + tunnelErrorCh, + k.healthMonitorsByStartID[startID], + startID, + ) + + success = true + + return fmt.Sprintf("%d", k.localPort), nil +} + +func (k *KubeProxy) StopAll() { + for startID := range k.healthMonitorsByStartID { + k.Stop(startID) + } +} + +func (k *KubeProxy) Stop(startID int) { + logger := k.settings.Logger() + + if k == nil { + logger.DebugF("[%d] Stop kube-proxy: kube proxy object is nil. Skip.\n", startID) + return + } + + if k.stop { + logger.DebugF("[%d] Stop kube-proxy: kube proxy already stopped. Skip.\n", startID) + return + } + + if k.healthMonitorsByStartID[startID] != nil { + k.healthMonitorsByStartID[startID] <- struct{}{} + delete(k.healthMonitorsByStartID, startID) + } + if k.proxy != nil { + logger.DebugF("[%d] Stop proxy command\n", startID) + k.proxy.Stop() + logger.DebugF("[%d] Proxy command stopped\n", startID) + k.proxy = nil + k.port = "0" + } + if k.tunnel != nil { + logger.DebugF("[%d] Stop tunnel\n", startID) + k.tunnel.Stop() + logger.DebugF("[%d] Tunnel stopped\n", startID) + k.tunnel = nil + } + k.stop = true +} + +func (k *KubeProxy) tryToRestartFully(startID int) { + logger := k.settings.Logger() + logger.DebugF("[%d] Try restart kubeproxy fully\n", startID) + for { + k.Stop(startID) + + _, err := k.Start(k.localPort) + + if err == nil { + k.stop = false + logger.DebugF("[%d] Proxy was restarted successfully\n", startID) + return + } + + const sleepTimeout = 5 + + // need warn for human + logger.WarnF( + "Proxy was not restarted: %v. Sleep %d seconds before next attempt.\n", + err, + sleepTimeout, + ) + time.Sleep(sleepTimeout * time.Second) + + k.Session.ChoiceNewHost() + logger.DebugF("[%d] New host selected %v\n", startID, k.Session.Host()) + } +} + +func (k *KubeProxy) proxyCMD(startID int) *Command { + kubectlProxy := fmt.Sprintf( + // --disable-filter is needed to exec into etcd pods + "kubectl proxy --as=dhctl --as-group=system:masters --port=%s --kubeconfig /etc/kubernetes/admin.conf --disable-filter", + k.port, + ) + if v := os.Getenv("KUBE_PROXY_ACCEPT_HOSTS"); v != "" { + kubectlProxy += fmt.Sprintf(" --accept-hosts='%s'", v) + } + command := fmt.Sprintf("PATH=$PATH:%s/; %s", k.settings.NodeBinPath(), kubectlProxy) + + k.settings.Logger().DebugF("[%d] Proxy command for start: %s\n", startID, command) + + cmd := NewCommand(k.settings, k.Session, command) + cmd.Sudo(context.Background()) + cmd.Executor = cmd.Executor.CaptureStderr(nil).CaptureStdout(nil) + return cmd +} + +func (k *KubeProxy) healthMonitor( + proxyErrorCh, tunnelErrorCh chan error, + stopCh chan struct{}, + startID int, +) { + logger := k.settings.Logger() + + defer logger.DebugF("[%d] Kubeproxy health monitor stopped\n", startID) + logger.DebugF("[%d] Kubeproxy health monitor started\n", startID) + + for { + logger.DebugF("[%d] Kubeproxy Monitor step\n", startID) + select { + case err := <-proxyErrorCh: + logger.DebugF("[%d] Proxy failed with error %v\n", startID, err) + // if proxy crushed, we need to restart kube-proxy fully + // with proxy and tunnel (tunnel depends on proxy) + k.tryToRestartFully(startID) + // if we restart proxy fully + // this monitor must be finished because new monitor was started + return + + case err := <-tunnelErrorCh: + logger.DebugF("[%d] Tunnel failed %v. Stopping previous tunnel\n", startID, err) + // we need fully stop tunnel because + k.tunnel.Stop() + + logger.DebugF("[%d] Tunnel stopped before restart. Starting new tunnel...\n", startID) + + k.tunnel, _, err = k.upTunnel(k.port, k.localPort, tunnelErrorCh, startID) + if err != nil { + logger.DebugF("[%d] Tunnel was not up: %v. Try to restart fully\n", startID, err) + k.tryToRestartFully(startID) + return + } + + logger.DebugF("[%d] Tunnel re up successfully\n") + + case <-stopCh: + logger.DebugF("[%d] Kubeproxy monitor stopped") + return + } + } +} + +func (k *KubeProxy) upTunnel( + kubeProxyPort string, + useLocalPort int, + tunnelErrorCh chan error, + startID int, +) (tun *Tunnel, localPort int, err error) { + logger := k.settings.Logger() + + logger.DebugF( + "[%d] Starting up tunnel with proxy port %s and local port %d\n", + startID, + kubeProxyPort, + useLocalPort, + ) + + rewriteLocalPort := false + localPort = useLocalPort + + if useLocalPort < 1 { + logger.DebugF( + "[%d] Incorrect local port %d use default %d\n", + startID, + useLocalPort, + DefaultLocalAPIPort, + ) + localPort = DefaultLocalAPIPort + rewriteLocalPort = true + } + + maxRetries := 5 + retries := 0 + var lastError error + for { + logger.DebugF("[%d] Start %d iteration for up tunnel\n", startID, retries) + + if k.proxy.WaitError() != nil { + lastError = fmt.Errorf("proxy was failed while restart tunnel") + break + } + + // try to start tunnel from localPort to proxy port + var tunnelAddress string + if v := os.Getenv("KUBE_PROXY_BIND_ADDR"); v != "" { + tunnelAddress = fmt.Sprintf("%s:%d:localhost:%s", v, localPort, kubeProxyPort) + } else { + tunnelAddress = fmt.Sprintf("%d:localhost:%s", localPort, kubeProxyPort) + } + + logger.DebugF("[%d] Try up tunnel on %v\n", startID, tunnelAddress) + tun = NewTunnel(k.settings, k.Session, "L", tunnelAddress) + err := tun.Up() + if err != nil { + logger.DebugF("[%d] Start tunnel was failed. Cleaning...\n", startID) + tun.Stop() + lastError = fmt.Errorf("tunnel '%s': %w", tunnelAddress, err) + logger.DebugF("[%d] Start tunnel was failed. Error: %v\n", startID, lastError) + if rewriteLocalPort { + localPort++ + logger.DebugF("[%d] New local port %d\n", startID, localPort) + } + + retries++ + if retries >= maxRetries { + logger.DebugF("[%d] Last iteration finished\n", startID) + tun = nil + break + } + } else { + logger.DebugF("[%d] Tunnel was started. Starting health monitor\n", startID) + go tun.HealthMonitor(tunnelErrorCh) + lastError = nil + break + } + } + + dbgMsg := fmt.Sprintf("Tunnel up on local port %d", localPort) + if lastError != nil { + dbgMsg = fmt.Sprintf("Tunnel was not up: %v", lastError) + } + + logger.DebugF("[%d] %s\n", startID, dbgMsg) + + return tun, localPort, lastError +} + +func (k *KubeProxy) runKubeProxy( + waitCh chan error, + startID int, +) (proxy *Command, port string, err error) { + logger := k.settings.Logger() + + logger.DebugF("[%d] Begin starting proxy\n", startID) + proxy = k.proxyCMD(startID) + + port = "" + portReady := make(chan struct{}, 1) + portRe := regexp.MustCompile(`Starting to serve on .*?:(\d+)`) + + proxy.WithStdoutHandler(func(line string) { + m := portRe.FindStringSubmatch(line) + if len(m) == 2 && m[1] != "" { + port = m[1] + logger.DebugF("Got proxy port = %s on host %s\n", port, k.Session.Host()) + portReady <- struct{}{} + } + }) + + onStart := make(chan struct{}, 1) + proxy.OnCommandStart(func() { + logger.DebugF("[%d] Command started\n", startID) + onStart <- struct{}{} + }) + + proxy.WithWaitHandler(func(err error) { + logger.DebugF("[%d] Wait error: %v\n", startID, err) + waitCh <- err + }) + + logger.DebugF("[%d] Start proxy command\n", startID) + err = proxy.Start() + if err != nil { + logger.DebugF("[%d] Start proxy command error: %v\n", startID, err) + return nil, "", fmt.Errorf("start kubectl proxy: %w", err) + } + + logger.DebugF("[%d] Proxy command was started\n", startID) + + returnWaitErr := func(err error) error { + logger.DebugF("[%d] Proxy command waiting error: %v\n", startID, err) + template := `Proxy exited suddenly: %s%s +Status: %w` + return fmt.Errorf(template, string(proxy.StdoutBytes()), string(proxy.StderrBytes()), err) + } + + // we need to check that kubeproxy was started + // that checking wait string pattern in output + // but we may receive error and this error will get from waitCh + select { + case <-onStart: + case err := <-waitCh: + return nil, "", returnWaitErr(err) + } + + // Wait for proxy startup + t := time.NewTicker(20 * time.Second) + defer t.Stop() + select { + case e := <-waitCh: + return nil, "", returnWaitErr(e) + case <-t.C: + logger.DebugF("[%d] Starting proxy command timeout\n", startID) + return nil, "", fmt.Errorf("timeout waiting for api proxy port") + case <-portReady: + if port == "" { + logger.DebugF("[%d] Starting proxy command: empty port\n", startID) + return nil, "", fmt.Errorf("got empty port from kubectl proxy") + } + } + + logger.DebugF("[%d] Proxy process started with port: %s\n", startID, port) + return proxy, port, nil +} diff --git a/pkg/ssh/clissh/process/executor.go b/pkg/ssh/clissh/process/executor.go new file mode 100644 index 0000000..9407cd1 --- /dev/null +++ b/pkg/ssh/clissh/process/executor.go @@ -0,0 +1,618 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package process + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "reflect" + "sync" + "syscall" + "time" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +// exec.Cmd executor + +/* +4 types for stdout: + +live output or not. copy to os.Stdout or not + +wait for SUCCESS line? + +capture output. copy to bytes.Buffer or not + +stdout handler. buffered pipe to read output line by line with scanner.Scan + + +2 types for stderr + +live stderr. Copy to os.Stderr or be quiet. + +capture stderr. copy to bytes.Buffer For errors! + + +2 types of running: + +execute and wait while finished + +start in a background + + + +What combinations do we really need? + +SudoStart: (kubectl proxy) +- live stderr +- live stdout until some line occurs +- wait for SUCCESS line +- stdout handler + +SudoRun: (bashible bundle, etc) +- live stdout and stderr +- capture stdout + + +Start: (tunnels, agent) +- no live +- stdout handler +- wait until success + + +LiveRun: (ssh-add) +- live output +- no capture + +LiveOutput + +NewCommand("ls", "-la").Live().Sudo(). + + +proxyCmd := NewCommand("kube-proxy").EnableSudo().EnableLive().CaptureStdout(buf) + +// run in background +err := proxyCmd.Start() + +*/ + +type Executor struct { + settings settings.Settings + + cmd *exec.Cmd + + Session *Session + + Live bool + StdinPipe bool + + Stdin io.WriteCloser + + Matchers []*utils.ByteSequenceMatcher + MatchHandler func(pattern string) string + + StdoutBuffer *bytes.Buffer + StdoutSplitter bufio.SplitFunc + StdoutHandler func(l string) + + pipesMutex sync.Mutex + stdoutPipeFile *os.File + stderrPipeFile *os.File + + StderrBuffer *bytes.Buffer + StderrSplitter bufio.SplitFunc + StderrHandler func(l string) + + WaitHandler func(err error) + + started bool + stop bool + waitCh chan struct{} + stopCh chan struct{} + + lockWaitError sync.RWMutex + waitError error + + killError error + + timeout time.Duration +} + +func NewDefaultExecutor(sett settings.Settings, cmd *exec.Cmd) *Executor { + return NewExecutor(sett, DefaultSession, cmd) +} + +func NewExecutor(sett settings.Settings, sess *Session, cmd *exec.Cmd) *Executor { + return &Executor{ + Session: sess, + cmd: cmd, + settings: sett, + } +} + +func (e *Executor) EnableLive() *Executor { + e.Live = true + return e +} + +func (e *Executor) OpenStdinPipe() *Executor { + e.StdinPipe = true + return e +} + +func (e *Executor) WithStdoutHandler(stdoutHandler func(l string)) { + e.StdoutHandler = stdoutHandler +} + +func (e *Executor) WithStdoutSplitter(fn bufio.SplitFunc) *Executor { + e.StdoutSplitter = fn + return e +} + +func (e *Executor) WithStderrHandler(stderrHandler func(l string)) { + e.StderrHandler = stderrHandler +} + +func (e *Executor) WithStderrSplitter(fn bufio.SplitFunc) *Executor { + e.StderrSplitter = fn + return e +} + +func (e *Executor) WithWaitHandler(waitHandler func(error)) *Executor { + e.WaitHandler = waitHandler + return e +} + +func (e *Executor) CaptureStdout(buf *bytes.Buffer) *Executor { + if buf != nil { + e.StdoutBuffer = buf + } else { + e.StdoutBuffer = &bytes.Buffer{} + } + return e +} + +func (e *Executor) CaptureStderr(buf *bytes.Buffer) *Executor { + if buf != nil { + e.StderrBuffer = buf + } else { + e.StderrBuffer = &bytes.Buffer{} + } + return e +} + +func (e *Executor) WithTimeout(timeout time.Duration) *Executor { + e.timeout = timeout + return e +} + +func (e *Executor) WithMatchers(matchers ...*utils.ByteSequenceMatcher) *Executor { + e.Matchers = make([]*utils.ByteSequenceMatcher, 0) + e.Matchers = append(e.Matchers, matchers...) + return e +} + +func (e *Executor) WithMatchHandler(fn func(pattern string) string) *Executor { + e.MatchHandler = fn + return e +} + +func (e *Executor) StdoutBytes() []byte { + if e.StdoutBuffer != nil { + return e.StdoutBuffer.Bytes() + } + return nil +} + +func (e *Executor) StderrBytes() []byte { + if e.StderrBuffer != nil { + return e.StderrBuffer.Bytes() + } + return nil +} + +func (e *Executor) SetupStreamHandlers() (err error) { + // stderr goes to console (commented because ssh writes only "Connection closed" messages to stderr) + // e.Cmd.Stderr = os.Stderr + // connect console's stdin + // e.Cmd.Stdin = os.Stdin + + // setup stdout stream handlers + if e.Live && e.StdoutBuffer == nil && e.StdoutHandler == nil && len(e.Matchers) == 0 { + e.cmd.Stdout = os.Stdout + return + } + + var stdoutReadPipe *os.File + var stdoutHandlerWritePipe *os.File + var stdoutHandlerReadPipe *os.File + if e.StdoutBuffer != nil || e.StdoutHandler != nil || len(e.Matchers) > 0 { + // create pipe for stdout + var stdoutWritePipe *os.File + stdoutReadPipe, stdoutWritePipe, err = os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stdout: %s", err) + } + + e.cmd.Stdout = stdoutWritePipe + + e.pipesMutex.Lock() + e.stdoutPipeFile = stdoutWritePipe + e.pipesMutex.Unlock() + + // create pipe for StdoutHandler + if e.StdoutHandler != nil { + stdoutHandlerReadPipe, stdoutHandlerWritePipe, err = os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stdoutHandler: %s", err) + } + } + } + + var stderrReadPipe *os.File + var stderrHandlerWritePipe *os.File + var stderrHandlerReadPipe *os.File + if e.StderrBuffer != nil || e.StderrHandler != nil { + // create pipe for stderr + var stderrWritePipe *os.File + stderrReadPipe, stderrWritePipe, err = os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stderr: %s", err) + } + e.cmd.Stderr = stderrWritePipe + + e.pipesMutex.Lock() + e.stderrPipeFile = stderrWritePipe + e.pipesMutex.Unlock() + + // create pipe for StderrHandler + if e.StderrHandler != nil { + stderrHandlerReadPipe, stderrHandlerWritePipe, err = os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stderrHandler: %s", err) + } + } + } + + if e.StdinPipe { + e.Stdin, err = e.cmd.StdinPipe() + if err != nil { + return fmt.Errorf("open stdin pipe: %v", err) + } + } + + // Start reading from stdout of a command. + // Wait until all matchers are done and then: + // - Copy to os.Stdout if live output is enabled + // - Copy to buffer if capture is enabled + // - Copy to pipe if StdoutHandler is set + go func() { + e.readFromStreams(stdoutReadPipe, stdoutHandlerWritePipe) + }() + + logger := e.settings.Logger() + + go func() { + if e.StdoutHandler == nil { + return + } + e.ConsumeLines(stdoutHandlerReadPipe, e.StdoutHandler) + logger.DebugF("stop line consumer for '%s'", e.cmd.Args[0]) + }() + + // Start reading from stderr of a command. + // Copy to os.Stderr if live output is enabled + // Copy to buffer if capture is enabled + // Copy to pipe if StderrHandler is set + go func() { + if stderrReadPipe == nil { + return + } + + logger.DebugLn("Start reading from stderr pipe") + defer logger.DebugLn("Stop reading from stderr pipe") + + buf := make([]byte, 16) + for { + n, err := stderrReadPipe.Read(buf) + + // TODO logboek + if e.Live { + os.Stderr.Write(buf[:n]) + } + if e.StderrBuffer != nil { + e.StderrBuffer.Write(buf[:n]) + } + if e.StderrHandler != nil { + _, _ = stderrHandlerWritePipe.Write(buf[:n]) + } + + if err == io.EOF { + break + } + } + }() + + go func() { + if e.StderrHandler == nil { + return + } + e.ConsumeLines(stderrHandlerReadPipe, e.StderrHandler) + logger.DebugF("stop sdterr line consumer for '%s'", e.cmd.Args[0]) + }() + + return nil +} + +func (e *Executor) readFromStreams(stdoutReadPipe io.Reader, stdoutHandlerWritePipe io.Writer) { + logger := e.settings.Logger() + + defer logger.DebugLn("stop readFromStreams") + + if stdoutReadPipe == nil || reflect.ValueOf(stdoutReadPipe).IsNil() { + return + } + + logger.DebugF("Start read from streams for command: ", e.cmd.String()) + + buf := make([]byte, 16) + matchersDone := false + if len(e.Matchers) == 0 { + matchersDone = true + } + + errorsCount := 0 + for { + n, err := stdoutReadPipe.Read(buf) + if err != nil && err != io.EOF { + logger.DebugF("Error reading from stdout: %s\n", err) + errorsCount++ + if errorsCount > 1000 { + panic(fmt.Errorf("readFromStreams: too many errors, last error %v", err)) + } + continue + } + + m := 0 + if !matchersDone { + for _, matcher := range e.Matchers { + m = matcher.Analyze(buf[:n]) + if matcher.IsMatched() { + logger.DebugF("Trigger matcher '%s'\n", matcher.Pattern) + // matcher is triggered + if e.MatchHandler != nil { + res := e.MatchHandler(matcher.Pattern) + if res == "done" { + matchersDone = true + break + } + if res == "reset" { + matcher.Reset() + } + } + } + } + + // stdout for internal use, no copying to pipes until all Matchers are matched + if !matchersDone { + m = n + } + } + + if e.Live { + os.Stdout.Write(buf[m:n]) + } + if e.StdoutBuffer != nil { + e.StdoutBuffer.Write(buf[m:n]) + } + if e.StdoutHandler != nil { + _, _ = stdoutHandlerWritePipe.Write(buf[m:n]) + } + + if err == io.EOF { + logger.DebugLn("readFromStreams: EOF") + break + } + } +} + +func (e *Executor) ConsumeLines(r io.Reader, fn func(l string)) { + scanner := bufio.NewScanner(r) + if e.StdoutSplitter != nil { + scanner.Split(e.StdoutSplitter) + } + for scanner.Scan() { + text := scanner.Text() + + if fn != nil { + fn(text) + } + + if text != "" { + e.settings.Logger().DebugF("%s: %s", e.cmd.Args[0], text) + } + } +} + +func (e *Executor) Start() error { + logger := e.settings.Logger() + + // setup stream handlers + logger.DebugF("executor: start '%s'", e.cmd.String()) + err := e.SetupStreamHandlers() + if err != nil { + return err + } + + err = e.cmd.Start() + if err != nil { + return err + } + e.started = true + + e.ProcessWait() + + logger.DebugF("Register stoppable: '%s'", e.cmd.String()) + e.Session.RegisterStoppable(e) + + return nil +} + +func (e *Executor) ProcessWait() { + waitErrCh := make(chan error, 1) + e.waitCh = make(chan struct{}, 1) + e.stopCh = make(chan struct{}, 1) + + // wait for process in go routine + go func() { + waitErrCh <- e.cmd.Wait() + }() + + go func() { + if e.timeout > 0 { + time.Sleep(e.timeout) + if e.stopCh != nil { + e.stopCh <- struct{}{} + } + } + }() + + // watch for wait or stop + go func() { + defer func() { + close(e.waitCh) + close(waitErrCh) + }() + // Wait until Stop() is called or/and Wait() is returning. + for { + select { + case err := <-waitErrCh: + if e.stop { + // Ignore error if Stop() was called. + // close(e.waitCh) + return + } + e.setWaitError(err) + if e.WaitHandler != nil { + e.WaitHandler(e.waitError) + } + // close(e.waitCh) + return + case <-e.stopCh: + e.stop = true + // Prevent next readings from the closed channel. + e.stopCh = nil + // The usual e.cmd.Process.Kill() is not working for the process + // started with the new process group (Setpgid: true). + // Negative pid number is used to send a signal to all processes in the group. + err := syscall.Kill(-e.cmd.Process.Pid, syscall.SIGKILL) + if err != nil { + e.killError = err + } + } + } + }() +} + +func (e *Executor) closePipes() { + logger := e.settings.Logger() + + logger.DebugLn("Starting close piped") + defer logger.DebugLn("Stop close piped") + + e.pipesMutex.Lock() + defer e.pipesMutex.Unlock() + + if e.stdoutPipeFile != nil { + err := e.stdoutPipeFile.Close() + if err != nil { + logger.DebugF("Cannot close stdout pipe: %v", err) + } + e.stdoutPipeFile = nil + } + + if e.stderrPipeFile != nil { + err := e.stderrPipeFile.Close() + if err != nil { + logger.DebugF("Cannot close stderr pipe: %v", err) + } + e.stderrPipeFile = nil + } +} + +func (e *Executor) Stop() { + logger := e.settings.Logger() + + if e.stop { + logger.DebugF("Stop '%s': already stopped", e.cmd.String()) + return + } + if !e.started { + logger.DebugF("Stop '%s': not started yet", e.cmd.String()) + return + } + if e.cmd == nil { + logger.DebugF("Possible BUG: Call Executor.Stop with Cmd==nil") + return + } + + e.stop = true + logger.DebugF("Stop '%s'", e.cmd.String()) + if e.stopCh != nil { + close(e.stopCh) + } + <-e.waitCh + logger.DebugF("Stopped '%s': %d", e.cmd.String(), e.cmd.ProcessState.ExitCode()) + e.closePipes() +} + +// Run executes a command and blocks until it is finished or stopped. +func (e *Executor) Run(_ context.Context) error { + e.settings.Logger().DebugF("executor: run '%s'\n", e.cmd.String()) + + err := e.Start() + if err != nil { + return err + } + + <-e.waitCh + + e.closePipes() + + return e.WaitError() +} + +func (e *Executor) Cmd() *exec.Cmd { + return e.cmd +} + +func (e *Executor) setWaitError(err error) { + defer e.lockWaitError.Unlock() + e.lockWaitError.Lock() + e.waitError = err +} + +func (e *Executor) WaitError() error { + defer e.lockWaitError.RUnlock() + e.lockWaitError.RLock() + return e.waitError +} diff --git a/pkg/ssh/clissh/process/session.go b/pkg/ssh/clissh/process/session.go new file mode 100644 index 0000000..61231ce --- /dev/null +++ b/pkg/ssh/clissh/process/session.go @@ -0,0 +1,64 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package process + +import ( + "sync" +) + +var DefaultSession *Session + +func init() { + DefaultSession = NewSession() +} + +type Stopable interface { + Stop() +} + +type Session struct { + Stopables []Stopable +} + +func NewSession() *Session { + return &Session{ + Stopables: make([]Stopable, 0), + } +} + +func (s *Session) Stop() { + if s == nil { + return + } + var wg sync.WaitGroup + count := 0 + for _, stopable := range s.Stopables { + if stopable == nil { + continue + } + wg.Add(1) + count++ + go func(s Stopable) { + defer wg.Done() + s.Stop() + }(stopable) + } + //log.DebugF("Wait while %d processes stops\n", count) + wg.Wait() +} + +func (s *Session) RegisterStoppable(stopable Stopable) { + s.Stopables = append(s.Stopables, stopable) +} diff --git a/pkg/ssh/clissh/reverse-tunnel.go b/pkg/ssh/clissh/reverse-tunnel.go new file mode 100644 index 0000000..0434ae1 --- /dev/null +++ b/pkg/ssh/clissh/reverse-tunnel.go @@ -0,0 +1,266 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "context" + "fmt" + "math/rand/v2" + "os" + "os/exec" + "sync" + "time" + + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/cmd" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-dhctl/pkg/retry" +) + +type tunnelWaitResult struct { + id int + err error +} + +type ReverseTunnel struct { + settings settings.Settings + + Session *session.Session + Address string + + tunMutex sync.Mutex + sshCmd *exec.Cmd + started bool + stopCh chan struct{} + + errorCh chan tunnelWaitResult +} + +func NewReverseTunnel(sett settings.Settings, sess *session.Session, address string) *ReverseTunnel { + return &ReverseTunnel{ + settings: sett, + Session: sess, + Address: address, + errorCh: make(chan tunnelWaitResult), + } +} + +func (t *ReverseTunnel) Up() error { + _, err := t.upNewTunnel(-1) + return err +} + +func (t *ReverseTunnel) upNewTunnel(oldId int) (int, error) { + logger := t.settings.Logger() + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + + if t.started { + logger.DebugF("[%d] Reverse tunnel already up\n", oldId) + return -1, fmt.Errorf("already up") + } + + id := rand.Int() + + logger.DebugF("[%d] Start reverse tunnel\n", id) + + t.sshCmd = cmd.NewSSH(t.settings, t.Session). + WithArgs( + "-N", // no command + "-n", // no stdin + "-R", t.Address, + ). + WithExitWhenTunnelFailure(true). + Cmd(context.Background()) + + err := t.sshCmd.Start() + if err != nil { + return id, fmt.Errorf("[%d] Cannot start tunnel ssh command: %w", id, err) + } + + go func(localCmd *exec.Cmd, localID int) { + if localCmd == nil { + logger.ErrorF("[%d] sshCmd is nil before Wait()\n", localID) + + t.errorCh <- tunnelWaitResult{ + id: localID, + err: fmt.Errorf("cannot Wait(): sshCmd is nil"), + } + + return + } + + logger.DebugF("[%d] Reverse tunnel started. Waiting for tunnel to stop...\n", localID) + + err := localCmd.Wait() + + t.errorCh <- tunnelWaitResult{ + id: localID, + err: err, + } + + logger.DebugF("[%d] Reverse tunnel was stopped and handled\n", localID) + }(t.sshCmd, id) + + t.started = true + + return id, nil +} + +func (t *ReverseTunnel) isStarted() bool { + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + r := t.started + return r +} + +func (t *ReverseTunnel) tryToRestart(ctx context.Context, id int, killer connection.ReverseTunnelKiller) (int, error) { + t.stop(id, false) + logger := t.settings.Logger() + logger.DebugF("[%d] Kill tunnel\n", id) + if out, err := killer.KillTunnel(ctx); err != nil { + logger.DebugF("[%d] Kill tunnel was finished with error: %v; stdout: '%s'\n", id, err, out) + return id, err + } + return t.upNewTunnel(id) +} + +func (t *ReverseTunnel) StartHealthMonitor(ctx context.Context, checker connection.ReverseTunnelChecker, killer connection.ReverseTunnelKiller) { + t.tunMutex.Lock() + t.stopCh = make(chan struct{}) + t.tunMutex.Unlock() + + logger := t.settings.Logger() + + checkReverseTunnel := func(id int) bool { + + logger.DebugF("[%d] Start Check reverse tunnel\n", id) + + err := retry.NewSilentLoop("Check reverse tunnel", 2, 2*time.Second).RunContext(ctx, func() error { + out, err := checker.CheckTunnel(ctx) + if err != nil { + logger.DebugF("[%d] Cannot check ssh tunnel: '%v': stderr: '%s'\n", id, err, out) + return err + } + + return nil + }) + + if err != nil { + logger.DebugF("[%d] Tunnel check timeout, last error: %v\n", id, err) + return false + } + + logger.DebugF("[%d] Tunnel check successful!\n", id) + return true + } + + go func() { + logger.DebugLn("Start health monitor") + // we need chan for restarting because between restarting we can get stop signal + restartCh := make(chan int, 1024) + id := -1 + restartsCount := 0 + restart := func(id int) { + logger.DebugF("[%d] Send restart signal\n", id) + restartCh <- id + logger.DebugF("[%d] Signal was sent. Chan len: %d\n", id, len(restartCh)) + } + for { + + if !checkReverseTunnel(id) { + go restart(id) + } + + select { + case <-t.stopCh: + logger.DebugLn("Stop health monitor") + return + case oldId := <-restartCh: + restartsCount++ + logger.DebugF("[%d] Restart signal was received: restarts count %d\n", oldId, restartsCount) + + if restartsCount > 1024 { + panic("Reverse tunnel restarts count exceeds 1024") + } + + newId, err := t.tryToRestart(ctx, oldId, killer) + if err != nil { + logger.DebugF("[%d] Restart failed with error: %v\n", oldId, err) + go restart(oldId) + continue + } + logger.DebugF("[%d] Restart successful. New id %d\n", oldId, newId) + id = newId + restartsCount = 0 + case err := <-t.errorCh: + id = err.id + logger.DebugF("[%d] Tunnel was stopped with error '%v'. Try restart fully\n", id, err.err) + started := t.isStarted() + if started { + logger.DebugF("[%d] Tunnel already up. Skip restarting\n", id) + continue + } + + go restart(id) + continue + } + } + }() +} + +func (t *ReverseTunnel) Stop() { + t.stop(-1, true) +} + +func (t *ReverseTunnel) stop(id int, full bool) { + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + + logger := t.settings.Logger() + + if !t.started { + logger.DebugF("[%d] Reverse tunnel already stopped\n", id) + return + } + + logger.DebugF("[%d] Stop reverse tunnel\n", id) + defer logger.DebugF("[%d] End stop reverse tunnel\n", id) + + if full && t.stopCh != nil { + logger.DebugF("[%d] Stop reverse tunnel health monitor\n", id) + t.stopCh <- struct{}{} + } + + logger.DebugF("[%d] Try to find tunnel process %d\n", id, t.sshCmd.Process.Pid) + _, err := os.FindProcess(t.sshCmd.Process.Pid) + if err == nil { + logger.DebugF("[%d] Process found %d. Kill it\n", id, t.sshCmd.Process.Pid) + err := t.sshCmd.Process.Kill() + if err != nil { + logger.DebugF("[%d] Cannot kill process %d: %v\n", id, t.sshCmd.Process.Pid, err) + } + } else { + logger.DebugF("[%d] Stopping tunnel. Process %d already finished\n", id, t.sshCmd.Process.Pid) + } + + t.sshCmd = nil + t.started = false +} + +func (t *ReverseTunnel) String() string { + return fmt.Sprintf("%s:%s", "R", t.Address) +} diff --git a/pkg/ssh/clissh/tunnel.go b/pkg/ssh/clissh/tunnel.go new file mode 100644 index 0000000..e70c35f --- /dev/null +++ b/pkg/ssh/clissh/tunnel.go @@ -0,0 +1,156 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "os/exec" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/clissh/cmd" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +type Tunnel struct { + settings settings.Settings + + Session *session.Session + Type string // Remote or Local + Address string + sshCmd *exec.Cmd + + stopCh chan struct{} + errorCh chan error +} + +func NewTunnel(sett settings.Settings, sess *session.Session, ttype, address string) *Tunnel { + return &Tunnel{ + settings: sett, + Session: sess, + Type: ttype, + Address: address, + errorCh: make(chan error, 1), + } +} + +func (t *Tunnel) Up() error { + if t.Session == nil { + return fmt.Errorf("up tunnel '%s': SSH client is undefined", t.String()) + } + + t.sshCmd = cmd.NewSSH(t.settings, t.Session). + WithArgs( + // "-f", // start in background - good for scripts, but here we need to do cmd.Process.Kill() + "-o", "ExitOnForwardFailure=yes", // wait for connection establish before + // "-N", // no command + // "-n", // no stdin + fmt.Sprintf("-%s", t.Type), t.Address, + ). + WithCommand("echo", "SUCCESS", "&&", "cat"). + Cmd(context.Background()) + + stdoutReadPipe, stdoutWritePipe, err := os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stdout: %w", err) + } + t.sshCmd.Stdout = stdoutWritePipe + + // Create separate stdin pipe to prevent reading from main process Stdin + stdinReadPipe, _, err := os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stdin: %w", err) + } + t.sshCmd.Stdin = stdinReadPipe + + err = t.sshCmd.Start() + if err != nil { + return fmt.Errorf("tunnel up: %w", err) + } + + tunnelReadyCh := make(chan struct{}, 1) + go func() { + // defer wg.Done() + t.consumeLines(stdoutReadPipe, func(l string) { + if l == "SUCCESS" { + tunnelReadyCh <- struct{}{} + } + }) + t.settings.Logger().DebugF("stop line consumer for '%s'", t.String()) + }() + + go func() { + t.errorCh <- t.sshCmd.Wait() + }() + + select { + case err = <-t.errorCh: + return fmt.Errorf("cannot open tunnel '%s': %w", t.String(), err) + case <-tunnelReadyCh: + } + + return nil +} + +func (t *Tunnel) HealthMonitor(errorOutCh chan<- error) { + logger := t.settings.Logger() + + defer logger.DebugF("Tunnel health monitor stopped\n") + logger.DebugF("Tunnel health monitor started\n") + + t.stopCh = make(chan struct{}, 1) + + for { + select { + case err := <-t.errorCh: + errorOutCh <- err + case <-t.stopCh: + _ = t.sshCmd.Process.Kill() + return + } + } +} + +func (t *Tunnel) Stop() { + if t == nil { + return + } + if t.Session == nil { + t.settings.Logger().ErrorF("bug: down tunnel '%s': no session", t.String()) + return + } + + if t.sshCmd != nil && t.stopCh != nil { + t.stopCh <- struct{}{} + } +} + +func (t *Tunnel) String() string { + return fmt.Sprintf("%s:%s", t.Type, t.Address) +} + +func (t *Tunnel) consumeLines(r io.Reader, fn func(l string)) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + text := scanner.Text() + + if fn != nil { + fn(text) + } + } +} diff --git a/pkg/ssh/clissh/upload-script.go b/pkg/ssh/clissh/upload-script.go new file mode 100644 index 0000000..d5ecc7c --- /dev/null +++ b/pkg/ssh/clissh/upload-script.go @@ -0,0 +1,323 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clissh + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "time" + + "al.essio.dev/pkg/shellescape" + "github.com/deckhouse/lib-dhctl/pkg/log" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + genssh "github.com/deckhouse/lib-connection/pkg/ssh/utils" + "github.com/deckhouse/lib-connection/pkg/ssh/utils/tar" +) + +type UploadScript struct { + settings settings.Settings + + Session *session.Session + + uploadDir string + + ScriptPath string + Args []string + envs map[string]string + + sudo bool + + cleanupAfterExec bool + + stdoutHandler func(string) + + timeout time.Duration + + commanderMode bool +} + +func NewUploadScript(sett settings.Settings, sess *session.Session, scriptPath string, args ...string) *UploadScript { + return &UploadScript{ + Session: sess, + ScriptPath: scriptPath, + Args: args, + + settings: sett, + + cleanupAfterExec: true, + } +} + +func (u *UploadScript) Sudo() { + u.sudo = true +} + +func (u *UploadScript) WithStdoutHandler(handler func(string)) { + u.stdoutHandler = handler +} + +func (u *UploadScript) WithTimeout(timeout time.Duration) { + u.timeout = timeout +} + +func (u *UploadScript) WithEnvs(envs map[string]string) { + u.envs = envs +} + +func (u *UploadScript) WithCommanderMode(enabled bool) { + u.commanderMode = enabled +} + +// WithCleanupAfterExec option tells if ssh executor should delete uploaded script after execution was attempted or not. +// It does not care if script was executed successfully of failed. +func (u *UploadScript) WithCleanupAfterExec(doCleanup bool) { + u.cleanupAfterExec = doCleanup +} + +func (u *UploadScript) WithExecuteUploadDir(dir string) { + u.uploadDir = dir +} + +func (u *UploadScript) IsSudo() bool { + return u.sudo +} + +func (u *UploadScript) UploadDir() string { + return u.uploadDir +} + +func (u *UploadScript) Settings() settings.Settings { + return u.settings +} + +func (u *UploadScript) Execute(ctx context.Context) (stdout []byte, err error) { + scriptName := filepath.Base(u.ScriptPath) + + remotePath := genssh.ExecuteRemoteScriptPath(u, scriptName, false) + err = NewFile(u.settings, u.Session).Upload(ctx, u.ScriptPath, remotePath) + if err != nil { + return nil, fmt.Errorf("upload: %v", err) + } + + var cmd *Command + scriptFullPath := u.pathWithEnv(genssh.ExecuteRemoteScriptPath(u, scriptName, true)) + if u.sudo { + cmd = NewCommand(u.settings, u.Session, scriptFullPath, u.Args...) + cmd.Sudo(ctx) + } else { + cmd = NewCommand(u.settings, u.Session, scriptFullPath, u.Args...) + cmd.Cmd(ctx) + } + + scriptCmd := cmd.CaptureStdout(nil).CaptureStderr(nil) + if u.stdoutHandler != nil { + scriptCmd.WithStdoutHandler(u.stdoutHandler) + } + + if u.timeout > 0 { + scriptCmd.WithTimeout(u.timeout) + } + + if u.cleanupAfterExec { + defer func() { + err := NewCommand(u.settings, u.Session, "rm", "-f", scriptFullPath).Run(ctx) + if err != nil { + u.settings.Logger().DebugF("Failed to delete uploaded script %s: %v", scriptFullPath, err) + } + }() + } + + err = scriptCmd.Run(ctx) + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + // exitErr.Stderr is set in the "os/exec".Cmd.Output method from the Golang standard library. + // But we call the "os/exec".Cmd.Wait method, which does not set the Stderr field. + // We can reuse the exec.ExitError type when handling errors. + exitErr.Stderr = cmd.StderrBytes() + } + + err = fmt.Errorf("execute on remote: %w", err) + } + return cmd.StdoutBytes(), err +} + +func (u *UploadScript) pathWithEnv(path string) string { + if len(u.envs) == 0 { + return path + } + + arrayToJoin := make([]string, 0, len(u.envs)*2) + + for k, v := range u.envs { + vEscaped := shellescape.Quote(v) + kvStr := fmt.Sprintf("%s=%s", k, vEscaped) + arrayToJoin = append(arrayToJoin, kvStr) + } + + envs := strings.Join(arrayToJoin, " ") + + return fmt.Sprintf("%s %s", envs, path) +} + +var ErrBashibleTimeout = errors.New("Timeout bashible step running") + +func (u *UploadScript) ExecuteBundle(ctx context.Context, parentDir, bundleDir string) (stdout []byte, err error) { + bundleName := fmt.Sprintf("bundle-%s.tar", time.Now().Format("20060102-150405")) + bundleLocalFilepath := filepath.Join(u.settings.TmpDir(), bundleName) + + // tar cpf bundle.tar -C /tmp/dhctl.1231qd23/var/lib bashible + err = tar.CreateTar(bundleLocalFilepath, parentDir, bundleDir) + if err != nil { + return nil, fmt.Errorf("tar bundle: %v", err) + } + + u.settings.RegisterOnShutdown( + "Delete bashible bundle folder", + func() { _ = os.Remove(bundleLocalFilepath) }, + ) + + // upload to node's deckhouse tmp directory + err = NewFile(u.settings, u.Session).Upload(ctx, bundleLocalFilepath, u.settings.TmpDir()) + if err != nil { + return nil, fmt.Errorf("upload: %v", err) + } + + // sudo: + // tar xpof ${app.DeckhouseNodeTmpPath}/bundle.tar -C /var/lib && /var/lib/bashible/bashible.sh args... + tarCmdline := fmt.Sprintf( + "tar xpof %s/%s -C /var/lib && /var/lib/%s/%s %s", + u.settings.TmpDir(), + bundleName, + bundleDir, + u.ScriptPath, + strings.Join(u.Args, " "), + ) + bundleCmd := NewCommand(u.settings, u.Session, tarCmdline) + bundleCmd.Sudo(ctx) + + // Buffers to implement output handler logic + lastStep := "" + failsCounter := 0 + isBashibleTimeout := false + + processLogger := u.settings.Logger().ProcessLogger() + + handler := bundleOutputHandler( + bundleCmd, + u.settings.Logger(), + processLogger, + &lastStep, + &failsCounter, + &isBashibleTimeout, + u.commanderMode, + ) + bundleCmd.WithStdoutHandler(handler) + bundleCmd.CaptureStdout(nil) + bundleCmd.CaptureStderr(nil) + err = bundleCmd.Run(ctx) + if err != nil { + if lastStep != "" { + processLogger.ProcessFail() + } + + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + // exitErr.Stderr is set in the "os/exec".Cmd.Output method from the Golang standard library. + // But we call the "os/exec".Cmd.Wait method, which does not set the Stderr field. + // We can reuse the exec.ExitError type when handling errors. + exitErr.Stderr = bundleCmd.StderrBytes() + } + + err = fmt.Errorf("execute bundle: %w", err) + } else { + processLogger.ProcessEnd() + } + + if isBashibleTimeout { + return bundleCmd.StdoutBytes(), ErrBashibleTimeout + } + + return bundleCmd.StdoutBytes(), err +} + +var stepHeaderRegexp = regexp.MustCompile("^=== Step: /var/lib/bashible/bundle_steps/(.*)$") + +func bundleOutputHandler( + cmd *Command, + logger log.Logger, + processLogger log.ProcessLogger, + lastStep *string, + failsCounter *int, + isBashibleTimeout *bool, + commanderMode bool, +) func(string) { + stepLogs := make([]string, 0) + return func(l string) { + if l == "===" { + return + } + if stepHeaderRegexp.Match([]byte(l)) { + match := stepHeaderRegexp.FindStringSubmatch(l) + stepName := match[1] + + if *lastStep == stepName { + logMessage := strings.Join(stepLogs, "\n") + + switch { + case commanderMode && *failsCounter == 0: + logger.ErrorF("%s", logMessage) + case commanderMode && *failsCounter > 0: + logger.ErrorF("Run step %s finished with error^^^\n", stepName) + logger.DebugF("%s", logMessage) + default: + logger.ErrorF("%s", logMessage) + } + *failsCounter++ + stepLogs = stepLogs[:0] + if *failsCounter > 10 { + *isBashibleTimeout = true + if cmd != nil { + // Force kill bashible + _ = cmd.cmd.Process.Kill() + } + return + } + + processLogger.ProcessFail() + stepName = fmt.Sprintf("%s, retry attempt #%d of 10\n", stepName, *failsCounter) + } else if *lastStep != "" { + stepLogs = make([]string, 0) + processLogger.ProcessEnd() + *failsCounter = 0 + } + + processLogger.ProcessStart("Run step " + stepName) + *lastStep = match[1] + return + } + + stepLogs = append(stepLogs, l) + logger.DebugLn(l) + } +} diff --git a/pkg/ssh/config/config.go b/pkg/ssh/config/config.go new file mode 100644 index 0000000..d0a0646 --- /dev/null +++ b/pkg/ssh/config/config.go @@ -0,0 +1,43 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +type AgentPrivateKey struct { + Key string `json:"key"` + Passphrase string `json:"passphrase,omitempty"` +} + +type Host struct { + Host string `json:"host"` +} + +type Config struct { + User string `json:"user"` + Port *int32 `json:"port,omitempty"` + PrivateKeys []AgentPrivateKey `json:"privateKeys,omitempty"` + ExtraArgs string `json:"extraArgs,omitempty"` + BastionHost string `json:"bastionHost,omitempty"` + BastionPort *int32 `json:"bastionPort,omitempty"` + BastionUser string `json:"bastionUser,omitempty"` + BastionPassword string `json:"bastionPassword,omitempty"` + SudoPassword string `json:"sudoPassword,omitempty"` + LegacyMode bool `json:"legacyMode,omitempty"` + ModernMode bool `json:"modernMode,omitempty"` +} + +type ConnectionConfig struct { + Config *Config + Hosts []Host +} diff --git a/pkg/ssh/config/openapi.go b/pkg/ssh/config/openapi.go new file mode 100644 index 0000000..48bbdc1 --- /dev/null +++ b/pkg/ssh/config/openapi.go @@ -0,0 +1,82 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + _ "embed" + "fmt" + "strings" + + "github.com/deckhouse/lib-dhctl/pkg/yaml/validation" + "github.com/go-openapi/spec" +) + +var ( + //go:embed openapi/ssh_configuration.yaml + configurationOpenAPISpecContent string + + //go:embed openapi/ssh_host_configuration.yaml + hostOpenAPISpecContent string + + specsForValidator = make(map[validation.SchemaIndex]*spec.Schema) +) + +func init() { + var err error + specsForValidator, err = loadSpecs() + if err != nil { + panic(err) + } +} + +func ConfigurationOpenAPISpec() string { + return configurationOpenAPISpecContent +} + +func HostOpenAPISpec() string { + return hostOpenAPISpecContent +} + +func loadSpecs() (map[validation.SchemaIndex]*spec.Schema, error) { + configurationSpecs, err := validation.LoadSchemas(strings.NewReader(configurationOpenAPISpecContent)) + if err != nil { + return nil, fmt.Errorf("error loading ssh connection configuration schema: %v", err) + } + + hostSpec, err := validation.LoadSchemas(strings.NewReader(hostOpenAPISpecContent)) + if err != nil { + return nil, fmt.Errorf("error loading ssh host configuration schema: %v", err) + } + + if len(configurationSpecs) == 0 { + return nil, fmt.Errorf("error loading ssh host configuration schema: no specs found") + } + + if len(hostSpec) == 0 { + return nil, fmt.Errorf("error loading ssh host configuration schema: no specs found") + } + + res := make(map[validation.SchemaIndex]*spec.Schema, len(configurationSpecs)+len(hostSpec)) + + for _, s := range configurationSpecs { + res[s.Index] = s.Schema + } + + for _, s := range hostSpec { + res[s.Index] = s.Schema + } + + return res, nil +} diff --git a/pkg/ssh/config/openapi/doc-ru-ssh_configuration.yaml b/pkg/ssh/config/openapi/doc-ru-ssh_configuration.yaml new file mode 100644 index 0000000..a18f893 --- /dev/null +++ b/pkg/ssh/config/openapi/doc-ru-ssh_configuration.yaml @@ -0,0 +1,51 @@ +# Copyright 2026 Flant JSC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +kind: SSHConfig +apiVersions: +- apiVersion: dhctl.deckhouse.io/v1 + openAPISpec: + description: | + Конфигурация SSH для dhctl. + properties: + apiVersion: + description: Версия Deckhouse API. + sshUser: + description: Имя пользователя SSH. + sshPort: + description: Порт SSH. + sshExtraArgs: + description: Дополнительные параметры соединения SSH. + sshAgentPrivateKeys: + items: + properties: + key: + description: Приватный SSH-ключ. + passphrase: + description: Пароль SSH-ключа. + sshBastionHost: + description: Хост SSH-бастиона. + sshBastionPort: + description: Порт SSH-бастиона. + sshBastionUser: + description: Имя пользователя бастиона. + sudoPassword: + description: | + Пароль sudo пользователя. + legacyMode: + description: | + Использовать легаси режим SSH (clissh). + modernMode: + description: | + Использовать актуальный режим SSH (gossh). diff --git a/pkg/ssh/config/openapi/doc-ru-ssh_host_configuration.yaml b/pkg/ssh/config/openapi/doc-ru-ssh_host_configuration.yaml new file mode 100644 index 0000000..d417044 --- /dev/null +++ b/pkg/ssh/config/openapi/doc-ru-ssh_host_configuration.yaml @@ -0,0 +1,25 @@ +# Copyright 2026 Flant JSC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +kind: SSHHost +apiVersions: +- apiVersion: dhctl.deckhouse.io/v1 + openAPISpec: + description: | + Конфигурация SSH-хоста. + properties: + apiVersion: + description: Версия Deckhouse API. + host: + description: SSH-хост. diff --git a/pkg/ssh/config/openapi/ssh_configuration.yaml b/pkg/ssh/config/openapi/ssh_configuration.yaml new file mode 100644 index 0000000..282f135 --- /dev/null +++ b/pkg/ssh/config/openapi/ssh_configuration.yaml @@ -0,0 +1,91 @@ +# Copyright 2026 Flant JSC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +kind: SSHConfig +apiVersions: +- apiVersion: dhctl.deckhouse.io/v1 + openAPISpec: + type: object + description: | + General dhctl SSH config. + additionalProperties: false + anyOf: + - required: [apiVersion, kind, sshUser, sshAgentPrivateKeys] + - required: [apiVersion, kind, sshUser, sudoPassword] + x-examples: + - apiVersion: dhctl.deckhouse.io/v1 + kind: SSHConfig + sshUser: user + sshPort: 22 + sshExtraArgs: -vvv + sshAgentPrivateKeys: + - key: + properties: + apiVersion: + type: string + description: Version of the Deckhouse API. + enum: [dhctl.deckhouse.io/v1] + kind: + type: string + enum: [SSHConfig] + sshUser: + type: string + description: SSH username. + sshPort: + type: integer + description: SSH port. + sshExtraArgs: + type: string + description: Additional arguments for SSH connection. + sshAgentPrivateKeys: + type: array + minItems: 1 + items: + type: object + additionalProperties: false + required: [key] + x-rules: [sshPrivateKey] + properties: + key: + type: string + description: Private SSH key. + passphrase: + type: string + description: Password for SSH key. + sshBastionHost: + type: string + description: SSH bastion host. + sshBastionPort: + type: integer + description: Port of SSH bastion. + sshBastionUser: + type: string + description: Username for bastion. + sshBastionPassword: + type: string + description: | + A password for the bastion user. + sudoPassword: + description: | + A sudo password for the user. + type: string + legacyMode: + description: | + Switch to legacy SSH mode (clissh). + type: boolean + modernMode: + description: | + Switch to modern SSH mode (gossh). + type: boolean + diff --git a/pkg/ssh/config/openapi/ssh_host_configuration.yaml b/pkg/ssh/config/openapi/ssh_host_configuration.yaml new file mode 100644 index 0000000..04ed3b2 --- /dev/null +++ b/pkg/ssh/config/openapi/ssh_host_configuration.yaml @@ -0,0 +1,38 @@ +# Copyright 2026 Flant JSC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +kind: SSHHost +apiVersions: +- apiVersion: dhctl.deckhouse.io/v1 + openAPISpec: + type: object + description: | + General dhctl SSH host config. + additionalProperties: false + required: [apiVersion,kind,host] + x-examples: + - apiVersion: dhctl.deckhouse.io/v1 + kind: SSHHost + host: 172.16.0.0 + properties: + apiVersion: + type: string + description: Version of the Deckhouse API. + enum: [dhctl.deckhouse.io/v1] + kind: + type: string + enum: [SSHHost] + host: + type: string + description: Host. diff --git a/pkg/ssh/config/parse_config.go b/pkg/ssh/config/parse_config.go new file mode 100644 index 0000000..a428ab3 --- /dev/null +++ b/pkg/ssh/config/parse_config.go @@ -0,0 +1,228 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io" + "strings" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/deckhouse/lib-dhctl/pkg/yaml" + "github.com/deckhouse/lib-dhctl/pkg/yaml/validation" + + "github.com/deckhouse/lib-connection/pkg/settings" +) + +const ( + sshConfigKind = "SSHConfig" + sshHostKind = "SSHHost" +) + +type validateOptions struct { + omitDocInError bool + strictUnmarshal bool + requiredSSHHost bool + noPrettyError bool +} + +type ValidateOption func(o *validateOptions) + +func ParseWithOmitDocInError(v bool) ValidateOption { + return func(o *validateOptions) { + o.omitDocInError = v + } +} + +func ParseWithStrictUnmarshal(v bool) ValidateOption { + return func(o *validateOptions) { + o.strictUnmarshal = v + } +} + +func ParseWithRequiredSSHHost(v bool) ValidateOption { + return func(o *validateOptions) { + o.requiredSSHHost = v + } +} + +func ParseWithNoPrettyError(v bool) ValidateOption { + return func(o *validateOptions) { + o.noPrettyError = v + } +} + +func ParseConnectionConfig(reader io.Reader, sett settings.Settings, opts ...ValidateOption) (*ConnectionConfig, error) { + options := &validateOptions{ + requiredSSHHost: true, + strictUnmarshal: true, + } + for _, o := range opts { + o(options) + } + + configData, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + docs := yaml.SplitYAML(string(configData)) + + errs := newParseErrors() + + logger := sett.Logger() + logger.DebugF("Parsing connection config has %d documents", len(docs)) + + validator := getValidator(sett.LoggerProvider()) + validatorOpts := parseOptionsToValidatorOpts(options) + + config := &ConnectionConfig{} + var connectionConfigDocsCount int + var sshHostConfigDocsCount int + + for i, doc := range docs { + doc = strings.TrimSpace(doc) + if doc == "" { + logger.DebugF("Skip empty document %d", i) + continue + } + + docData := []byte(doc) + + index, err := validation.ParseIndex(strings.NewReader(doc)) + if err != nil { + errs.appendError(err, i, "Extract index from document") + continue + } + + logger.DebugF("Process validate and parse connection config document %d for index %v", i, index) + + err = validator.ValidateWithIndex(index, &docData, validatorOpts...) + if err != nil { + // no message error, err contains all information + errs.appendError(err, i, "") + continue + } + + switch index.Kind { + case sshConfigKind: + connectionConfigDocsCount++ + sshConfig, err := yaml.Unmarshal[Config](docData) + if err != nil { + errs.appendUnmarshalError(err, index, i) + continue + } + config.Config = &sshConfig + logger.DebugF("SSHConfig added in result config") + case sshHostKind: + sshHostConfigDocsCount++ + sshHost, err := yaml.Unmarshal[Host](docData) + if err != nil { + errs.appendUnmarshalError(err, index, i) + continue + } + + config.Hosts = append(config.Hosts, sshHost) + logger.DebugF("SSHHost '%s' added in result config, host in result config %d", sshHost.Host, len(config.Hosts)) + default: + errs.appendError( + validation.ErrKindValidationFailed, + i, + "Unknown kind, expected one of (%q, %q)", sshConfigKind, sshHostKind, + ) + continue + } + } + + if err := errs.ErrorOrNil(); err != nil { + return nil, err + } + + if connectionConfigDocsCount != 1 { + errs.appendError( + validation.ErrKindValidationFailed, + 0, + "exactly one %q required", sshConfigKind, + ) + } + + if options.requiredSSHHost && sshHostConfigDocsCount == 0 { + errs.appendError( + validation.ErrKindValidationFailed, + 0, + "at least one %q required", sshHostKind, + ) + } + + return config, nil +} + +func getValidator(logger log.LoggerProvider) *validation.Validator { + validator := validation.NewValidatorWithLogger(specsForValidator, logger) + return addXRules(validator) +} + +func parseOptionsToValidatorOpts(o *validateOptions) []validation.ValidateOption { + return []validation.ValidateOption{ + validation.ValidateWithOmitDocInError(o.omitDocInError), + validation.ValidateWithStrictUnmarshal(o.strictUnmarshal), + validation.ValidateWithNoPrettyError(o.noPrettyError), + } +} + +type parseErrors struct { + *validation.ValidationError +} + +func newParseErrors() *parseErrors { + return &parseErrors{ + ValidationError: &validation.ValidationError{}, + } +} + +func (e *parseErrors) appendError(err error, index int, msgFormat string, args ...interface{}) { + msg := fmt.Sprintf(msgFormat, args...) + if msg != "" { + msg = fmt.Sprintf("%s: %v", msg, err) + } else { + msg = err.Error() + } + + toAppend := validation.Error{ + Messages: []string{msg}, + } + + if index > 0 { + toAppend.Index = &index + } + + validationError := validation.ExtractValidationError(err) + e.Append(validationError, toAppend) +} + +func (e *parseErrors) appendUnmarshalError(err error, schemaIndex *validation.SchemaIndex, docIndex int) { + kind := schemaIndex.Kind + group, groupVersion := schemaIndex.GroupAndGroupVersion() + + e.Append(validation.ErrKindValidationFailed, validation.Error{ + Index: &docIndex, + Messages: []string{ + fmt.Sprintf("Cannot unmarshal to %s document %d: %v", kind, docIndex, err), + }, + Kind: kind, + Version: groupVersion, + Group: group, + }) +} diff --git a/pkg/ssh/config/parse_config_test.go b/pkg/ssh/config/parse_config_test.go new file mode 100644 index 0000000..a2d7ce2 --- /dev/null +++ b/pkg/ssh/config/parse_config_test.go @@ -0,0 +1,85 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "strings" + "testing" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + input string + hasErrorContains string + expected *ConnectionConfig + opts []ValidateOption + }{ + { + name: "empty input", + input: "", + hasErrorContains: `exactly one "SSHConfig" required`, + }, + { + name: "multiple empty documents", + input: ` +--- + + +`, + hasErrorContains: `exactly one "SSHConfig" required`, + }, + + { + name: "only connection: incorrect input without auth", + input: ` +apiVersion: dhctl.deckhouse.io/v1 +kind: SSHConfig +sshPort: 22 +sshUser: ubuntu +`, + hasErrorContains: "DocumentValidationFailed: Document validation failed:\n---\napiVersion: dhctl.deckhouse.io/v1\nkind: SSHConfig", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const isDebug = true + sett := settings.NewBaseProviders(settings.ProviderParams{ + LoggerProvider: log.SimpleLoggerProvider( + log.NewInMemoryLoggerWithParent( + log.NewSimpleLogger(log.LoggerOptions{IsDebug: isDebug}), + ), + ), + IsDebug: isDebug, + }) + + cfg, err := ParseConnectionConfig(strings.NewReader(test.input), sett, test.opts...) + + if test.hasErrorContains != "" { + require.Error(t, err, "expected error but got none") + require.Contains(t, err.Error(), test.hasErrorContains, "error should contain") + return + } + + require.NoError(t, err, "expected no error but got one") + require.Equal(t, test.expected, cfg) + }) + } +} diff --git a/pkg/ssh/config/validators.go b/pkg/ssh/config/validators.go new file mode 100644 index 0000000..d3d9e86 --- /dev/null +++ b/pkg/ssh/config/validators.go @@ -0,0 +1,67 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "fmt" + + "github.com/deckhouse/lib-dhctl/pkg/yaml/validation" + ssh "github.com/deckhouse/lib-gossh" + "sigs.k8s.io/yaml" +) + +var ( + ErrValidationRuleFailed = fmt.Errorf("validation rule failed") +) + +func addXRules(validator *validation.Validator) *validation.Validator { + extsValidator := validation.NewXRulesExtensionsValidator(map[string]validation.ExtensionsValidatorHandler{ + "sshPrivateKey": validateSSHPrivateKey, + }) + + validator.AddExtensionsValidators(extsValidator) + + return validator +} + +func validateSSHPrivateKey(value json.RawMessage) error { + var key AgentPrivateKey + + err := yaml.Unmarshal(value, &key) + if err != nil { + return err + } + + privateKeyBytes := []byte(key.Key) + + if key.Passphrase == "" { + if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil { + return validateSSHPrivateKeyErr(err) + } + + return nil + } + + if _, err = ssh.ParseRawPrivateKeyWithPassphrase(privateKeyBytes, []byte(key.Passphrase)); err != nil { + return validateSSHPrivateKeyErr(err) + } + + return nil +} + +func validateSSHPrivateKeyErr(err error) error { + return fmt.Errorf("%w: invalid ssh key: %w", ErrValidationRuleFailed, err) +} diff --git a/pkg/ssh/gossh/client.go b/pkg/ssh/gossh/client.go new file mode 100644 index 0000000..a2674e3 --- /dev/null +++ b/pkg/ssh/gossh/client.go @@ -0,0 +1,820 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "net" + "slices" + "sync" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/deckhouse/lib-dhctl/pkg/retry" + gossh "github.com/deckhouse/lib-gossh" + "github.com/deckhouse/lib-gossh/agent" + "github.com/name212/govalue" + + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +func NewClient(ctx context.Context, sett settings.Settings, session *session.Session, privKeys []session.AgentPrivateKey) *Client { + return &Client{ + sessionClient: session, + privateKeys: privKeys, + live: false, + sshSessionsList: make([]*gossh.Session, 0, 10), + ctx: ctx, + silent: false, + settings: sett, + } +} + +type ClientLoopsParams struct { + ConnectToBastion retry.Params + ConnectToHostViaBastion retry.Params + ConnectToHostDirectly retry.Params + NewSession retry.Params + CheckReverseTunnel retry.Params +} + +var defaultClientDirectlyLoopParamsOps = []retry.ParamsBuilderOpt{ + retry.WithWait(2 * time.Second), + retry.WithAttempts(50), +} + +var defaultClientViaBastionLoopParamsOps = []retry.ParamsBuilderOpt{ + retry.WithWait(5 * time.Second), + retry.WithAttempts(30), +} + +var defaultSessionLoopParamsOps = []retry.ParamsBuilderOpt{ + retry.WithWait(5 * time.Second), + retry.WithAttempts(10), +} + +var defaultReverseTunnelParamsOps = []retry.ParamsBuilderOpt{ + retry.WithWait(2 * time.Second), + retry.WithAttempts(2), +} + +type Client struct { + ctx context.Context + + settings settings.Settings + loopsParams ClientLoopsParams + + bastionClient *gossh.Client + sshClient *gossh.Client + + sessionClient *session.Session + + sshConn gossh.Conn + sshNetConn net.Conn + stopChan chan struct{} + + live bool + kubeProxies []*KubeProxy + + sshSessionsMu sync.Mutex + sshSessionsList []*gossh.Session + + privateKeys []session.AgentPrivateKey + signers []gossh.Signer + + agentClient agent.ExtendedAgent + agentConnection net.Conn + + silent bool +} + +func (s *Client) WithLoopsParams(p ClientLoopsParams) *Client { + s.loopsParams = p + return s +} + +func (s *Client) OnlyPreparePrivateKeys() error { + return s.initSigners() +} + +// Tunnel is used to open local (L) and remote (R) tunnels +func (s *Client) Tunnel(address string) connection.Tunnel { + return NewTunnel(s, address) +} + +// ReverseTunnel is used to open remote (R) tunnel +func (s *Client) ReverseTunnel(address string) connection.ReverseTunnel { + return NewReverseTunnel(s, address) +} + +// Command is used to run commands on remote server +func (s *Client) Command(name string, arg ...string) connection.Command { + return NewSSHCommand(s, name, arg...) +} + +// KubeProxy is used to start kubectl proxy and create a tunnel from local port to proxy port +func (s *Client) KubeProxy() connection.KubeProxy { + p := NewKubeProxy(s, s.sessionClient) + s.kubeProxies = append(s.kubeProxies, p) + return p +} + +// File is used to upload and download files and directories +func (s *Client) File() connection.File { + return NewSSHFile(s.settings, s.sshClient) +} + +// UploadScript is used to upload script and execute it on remote server +func (s *Client) UploadScript(scriptPath string, args ...string) connection.Script { + return NewSSHUploadScript(s, scriptPath, args...) +} + +// Check is used to upload script and execute it on remote server +func (s *Client) Check() connection.Check { + f := func(sess *session.Session, cmd string) connection.Command { + return NewSSHCommand(s, cmd) + } + return utils.NewCheck(f, s.sessionClient, s.settings) +} + +// Stop the client +func (s *Client) Stop() { + s.stopAllAndLogErrors("call Stop()") + s.debug("SSH client is stopped") +} + +func (s *Client) Session() *session.Session { + return s.sessionClient +} + +func (s *Client) PrivateKeys() []session.AgentPrivateKey { + return s.privateKeys +} + +func (s *Client) RefreshPrivateKeys() error { + // new go ssh client already have all keys + return nil +} + +// Loop Looping all available hosts +func (s *Client) Loop(fn connection.SSHLoopHandler) error { + var err error + + resetSession := func() { + s.sessionClient = s.sessionClient.Copy() + s.sessionClient.ChoiceNewHost() + } + defer resetSession() + resetSession() + + for range s.sessionClient.AvailableHosts() { + err = fn(s) + if err != nil { + return err + } + s.sessionClient.ChoiceNewHost() + } + + return nil +} + +func (s *Client) NewSSHSession() (*gossh.Session, error) { + var sess *gossh.Session + + newSessionLoopParams := retry.SafeCloneOrNewParams(s.loopsParams.NewSession, defaultSessionLoopParamsOps...). + WithName("Establish new session"). + WithLogger(s.settings.Logger()) + + err := retry.NewSilentLoopWithParams(newSessionLoopParams).RunContext(s.ctx, func() error { + var err error + sess, err = s.sshClient.NewSession() + return err + }) + + if err != nil { + return nil, err + } + + s.registerSession(sess) + return sess, nil +} + +func (s *Client) GetClient() *gossh.Client { + return s.sshClient +} + +func (s *Client) Live() bool { + return s.live +} + +func (s *Client) Start() error { + return s.startWithContext(s.ctx) +} + +func (s *Client) UnregisterSession(sess *gossh.Session) { + s.sshSessionsMu.Lock() + defer s.sshSessionsMu.Unlock() + num := len(s.sshSessionsList) + for i, registeredSession := range s.sshSessionsList { + if registeredSession == sess { + num = i + break + } + } + if num < len(s.sshSessionsList) { + s.sshSessionsList = slices.Delete(s.sshSessionsList, num, num+1) + } +} + +func (s *Client) stopAfterStartFailed(cause string, err error) error { + s.stopAllAndLogErrors(cause) + return err +} + +func (s *Client) startWithContext(ctx context.Context) error { + if s.sessionClient == nil { + return fmt.Errorf("Possible bug in ssh client: client session should be passed start") + } + + if govalue.Nil(ctx) { + return fmt.Errorf("nil context passed to client") + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + s.debug("Starting go ssh client....") + + if err := s.initSigners(); err != nil { + return err + } + + if err := s.connectToAgent(ctx); err != nil { + return s.stopAfterStartFailed("unable to connect to agent", err) + } + + bastionClient, err := s.connectToBastion(ctx) + if err != nil { + return s.stopAfterStartFailed("unable to connect to bastion", err) + } + + if err := s.connectToTarget(ctx, bastionClient); err != nil { + return s.stopAfterStartFailed("unable to connect to target", err) + } + + return nil +} + +func (s *Client) connectToAgent(ctx context.Context) error { + socket := s.settings.AuthSock() + if socket == "" { + s.debug("No auth socket passed. Skip connecting to agent") + return nil + } + + s.debug("Dialing SSH agent unix socket %s ...", socket) + + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + // Use net.Dialer's DialContext method directly + dialer := net.Dialer{} + conn, err := dialer.DialContext(cctx, "unix", socket) + if err != nil { + return fmt.Errorf("Failed to open agent socket %s: %v", socket, err) + } + + s.agentConnection = conn + s.agentClient = agent.NewClient(conn) + + return nil +} + +func (s *Client) createClientConfig(user string, password string, connectName string) (*gossh.ClientConfig, error) { + authMethods, err := s.authMethods(password) + if err != nil { + return nil, err + } + + config := &gossh.ClientConfig{ + User: user, + Auth: authMethods, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + config.BannerCallback = func(message string) error { + s.debug("Got banner message for %s: %s", connectName, message) + return nil + } + + return config, nil +} + +func (s *Client) createTargetClientConfig(connectName string) (*gossh.ClientConfig, error) { + return s.createClientConfig(s.sessionClient.User, s.sessionClient.BecomePass, connectName) +} + +func (s *Client) connectToTargetViaBastion(ctx context.Context, bastionClient *gossh.Client) (*gossh.Client, error) { + if bastionClient == nil { + return nil, fmt.Errorf("Bastion client is nil for connect via bastion") + } + + var ( + addr string + targetConn net.Conn + targetClientConn gossh.Conn + ) + + s.debug("Try to connect to through bastion host master host...") + + config, err := s.createTargetClientConfig("connect via bastion") + if err != nil { + return nil, err + } + + var sshConn *sshConnection + + connectToTarget := func() error { + if len(s.kubeProxies) == 0 { + s.sessionClient.ChoiceNewHost() + } + addr = fmt.Sprintf("%s:%s", s.sessionClient.Host(), s.sessionClient.Port) + s.debug("Connect to target host '%s' with user '%s' through bastion host", addr, s.sessionClient.User) + + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var err error + targetConn, err = bastionClient.DialContext(cctx, "tcp", addr) + if err != nil { + return fmt.Errorf("Cannot Dial to %s over bastion: %w", addr, err) + } + + sshConn, err = s.createSSHConnection(targetConn, addr, config) + if err != nil { + return fmt.Errorf("Cannot create SSH connection to %s: %w", addr, err) + } + + return nil + } + + viaBastionLoopParams := retry.SafeCloneOrNewParams(s.loopsParams.ConnectToHostViaBastion, defaultClientViaBastionLoopParamsOps...). + WithName("Get SSH client and connect to target host via bastion") + + if err := s.runInLoop(ctx, viaBastionLoopParams, connectToTarget); err != nil { + lastHost := fmt.Sprintf("'%s:%s' with user '%s'", s.sessionClient.Host(), s.sessionClient.Port, s.sessionClient.User) + return nil, fmt.Errorf("Failed to connect to target host through bastion host (last %s): %w", lastHost, err) + } + + s.sshNetConn = targetConn + s.sshConn = targetClientConn + + return sshConn.createGoClient(), nil +} + +func (s *Client) connectToTarget(ctx context.Context, bastionClient *gossh.Client) error { + var ( + client *gossh.Client + err error + ) + + if bastionClient == nil { + client, err = s.directConnectToTarget(ctx) + } else { + client, err = s.connectToTargetViaBastion(ctx, bastionClient) + } + + if err != nil { + return err + } + + s.sshClient = client + s.bastionClient = bastionClient + s.live = true + + if s.stopChan == nil { + stopCh := make(chan struct{}) + s.stopChan = stopCh + } + + go s.keepAlive() + + return nil +} + +func (s *Client) directConnectToTarget(ctx context.Context) (*gossh.Client, error) { + s.debug("Try to direct connect host master host...") + + config, err := s.createTargetClientConfig("direct connect") + if err != nil { + return nil, err + } + + var client *gossh.Client + + connectToHost := func() error { + if len(s.kubeProxies) == 0 { + s.sessionClient.ChoiceNewHost() + } + + addr := fmt.Sprintf("%s:%s", s.sessionClient.Host(), s.sessionClient.Port) + s.debug("Connect to master host '%s' with user '%s'\n", addr, s.sessionClient.User) + + var err error + client, err = s.dialContext(ctx, "tcp", addr, config) + + return err + } + + hostLoopParams := retry.SafeCloneOrNewParams(s.loopsParams.ConnectToHostDirectly, defaultClientDirectlyLoopParamsOps...). + WithName("Get SSH client") + + if err := s.runInLoop(ctx, hostLoopParams, connectToHost); err != nil { + lastHost := fmt.Sprintf("'%s:%s' with user '%s'", s.sessionClient.Host(), s.sessionClient.Port, s.sessionClient.User) + return nil, fmt.Errorf("Failed to connect to target directly (last %s): %w", lastHost, err) + } + + return client, nil +} + +func (s *Client) connectToBastion(ctx context.Context) (*gossh.Client, error) { + if s.sessionClient.BastionHost == "" { + s.debug("Bastion host is empty. Skip connection to bastion") + return nil, nil + } + + s.debug("Initialize bastion connection...") + + bastionUser := s.sessionClient.BastionUser + + bastionConfig, err := s.createClientConfig(bastionUser, s.sessionClient.BastionPassword, "bastion") + if err != nil { + return nil, err + } + + var bastionClient *gossh.Client + + bastionAddr := fmt.Sprintf("%s:%s", s.sessionClient.BastionHost, s.sessionClient.BastionPort) + fullHost := fmt.Sprintf("bastion host '%s' with user '%s'", bastionAddr, bastionUser) + + connectToBastion := func() error { + s.debug("Connect to %s", fullHost) + + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var err error + bastionClient, err = s.dialContext(cctx, "tcp", bastionAddr, bastionConfig) + + return err + } + + bastionLoopParams := retry.SafeCloneOrNewParams(s.loopsParams.ConnectToBastion, defaultClientViaBastionLoopParamsOps...). + WithName("Get bastion SSH client") + + if err := s.runInLoop(ctx, bastionLoopParams, connectToBastion); err != nil { + return nil, fmt.Errorf("Could not connect to %s: %w", fullHost, err) + } + + s.debug("Connected successfully to bastion host %s", bastionAddr) + + return bastionClient, nil +} + +func (s *Client) authMethods(password string) ([]gossh.AuthMethod, error) { + var authMethods []gossh.AuthMethod + if len(s.signers) > 0 { + s.debug("Adding private key method") + authMethods = append(authMethods, gossh.PublicKeys(s.signers...)) + } + + if !govalue.Nil(s.agentClient) { + s.debug("Adding agent socket to auth method") + authMethods = append(authMethods, gossh.PublicKeysCallback(s.agentClient.Signers)) + } + + if password != "" { + s.debug("Initial password auth to master host") + authMethods = append(authMethods, gossh.Password(password)) + } + + if len(authMethods) == 0 { + return nil, fmt.Errorf("Private keys or SSH_AUTH_SOCK environment variable or become password should passed") + } + + return authMethods, nil +} + +func (s *Client) runInLoop(ctx context.Context, params retry.Params, task func() error) error { + createLoop := retry.NewLoopWithParams + if s.silent { + createLoop = retry.NewSilentLoopWithParams + } + + paramsWithLogger := params.Clone().WithLogger(s.settings.Logger()) + + return createLoop(paramsWithLogger).RunContext(ctx, task) +} + +func (s *Client) keepAlive() { + defer s.debug("Keepalive goroutine stopped") + + checker := newKepAliveChecker(s, time.Second*5, 3) + + for { + select { + case <-s.stopChan: + s.debug("Receive stop keepalive goroutine") + close(s.stopChan) + s.stopChan = nil + return + default: + if err := checker.Check(); err != nil { + // if check returns error we should restart client exit from goroutine + // all sleeps doing in to Check + s.restart() + return + } + } + } +} + +func (s *Client) restart() { + s.live = false + s.stopChan = nil + s.silent = true + if err := s.Start(); err != nil { + s.debug("Start failed during restart: %v", err) + } + s.sshSessionsList = nil +} + +type sshConnection struct { + conn gossh.Conn + ch <-chan gossh.NewChannel + requestCh <-chan *gossh.Request +} + +func (s *sshConnection) createGoClient() *gossh.Client { + return gossh.NewClient(s.conn, s.ch, s.requestCh) +} + +func (s *Client) createSSHConnection(c net.Conn, addr string, config *gossh.ClientConfig) (*sshConnection, error) { + var ( + err error + conn gossh.Conn + ch <-chan gossh.NewChannel + requestCh <-chan *gossh.Request + ) + + if s.settings.IsDebug() { + sshLogger := log.NewSLogWithPrefixAndDebug( + context.TODO(), + s.settings.LoggerProvider(), + "go-ssh", + true, + ) + conn, ch, requestCh, err = gossh.NewClientConnWithDebug(c, addr, config, sshLogger) + } else { + conn, ch, requestCh, err = gossh.NewClientConn(c, addr, config) + } + + if err != nil { + return nil, err + } + + return &sshConnection{ + conn: conn, + ch: ch, + requestCh: requestCh, + }, nil +} + +func (s *Client) dialContext(ctx context.Context, network, addr string, config *gossh.ClientConfig) (*gossh.Client, error) { + closeConnectionAndReturnErr := func(msg string, err error, conn net.Conn) (*gossh.Client, error) { + err = fmt.Errorf("Cannot Dial to '%s' %s: %w", addr, msg, err) + + if closeErr := utils.SafeClose(conn); closeErr != nil { + err = fmt.Errorf("%w and cannot close connection %w", err, closeErr) + } + return nil, err + } + + d := net.Dialer{Timeout: config.Timeout} + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return closeConnectionAndReturnErr("connect", err, conn) + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return closeConnectionAndReturnErr("is not tcp", err, conn) + } + + err = tcpConn.SetKeepAlive(true) + if err != nil { + return closeConnectionAndReturnErr("cannot set keepalive", err, tcpConn) + } + + timeFactor := time.Duration(3) + deadline := time.Now().Add(config.Timeout * timeFactor) + err = tcpConn.SetDeadline(deadline) + if err != nil { + return closeConnectionAndReturnErr( + fmt.Sprintf("cannot set deadline %s", deadline.String()), + err, + tcpConn, + ) + } + + sshConn, err := s.createSSHConnection(tcpConn, addr, config) + if err != nil { + return closeConnectionAndReturnErr("cannot create ssh connection", err, tcpConn) + } + + err = tcpConn.SetDeadline(time.Time{}) + if err != nil { + return closeConnectionAndReturnErr("cannot reset deadline", err, tcpConn) + } + + return sshConn.createGoClient(), nil +} + +func (s *Client) initSigners() error { + if len(s.signers) > 0 { + s.settings.Logger().DebugLn("Signers already initialized") + return nil + } + + signers := make([]gossh.Signer, 0, len(s.privateKeys)) + for _, keypath := range s.privateKeys { + key, err := utils.GetSSHPrivateKey(keypath.Key, keypath.Passphrase) + if err != nil { + return err + } + signer, err := gossh.NewSignerFromKey(key) + if err != nil { + return fmt.Errorf("Unable to parse private key: %v", err) + } + signers = append(signers, signer) + } + + s.signers = signers + return nil +} + +func (s *Client) stopAllAndLogErrors(cause string) { + errors := s.stopAll(cause) + + if len(errors) > 0 { + s.debug("Have %d errors after stop:", len(errors)) + } + for _, err := range errors { + s.debug(err.Error()) + } +} + +func (s *Client) stopAll(cause string) []error { + s.debug("Stop client after %s...", cause) + + errors := make([]error, 0) + addError := func(e error, format string, v ...any) { + prefix := fmt.Sprintf(format, v...) + errors = append(errors, fmt.Errorf("%s: %w", prefix, e)) + } + + closeBastionAndAgent := func() { + if err := utils.SafeClose(s.bastionClient, s.logPresentHandler("Bastion client")...); err != nil { + addError(err, "Failed to close agent connection") + } + + if err := utils.SafeClose(s.agentConnection, s.logPresentHandler("Agent")...); err != nil { + addError(err, "Failed to close agent connection") + } + } + + if govalue.Nil(s.sshClient) { + // we can stop in Start + // client is nil but agent and bastion prepared + // try to close. it is safe + closeBastionAndAgent() + s.debug("No SSH client found to stop. Exiting...") + return errors + } + + s.debug("SSH client and its routines stopping...") + + s.debug("Stopping kube proxies...") + for _, p := range s.kubeProxies { + p.StopAll() + } + s.kubeProxies = nil + + s.debug("Closing sessions...") + for indx, sess := range s.sshSessionsList { + if govalue.Nil(sess) { + continue + } + + if err := sess.Signal(gossh.SIGKILL); err != nil { + addError(err, "Failed to kill session %d", indx) + } + + if err := sess.Close(); err != nil { + addError(err, "Failed to close session %d: %v", indx) + } + } + s.sshSessionsList = nil + + // by starting kubeproxy on remote, there is one more process starts + // it cannot be killed by sending any signal to his parrent process + // so we need to use killall command to kill all this processes + s.debug("Stopping kube proxies on remote...") + if err := s.stopRemoteKubeProxies(); err != nil { + addError(err, "Failed to stop kube proxy") + } + + s.debug("Stopping keep-alive goroutine...") + if s.stopChan != nil { + s.stopChan <- struct{}{} + } + + err := s.sshClient.Close() + if err != nil { + addError(err, "Failed to close ssh client") + } + + if err := utils.SafeClose(s.sshConn); err != nil { + addError(err, "Failed to close ssh connection") + } + + if err := utils.SafeClose(s.sshNetConn); err != nil { + addError(err, "Failed to close net ssh connection") + } + + closeBastionAndAgent() + + return errors +} + +func (s *Client) registerSession(sess *gossh.Session) { + s.sshSessionsMu.Lock() + defer s.sshSessionsMu.Unlock() + s.sshSessionsList = append(s.sshSessionsList, sess) +} + +func (s *Client) stopRemoteKubeProxies() error { + ctx := s.ctx + if govalue.Nil(ctx) { + ctx = context.Background() + } + + cmd := NewSSHCommand(s, "killall kubectl") + + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + cmd.Sudo(cctx) + if err := cmd.Run(cctx); err != nil { + return err + } + + s.debug("Kube proxies on remote were stopped") + return nil +} + +func (s *Client) debug(format string, v ...any) { + s.settings.Logger().DebugF(format, v...) +} + +func (s *Client) logPresentHandler(connectionName string) []utils.PresentCloseHandler { + return []utils.PresentCloseHandler{ + func(isPresent bool) { + if !isPresent { + return + } + + s.debug("%s connection is present. Try to close...", connectionName) + }, + } +} diff --git a/pkg/ssh/gossh/client_test.go b/pkg/ssh/gossh/client_test.go new file mode 100644 index 0000000..1052c3e --- /dev/null +++ b/pkg/ssh/gossh/client_test.go @@ -0,0 +1,408 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +func TestOnlyPreparePrivateKeys(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestOnlyPreparePrivateKeys") + + // genetaring ssh keys + keyWithoutPath, _, err := sshtesting.GenerateKeys(test, "") + require.NoError(t, err, "failed to generate keys without password") + + wrongKeyPath := test.MustCreateTmpFile(t, "Hello world", false, sshtesting.PrivateKeysRoot, "wrong-key") + + validPassword := sshtesting.RandPassword(12) + keyWithPass, _, err := sshtesting.GenerateKeys(test, validPassword) + require.NoError(t, err, "failed to generate keys with password") + + t.Run("OnlyPrepareKeys cases", func(t *testing.T) { + type testCase struct { + title string + keys []session.AgentPrivateKey + wantErr bool + err string + } + + cases := []testCase{ + { + title: "No keys", + keys: make([]session.AgentPrivateKey, 0, 1), + wantErr: false, + }, + { + title: "Key auth, no password", + keys: []session.AgentPrivateKey{{Key: keyWithoutPath}}, + wantErr: false, + }, + { + title: "Key auth, no password, noexistent key", + keys: []session.AgentPrivateKey{{Key: "/tmp/noexistent-key"}}, + wantErr: true, + err: "open /tmp/noexistent-key: no such file or directory", + }, + { + title: "Key auth, no password, wrong key", + keys: []session.AgentPrivateKey{{Key: wrongKeyPath}}, + wantErr: true, + err: "ssh: no key found", + }, + { + title: "Key auth, with passphrase", + keys: []session.AgentPrivateKey{{Key: keyWithPass, Passphrase: validPassword}}, + wantErr: false, + }, + { + title: "Key auth, with wrong passphrase", + keys: []session.AgentPrivateKey{{Key: keyWithPass, Passphrase: sshtesting.RandPassword(6)}}, + wantErr: true, + err: "x509: decryption password incorrect", + }, + } + + assertError := func(t *testing.T, tst testCase, err error) { + if !tst.wantErr { + require.NoError(t, err) + return + } + require.Error(t, err) + require.Contains(t, err.Error(), tst.err) + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(context.Background(), sshSettings, nil, c.keys) + err := sshClient.OnlyPreparePrivateKeys() + assertError(t, c, err) + + // double run + err = sshClient.OnlyPreparePrivateKeys() + assertError(t, c, err) + }) + } + + }) +} + +func TestClientStart(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestClientStart") + + const bastionUserName = "bastionuser" + + container := sshtesting.NewTestContainerWrapper(t, test) + bastion := sshtesting.NewTestContainerWrapper( + t, + test, + sshtesting.WithConnectToContainerNetwork(container), + sshtesting.WithUserName(bastionUserName), + sshtesting.WithAuthSettings(container), + sshtesting.WithContainerName("bastion"), + ) + + agent := sshtesting.StartTestAgent(t, container) + + type testCase struct { + title string + settings *session.Session + keys []session.AgentPrivateKey + wantErr bool + err string + authSock string + } + + keys := container.AgentPrivateKeys() + noKeys := make([]session.AgentPrivateKey, 0, 1) + incorrectHost := sshtesting.IncorrectHost() + + cases := []testCase{ + { + title: "Password auth, no keys", + settings: sshtesting.Session(container), + keys: make([]session.AgentPrivateKey, 0, 1), + wantErr: false, + }, + { + title: "Key auth, no password", + settings: sshtesting.Session(container), + keys: keys, + wantErr: false, + }, + { + title: "SSH_AUTH_SOCK auth", + settings: sshtesting.Session(container), + keys: noKeys, + wantErr: false, + authSock: agent.SockPath(), + }, + { + title: "SSH_AUTH_SOCK auth, wrong socket", + settings: sshtesting.Session(container), + keys: noKeys, + wantErr: true, + err: "Failed to open agent socket", + authSock: "/run/nonexistent", + }, + { + title: "Key auth, no password, wrong key", + settings: sshtesting.Session(container), + keys: []session.AgentPrivateKey{{Key: "/tmp/noexistent-key"}}, + wantErr: true, + }, + { + title: "No session", + settings: nil, + keys: []session.AgentPrivateKey{{Key: "/tmp/noexistent-key"}}, + wantErr: true, + err: "Possible bug in ssh client: client session should be passed start", + }, + { + title: "No auth", + settings: sshtesting.Session(container, func(input *session.Input) { + input.BecomePass = "" + }), + keys: noKeys, + wantErr: true, + err: "Private keys or SSH_AUTH_SOCK environment variable or become password should passed", + authSock: "", + }, + { + title: "Wrong port", + settings: sshtesting.Session( + container, + sshtesting.OverrideSessionWithIncorrectPort(container, bastion), + ), + keys: keys, + wantErr: true, + err: "Failed to connect to target directly", + authSock: "", + }, + { + title: "With bastion, key auth", + settings: sshtesting.SessionWithBastion(container, bastion), + keys: keys, + wantErr: false, + authSock: "", + }, + { + title: "With bastion, password auth", + settings: sshtesting.SessionWithBastion(container, bastion), + keys: noKeys, + wantErr: false, + authSock: "", + }, + { + title: "With bastion, no auth", + settings: sshtesting.SessionWithBastion(container, bastion, func(input *session.Input) { + input.BastionPassword = "" + }), + keys: noKeys, + wantErr: true, + err: "Private keys or SSH_AUTH_SOCK environment variable or become password should passed", + authSock: "", + }, + { + title: "With bastion, SSH_AUTH_SOCK auth", + settings: sshtesting.SessionWithBastion(container, bastion), + keys: keys, + wantErr: false, + authSock: agent.SockPath(), + }, + { + title: "With bastion, key auth, wrong target host", + settings: sshtesting.SessionWithBastion(container, bastion, func(input *session.Input) { + input.AvailableHosts = []session.Host{ + {Host: incorrectHost, Name: incorrectHost}, + } + }), + keys: keys, + wantErr: true, + err: "Failed to connect to target host through bastion host", + authSock: "", + }, + { + title: "With bastion, key auth, wrong bastion port", + settings: sshtesting.SessionWithBastion( + container, + bastion, + sshtesting.OverrideSessionWithIncorrectBastionPort(container, bastion), + ), + keys: keys, + wantErr: true, + err: "Could not connect to bastion host", + authSock: "", + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + test.SetSubTest(c.title) + + sshSettings := sshtesting.CreateDefaultTestSettingsWithAgent(test, c.authSock) + + sshClient := NewClient(context.Background(), sshSettings, c.settings, c.keys). + WithLoopsParams(ClientLoopsParams{ + ConnectToHostViaBastion: sshtesting.GetTestLoopParamsForFailed(), + ConnectToBastion: sshtesting.GetTestLoopParamsForFailed(), + ConnectToHostDirectly: sshtesting.GetTestLoopParamsForFailed(), + }) + + err := sshClient.Start() + sshClient.Stop() + + if !c.wantErr { + require.NoError(t, err) + test.Logger.InfoLn("client started successfully") + return + } + + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + }) + } +} + +func TestClientKeepalive(t *testing.T) { + testName := "TestClientKeepalive" + sshtesting.CheckSkipSSHTest(t, testName) + + waitKeepAlive := func() { + time.Sleep(3 * time.Second) + } + + t.Run("keepalive test", func(t *testing.T) { + test := sshtesting.ShouldNewTest(t, testName).SetSubTest(t.Name()) + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(context.Background(), sshSettings, sess, keys). + WithLoopsParams(ClientLoopsParams{ + NewSession: sshtesting.GetTestLoopParamsForFailed(), + }) + + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err, "failed to start ssh client") + + registerStopClient(t, sshClient) + + runEcho := func(t *testing.T, msg string) { + s, err := sshClient.NewSSHSession() + require.NoError(t, err) + + cmd := fmt.Sprintf(`echo -n "%s"`, msg) + + defer func() { + err := s.Close() + if err != nil { + test.Logger.InfoF("failed to close runEcho session: %v", err) + } + }() + + out, err := s.CombinedOutput(cmd) + require.NoError(t, err, "failed to run command '%s'", cmd) + require.Contains(t, string(out), msg, "run command '%s' should contains output '%s'. Out: %s", cmd, msg, out) + } + + runEcho(t, "Hello before restart") + + err = container.Container.Restart(true, 2*time.Second) + require.NoError(t, err, "failed to restart container") + waitKeepAlive() + + runEcho(t, "Hello after restart") + }) + + t.Run("keepalive with context test", func(t *testing.T) { + test := sshtesting.ShouldNewTest(t, testName).SetSubTest(t.Name()) + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancel() + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, keys) + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + // expecting client is not live + sshClient.Stop() + waitKeepAlive() + err = sshClient.Start() + + require.Error(t, err) + require.Contains(t, err.Error(), "deadline exceeded") + }) + + t.Run("client start with context test", func(t *testing.T) { + test := sshtesting.ShouldNewTest(t, testName).SetSubTest(t.Name()) + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container, sshtesting.OverrideSessionWithIncorrectPort(container)) + keys := container.AgentPrivateKeys() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(20*time.Second)) + defer cancel() + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, keys) + err := sshClient.Start() + // expecting error on client start: host is unreachable, but loop should exit on context deadline exceeded + require.Error(t, err) + require.Contains(t, err.Error(), "Loop was canceled: context deadline exceeded") + // expecting client is not live + sshClient.Stop() + err = sshClient.Start() + require.Error(t, err) + require.Contains(t, err.Error(), "deadline exceeded") + }) +} + +func TestDialContextVerySmall(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestDialContextVerySmall") + + sess := sshtesting.FakeSession() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) + defer cancel() + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, make([]session.AgentPrivateKey, 0, 1)) + err := sshClient.Start() + // expecting error on client start: host is unreachable, but loop should exit on context deadline exceeded + require.Error(t, err) + require.Contains(t, err.Error(), "Loop was canceled: context deadline exceeded") + // expecting client is not live + sshClient.Stop() + err = sshClient.Start() + require.Error(t, err) + require.Contains(t, err.Error(), "deadline exceeded") +} diff --git a/pkg/ssh/gossh/command.go b/pkg/ssh/gossh/command.go new file mode 100644 index 0000000..699a48a --- /dev/null +++ b/pkg/ssh/gossh/command.go @@ -0,0 +1,830 @@ +package gossh + +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" + "time" + + gossh "github.com/deckhouse/lib-gossh" + "github.com/name212/govalue" + + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +type SSHCommand struct { + sshClient *Client + session *gossh.Session + + Name string + Args []string + Env []string + + SSHArgs []string + + stdoutPipeFile io.Reader + stderrPipeFile io.Reader + StdoutSplitter bufio.SplitFunc + + StdinPipe bool + Stdin io.WriteCloser + + Matchers []*utils.ByteSequenceMatcher + MatchHandler func(pattern string) string + + onCommandStart func() + stderrHandler func(string) + stdoutHandler func(string) + + WaitHandler func(err error) + + out *bytes.Buffer + err *bytes.Buffer + combined *singleWriter + + OutBytes bytes.Buffer + ErrBytes bytes.Buffer + + stop bool + waitCh chan struct{} + stopCh chan struct{} + + lockWaitError sync.RWMutex + waitError error + killError error + + cmd string + timeout time.Duration + + ctx context.Context + Cancel func() error + ctxResult <-chan error + wg sync.WaitGroup +} + +func NewSSHCommand(client *Client, name string, arg ...string) *SSHCommand { + args := make([]string, len(arg)) + copy(args, arg) + cmd := name + " " + for i := range args { + if !strings.HasPrefix(args[i], `"`) && + !strings.HasSuffix(args[i], `"`) && + strings.Contains(args[i], " ") { + args[i] = strconv.Quote(args[i]) + } + } + + // todo move new session to Start() + session, _ := client.NewSSHSession() + + return &SSHCommand{ + // Executor: process.NewDefaultExecutor(sess.Run(cmd)), + sshClient: client, + session: session, + Name: name, + Args: args, + Env: os.Environ(), + cmd: cmd, + } +} + +func (c *SSHCommand) WithSSHArgs(args ...string) { + c.SSHArgs = args +} + +func (c *SSHCommand) OnCommandStart(fn func()) { + c.onCommandStart = fn +} + +func (c *SSHCommand) Start() error { + // setup stream handlers + c.logDebugF("Call start") + if c.session == nil { + return fmt.Errorf("ssh session not started") + } + + err := c.SetupStreamHandlers() + if err != nil { + c.logDebugF("Could not set up stream handlers: %s", err) + return err + } + + err = c.start() + if err != nil { + c.logDebugF("Could not start: %v", err) + return err + } + + if c.WaitHandler != nil || c.timeout > 0 { + c.ProcessWait() + // wait only with timeout because WaitHandler run in long time commands like kube proxy + if c.timeout > 0 { + if c.waitCh != nil { + <-c.waitCh + } else { + c.logDebugF("Wait channel is nil. Possible bug. Returns immediately") + } + } + } else { + err = c.wait() + if err != nil { + return err + } + } + + return nil +} + +func (c *SSHCommand) start() error { + if c.ctx != nil { + select { + case <-c.ctx.Done(): + return c.ctx.Err() + default: + } + } + + if c.Cancel != nil && c.ctx != nil && c.ctx.Done() != nil { + resultc := make(chan error) + c.ctxResult = resultc + go c.watchCtx(resultc) + } + + command := c.cmd + " " + strings.Join(c.Args, " ") + + return c.session.Start(command) +} + +func (c *SSHCommand) watchCtx(resultc chan<- error) { + <-c.ctx.Done() + + var err error + if c.Cancel != nil { + if interruptErr := c.Cancel(); interruptErr == nil { + // We appear to have successfully interrupted the command, so any + // program behavior from this point may be due to ctx even if the + // command exits with code 0. + err = c.ctx.Err() + } else if errors.Is(interruptErr, os.ErrProcessDone) { + // The process already finished: we just didn't notice it yet. + // (Perhaps c.Wait hadn't been called, or perhaps it happened to race with + // c.ctx being canceled.) Don't inject a needless error. + } else { + err = interruptErr + } + } + + resultc <- err +} + +func (c *SSHCommand) wait() error { + waitCh := make(chan (error)) + + go func() { + waitCh <- c.session.Wait() + }() + + select { + case err := <-c.ctxResult: + if c.ctxResult != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + } + case err := <-waitCh: + if err != nil { + return err + } + + } + return nil +} + +func (c *SSHCommand) ProcessWait() { + waitErrCh := make(chan error, 1) + c.waitCh = make(chan struct{}, 1) + c.stopCh = make(chan struct{}, 1) + + // wait for process in go routine + go func() { + waitErrCh <- c.wait() + }() + + // todo need investigation for get rid of this gorutine. we need check to channel is stopped + // and gourutine does not exit if we use timeout and command stopped before timeout exited + // probably we can use timer or context instead of this goroutine + go func() { + if c.timeout > 0 { + time.Sleep(c.timeout) + if !c.stop && c.stopCh != nil { + // todo ugly solution + // here we check that channel is closed it is not correct + select { + case _, ok := <-c.stopCh: + if !ok { + c.logDebugF("StopCh was closed and '%s' timeout exceeded. Possible goroutine not closed.", c.timeout) + return + } + default: + c.logDebugF("StopCh is not close and '%s' timeout exceeded. Send stop", c.timeout) + } + + c.stopCh <- struct{}{} + } + } + }() + + // watch for wait or stop + go func() { + defer func() { + close(c.waitCh) + close(waitErrCh) + }() + // Wait until Stop() is called or/and Wait() is returning. + for { + select { + case err := <-waitErrCh: + if c.stop { + // Ignore error if Stop() was called. + return + } + c.setWaitError(err) + if c.WaitHandler != nil { + c.WaitHandler(c.waitError) + } + return + case <-c.stopCh: + c.stop = true + // Prevent next readings from the closed channel. + c.stopCh = nil + // The usual e.cmd.Process.Kill() is not working for the process + // started with the new process group (Setpgid: true). + // Negative pid number is used to send a signal to all processes in the group. + err := c.session.Signal(gossh.SIGKILL) + if err != nil { + c.killError = err + } + } + } + }() +} + +func (c *SSHCommand) clientString() string { + sessionString := "unknown" + sess := c.sshClient.Session() + if c.sshClient != nil && sess != nil { + sessionString = sess.String() + } + + return sessionString +} + +func (c *SSHCommand) Run(ctx context.Context) error { + c.logDebugF("Call run") + c.Cmd(ctx) + + if c.session == nil { + return fmt.Errorf("ssh session not started") + } + defer c.closeSession() + + err := c.Start() + if err != nil { + return err + } + + c.Stop() + + return c.WaitError() +} + +func (c *SSHCommand) WaitError() error { + defer c.lockWaitError.RUnlock() + c.lockWaitError.RLock() + return c.waitError +} + +func (c *SSHCommand) StderrBytes() []byte { + if len(c.ErrBytes.Bytes()) > 0 { + return c.ErrBytes.Bytes() + } + + if c.err != nil { + return c.err.Bytes() + } + + return nil +} + +func (c *SSHCommand) StdoutBytes() []byte { + if len(c.OutBytes.Bytes()) > 0 { + return c.OutBytes.Bytes() + } + + if c.out != nil { + return c.out.Bytes() + } + + return nil +} + +func (c *SSHCommand) WithMatchers(matchers ...*utils.ByteSequenceMatcher) *SSHCommand { + c.Matchers = make([]*utils.ByteSequenceMatcher, 0) + c.Matchers = append(c.Matchers, matchers...) + return c +} + +func (c *SSHCommand) WithWaitHandler(waitHandler func(error)) *SSHCommand { + c.WaitHandler = waitHandler + return c +} + +func (c *SSHCommand) OpenStdinPipe() *SSHCommand { + c.StdinPipe = true + return c +} + +func (c *SSHCommand) WithMatchHandler(fn func(pattern string) string) *SSHCommand { + c.MatchHandler = fn + return c +} + +func (c *SSHCommand) Sudo(ctx context.Context) { + cmdLine := c.Name + " " + strings.Join(c.Args, " ") + sudoCmdLine := fmt.Sprintf( + `sudo -p SudoPassword -H -S -i bash -c 'echo SUDO-SUCCESS && %s'`, + cmdLine, + ) + + c.cmd = sudoCmdLine + c.Cmd(ctx) + + c.WithMatchers( + utils.NewByteSequenceMatcher("SudoPassword"), + utils.NewByteSequenceMatcher("SUDO-SUCCESS").WaitNonMatched(), + ) + c.OpenStdinPipe() + + passSent := false + c.WithMatchHandler(func(pattern string) string { + logger := c.sshClient.settings.Logger() + if pattern == "SudoPassword" { + c.logDebugF("Send become pass to cmd") + becomePass := c.sshClient.Session().BecomePass + + var err error + _, err = c.Stdin.Write([]byte(becomePass + "\n")) + if err != nil { + logger.ErrorF("Got error from sending pass to stdin for '%s': %v", c.clientString(), err) + } + if !passSent { + passSent = true + } else { + // Second prompt is error! + logger.ErrorLn("Bad sudo password.") + } + return "reset" + } + if pattern == "SUDO-SUCCESS" { + c.logDebugF("Got SUCCESS for sudo password") + if c.onCommandStart != nil { + c.onCommandStart() + } + return "done" + } + return "" + }) +} + +func (c *SSHCommand) WithStdoutHandler(handler func(string)) { + c.stdoutHandler = handler +} + +func (c *SSHCommand) WithStderrHandler(handler func(string)) { + c.stderrHandler = handler +} + +func (c *SSHCommand) Cmd(ctx context.Context) { + if ctx != nil { + c.ctx = ctx + } + c.Cancel = func() error { + return c.session.Signal(gossh.SIGINT) + } +} + +func (c *SSHCommand) Output(ctx context.Context) ([]byte, []byte, error) { + c.Cmd(ctx) + if c.session == nil { + return nil, nil, fmt.Errorf("ssh session not started") + } + defer c.closeSession() + + if c.out == nil { + c.out = new(bytes.Buffer) + } else { + c.out.Reset() + } + + if c.err == nil { + c.err = new(bytes.Buffer) + } else { + c.err.Reset() + } + + var err error + c.stdoutPipeFile, err = c.session.StdoutPipe() + if err != nil { + return nil, nil, fmt.Errorf("open stdout pipe '%s': %w", c.Name, err) + } + + c.stderrPipeFile, err = c.session.StderrPipe() + if err != nil { + return nil, nil, fmt.Errorf("open stderr pipe '%s': %w", c.Name, err) + } + + err = c.Start() + c.wg.Wait() + return c.out.Bytes(), c.err.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +func (c *SSHCommand) CombinedOutput(ctx context.Context) ([]byte, error) { + c.Cmd(ctx) + if c.session == nil { + return nil, fmt.Errorf("ssh session not started") + } + + defer c.closeSession() + + if c.out == nil { + c.out = new(bytes.Buffer) + } else { + c.out.Reset() + } + + if c.err == nil { + c.err = new(bytes.Buffer) + } else { + c.err.Reset() + } + + var err error + c.stdoutPipeFile, err = c.session.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("open stdout pipe '%s': %w", c.Name, err) + } + + c.stderrPipeFile, err = c.session.StderrPipe() + if err != nil { + return nil, fmt.Errorf("open stderr pipe '%s': %w", c.Name, err) + } + var co singleWriter + c.combined = &co + + err = c.Start() + c.wg.Wait() + return c.combined.b.Bytes(), err +} + +func (c *SSHCommand) WithTimeout(timeout time.Duration) { + c.timeout = timeout +} + +func (c *SSHCommand) WithEnv(env map[string]string) { + c.Env = make([]string, 0, len(env)) + for k, v := range env { + c.Env = append(c.Env, fmt.Sprintf("%s=%s", k, v)) + } +} +func (c *SSHCommand) CaptureStdout(buf *bytes.Buffer) *SSHCommand { + if buf != nil { + c.out = buf + } else { + c.out = &bytes.Buffer{} + } + return c +} + +func (c *SSHCommand) CaptureStderr(buf *bytes.Buffer) *SSHCommand { + if buf != nil { + c.err = buf + } else { + c.err = &bytes.Buffer{} + } + return c +} + +func (c *SSHCommand) SetupStreamHandlers() (err error) { + // stderr goes to console (commented because ssh writes only "Connection closed" messages to stderr) + // c.Cmd.Stderr = os.Stderr + // connect console's stdin + // c.Cmd.Stdin = os.Stdin + + // setup stdout stream handlers + if c.session != nil && c.out == nil && c.stdoutHandler == nil && len(c.Matchers) == 0 { + c.session.Stdout = os.Stdout + c.session.Stdout = &c.OutBytes + c.session.Stderr = &c.ErrBytes + return + } + + var stdoutHandlerWritePipe *os.File + var stdoutHandlerReadPipe *os.File + if c.out != nil || c.stdoutHandler != nil || len(c.Matchers) > 0 { + + if c.out == nil { + c.out = new(bytes.Buffer) + } + + if c.stdoutPipeFile == nil { + var err error + c.stdoutPipeFile, err = c.session.StdoutPipe() + if err != nil { + return fmt.Errorf("open stdout pipe '%s': %w", c.Name, err) + } + } + + // create pipe for StdoutHandler + if c.stdoutHandler != nil { + stdoutHandlerReadPipe, stdoutHandlerWritePipe, err = os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stdoutHandler: %s", err) + } + } + } + + var stderrReadPipe io.Reader + var stderrHandlerWritePipe *os.File + var stderrHandlerReadPipe *os.File + if c.err != nil || c.stderrHandler != nil || len(c.Matchers) > 0 { + + if c.err == nil { + c.err = new(bytes.Buffer) + } + + if c.stderrPipeFile == nil { + var err error + c.stderrPipeFile, err = c.session.StderrPipe() + if err != nil { + return fmt.Errorf("open stdout pipe '%s': %w", c.Name, err) + } + } + + // create pipe for StderrHandler + if c.stderrHandler != nil { + stderrHandlerReadPipe, stderrHandlerWritePipe, err = os.Pipe() + if err != nil { + return fmt.Errorf("unable to create os pipe for stderrHandler: %s", err) + } + } + } + + if c.StdinPipe { + c.Stdin, err = c.session.StdinPipe() + if err != nil { + return fmt.Errorf("open stdin pipe: %v", err) + } + } + + // Start reading from stdout of a command. + // Wait until all matchers are done and then: + // - Copy to os.Stdout if live output is enabled + // - Copy to buffer if capture is enabled + // - Copy to pipe if StdoutHandler is set + c.wg.Add(2) + go func() { + c.readFromStreams(c.stdoutPipeFile, stdoutHandlerWritePipe, false) + }() + + // sudo hack, becouse of password prompt is sent to STDERR, not STDOUT + go func() { + c.readFromStreams(c.stderrPipeFile, stdoutHandlerWritePipe, true) + }() + + go func() { + if c.stdoutHandler == nil { + c.logDebugF("stdout read pipe not set. Consumer does not start") + return + } + c.ConsumeLines(stdoutHandlerReadPipe, c.stdoutHandler) + c.logDebugF("Stop lines consumer") + }() + + // Start reading from stderr of a command. + // Copy to os.Stderr if live output is enabled + // Copy to buffer if capture is enabled + // Copy to pipe if StderrHandler is set + go func() { + if stderrReadPipe == nil { + c.logDebugF("stdterr read pipe not set. Pipe reader does not start") + return + } + + c.logDebugF("Start reading from stderr pipe") + defer c.logDebugF("Stop reading from stderr pipe") + + buf := make([]byte, 16) + for { + n, err := stderrReadPipe.Read(buf) + + // TODO logboek + if c.sshClient.settings.IsDebug() { + os.Stderr.Write(buf[:n]) + } + if c.err != nil { + c.err.Write(buf[:n]) + } + if c.stderrHandler != nil { + _, _ = stderrHandlerWritePipe.Write(buf[:n]) + } + + if err == io.EOF { + break + } + } + }() + + go func() { + if c.stderrHandler == nil { + c.logDebugF("stdterr line consumer not set. Consumer does not start") + return + } + c.ConsumeLines(stderrHandlerReadPipe, c.stderrHandler) + c.logDebugF("Stop stdterr line consumer") + }() + + return nil +} + +func (c *SSHCommand) readFromStreams(stdoutReadPipe io.Reader, stdoutHandlerWritePipe io.Writer, isError bool) { + defer c.logDebugF("readFromStreams stopped") + defer c.wg.Done() + + if govalue.Nil(stdoutReadPipe) { + c.logDebugF("stdout pipe is nil") + return + } + + c.logDebugF("Start read from streams") + + buf := make([]byte, 16) + matchersDone := false + errorsCount := 0 + for { + n, err := stdoutReadPipe.Read(buf) + if err != nil && err != io.EOF { + c.logDebugF("Error reading from stdout: %v", err) + errorsCount++ + if errorsCount > 1000 { + panic(fmt.Errorf("readFromStreams: too many errors, last error %v", err)) + } + continue + } + + m := 0 + if !matchersDone { + for _, matcher := range c.Matchers { + m = matcher.Analyze(buf[:n]) + if matcher.IsMatched() { + c.logDebugF("Triggered match for '%s'", matcher.Pattern) + // matcher is triggered + if c.MatchHandler != nil { + res := c.MatchHandler(matcher.Pattern) + if res == "done" { + matchersDone = true + break + } + if res == "reset" { + matcher.Reset() + } + } + } + } + + // stdout for internal use, no copying to pipes until all Matchers are matched + if !matchersDone { + m = n + } + } + // TODO logboek + if c.sshClient.settings.IsDebug() { + os.Stdout.Write(buf[m:n]) + } + if c.out != nil && !isError { + c.out.Write(buf[:n]) + } + + if c.err != nil && isError { + c.err.Write(buf[:n]) + } + + if c.combined != nil { + c.combined.Write(buf[:n]) + } + if c.stdoutHandler != nil { + _, _ = stdoutHandlerWritePipe.Write(buf[m:n]) + } + + if err == io.EOF { + c.logDebugF("readFromStreams: EOF") + break + } + } +} + +func (c *SSHCommand) ConsumeLines(r io.Reader, fn func(l string)) { + scanner := bufio.NewScanner(r) + if c.StdoutSplitter != nil { + scanner.Split(c.StdoutSplitter) + } + for scanner.Scan() { + text := scanner.Text() + + if fn != nil { + fn(text) + } + + if text != "" { + c.logDebugF("Line consumed: '%s'", text) + } + } +} + +func (c *SSHCommand) Stop() { + c.logDebugF("Running stop") + + if c.stop { + c.logDebugF("Already stopped") + return + } + if c.session == nil { + c.logDebugF("Session not started yet") + return + } + if c.cmd == "" { + c.logDebugF("Possible BUG: Call Executor.Stop with Cmd==nil") + return + } + + c.stop = true + if c.stopCh != nil { + c.logDebugF("Send stop signal") + close(c.stopCh) + } + c.logDebugF("Stopped") + c.logDebugF("Sending SIGINT...") + c.session.Signal(gossh.SIGINT) + c.logDebugF("Signal SIGINT sent") + c.session.Signal(gossh.SIGKILL) +} + +func (c *SSHCommand) setWaitError(err error) { + defer c.lockWaitError.Unlock() + c.lockWaitError.Lock() + c.waitError = err +} + +func (c *SSHCommand) closeSession() { + c.session.Close() + c.sshClient.UnregisterSession(c.session) +} +func (c *SSHCommand) logDebugF(format string, v ...interface{}) { + msg := fmt.Sprintf(format, v...) + args := "" + if len(c.Args) > 0 { + args = strings.Join(c.Args, " ") + } + c.sshClient.settings.Logger().DebugF("'%s' for cmd '%s' with args '%s' with client '%s'\n", msg, c.cmd, args, c.clientString()) +} diff --git a/pkg/ssh/gossh/command_test.go b/pkg/ssh/gossh/command_test.go new file mode 100644 index 0000000..9c3bf06 --- /dev/null +++ b/pkg/ssh/gossh/command_test.go @@ -0,0 +1,656 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +func TestCommandOutput(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestCommandOutput") + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + t.Run("Get command Output", func(t *testing.T) { + cases := []struct { + title string + command string + args []string + expectedOutput string + expectedErrOutput string + timeout time.Duration + prepareFunc func(c *SSHCommand) error + wantErr bool + err string + }{ + { + title: "Just echo, success", + command: "echo", + args: []string{`"test output"`}, + expectedOutput: "test output\n", + wantErr: false, + }, + { + title: "With context", + command: `while true; do echo "test"; sleep 5; done`, + args: []string{}, + expectedOutput: "test\ntest\n", + timeout: 7 * time.Second, + wantErr: false, + }, + { + title: "Command return error", + command: "cat", + args: []string{`"/etc/sudoers"`}, + wantErr: true, + err: "Process exited with status 1", + expectedErrOutput: "cat: /etc/sudoers: Permission denied\n", + }, + { + title: "With opened stdout pipe", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + return c.Run(context.Background()) + }, + wantErr: true, + err: "open stdout pipe", + }, + { + title: "With opened stderr pipe", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + buf := new(bytes.Buffer) + c.session.Stderr = buf + return nil + }, + wantErr: true, + err: "open stderr pipe", + }, + { + title: "With nil session", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + err := c.session.Close() + c.session = nil + return err + }, + wantErr: true, + err: "ssh session not started", + }, + { + title: "With defined buffers", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + c.out = new(bytes.Buffer) + c.err = new(bytes.Buffer) + return nil + }, + expectedOutput: "test output\n", + wantErr: false, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + ctx := context.Background() + var emptyDuration time.Duration + var cancel context.CancelFunc + if c.timeout != emptyDuration { + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(c.timeout)) + } + if cancel != nil { + defer cancel() + } + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, keys). + WithLoopsParams(newSessionTestLoopParams()) + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + cmd := NewSSHCommand(sshClient, c.command, c.args...) + + if c.prepareFunc != nil { + err = c.prepareFunc(cmd) + require.NoError(t, err) + } + out, errBytes, err := cmd.Output(ctx) + if !c.wantErr { + require.NoError(t, err) + require.Equal(t, c.expectedOutput, string(out)) + } else { + require.Error(t, err) + require.Equal(t, c.expectedErrOutput, string(errBytes)) + require.Contains(t, err.Error(), c.err) + } + }) + } + }) +} + +func TestCommandCombinedOutput(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestCommandCombinedOutput") + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + t.Run("Get command CombinedOutput", func(t *testing.T) { + cases := []struct { + title string + command string + args []string + expectedOutput string + expectedErrOutput string + timeout time.Duration + prepareFunc func(c *SSHCommand) error + wantErr bool + err string + }{ + { + title: "Just echo, success", + command: "echo", + args: []string{"\"test output\""}, + expectedOutput: "test output\n", + wantErr: false, + }, + { + title: "With context", + command: "while true; do echo \"test\"; sleep 5; done", + args: []string{}, + expectedOutput: "test\ntest\n", + timeout: 7 * time.Second, + wantErr: false, + }, + { + title: "Command return error", + command: "cat", + args: []string{"\"/etc/sudoers\""}, + wantErr: true, + err: "Process exited with status 1", + expectedErrOutput: "cat: /etc/sudoers: Permission denied\n", + }, + { + title: "With opened stdout pipe", + command: "echo", + args: []string{"\"test output\""}, + prepareFunc: func(c *SSHCommand) error { + return c.Run(context.Background()) + }, + wantErr: true, + err: "open stdout pipe", + }, + { + title: "With opened stderr pipe", + command: "echo", + args: []string{"\"test output\""}, + prepareFunc: func(c *SSHCommand) error { + buf := new(bytes.Buffer) + c.session.Stderr = buf + return nil + }, + wantErr: true, + err: "open stderr pipe", + }, + { + title: "With nil session", + command: "echo", + args: []string{"\"test output\""}, + prepareFunc: func(c *SSHCommand) error { + err := c.session.Close() + c.session = nil + return err + }, + wantErr: true, + err: "ssh session not started", + }, + { + title: "With defined buffers", + command: "echo", + args: []string{"\"test output\""}, + prepareFunc: func(c *SSHCommand) error { + c.out = new(bytes.Buffer) + c.err = new(bytes.Buffer) + return nil + }, + expectedOutput: "test output\n", + wantErr: false, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + ctx := context.Background() + var emptyDuration time.Duration + var cancel context.CancelFunc + if c.timeout != emptyDuration { + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(c.timeout)) + } + if cancel != nil { + defer cancel() + } + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, keys). + WithLoopsParams(newSessionTestLoopParams()) + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + cmd := NewSSHCommand(sshClient, c.command, c.args...) + if c.prepareFunc != nil { + err = c.prepareFunc(cmd) + require.NoError(t, err) + } + combined, err := cmd.CombinedOutput(ctx) + if !c.wantErr { + require.NoError(t, err) + require.Equal(t, c.expectedOutput, string(combined)) + } else { + require.Error(t, err) + require.Equal(t, c.expectedErrOutput, string(combined)) + require.Contains(t, err.Error(), c.err) + } + }) + } + }) +} + +func TestCommandRun(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestCommandRun") + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + // evns test + envs := make(map[string]string) + envs["TEST_ENV"] = "test" + + t.Run("Run a command", func(t *testing.T) { + cases := []struct { + title string + command string + args []string + expectedOutput string + expectedErrOutput string + timeout time.Duration + prepareFunc func(c *SSHCommand) error + envs map[string]string + wantErr bool + err string + }{ + { + title: "Just echo, success", + command: "echo", + args: []string{`"est output"`}, + expectedOutput: "test output\n", + wantErr: false, + }, + { + title: "Just echo, with envs, success", + command: "echo", + args: []string{`test output"`}, + expectedOutput: "test output\n", + envs: envs, + wantErr: false, + }, + { + title: "With context", + command: `while true; do echo "test"; sleep 5; done`, + args: []string{}, + expectedOutput: "test\ntest\n", + timeout: 7 * time.Second, + wantErr: false, + }, + { + title: "Command return error", + command: "cat", + args: []string{`"/etc/sudoers"`}, + wantErr: true, + err: "Process exited with status 1", + expectedErrOutput: "cat: /etc/sudoers: Permission denied\n", + }, + { + title: "With opened stdout pipe", + command: "echo", + args: []string{`"test output\"`}, + prepareFunc: func(c *SSHCommand) error { + return c.Run(context.Background()) + }, + wantErr: true, + err: "ssh: session already started", + }, + { + title: "With nil session", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + err := c.session.Close() + c.session = nil + return err + }, + wantErr: true, + err: "ssh session not started", + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + ctx := context.Background() + var emptyDuration time.Duration + var cancel context.CancelFunc + if c.timeout != emptyDuration { + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(c.timeout)) + } + if cancel != nil { + defer cancel() + } + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, keys). + WithLoopsParams(newSessionTestLoopParams()) + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + cmd := NewSSHCommand(sshClient, c.command, c.args...) + if c.prepareFunc != nil { + err = c.prepareFunc(cmd) + require.NoError(t, err) + } + if len(c.envs) > 0 { + cmd.WithEnv(c.envs) + } + + err = cmd.Run(ctx) + if !c.wantErr { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + + // second run for context after deadline exceeded + if c.timeout != emptyDuration { + cmd2 := NewSSHCommand(sshClient, c.command, c.args...) + if c.prepareFunc != nil { + err = c.prepareFunc(cmd2) + require.NoError(t, err) + } + if len(c.envs) > 0 { + cmd2.WithEnv(c.envs) + } + err = cmd2.Run(ctx) + // command should fail to run + require.Error(t, err) + require.Contains(t, err.Error(), "context deadline exceeded") + + } + sshClient.Stop() + }) + } + }) +} + +func TestCommandStart(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestCommandStart") + + container := sshtesting.NewTestContainerWrapper(t, test) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + ctx := context.Background() + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, sess, keys). + WithLoopsParams(newSessionTestLoopParams()) + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + t.Run("Start and stop a command", func(t *testing.T) { + cases := []struct { + title string + command string + args []string + expectedOutput string + expectedErrOutput string + timeout time.Duration + prepareFunc func(c *SSHCommand) error + wantErr bool + err string + }{ + { + title: "Just echo, success", + command: "echo", + args: []string{`"test output"`}, + expectedOutput: "test output\n", + wantErr: false, + }, + { + title: "With context", + command: `while true; do echo "test"; sleep 5; done`, + args: []string{}, + expectedOutput: "test\ntest\n", + timeout: 7 * time.Second, + wantErr: false, + }, + { + title: "Command return error", + command: "cat", + args: []string{`"/etc/sudoers"`}, + wantErr: true, + err: "Process exited with status 1", + expectedErrOutput: "cat: /etc/sudoers: Permission denied\n", + }, + { + title: "With opened stdout pipe", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + return c.Run(context.Background()) + }, + wantErr: true, + err: "ssh: session already started", + }, + { + title: "With nil session", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + err := c.session.Close() + c.session = nil + return err + }, + wantErr: true, + err: "ssh session not started", + }, + { + title: "waitHandler", + command: "echo", + args: []string{`"test output"`}, + prepareFunc: func(c *SSHCommand) error { + c.WithWaitHandler(func(err error) { + if err != nil { + test.Logger.ErrorF("SSH-agent process exited, now stop. Wait error: %v", err) + return + } + test.Logger.InfoF("SSH-agent process exited, now stop") + }) + return nil + }, + expectedOutput: "test output\n", + wantErr: false, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + cmd := NewSSHCommand(sshClient, c.command, c.args...) + var emptyDuration time.Duration + if c.timeout != emptyDuration { + cmd.WithTimeout(c.timeout) + } + if c.prepareFunc != nil { + err = c.prepareFunc(cmd) + require.NoError(t, err) + } + cmd.Cmd(ctx) + err = cmd.Start() + if !c.wantErr { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + cmd.Stop() + }) + } + }) +} + +func TestCommandSudoRun(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestCommandRunSudo") + + container := sshtesting.NewTestContainerWrapper(t, test, sshtesting.WithNoPassword()) + keys := container.AgentPrivateKeys() + + // starting openssh container with password auth + containerWithPass := sshtesting.NewTestContainerWrapper( + t, + test, + sshtesting.WithPassword(sshtesting.RandPassword(12)), + ) + + sessionWithoutPassword := sshtesting.Session(container) + + sessionWithValidPass := sshtesting.Session(containerWithPass) + + // client with wrong sudo password + sessionWithInvalidPass := sshtesting.Session(containerWithPass, func(input *session.Input) { + input.BecomePass = sshtesting.RandPassword(3) + }) + + t.Run("Run a command with sudo", func(t *testing.T) { + cases := []struct { + title string + settings *session.Session + keys []session.AgentPrivateKey + command string + args []string + timeout time.Duration + prepareFunc func(c *SSHCommand) error + wantErr bool + err string + errorOutput string + }{ + { + title: "Just echo, success", + settings: sessionWithoutPassword, + keys: keys, + command: "echo", + args: []string{`"test output"`}, + wantErr: false, + }, + { + title: "Just echo, success, with password", + settings: sessionWithValidPass, + keys: make([]session.AgentPrivateKey, 0, 1), + command: "echo", + args: []string{`"test output"`}, + wantErr: false, + }, + { + title: "Just echo, failure, with wrong password", + settings: sessionWithInvalidPass, + keys: keys, + command: "echo", + args: []string{`"test output"`}, + wantErr: true, + err: "Process exited with status 1", + errorOutput: "SudoPasswordSorry, try again.\nSudoPasswordSorry, try again.\nSudoPasswordsudo: 3 incorrect password attempts\n", + }, + { + title: "With context", + settings: sessionWithoutPassword, + keys: keys, + command: `while true; do echo "test"; sleep 5; done`, + args: []string{}, + timeout: 7 * time.Second, + wantErr: false, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + ctx := context.Background() + var emptyDuration time.Duration + var cancel context.CancelFunc + if c.timeout != emptyDuration { + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(c.timeout)) + } + if cancel != nil { + defer cancel() + } + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(ctx, sshSettings, c.settings, c.keys). + WithLoopsParams(newSessionTestLoopParams()) + + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + cmd := NewSSHCommand(sshClient, c.command, c.args...).CaptureStderr(nil) + if c.prepareFunc != nil { + err = c.prepareFunc(cmd) + require.NoError(t, err) + } + cmd.Sudo(ctx) + err = cmd.Run(ctx) + if !c.wantErr { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + errBytes := cmd.StderrBytes() + + require.Contains(t, string(errBytes), c.errorOutput) + } + }) + } + }) +} diff --git a/pkg/ssh/gossh/common_test.go b/pkg/ssh/gossh/common_test.go new file mode 100644 index 0000000..d6698cd --- /dev/null +++ b/pkg/ssh/gossh/common_test.go @@ -0,0 +1,75 @@ +package gossh + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/stretchr/testify/require" +) + +func registerStopClient(t *testing.T, sshClient *Client) { + t.Cleanup(func() { + sshClient.Stop() + }) +} + +// todo mount local directory to container and assert via local exec +func assertFilesViaRemoteRun(t *testing.T, sshClient *Client, cmd string, expectedOutput string) { + s, err := sshClient.NewSSHSession() + require.NoError(t, err, "session should start") + defer sshClient.UnregisterSession(s) + out, err := s.Output(cmd) + require.NoError(t, err) + // out contains a contant of uploaded file, should be equal to testFile contant + require.Equal(t, expectedOutput, string(out)) +} + +func startContainerAndClientWithContainer(t *testing.T, test *sshtesting.Test, opts ...sshtesting.TestContainerWrapperSettingsOpts) (*Client, *sshtesting.TestContainerWrapper) { + container := sshtesting.NewTestContainerWrapper(t, test, opts...) + sess := sshtesting.Session(container) + keys := container.AgentPrivateKeys() + + sshSettings := sshtesting.CreateDefaultTestSettings(test) + sshClient := NewClient(context.Background(), sshSettings, sess, keys).WithLoopsParams(ClientLoopsParams{ + NewSession: sshtesting.GetTestLoopParamsForFailed(), + }) + + err := sshClient.Start() + // expecting no error on client start + require.NoError(t, err) + + registerStopClient(t, sshClient) + + return sshClient, container +} + +func startContainerAndClient(t *testing.T, test *sshtesting.Test, opts ...sshtesting.TestContainerWrapperSettingsOpts) *Client { + sshClient, _ := startContainerAndClientWithContainer(t, test, opts...) + return sshClient +} + +func newSessionTestLoopParams() ClientLoopsParams { + return ClientLoopsParams{ + NewSession: retry.NewEmptyParams( + retry.WithWait(2*time.Second), + retry.WithAttempts(5), + ), + } +} + +func tunnelAddressString(local, remote int) string { + localAddr := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", local)) + remoteAddr := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", remote)) + return fmt.Sprintf("%s:%s", remoteAddr, localAddr) +} + +func registerStopTunnel(t *testing.T, tunnel *Tunnel) { + t.Cleanup(func() { + tunnel.Stop() + }) +} diff --git a/pkg/ssh/gossh/file.go b/pkg/ssh/gossh/file.go new file mode 100644 index 0000000..2fd20be --- /dev/null +++ b/pkg/ssh/gossh/file.go @@ -0,0 +1,492 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "path" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/bramvdbogaerde/go-scp" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-dhctl/pkg/log" + gossh "github.com/deckhouse/lib-gossh" + "github.com/google/uuid" +) + +type SSHFile struct { + settings settings.Settings + sshClient *gossh.Client +} + +func NewSSHFile(sett settings.Settings, client *gossh.Client) *SSHFile { + return &SSHFile{ + sshClient: client, + settings: sett, + } +} + +func (f *SSHFile) Upload(ctx context.Context, srcPath, remotePath string) error { + logger := f.settings.Logger() + + fType, err := CheckLocalPath(srcPath) + if err != nil { + return fmt.Errorf("failed to open local file: %w", err) + } + + session, err := f.sshClient.NewSession() + if err != nil { + return err + } + defer session.Close() + + if fType != "DIR" { + localFile, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("failed to open local file: %w", err) + } + defer localFile.Close() + + rType, err := getRemoteFileStat(f.sshClient, remotePath, logger) + if err != nil { + if !strings.ContainsAny(err.Error(), "No such file or directory") { + return err + } + } + if rType == "DIR" { + remotePath = remotePath + "/" + filepath.Base(srcPath) + } + logger.DebugF("starting upload local %s to remote %s", srcPath, remotePath) + + if err := CopyFile(ctx, localFile, remotePath, "0755", session); err != nil { + return fmt.Errorf("failed to copy file to remote host: %w", err) + } + } else { + err = session.Run("mkdir -p " + remotePath) + if err != nil { + return err + } + files, err := os.ReadDir(srcPath) + if err != nil { + return fmt.Errorf("could not read directory: %w", err) + } + for _, file := range files { + err = f.Upload(ctx, srcPath+"/"+file.Name(), remotePath+"/"+file.Name()) + if err != nil { + return err + } + } + } + + return nil +} + +// UploadBytes creates a tmp file and upload it to remote dstPath +func (f *SSHFile) UploadBytes(ctx context.Context, data []byte, remotePath string) error { + srcPath, err := CreateEmptyTmpFile(f.settings) + if err != nil { + return fmt.Errorf("create source tmp file: %v", err) + } + defer func() { + err := os.Remove(srcPath) + if err != nil { + f.settings.Logger().ErrorF("Error: cannot remove tmp file '%s': %v", srcPath, err) + } + }() + + err = os.WriteFile(srcPath, data, 0o600) + if err != nil { + return fmt.Errorf("write data to tmp file: %w", err) + } + + err = f.Upload(ctx, srcPath, remotePath) + return err +} + +func (f *SSHFile) Download(ctx context.Context, remotePath, dstPath string) error { + logger := f.settings.Logger() + + fType, err := getRemoteFileStat(f.sshClient, remotePath, logger) + if err != nil { + return err + } + + if fType != "DIR" { + // regular file logic + lType, err := CheckLocalPath(dstPath) + if err != nil { + if !strings.ContainsAny(err.Error(), "No such file or directory") { + return err + } + } + if lType == "DIR" { + dstPath = filepath.Join(dstPath, filepath.Base(remotePath)) + } + localFile, err := os.Create(dstPath) + if err != nil { + return fmt.Errorf("failed to open local file: %w", err) + } + defer localFile.Close() + if err := CopyFromRemote(ctx, localFile, remotePath, f.sshClient); err != nil { + return fmt.Errorf("failed to copy file from remote host: %w", err) + } + } else { + // recursive copy logic + filesString, err := getRemoteFilesList(f.sshClient, remotePath) + if err != nil { + return err + } + + if filepath.Base(dstPath) != filepath.Base(remotePath) { + dstPath = dstPath + "/" + filepath.Base(remotePath) + } + + err = os.MkdirAll(dstPath, os.ModePerm) + if err != nil { + return err + } + + re := regexp.MustCompile(`\s+`) + files := re.Split(filesString, -1) + for _, file := range files { + f.Download(ctx, remotePath+"/"+file, dstPath+"/"+file) + } + } + + return nil +} + +// Download remote file and returns its content as an array of bytes. +func (f *SSHFile) DownloadBytes(ctx context.Context, remotePath string) ([]byte, error) { + dstPath, err := CreateEmptyTmpFile(f.settings) + if err != nil { + return nil, fmt.Errorf("create target tmp file: %v", err) + } + defer func() { + err := os.Remove(dstPath) + if err != nil { + f.settings.Logger().DebugF("Error: cannot remove tmp file '%s': %v", dstPath, err) + } + }() + + err = f.Download(ctx, remotePath, dstPath) + if err != nil { + return nil, fmt.Errorf("download target tmp file: %v", err) + } + + data, err := os.ReadFile(dstPath) + if err != nil { + return nil, fmt.Errorf("reading tmp file '%s': %w", dstPath, err) + } + + return data, nil +} + +func getRemoteFileStat(client *gossh.Client, remoteFilePath string, logger log.Logger) (string, error) { + if remoteFilePath == "." { + return "DIR", nil + } + + session, err := client.NewSession() + if err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + defer session.Close() + + command := fmt.Sprint("LC_ALL=en_US.utf8 stat -c %F " + remoteFilePath) + output, err := session.CombinedOutput(command) + + logger.DebugF("remote path %s is %s\n", remoteFilePath, output) + + if strings.TrimSpace(string(output)) == "directory" { + return "DIR", nil + } + + if strings.TrimSpace(string(output)) == "regular file" { + return "FILE", nil + } + + return "", err +} + +func getRemoteFilesList(client *gossh.Client, remoteFilePath string) (string, error) { + session, err := client.NewSession() + if err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + defer session.Close() + + command := fmt.Sprint("ls " + remoteFilePath) + output, err := session.CombinedOutput(command) + + return strings.TrimSpace(string(output)), err +} + +func CreateEmptyTmpFile(sett settings.Settings) (string, error) { + u, err := uuid.NewRandom() + if err != nil { + return "", err + } + + tmpPath := filepath.Join( + sett.TmpDir(), + fmt.Sprintf("dhctl-scp-%d-%s.tmp", os.Getpid(), u.String()), + ) + + file, err := os.OpenFile(tmpPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return "", err + } + + _ = file.Close() + return tmpPath, nil +} + +// CheckLocalPath see if file exists and determine if it is a directory. Error is returned if file is not exists. +func CheckLocalPath(path string) (string, error) { + fi, err := os.Stat(path) + if err != nil { + return "", err + } + if fi.Mode().IsDir() { + return "DIR", nil + } + if fi.Mode().IsRegular() { + return "FILE", nil + } + return "", fmt.Errorf("path '%s' is not a directory or file", path) +} + +type PassThru func(r io.Reader, total int64) io.Reader + +func CopyFile( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, + session *gossh.Session, +) error { + contentsBytes, err := io.ReadAll(fileReader) + if err != nil { + return fmt.Errorf("failed to read all data from reader: %w", err) + } + r := bytes.NewReader(contentsBytes) + size := int64(len(contentsBytes)) + + stdout, err := session.StdoutPipe() + if err != nil { + return err + } + w, err := session.StdinPipe() + if err != nil { + return err + } + defer w.Close() + + filename := path.Base(remotePath) + + // Start the command first and get confirmation that it has been started + // before sending anything through the pipes. + err = session.Start(fmt.Sprintf("%s -qt %q", "scp", remotePath)) + if err != nil { + return err + } + + wg := sync.WaitGroup{} + wg.Add(2) + + errCh := make(chan error, 2) + + // SCP protocol and file sending + go func() { + defer wg.Done() + defer w.Close() + + _, err = fmt.Fprintln(w, "C"+permissions, size, filename) + if err != nil { + errCh <- err + return + } + + if err = checkResponse(stdout); err != nil { + errCh <- err + return + } + + _, err = io.Copy(w, r) + if err != nil { + errCh <- err + return + } + + _, err = fmt.Fprint(w, "\x00") + if err != nil { + errCh <- err + return + } + + if err = checkResponse(stdout); err != nil { + errCh <- err + return + } + }() + + // Wait for the process to exit + go func() { + defer wg.Done() + err := session.Wait() + if err != nil { + errCh <- err + return + } + }() + + // Wait for one of the conditions (error/timeout/completion) to occur + if err := wait(&wg, ctx); err != nil { + return err + } + + close(errCh) + + // Collect any errors from the error channel + for err := range errCh { + if err != nil { + return err + } + } + + return nil +} + +func checkResponse(r io.Reader) error { + _, err := scp.ParseResponse(r, nil) + if err != nil { + return err + } + + return nil + +} + +func wait(wg *sync.WaitGroup, ctx context.Context) error { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + + select { + case <-c: + return nil + + case <-ctx.Done(): + return ctx.Err() + } +} + +func CopyFromRemote(ctx context.Context, file *os.File, remotePath string, sshClient *gossh.Client) error { + + session, err := sshClient.NewSession() + if err != nil { + return fmt.Errorf("Error creating ssh session in copy from remote: %v", err) + } + defer session.Close() + + wg := sync.WaitGroup{} + errCh := make(chan error, 4) + + wg.Add(1) + go func() { + var err error + + defer func() { + // NOTE: this might send an already sent error another time, but since we only receive one, this is fine. On the "happy-path" of this function, the error will be `nil` therefore completing the "err<-errCh" at the bottom of the function. + errCh <- err + // We must unblock the go routine first as we block on reading the channel later + wg.Done() + + }() + + r, err := session.StdoutPipe() + if err != nil { + errCh <- err + return + } + + in, err := session.StdinPipe() + if err != nil { + errCh <- err + return + } + defer in.Close() + + err = session.Start(fmt.Sprintf("%s -f %q", "scp", remotePath)) + if err != nil { + errCh <- err + return + } + + err = scp.Ack(in) + if err != nil { + errCh <- err + return + } + + fileInfo, err := scp.ParseResponse(r, in) + if err != nil { + errCh <- err + return + } + + err = scp.Ack(in) + if err != nil { + errCh <- err + return + } + + _, err = scp.CopyN(file, r, fileInfo.Size) + if err != nil { + errCh <- err + return + } + + err = scp.Ack(in) + if err != nil { + errCh <- err + return + } + + err = session.Wait() + if err != nil { + errCh <- err + return + } + }() + + if err := wait(&wg, ctx); err != nil { + return err + } + + finalErr := <-errCh + close(errCh) + return finalErr +} diff --git a/pkg/ssh/gossh/file_test.go b/pkg/ssh/gossh/file_test.go new file mode 100644 index 0000000..8f89378 --- /dev/null +++ b/pkg/ssh/gossh/file_test.go @@ -0,0 +1,441 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "os" + "os/exec" + "path" + "path/filepath" + "testing" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" +) + +func TestSSHFileUpload(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestCommandOutput") + + const uploadDir = "upload_dir" + const testFileContent = "Hello World" + const notExec = false + + filePath := func(subPath ...string) []string { + require.NotEmpty(t, subPath, "subPath is empty for filePath") + return append([]string{uploadDir}, subPath...) + } + + testFile := test.MustCreateTmpFile(t, testFileContent, notExec, filePath("upload")...) + testDir := filepath.Dir(testFile) + test.MustCreateTmpFile(t, "second", notExec, filePath("second")...) + test.MustCreateTmpFile(t, "empty", notExec, filePath("second")...) + test.MustCreateTmpFile(t, "sub", notExec, filePath("sub", "third")...) + + symlink := filepath.Join(test.TmpDir(), "symlink") + err := os.Symlink(testFile, symlink) + require.NoError(t, err) + + sshClient := startContainerAndClient(t, test) + + t.Run("Upload files and directories to container via existing ssh client", func(t *testing.T) { + cases := []struct { + title string + srcPath string + dstPath string + wantErr bool + err string + }{ + { + title: "Single file", + srcPath: testFile, + dstPath: ".", + wantErr: false, + }, + { + title: "Directory", + srcPath: testDir, + dstPath: "/tmp", + wantErr: false, + }, + { + title: "Nonexistent", + srcPath: "/path/to/nonexistent/flie", + dstPath: "/tmp", + wantErr: true, + err: "failed to open local file", + }, + { + title: "File to root", + srcPath: testFile, + dstPath: "/any", + wantErr: true, + }, + { + title: "File to /var/lib", + srcPath: testFile, + dstPath: "/var/lib", + wantErr: true, + }, + { + title: "File to unaccessible file", + srcPath: testFile, + dstPath: "/path/what/not/exists.txt", + wantErr: true, + err: "failed to copy file to remote host", + }, + { + title: "Directory to root", + srcPath: testDir, + dstPath: "/", + wantErr: true, + }, + { + title: "Symlink", + srcPath: symlink, + dstPath: ".", + wantErr: false, + }, + { + title: "Device", + srcPath: "/dev/zero", + dstPath: "/", + wantErr: true, + err: "is not a directory or file", + }, + { + title: "Unaccessible dir", + srcPath: "/var/audit", + dstPath: ".", + wantErr: true, + err: "could not read directory", + }, + { + title: "Unaccessible file", + srcPath: "/etc/sudoers", + dstPath: ".", + wantErr: true, + err: "failed to open local file", + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + f := sshClient.File() + err = f.Upload(context.Background(), c.srcPath, c.dstPath) + if !c.wantErr { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + }) + } + }) + + t.Run("Equality of uploaded and local file content", func(t *testing.T) { + f := sshClient.File() + err := f.Upload(context.Background(), testFile, "/tmp/testfile.txt") + // testFile contains "Hello world" string + require.NoError(t, err) + + assertFilesViaRemoteRun(t, sshClient, "cat /tmp/testfile.txt", testFileContent) + }) + + t.Run("Equality of uploaded and local directory", func(t *testing.T) { + f := sshClient.File() + err := f.Upload(context.Background(), testDir, "/tmp/upload") + require.NoError(t, err) + + cmd := exec.Command("ls", testDir) + lsResult, err := cmd.Output() + require.NoError(t, err) + + assertFilesViaRemoteRun(t, sshClient, "ls /tmp/upload", string(lsResult)) + }) +} + +func TestSSHFileUploadBytes(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestSSHFileUploadBytes") + + sshClient := startContainerAndClient(t, test) + + t.Run("Upload bytes", func(t *testing.T) { + const content = "Hello world" + f := sshClient.File() + err := f.UploadBytes(context.Background(), []byte(content), "/tmp/testfile.txt") + require.NoError(t, err) + + assertFilesViaRemoteRun(t, sshClient, "cat /tmp/testfile.txt", content) + }) +} + +func TestCreateEmptyTmpFile(t *testing.T) { + sshtesting.CheckSkipSSHTest(t, "TestCreateEmptyTmpFile") + + t.Run("Creating empty temp file", func(t *testing.T) { + cases := []struct { + title string + tmpDirName string + wantErr bool + err string + }{ + { + title: "Accessible tmp", + tmpDirName: os.TempDir(), + wantErr: false, + }, + { + title: "Unaccessible tmp", + tmpDirName: "/var/lib", + wantErr: true, + err: "permission denied", + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + uid := os.Geteuid() + sshSettings := settings.NewBaseProviders(settings.ProviderParams{ + TmpDir: c.tmpDirName, + }) + if uid == 0 && c.wantErr { + t.Skip("Test TestCreateEmptyTmpFile was skipped, cannot try to access unaccessible dir from root user") + } + filename, err := CreateEmptyTmpFile(sshSettings) + if !c.wantErr { + require.NoError(t, err) + os.Remove(filename) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + }) + } + }) +} + +func TestSSHFileDownload(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestSSHFileDownload") + + sshClient := startContainerAndClient(t, test) + + const expectedFileContent = "Some test data" + + // preparing some test related data + err := sshClient.Command("mkdir -p /tmp/testdata").Run(context.Background()) + require.NoError(t, err) + err = sshClient.Command(fmt.Sprintf(`echo -n '%s' > /tmp/testdata/first`, expectedFileContent)).Run(context.Background()) + require.NoError(t, err) + err = sshClient.Command("touch /tmp/testdata/second").Run(context.Background()) + require.NoError(t, err) + err = sshClient.Command("touch /tmp/testdata/third").Run(context.Background()) + require.NoError(t, err) + err = sshClient.Command("ln -s /tmp/testdata/first /tmp/link").Run(context.Background()) + require.NoError(t, err) + + t.Run("Download files and directories to container via existing ssh client", func(t *testing.T) { + testDir := test.MustMkSubDirs(t, "download") + + cases := []struct { + title string + srcPath string + dstPath string + wantErr bool + err string + }{ + { + title: "Single file", + srcPath: "/tmp/testdata/first", + dstPath: testDir, + wantErr: false, + }, + { + title: "Directory", + srcPath: "/tmp/testdata", + dstPath: path.Join(testDir, "downloaded"), + wantErr: false, + }, + { + title: "Nonexistent", + srcPath: "/path/to/nonexistent/flie", + dstPath: "/tmp", + wantErr: true, + }, + { + title: "File to root", + srcPath: "/tmp/testdata/first", + dstPath: "/any", + wantErr: true, + }, + { + title: "File to /var/lib", + srcPath: "/tmp/testdata/first", + dstPath: "/var/lib", + wantErr: true, + }, + { + title: "File to unaccessible file", + srcPath: "/tmp/testdata/first", + dstPath: "/path/what/not/exists.txt", + wantErr: true, + err: "no such file or directory", + }, + { + title: "Directory to root", + srcPath: "/tmp/testdata", + dstPath: "/", + wantErr: true, + }, + { + title: "Symlink", + srcPath: "/tmp/link", + dstPath: testDir, + wantErr: false, + }, + { + title: "Device", + srcPath: "/dev/zero", + dstPath: "/", + wantErr: true, + err: "failed to open local file", + }, + { + title: "Unaccessible dir", + srcPath: "/var/audit", + dstPath: testDir, + wantErr: true, + }, + { + title: "Unaccessible file", + srcPath: "/etc/sudoers", + dstPath: testDir, + wantErr: true, + err: "failed to copy file from remote host", + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + f := sshClient.File() + err := f.Download(context.Background(), c.srcPath, c.dstPath) + if c.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + return + } + + require.NoError(t, err) + + _, err = os.Stat(c.dstPath) + require.NoError(t, err, "%s path should exist after download", c.dstPath) + }) + } + }) + + t.Run("Equality of downloaded and remote file content", func(t *testing.T) { + downloadContentDir := test.MustMkSubDirs(t, "download_content") + + f := sshClient.File() + + dstPath := path.Join(downloadContentDir, "testfile.txt") + + err := f.Download(context.Background(), "/tmp/testdata/first", dstPath) + // /tmp/testdata/first contains "Some test data" string + require.NoError(t, err) + + assertFilesViaRemoteRun(t, sshClient, "cat /tmp/testdata/first", dstPath) + + downloadedContent, err := os.ReadFile(dstPath) + require.NoError(t, err) + // out contains a contant of uploaded file, should be equal to testFile contant + require.Equal(t, expectedFileContent, string(downloadedContent)) + }) + + t.Run("Equality of downloaded and remote directory", func(t *testing.T) { + downloadWholeDirDir := test.MustMkSubDirs(t, "download_dir") + + f := sshClient.File() + err = f.Download(context.Background(), "/tmp/testdata", downloadWholeDirDir) + require.NoError(t, err) + + cmd := exec.Command("ls -R", downloadWholeDirDir) + lsResult, err := cmd.Output() + require.NoError(t, err) + + assertFilesViaRemoteRun(t, sshClient, "ls -R /tmp/testdata", string(lsResult)) + }) +} + +func TestSSHFileDownloadBytes(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestSSHFileDownloadBytes") + + sshClient := startContainerAndClient(t, test) + + const expectedFileContent = "Some test data" + + // preparing file to download + err := sshClient.Command(fmt.Sprintf(`echo -n '%s' > /tmp/testfile`, expectedFileContent)).Run(context.Background()) + require.NoError(t, err) + + t.Run("Download bytes", func(t *testing.T) { + cases := []struct { + title string + remotePath string + tmpDirName string + wantErr bool + err string + }{ + { + title: "Positive result", + remotePath: "/tmp/testfile", + tmpDirName: os.TempDir(), + wantErr: false, + }, + { + title: "Unaccessible tmp", + remotePath: "/tmp/testfile", + tmpDirName: "/var/lib", + wantErr: true, + err: "create target tmp file", + }, + { + title: "Unaccessible remote file", + remotePath: "/etc/sudoers", + tmpDirName: os.TempDir(), + wantErr: true, + err: "download target tmp file", + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + f := sshClient.File() + bytes, err := f.DownloadBytes(context.Background(), c.remotePath) + if c.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + + require.NoError(t, err) + // out contains a contant of uploaded file, should be equal to testFile contant + require.Equal(t, expectedFileContent, string(bytes)) + }) + } + }) +} diff --git a/pkg/ssh/gossh/keepalive.go b/pkg/ssh/gossh/keepalive.go new file mode 100644 index 0000000..73399c7 --- /dev/null +++ b/pkg/ssh/gossh/keepalive.go @@ -0,0 +1,125 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "errors" + "fmt" + "math/rand" + "time" + + gossh "github.com/deckhouse/lib-gossh" + "github.com/name212/govalue" +) + +var ( + errKeepAliveSessionCreate = fmt.Errorf("Cannot create self keepalive session") + errKeepExited = fmt.Errorf("All keepalive attempts failed") +) + +type keepAliveChecker struct { + client *Client + sleep time.Duration + maxErrors int + + id int + errorsCount int +} + +func newKepAliveChecker(client *Client, sleep time.Duration, maxErrors int) *keepAliveChecker { + id := rand.New(rand.NewSource(time.Now().UnixNano())).Int() + return &keepAliveChecker{ + client: client, + sleep: sleep, + maxErrors: maxErrors, + id: id, + } +} + +func (c *keepAliveChecker) Check() error { + c.debug("do next check...") + + err := c.checkClientAlive() + + if err != nil { + return c.handleClientAliveFailed(err) + } + + c.sendAliveToSessions() + + c.debug("success. Sleep %s before next check", c.sleep.String()) + time.Sleep(c.sleep) + + return nil +} + +func (c *keepAliveChecker) sendKeepAlive(sess *gossh.Session) error { + _, err := sess.SendRequest("keepalive@openssh.com", false, nil) + return err +} + +func (c *keepAliveChecker) checkClientAlive() error { + sess, err := c.client.sshClient.NewSession() + if err != nil { + return fmt.Errorf("%w: %w", errKeepAliveSessionCreate, err) + } + + defer func() { + if err := sess.Close(); err != nil { + c.debug("client self check session close failed: %v", err) + } + }() + + if err := c.sendKeepAlive(sess); err != nil { + return fmt.Errorf("Cannot send to client self check session failed: %w", err) + } + + return nil +} + +func (c *keepAliveChecker) sendAliveToSessions() { + for indx, registeredSession := range c.client.sshSessionsList { + if govalue.Nil(registeredSession) { + c.client.UnregisterSession(registeredSession) + continue + } + + if err := c.sendKeepAlive(registeredSession); err != nil { + c.debug("%s to registered session %d failed: %v", indx, err) + } + } +} + +func (c *keepAliveChecker) handleClientAliveFailed(err error) error { + c.errorsCount++ + + if c.errorsCount > c.maxErrors { + c.debug("too many errors %d encountered. Last err: '%v'. Exit", c.maxErrors, err) + return errKeepExited + } + + if errors.Is(err, errKeepAliveSessionCreate) { + c.debug("failed: '%v'. Error count %d. Sleep %s before next attempt", err, c.errorsCount, c.sleep.String()) + time.Sleep(c.sleep) + } + + return nil +} + +func (c *keepAliveChecker) debug(format string, a ...any) { + debugPrefix := fmt.Sprintf("Keepalive[%d] to %s ", c.id, c.client.sessionClient.String()) + format = debugPrefix + format + c.client.settings.Logger().InfoF(format, a...) +} diff --git a/pkg/ssh/gossh/kube-proxy.go b/pkg/ssh/gossh/kube-proxy.go new file mode 100644 index 0000000..2fb8cbe --- /dev/null +++ b/pkg/ssh/gossh/kube-proxy.go @@ -0,0 +1,424 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "math/rand" + "os" + "regexp" + "strconv" + "time" + + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +const DefaultLocalAPIPort = 22322 + +type KubeProxy struct { + Session *session.Session + sshClient *Client + + KubeProxyPort string + LocalPort string + + proxy *SSHCommand + tunnel *Tunnel + + stop bool + port string + localPort int + + healthMonitorsByStartID map[int]chan struct{} +} + +func NewKubeProxy(client *Client, sess *session.Session) *KubeProxy { + return &KubeProxy{ + sshClient: client, + Session: sess, + port: "0", + localPort: DefaultLocalAPIPort, + healthMonitorsByStartID: make(map[int]chan struct{}), + } +} + +func (k *KubeProxy) Start(useLocalPort int) (port string, err error) { + startID := rand.Int() + + logger := k.sshClient.settings.Logger() + + logger.DebugF("Kube-proxy start id=[%d]; port:%d", startID, useLocalPort) + + success := false + defer func() { + k.stop = false + if !success { + logger.DebugF("[%d] Kube-proxy was not started. Try to clear all", startID) + k.Stop(startID) + } + logger.DebugF("[%d] Kube-proxy starting was finished", startID) + }() + + proxyCommandErrorCh := make(chan error, 1) + var proxy *SSHCommand + for { + proxy, port, err = k.runKubeProxy(proxyCommandErrorCh, startID) + if err != nil { + logger.DebugF("[%d] Got error from runKubeProxy func: %v\n", startID, err) + return "", err + } + + k.stop = false + portNum, err := strconv.Atoi(port) + if err != nil { + continue + } + if portNum > 1024 { + break + } + logger.DebugF("Proxy run on priveleged port %s and will be stopped and restarted\n", port) + k.Stop(startID) + } + + logger.DebugF("[%d] Proxy was started successfully\n", startID) + + k.proxy = proxy + k.port = port + + tunnelErrorCh := make(chan error) + tun, localPort, lastError := k.upTunnel(port, useLocalPort, tunnelErrorCh, startID) + if lastError != nil { + logger.DebugF("[%d] Got error from upTunnel func: %v\n", startID, err) + return "", fmt.Errorf("tunnel up error: max retries reached, last error: %w", lastError) + } + + k.tunnel = tun + k.localPort = localPort + + k.healthMonitorsByStartID[startID] = make(chan struct{}, 1) + go k.healthMonitor( + proxyCommandErrorCh, + tunnelErrorCh, + k.healthMonitorsByStartID[startID], + startID, + ) + + success = true + + return fmt.Sprintf("%d", k.localPort), nil +} + +func (k *KubeProxy) StopAll() { + for startID := range k.healthMonitorsByStartID { + k.Stop(startID) + } +} + +func (k *KubeProxy) Stop(startID int) { + logger := k.sshClient.settings.Logger() + + if k == nil { + logger.DebugF("[%d] Stop kube-proxy: kube proxy object is nil. Skip.\n", startID) + return + } + + if k.stop { + logger.DebugF("[%d] Stop kube-proxy: kube proxy already stopped. Skip.\n", startID) + return + } + + if k.healthMonitorsByStartID[startID] != nil { + k.healthMonitorsByStartID[startID] <- struct{}{} + delete(k.healthMonitorsByStartID, startID) + } + + if k.proxy != nil { + logger.DebugF("[%d] Stop proxy command\n", startID) + k.proxy.Stop() + logger.DebugF("[%d] Proxy command stopped\n", startID) + k.proxy = nil + k.port = "0" + } + if k.tunnel != nil { + logger.DebugF("[%d] Stop tunnel\n", startID) + k.tunnel.Stop() + logger.DebugF("[%d] Tunnel stopped\n", startID) + k.tunnel = nil + } + k.stop = true +} + +func (k *KubeProxy) tryToRestartFully(startID int) { + logger := k.sshClient.settings.Logger() + logger.DebugF("[%d] Try restart kubeproxy fully\n", startID) + for { + k.Stop(startID) + + _, err := k.Start(k.localPort) + + if err == nil { + k.stop = false + logger.DebugF("[%d] Proxy was restarted successfully\n", startID) + return + } + + const sleepTimeout = 5 + + // need warn for human + logger.WarnF( + "Proxy was not restarted: %v. Sleep %d seconds before next attempt.\n", + err, + sleepTimeout, + ) + time.Sleep(sleepTimeout * time.Second) + + k.Session.ChoiceNewHost() + logger.DebugF("[%d] New host selected %v\n", startID, k.Session.Host()) + } +} + +func (k *KubeProxy) proxyCMD(startID int) *SSHCommand { + kubectlProxy := fmt.Sprintf( + // --disable-filter is needed to exec into etcd pods + "kubectl proxy --as=dhctl --as-group=system:masters --port=%s --kubeconfig /etc/kubernetes/admin.conf --disable-filter", + k.port, + ) + if v := os.Getenv("KUBE_PROXY_ACCEPT_HOSTS"); v != "" { + kubectlProxy += fmt.Sprintf(" --accept-hosts='%s'", v) + } + command := fmt.Sprintf("PATH=$PATH:%s/; %s", k.sshClient.settings.NodeBinPath(), kubectlProxy) + + k.sshClient.settings.Logger().DebugF("[%d] Proxy command for start: %s\n", startID, command) + + cmd := NewSSHCommand(k.sshClient, command) + cmd.Sudo(context.Background()) + return cmd +} + +func (k *KubeProxy) healthMonitor( + proxyErrorCh, tunnelErrorCh chan error, + stopCh chan struct{}, + startID int, +) { + logger := k.sshClient.settings.Logger() + + defer logger.DebugF("[%d] Kubeproxy health monitor stopped\n", startID) + logger.DebugF("[%d] Kubeproxy health monitor started\n", startID) + + proxyErrorCount := 0 + for { + logger.DebugF("[%d] Kubeproxy Monitor step\n", startID) + select { + case err := <-proxyErrorCh: + logger.DebugF("[%d] Proxy failed with error %v\n", startID, err) + // if proxy crushed, we need to restart kube-proxy fully + // with proxy and tunnel (tunnel depends on proxy) + k.tryToRestartFully(startID) + // if we restart proxy fully + // this monitor must be finished because new monitor was started + return + + case err := <-tunnelErrorCh: + logger.DebugF("[%d] Tunnel failed %v. Stopping previous tunnel\n", startID, err) + // we need fully stop tunnel because + k.tunnel.Stop() + + logger.DebugF("[%d] Tunnel stopped before restart. Starting new tunnel...\n", startID) + + if proxyErrorCount < 3 { + k.tunnel, _, err = k.upTunnel(k.port, k.localPort, tunnelErrorCh, startID) + if err != nil { + logger.DebugF("[%d] Tunnel was not up: %v. Try to restart fully\n", startID, err) + k.tryToRestartFully(startID) + return + } + proxyErrorCount++ + } else { + k.tryToRestartFully(startID) + return + } + + logger.DebugF("[%d] Tunnel re up successfully\n") + + case <-stopCh: + logger.DebugF("[%d] Kubeproxy monitor stopped") + return + } + } +} + +func (k *KubeProxy) upTunnel( + kubeProxyPort string, + useLocalPort int, + tunnelErrorCh chan error, + startID int, +) (tun *Tunnel, localPort int, err error) { + logger := k.sshClient.settings.Logger() + + logger.DebugF( + "[%d] Starting up tunnel with proxy port %s and local port %d\n", + startID, + kubeProxyPort, + useLocalPort, + ) + + rewriteLocalPort := false + localPort = useLocalPort + + if useLocalPort < 1 { + logger.DebugF( + "[%d] Incorrect local port %d use default %d\n", + startID, + useLocalPort, + DefaultLocalAPIPort, + ) + localPort = DefaultLocalAPIPort + rewriteLocalPort = true + } + + maxRetries := 5 + retries := 0 + var lastError error + for { + logger.DebugF("[%d] Start %d iteration for up tunnel\n", startID, retries) + + if k.proxy.WaitError() != nil { + lastError = fmt.Errorf("proxy was failed while restart tunnel") + break + } + + // try to start tunnel from localPort to proxy port + var tunnelAddress string + if v := os.Getenv("KUBE_PROXY_BIND_ADDR"); v != "" { + tunnelAddress = fmt.Sprintf("%s:%d:localhost:%s", v, localPort, kubeProxyPort) + } else { + tunnelAddress = fmt.Sprintf("%s:%s:localhost:%d", "127.0.0.1", kubeProxyPort, localPort) + } + + logger.DebugF("[%d] Try up tunnel on %v\n", startID, tunnelAddress) + tun = NewTunnel(k.sshClient, tunnelAddress) + err := tun.Up() + if err != nil { + logger.DebugF("[%d] Start tunnel was failed. Cleaning...\n", startID) + tun.Stop() + lastError = fmt.Errorf("tunnel '%s': %w", tunnelAddress, err) + logger.DebugF("[%d] Start tunnel was failed. Error: %v\n", startID, lastError) + if rewriteLocalPort { + localPort++ + logger.DebugF("[%d] New local port %d\n", startID, localPort) + } + + retries++ + if retries >= maxRetries { + logger.DebugF("[%d] Last iteration finished\n", startID) + tun = nil + break + } + } else { + logger.DebugF("[%d] Tunnel was started. Starting health monitor\n", startID) + go tun.HealthMonitor(tunnelErrorCh) + lastError = nil + break + } + } + + dbgMsg := fmt.Sprintf("Tunnel up on local port %d", localPort) + if lastError != nil { + dbgMsg = fmt.Sprintf("Tunnel was not up: %v", lastError) + } + + logger.DebugF("[%d] %s\n", startID, dbgMsg) + + return tun, localPort, lastError +} + +func (k *KubeProxy) runKubeProxy( + waitCh chan error, + startID int, +) (proxy *SSHCommand, port string, err error) { + logger := k.sshClient.settings.Logger() + + logger.DebugF("[%d] Begin starting proxy\n", startID) + proxy = k.proxyCMD(startID) + + port = "" + portReady := make(chan struct{}, 1) + portRe := regexp.MustCompile(`Starting to serve on .*?:(\d+)`) + + proxy.WithStdoutHandler(func(line string) { + m := portRe.FindStringSubmatch(line) + if len(m) == 2 && m[1] != "" { + port = m[1] + logger.DebugF("Got proxy port = %s on host %s\n", port, k.Session.Host()) + portReady <- struct{}{} + } + }) + + onStart := make(chan struct{}, 1) + proxy.OnCommandStart(func() { + logger.DebugF("[%d] Command started\n", startID) + onStart <- struct{}{} + }) + + proxy.WithWaitHandler(func(err error) { + logger.DebugF("[%d] Wait error: %v\n", startID, err) + waitCh <- err + }) + + logger.DebugF("[%d] Start proxy command\n", startID) + err = proxy.Start() + if err != nil { + logger.DebugF("[%d] Start proxy command error: %v\n", startID, err) + return nil, "", fmt.Errorf("start kubectl proxy: %w", err) + } + + logger.DebugF("[%d] Proxy command was started\n", startID) + + returnWaitErr := func(err error) error { + logger.DebugF("[%d] Proxy command waiting error: %v\n", startID, err) + template := `Proxy exited suddenly: %s%s +Status: %w` + return fmt.Errorf(template, string(proxy.StdoutBytes()), string(proxy.StderrBytes()), err) + } + + // we need to check that kubeproxy was started + // that checking wait string pattern in output + // but we may receive error and this error will get from waitCh + select { + case <-onStart: + case err := <-waitCh: + return nil, "", returnWaitErr(err) + } + + // Wait for proxy startup + t := time.NewTicker(20 * time.Second) + defer t.Stop() + select { + case e := <-waitCh: + return nil, "", returnWaitErr(e) + case <-t.C: + logger.DebugF("[%d] Starting proxy command timeout\n", startID) + return nil, "", fmt.Errorf("timeout waiting for api proxy port") + case <-portReady: + if port == "" { + logger.DebugF("[%d] Starting proxy command: empty port\n", startID) + return nil, "", fmt.Errorf("got empty port from kubectl proxy") + } + } + + logger.DebugF("[%d] Proxy process started with port: %s\n", startID, port) + return proxy, port, nil +} diff --git a/pkg/ssh/gossh/reverse-tunnel.go b/pkg/ssh/gossh/reverse-tunnel.go new file mode 100644 index 0000000..3f9d493 --- /dev/null +++ b/pkg/ssh/gossh/reverse-tunnel.go @@ -0,0 +1,315 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "io" + "math/rand/v2" + "net" + "strings" + "sync" + + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/pkg/errors" + + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +type tunnelWaitResult struct { + id int + err error +} + +type ReverseTunnel struct { + sshClient *Client + address string + + tunMutex sync.Mutex + + started bool + stopCh chan struct{} + remoteListener net.Listener + + errorCh chan tunnelWaitResult +} + +func NewReverseTunnel(sshClient *Client, address string) *ReverseTunnel { + return &ReverseTunnel{ + sshClient: sshClient, + address: address, + errorCh: make(chan tunnelWaitResult), + } +} + +func (t *ReverseTunnel) Up() error { + _, err := t.upNewTunnel(-1) + return err +} + +func (t *ReverseTunnel) upNewTunnel(oldId int) (int, error) { + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + + logger := t.sshClient.settings.Logger() + + if t.started { + logger.DebugF("[%d] Reverse tunnel already up\n", oldId) + return -1, fmt.Errorf("already up") + } + + id := rand.Int() + + parts := strings.Split(t.address, ":") + if len(parts) != 4 { + return -1, fmt.Errorf("invalid address must be 'remote_bind:remote_port:local_bind:local_port': %s", t.address) + } + + remoteBind, remotePort, localBind, localPort := parts[0], parts[1], parts[2], parts[3] + + logger.DebugF("[%d] Remote bind: %s remote port: %s local bind: %s local port: %s\n", id, remoteBind, remotePort, localBind, localPort) + + logger.DebugF("[%d] Start reverse tunnel\n", id) + + remoteAddress := net.JoinHostPort(remoteBind, remotePort) + localAddress := net.JoinHostPort(localBind, localPort) + + // reverse listen on remote server port + listener, err := t.sshClient.GetClient().Listen("tcp", remoteAddress) + if err != nil { + return -1, errors.Wrap(err, fmt.Sprintf("failed to listen remote on %s", remoteAddress)) + } + + logger.DebugF("[%d] Listen remote %s successful\n", id, remoteAddress) + + go t.acceptTunnelConnection(id, localAddress, listener) + + t.remoteListener = listener + t.started = true + + return id, nil +} + +func (t *ReverseTunnel) acceptTunnelConnection(id int, localAddress string, listener net.Listener) { + logger := t.sshClient.settings.Logger() + for { + client, err := listener.Accept() + if err != nil { + e := fmt.Errorf("Accept(): %s", err.Error()) + t.errorCh <- tunnelWaitResult{ + id: id, + err: e, + } + return + } + + logger.DebugF("[%d] connection accepted. Try to connect to local %s\n", id, localAddress) + + local, err := net.Dial("tcp", localAddress) + if err != nil { + e := fmt.Errorf("Cannot dial to %s: %s", localAddress, err.Error()) + t.errorCh <- tunnelWaitResult{ + id: id, + err: e, + } + return + } + + logger.DebugF("[%d] Connected to local %s\n", id, localAddress) + + // handle the connection in another goroutine, so we can support multiple concurrent + // connections on the same port + go t.handleClient(id, client, local) + } +} + +func (t *ReverseTunnel) handleClient(id int, client net.Conn, remote net.Conn) { + logger := t.sshClient.settings.Logger() + + defer func() { + err := client.Close() + if err != nil { + logger.DebugF("[%d] Cannot close connection: %s\n", id, err) + } + }() + + chDone := make(chan struct{}, 2) + + // Start remote -> local data transfer + go func() { + _, err := io.Copy(client, remote) + if err != nil { + logger.WarnF(fmt.Sprintf("[%d] Error while copy remote->local: %s\n", id, err)) + } + chDone <- struct{}{} + }() + + // Start local -> remote data transfer + go func() { + _, err := io.Copy(remote, client) + if err != nil { + logger.WarnF(fmt.Sprintf("[%d] Error while copy local->remote: %s\n", id, err)) + } + chDone <- struct{}{} + }() + + <-chDone +} + +func (t *ReverseTunnel) isStarted() bool { + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + r := t.started + return r +} + +func (t *ReverseTunnel) tryToRestart(ctx context.Context, id int, killer connection.ReverseTunnelKiller) (int, error) { + t.stop(id, false) + t.sshClient.settings.Logger().DebugF("[%d] Kill tunnel\n", id) + // (k EmptyReverseTunnelKiller) KillTunnel won't return error anyways, so we couldn't check return values + killer.KillTunnel(ctx) + return t.upNewTunnel(id) +} + +func (t *ReverseTunnel) StartHealthMonitor(ctx context.Context, checker connection.ReverseTunnelChecker, _ connection.ReverseTunnelKiller) { + t.tunMutex.Lock() + t.stopCh = make(chan struct{}) + t.tunMutex.Unlock() + + logger := t.sshClient.settings.Logger() + + // in go ssh implementation we do not need separate script for kill tunnel from server-side + // because listener.Close() close tunnel in the server side + // but we need to backward compatibility with cli ssh + killer := utils.EmptyReverseTunnelKiller{} + + checkReverseTunnel := func(id int) bool { + logger.DebugF("[%d] Start Check reverse tunnel\n", id) + + checkLoopParams := t.sshClient.loopsParams.CheckReverseTunnel + checkLoopParams = retry.SafeCloneOrNewParams(checkLoopParams, defaultReverseTunnelParamsOps...). + WithName("Check reverse tunnel"). + WithLogger(logger) + + err := retry.NewSilentLoopWithParams(checkLoopParams).RunContext(ctx, func() error { + out, err := checker.CheckTunnel(ctx) + if err != nil { + logger.DebugF("[%d] Cannot check ssh tunnel: '%v': stderr: '%s'\n", id, err, out) + return err + } + + return nil + }) + + if err != nil { + logger.DebugF("[%d] Tunnel check timeout, last error: %v\n", id, err) + return false + } + + logger.DebugF("[%d] Tunnel check successful!\n", id) + return true + } + + go func() { + logger.DebugLn("Start health monitor") + // we need chan for restarting because between restarting we can get stop signal + restartCh := make(chan int, 1024) + id := -1 + restartsCount := 0 + restart := func(id int) { + logger.DebugF("[%d] Send restart signal\n", id) + restartCh <- id + logger.DebugF("[%d] Signal was sent. Chan len: %d\n", id, len(restartCh)) + } + for { + + if !checkReverseTunnel(id) { + go restart(id) + } + + select { + case <-t.stopCh: + logger.DebugLn("Stop health monitor") + return + case oldId := <-restartCh: + restartsCount++ + logger.DebugF("[%d] Restart signal was received: restarts count %d\n", oldId, restartsCount) + + if restartsCount > 1024 { + panic("Reverse tunnel restarts count exceeds 1024") + } + + newId, err := t.tryToRestart(ctx, oldId, killer) + if err != nil { + logger.DebugF("[%d] Restart failed with error: %v\n", oldId, err) + go restart(oldId) + continue + } + logger.DebugF("[%d] Restart successful. New id %d\n", oldId, newId) + id = newId + restartsCount = 0 + case err := <-t.errorCh: + id = err.id + logger.DebugF("[%d] Tunnel was stopped with error '%v'. Try restart fully\n", id, err.err) + started := t.isStarted() + if started { + logger.DebugF("[%d] Tunnel already up. Skip restarting\n", id) + continue + } + + go restart(id) + continue + } + } + }() +} + +func (t *ReverseTunnel) Stop() { + t.stop(-1, true) +} + +func (t *ReverseTunnel) stop(id int, full bool) { + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + + logger := t.sshClient.settings.Logger() + + if !t.started { + logger.DebugF("[%d] Reverse tunnel already stopped\n", id) + return + } + + logger.DebugF("[%d] Stop reverse tunnel\n", id) + defer logger.DebugF("[%d] End stop reverse tunnel\n", id) + + if full && t.stopCh != nil { + logger.DebugF("[%d] Stop reverse tunnel health monitor\n", id) + t.stopCh <- struct{}{} + } + + err := t.remoteListener.Close() + if err != nil { + logger.WarnF("[%d] Cannot close remote listener: %s\n", id, err.Error()) + } + + t.remoteListener = nil + t.started = false +} + +func (t *ReverseTunnel) String() string { + return fmt.Sprintf("%s:%s", "R", t.address) +} diff --git a/pkg/ssh/gossh/reverse-tunnel_test.go b/pkg/ssh/gossh/reverse-tunnel_test.go new file mode 100644 index 0000000..5ecd934 --- /dev/null +++ b/pkg/ssh/gossh/reverse-tunnel_test.go @@ -0,0 +1,223 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +func TestReverseTunnel(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestReverseTunnel") + + sshClient, container := startContainerAndClientWithContainer(t, test, sshtesting.WithNoWriteSSHDConfig()) + + // we don't have /opt/deckhouse in the container, so we should create it before start any UploadScript with sudo + err := container.Container.CreateDeckhouseDirs() + require.NoError(t, err, "could not create deckhouse dirs") + + containerPort := container.LocalPort() + localServerPort := sshtesting.RandPortExclude([]int{containerPort}) + + const response = "Simple response" + handler := sshtesting.NewSimpleHTTPHandler("/my/action", response) + + sshtesting.MustStartHTTPServer(t, test, localServerPort, handler) + + registerStopReverceTunnel := func(t *testing.T, tunnel *ReverseTunnel) { + t.Cleanup(func() { + tunnel.Stop() + }) + } + + containerSSHDPort := container.Container.RemotePort() + upTunnelRemoteServerPort := sshtesting.RandPortExclude([]int{containerSSHDPort}) + + t.Run("Reverse tunnel from container to host", func(t *testing.T) { + remoteServerInvalidPort := sshtesting.RandPortExclude([]int{upTunnelRemoteServerPort, containerSSHDPort}) + localInvalidPort := sshtesting.RandInvalidPortExclude([]int{localServerPort}) + + cases := []struct { + title string + address string + wantErr bool + err string + errFromChan string + }{ + { + title: "Tunnel, success", + address: tunnelAddressString(localServerPort, upTunnelRemoteServerPort), + wantErr: false, + }, + { + title: "Invalid address", + address: "swsws:111aaa:", + wantErr: true, + err: "invalid address must be 'remote_bind:remote_port:local_bind:local_port'", + }, + { + title: "Invalid local bind", + address: tunnelAddressString(localInvalidPort, containerSSHDPort), + wantErr: true, + err: fmt.Sprintf("failed to listen remote on 127.0.0.1:%d", upTunnelRemoteServerPort), + }, + { + title: "Wrong local bind", + address: tunnelAddressString(localServerPort, remoteServerInvalidPort), + wantErr: false, + errFromChan: fmt.Sprintf("Cannot dial to 127.0.0.1:%d", remoteServerInvalidPort), + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + tun := NewReverseTunnel(sshClient, c.address) + err := tun.Up() + + registerStopReverceTunnel(t, tun) + + if c.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + + requestAddress := fmt.Sprintf("http://127.0.0.1:%d%s", upTunnelRemoteServerPort, handler.Path) + + // try to get a response from local web server + cmd := NewSSHCommand(sshClient, "curl", "-m", "4", "-s", requestAddress) + cmd.WithTimeout(6 * time.Second) + out, err := cmd.CombinedOutput(context.Background()) + require.NoError(t, err, "execute remote curl %s", requestAddress) + + if len(c.errFromChan) == 0 { + require.Equal(t, response, string(out)) + } else { + errMsg := <-tun.errorCh + require.Contains(t, errMsg.err.Error(), c.errFromChan) + } + + // try to up again: expecting error + err = tun.Up() + require.Error(t, err) + require.Equal(t, err.Error(), "already up") + }) + } + }) + + t.Run("HealthMonitor test", func(t *testing.T) { + healthMonitorRemoteServerPort := sshtesting.RandPortExclude([]int{upTunnelRemoteServerPort, containerSSHDPort}) + + tun := NewReverseTunnel(sshClient, tunnelAddressString(localServerPort, healthMonitorRemoteServerPort)) + err := tun.Up() + require.NoError(t, err) + + registerStopReverceTunnel(t, tun) + + remoteHealthz := fmt.Sprintf("http://127.0.0.1:%d%s", healthMonitorRemoteServerPort, sshtesting.HealthzPath) + + script := fmt.Sprintf(`#!/bin/bash +URL="%s" + +curl -m 4 -s $URL > /dev/null +exit $? +`, remoteHealthz) + + testFile := test.MustCreateTmpFile(t, script, true, "script", "test.sh") + + checker := utils.NewRunScriptReverseTunnelChecker(sshClient, testFile) + killer := utils.EmptyReverseTunnelKiller{} + + checkLoop := retry.NewEmptyParams( + retry.WithName("Check tunnel"), + retry.WithAttempts(30), + retry.WithWait(2*time.Second), + retry.WithLogger(test.Logger), + ) + + checkTunnelAction := func() error { + out, err := checker.CheckTunnel(context.Background()) + if err != nil { + test.Logger.InfoF("Failed to check tunnel: %s %v", out, err) + return err + } + return nil + } + + err = retry.NewLoopWithParams(checkLoop).Run(checkTunnelAction) + require.NoError(t, err, "tunnel check") + + sshClient.WithLoopsParams(ClientLoopsParams{ + CheckReverseTunnel: retry.NewEmptyParams( + retry.WithAttempts(5), + retry.WithWait(500*time.Millisecond), + ), + }) + + upMonitorSleep := 2 * time.Second + restartSleep := 5 * time.Second + + tun.StartHealthMonitor(context.Background(), checker, killer) + test.Logger.InfoF( + "Waiting %s for tunnel monitor to start. And restart container. Wait %s before start container for fail check", + upMonitorSleep.String(), + restartSleep.String(), + ) + + time.Sleep(upMonitorSleep) + err = container.Container.Restart(true, restartSleep) + require.NoError(t, err, "container restart") + err = container.Container.CreateDeckhouseDirs() + require.NoError(t, err, "create deckhouse dirs") + + test.Logger.InfoF( + "Waiting %s for tunnel monitor to restart", + upMonitorSleep.String(), + ) + + time.Sleep(upMonitorSleep) + + checkLoopAfterRestart := retry.SafeCloneOrNewParams(checkLoop). + WithName("Check tunnel after restart") + + err = retry.NewLoopWithParams(checkLoopAfterRestart).Run(checkTunnelAction) + require.NoError(t, err, "tunnel check after restart") + + test.Logger.InfoF( + "Disconnect (fail connection between server and client) case. Wait %s before connect. Wait %s before check", + restartSleep.String(), + upMonitorSleep.String(), + ) + + // fail connection case + err = container.Container.FailAndUpConnection(restartSleep) + require.NoError(t, err, "container fail connection") + + time.Sleep(upMonitorSleep) + + checkLoopAfterDisconnect := retry.SafeCloneOrNewParams(checkLoop). + WithName("Check tunnel after disconnect") + + err = retry.NewLoopWithParams(checkLoopAfterDisconnect).Run(checkTunnelAction) + require.NoError(t, err, "tunnel check after disconnect") + }) +} diff --git a/pkg/ssh/gossh/testing/agent.go b/pkg/ssh/gossh/testing/agent.go new file mode 100644 index 0000000..089b278 --- /dev/null +++ b/pkg/ssh/gossh/testing/agent.go @@ -0,0 +1,276 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "sync" + "syscall" + "testing" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/name212/govalue" + "github.com/stretchr/testify/require" +) + +type Agent struct { + logger log.Logger + + mu sync.RWMutex + sockPath string + pid int + + stopCh chan struct{} +} + +var pidRegex = regexp.MustCompile(`SSH_AGENT_PID=(\d+);`) + +type PrivateKey struct { + Path string + Password string +} + +func StartTestAgent(t *testing.T, wrapper *TestContainerWrapper) *Agent { + sockDir := wrapper.Settings.Test.TmpDir() + var privateKey []PrivateKey + if wrapper.PrivateKeyPath != "" { + privateKey = append(privateKey, PrivateKey{ + Path: wrapper.PrivateKeyPath, + }) + } + + agent, err := StartAgent(sockDir, wrapper.Settings.Test.Logger, privateKey...) + require.NoError(t, err) + agent.RegisterCleanup(t) + + return agent +} + +func StartAgent(sockDir string, logger log.Logger, keysPath ...PrivateKey) (*Agent, error) { + _, err := os.Stat(sockDir) + if err != nil { + return nil, fmt.Errorf("failed to stat agent socket directory %s: %s", sockDir, err) + } + + id := GenerateID("test-agent") + sockPath := filepath.Join(sockDir, fmt.Sprintf("test-ssh-agent-%s.sock", id)) + + if govalue.Nil(logger) { + logger = TestLogger() + } + + agent := &Agent{ + logger: logger, + sockPath: sockPath, + stopCh: make(chan struct{}, 1), + } + + if err := agent.start(); err != nil { + return nil, fmt.Errorf("failed to start test ssh-agent: %w", err) + } + + for _, key := range keysPath { + if err := agent.AddKey(key); err != nil { + agent.Stop() + return nil, err + } + } + + return agent, nil +} + +func (a *Agent) start() error { + sock := a.SockPath() + cmd := exec.Command("ssh-agent", "-a", sock) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("cannot start ssh-agent with sock %s: %w", sock, err) + } + + pidSubmatches := pidRegex.FindSubmatch(out) + if len(pidSubmatches) < 2 { + return fmt.Errorf("cannot find pid in ssh-agent output: %s", string(out)) + } + + pid, err := strconv.Atoi(string(pidSubmatches[1])) + if err != nil { + return fmt.Errorf("cannot parse pid in ssh-agent output: %s", string(out)) + } + + a.pid = pid + + a.logInfo("started successfully with pid: %d", a.Pid()) + + go func() { + stopCh := a.stopCh + select { + case <-stopCh: + a.logInfo("shutting down ssh-agent") + // Find the process by its PID + process, err := os.FindProcess(a.Pid()) + if err != nil { + a.cleanupAndLog("find process", err) + return + } + + err = process.Signal(syscall.SIGTERM) + a.cleanupAndLog("kill", err) + return + } + }() + + return nil +} + +func (a *Agent) AddKey(key PrivateKey) error { + path := key.Path + if path == "" { + return a.wrapError("key path is empty", fmt.Errorf("invalid input")) + } + _, err := os.Stat(path) + if err != nil { + return a.wrapError(fmt.Sprintf("failed to check private key path %s exist", path), err) + } + + return a.run(key.Path, "ssh-add", path) +} + +func (a *Agent) RemoveKey(key PrivateKey) error { + return a.run("", "ssh-add", "-d", key.Path) +} + +func (a *Agent) IsStopped() bool { + pid := a.Pid() + return pid == 0 || a.stopCh == nil +} + +func (a *Agent) Pid() int { + a.mu.RLock() + defer a.mu.RUnlock() + + return a.pid +} + +func (a *Agent) SockPath() string { + a.mu.RLock() + defer a.mu.RUnlock() + + return a.sockPath +} + +func (a *Agent) Stop() { + if a.stopCh == nil { + return + } + + ch := a.stopCh + a.stopCh = nil + + close(ch) +} + +func (a *Agent) RegisterCleanup(t *testing.T) { + t.Cleanup(func() { + socket := a.SockPath() + if socket == "" { + return + } + + a.Stop() + leaveSocket := retry.NewEmptyParams( + retry.WithName(fmt.Sprintf("Wait socket %s leave", socket)), + retry.WithWait(2*time.Second), + retry.WithAttempts(10), + retry.WithLogger(a.logger), + ) + + _ = retry.NewLoopWithParams(leaveSocket).Run(func() error { + _, err := os.Stat(socket) + if err != nil { + return nil + } + + return fmt.Errorf("socket %s is still running", socket) + }) + }) +} + +func (a *Agent) String() string { + return fmt.Sprintf("test agent (socket: '%s'; pid: %d)", a.SockPath(), a.Pid()) +} + +func (a *Agent) run(stdin string, name string, args ...string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, name, args...) + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_AUTH_SOCK=%s", a.SockPath())) + + if stdin != "" { + cmd.Stdin = strings.NewReader(stdin) + } + + a.logInfo("run %s with envs: %s", cmd.String(), strings.Join(cmd.Env, " ")) + + out, err := cmd.CombinedOutput() + if err != nil { + return a.wrapError(fmt.Sprintf("error running %s (output: %s)", cmd.String(), string(out)), err) + } + + return nil +} + +func (a *Agent) cleanupAndLog(msg string, err error) { + if err != nil { + a.logError("%s receive error: %v", msg, err) + return + } + + a.mu.Lock() + + a.pid = 0 + a.sockPath = "" + + a.mu.Unlock() + + a.logInfo("%s success", msg) +} + +func (a *Agent) logInfo(f string, args ...any) { + a.log(a.logger.InfoF, f, args...) +} + +func (a *Agent) logError(f string, args ...any) { + a.log(a.logger.ErrorF, f, args...) +} + +func (a *Agent) wrapError(msg string, err error) error { + return fmt.Errorf("%s %s: %w", msg, a.String(), err) +} + +func (a *Agent) log(writeLog func(string, ...any), f string, args ...any) { + f = a.String() + ": " + f + writeLog(f, args...) +} diff --git a/pkg/ssh/gossh/testing/docker.go b/pkg/ssh/gossh/testing/docker.go new file mode 100644 index 0000000..627919f --- /dev/null +++ b/pkg/ssh/gossh/testing/docker.go @@ -0,0 +1,43 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "errors" + "fmt" + "os/exec" +) + +func RunDockerWithOut(command ...string) (string, error) { + if len(command) == 0 { + return "", errors.New("docker command is empty") + } + + cmd := exec.Command("docker", command...) + out, err := cmd.CombinedOutput() + + outStr := string(out) + + if err != nil { + return "", fmt.Errorf("cannot run docker: '%v', output: %s", err, outStr) + } + + return outStr, nil +} + +func RunDocker(command ...string) error { + _, err := RunDockerWithOut(command...) + return err +} diff --git a/pkg/ssh/gossh/testing/helpers.go b/pkg/ssh/gossh/testing/helpers.go new file mode 100644 index 0000000..8241f7d --- /dev/null +++ b/pkg/ssh/gossh/testing/helpers.go @@ -0,0 +1,142 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/pem" + "fmt" + "os" + "testing" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/deckhouse/lib-dhctl/pkg/retry" + gossh "github.com/deckhouse/lib-gossh" + "github.com/name212/govalue" + "github.com/stretchr/testify/require" +) + +const PrivateKeysRoot = "private_keys" + +func marshalKey(privateKey *rsa.PrivateKey, passphrase string) (*pem.Block, error) { + if len(passphrase) == 0 { + return gossh.MarshalPrivateKey(privateKey, "") + } + + return gossh.MarshalPrivateKeyWithPassphrase(privateKey, "", []byte(passphrase)) +} + +// helper func to generate SSH keys +func GenerateKeys(test *Test, passphrase string) (string, string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", fmt.Errorf("cannot generate key: %w", err) + } + + publicKey, err := gossh.NewPublicKey(privateKey.Public()) + if err != nil { + return "", "", fmt.Errorf("cannot create public key from private: %w", err) + } + + privateKeyPem, err := marshalKey(privateKey, passphrase) + if err != nil { + return "", "", fmt.Errorf("cannot marshal private key: %w", err) + } + + pemBytes := pem.EncodeToMemory(privateKeyPem) + + privateKeyPath, err := test.CreateTmpFile(string(pemBytes), false, PrivateKeysRoot, "id_rsa") + if err != nil { + return "", "", fmt.Errorf("cannot write private key: %w", err) + } + + if err := os.Chmod(privateKeyPath, 0600); err != nil { + return "", "", fmt.Errorf("cannot chmod to 600 private key file: %w", err) + } + + return privateKeyPath, string(gossh.MarshalAuthorizedKey(publicKey)), nil +} + +func WritePubKeyFileForPrivate(test *Test, privateKeyPath string, pubKey string) (string, error) { + return test.CreateFileWithSameSuffix(privateKeyPath, pubKey, false, PrivateKeysRoot, "id_rsa.pub") +} + +func LogErrorOrAssert(t *testing.T, description string, err error, logger log.Logger) { + if err == nil { + return + } + + if govalue.Nil(logger) { + require.NoError(t, err, description) + return + } + + logger.ErrorF("%s: %v", description, err) +} + +func CheckSkipSSHTest(t *testing.T, testName string) { + if os.Getenv("SKIP_GOSSH_TEST") == "true" { + t.Skipf("Skipping %s test. SKIP_GOSSH_TEST=true env passed", testName) + } +} + +func GetTestLoopParamsForFailed() retry.Params { + return retry.NewEmptyParams( + retry.WithWait(2*time.Second), + retry.WithAttempts(4), + ) +} + +func IncorrectHost() string { + third := RandRange(1, 254) + four := RandRange(1, 254) + return fmt.Sprintf("192.168.%d.%d", third, four) +} + +func Sleep(d time.Duration) { + if d > 0 { + time.Sleep(d) + } +} + +func removeFiles(paths ...string) []error { + removeErrors := make([]error, 0, len(paths)) + for _, path := range paths { + if path == "" { + continue + } + + stat, err := os.Stat(path) + if err != nil { + if !os.IsNotExist(err) { + removeErrors = append(removeErrors, fmt.Errorf("cannot stat %s: %w", path, err)) + } + continue + } + + remove := os.Remove + if stat.IsDir() { + remove = os.RemoveAll + } + + if err := remove(path); err != nil && !os.IsNotExist(err) { + removeErrors = append(removeErrors, err) + } + } + + return removeErrors +} diff --git a/pkg/ssh/gossh/testing/prefix_logger.go b/pkg/ssh/gossh/testing/prefix_logger.go new file mode 100644 index 0000000..32a5423 --- /dev/null +++ b/pkg/ssh/gossh/testing/prefix_logger.go @@ -0,0 +1,66 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "fmt" + + "github.com/deckhouse/lib-dhctl/pkg/log" +) + +type PrefixLogger struct { + log.Logger + prefix string + address string +} + +func newPrefixLoggerWithAddress(logger log.Logger, address string) *PrefixLogger { + l := NewPrefixLogger(logger) + l.address = address + return l.WithPrefix("") +} + +func NewPrefixLogger(logger log.Logger) *PrefixLogger { + l := &PrefixLogger{ + Logger: logger, + } + + return l.WithPrefix("") +} + +func (l *PrefixLogger) Log(write func(string, ...any), f string, args ...any) { + if l.prefix != "" { + f = l.prefix + ": " + f + } + + write(f, args...) +} + +func (l *PrefixLogger) Error(f string, args ...any) { + l.Log(l.ErrorF, f, args...) +} + +func (l *PrefixLogger) Info(f string, args ...any) { + l.Log(l.InfoF, f, args...) +} + +func (l *PrefixLogger) WithPrefix(p string) *PrefixLogger { + if l.address != "" { + p = fmt.Sprintf("%s (%s)", p, l.address) + } + + l.prefix = p + return l +} diff --git a/pkg/ssh/gossh/testing/rand.go b/pkg/ssh/gossh/testing/rand.go new file mode 100644 index 0000000..83389be --- /dev/null +++ b/pkg/ssh/gossh/testing/rand.go @@ -0,0 +1,110 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "crypto/sha256" + "fmt" + mathrand "math/rand" + "slices" + "strings" + "time" +) + +const ( + portRangeStart = 22000 + portRangeEnd = 29999 +) + +var ( + lettersRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + passwordRunes = append( + append([]rune{}, lettersRunes...), + []rune(" %!@#$&^*.,/")..., + ) +) + +func RandRange(min, max int) int { + return randRange(getRand(), min, max) +} + +func RandPort() int { + return RandRange(portRangeStart, portRangeEnd) +} + +func RandPortExclude(exclude []int) int { + return RandRangeExclude(portRangeStart, portRangeEnd, exclude) +} + +func GenerateID(names ...string) string { + if len(names) == 0 { + names = make([]string, 0, 1) + } + + names = append(names, randString(12, lettersRunes)) + sumString := strings.Join(names, "/") + sum := sha256Encode(sumString) + + return fmt.Sprintf("%.12s", sum) +} + +func RandRangeExclude(min, max int, exclude []int) int { + randomizer := getRand() + for i := 0; i < 100; i++ { + v := randRange(randomizer, min, max) + if slices.Contains(exclude, v) { + continue + } + + return v + } + + panic("random range exclude failed after 100 iterations") +} + +func RandInvalidPortExclude(_ []int) int { + return 0 +} + +func RandPassword(n int) string { + return randString(n, passwordRunes) +} + +func randString(n int, letters []rune) string { + randomizer := getRand() + + b := make([]rune, n) + for i := range b { + b[i] = letters[randomizer.Intn(len(letters))] + } + + return string(b) +} + +func getRand() *mathrand.Rand { + return mathrand.New(mathrand.NewSource(time.Now().UnixNano())) +} + +func randRange(randomizer *mathrand.Rand, min, max int) int { + return randomizer.Intn(max-min) + min +} + +func sha256Encode(input string) string { + hasher := sha256.New() + + hasher.Write([]byte(input)) + + return fmt.Sprintf("%x", hasher.Sum(nil)) +} diff --git a/pkg/ssh/gossh/testing/setttings.go b/pkg/ssh/gossh/testing/setttings.go new file mode 100644 index 0000000..435fd20 --- /dev/null +++ b/pkg/ssh/gossh/testing/setttings.go @@ -0,0 +1,124 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "strconv" + + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-dhctl/pkg/log" +) + +func TestLogger() *log.InMemoryLogger { + return log.NewInMemoryLoggerWithParent(log.NewPrettyLogger(log.LoggerOptions{IsDebug: false})) +} + +func getDefaultParams(test *Test) settings.ProviderParams { + return settings.ProviderParams{ + LoggerProvider: log.SimpleLoggerProvider(test.Logger), + IsDebug: true, + } +} + +func CreateDefaultTestSettings(test *Test) settings.Settings { + return settings.NewBaseProviders(getDefaultParams(test)) +} + +func CreateDefaultTestSettingsWithAgent(test *Test, agentSockPath string) settings.Settings { + params := getDefaultParams(test) + params.AuthSock = agentSockPath + return settings.NewBaseProviders(params) +} + +type SessionOverride func(input *session.Input) + +func generateIncorrectPort(wrappers ...*TestContainerWrapper) string { + exclude := make([]int, 0, len(wrappers)) + for _, wrapper := range wrappers { + exclude = append(exclude, wrapper.LocalPort()) + } + + return strconv.Itoa(RandPortExclude(exclude)) +} + +func OverrideSessionWithIncorrectPort(wrappers ...*TestContainerWrapper) SessionOverride { + return func(input *session.Input) { + input.Port = generateIncorrectPort(wrappers...) + } +} + +func OverrideSessionWithIncorrectBastionPort(wrappers ...*TestContainerWrapper) SessionOverride { + return func(input *session.Input) { + input.BastionPort = generateIncorrectPort(wrappers...) + } +} + +func Session(wrapper *TestContainerWrapper, overrides ...SessionOverride) *session.Session { + container := wrapper.Container + sett := container.ContainerSettings() + + input := session.Input{ + AvailableHosts: []session.Host{ + {Host: "127.0.0.1", Name: "localhost"}, + }, + User: sett.Username, + Port: container.LocalPortString(), + BecomePass: sett.Password, + } + + for _, override := range overrides { + override(&input) + } + + return session.NewSession(input) +} + +func SessionWithBastion(wrapper *TestContainerWrapper, bastionWrapper *TestContainerWrapper, overrides ...SessionOverride) *session.Session { + container := wrapper.Container + sett := container.ContainerSettings() + + bastionContainer := bastionWrapper.Container + bastionSetting := bastionContainer.ContainerSettings() + + input := session.Input{ + AvailableHosts: []session.Host{ + {Host: container.GetContainerIP(), Name: container.GetContainerIP()}, + }, + User: sett.Username, + Port: container.RemotePortString(), + BecomePass: sett.Password, + BastionHost: "127.0.0.1", + BastionPort: bastionContainer.LocalPortString(), + BastionUser: bastionSetting.Username, + BastionPassword: bastionSetting.Password, + } + + for _, override := range overrides { + override(&input) + } + + return session.NewSession(input) +} + +func FakeSession() *session.Session { + host := IncorrectHost() + return session.NewSession(session.Input{ + AvailableHosts: []session.Host{{Host: host, Name: host}}, + User: "user", + Port: strconv.Itoa(RandPort()), + BecomePass: RandPassword(6), + }) +} diff --git a/pkg/ssh/gossh/testing/ssh_container.go b/pkg/ssh/gossh/testing/ssh_container.go new file mode 100644 index 0000000..be0f711 --- /dev/null +++ b/pkg/ssh/gossh/testing/ssh_container.go @@ -0,0 +1,643 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + "strings" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/name212/govalue" +) + +const ( + tmpGlobalDirName = "test-lib-connection" + dockerNamePrefix = "test_lib_connection" +) + +type PublicKey struct { + Key string + Path string +} + +type ContainerSettings struct { + *Test + + PublicKey *PublicKey + Password string + Username string + NodeTmpPath string + LocalPort int + SudoAccess bool + + ContainerName string +} + +func (s *ContainerSettings) LocalPortString() string { + return strconv.Itoa(s.LocalPort) +} + +func (s *ContainerSettings) HasPublicKey() bool { + return !govalue.Nil(s.PublicKey) +} + +func (s *ContainerSettings) HasPublicKeyContent() bool { + return s.HasPublicKey() && s.PublicKey.Key != "" +} + +func (s *ContainerSettings) HasPublicKeyPath() bool { + return s.HasPublicKey() && s.PublicKey.Path != "" +} + +func (s *ContainerSettings) HasPassword() bool { + return s.Password != "" +} + +func (s *ContainerSettings) String() string { + return fmt.Sprintf( + "Settings: User: '%s' Port: %d SudoAccess: %v Name: %s", + s.Username, + s.LocalPort, + s.SudoAccess, + s.ContainerName, + ) +} + +type SSHContainer struct { + settings *ContainerSettings + id string + ip string + sshdConfigPath string + network string + externalNetwork bool +} + +func NewSSHContainer(settings *ContainerSettings) (*SSHContainer, error) { + if govalue.Nil(settings) { + return nil, errors.New("settings must be provided") + } + + if govalue.Nil(settings.Test) { + return nil, errors.New("Test must not be nil in settings") + } + + if settings.Test.FullName() == "" { + return nil, errors.New("testName is empty") + } + + if settings.Test.IsZero() { + return nil, errors.New("Test is empty") + } + + // <= 1024 port require root privilege + if settings.LocalPort <= 1024 { + settings.LocalPort = RandPort() + } + + if settings.ContainerName == "" { + settings.ContainerName = "target" + } + + c := &SSHContainer{ + settings: settings, + } + + return c, nil +} + +// force AllowTcpForwarding yes to allow connection throufh bastion +func (c *SSHContainer) WriteConfig() error { + passwordAuthEnabled := "no" + if c.ContainerSettings().Password != "" { + passwordAuthEnabled = "yes" + } + + configTpl := ` +Port %s +AuthorizedKeysFile .ssh/authorized_keys +AllowTcpForwarding yes +GatewayPorts no +X11Forwarding no +PidFile /config/sshd.pid +Subsystem sftp internal-sftp +PasswordAuthentication %s +MaxAuthTries 1000 +MaxSessions 1000 +AllowTcpForwarding yes +` + config := fmt.Sprintf(configTpl, c.RemotePortString(), passwordAuthEnabled) + + resPath, err := c.settings.Test.CreateTmpFile(config, false, "sshd", "config") + if err != nil { + return err + } + + c.sshdConfigPath = resPath + return nil +} + +func (c *SSHContainer) Start(waitSSHDStarted bool) error { + c.logInfo("Starting container fully...") + + err := c.createNetwork() + if err != nil { + return err + } + + err = c.startContainer(waitSSHDStarted) + if err != nil { + loopParams := c.defaultRetryParams(fmt.Sprintf("Remove network %s after fail run container", c.GetNetwork())) + removeNetworkErr := retry.NewLoopWithParams(loopParams).Run(func() error { + return c.removeNetwork() + }) + if removeNetworkErr != nil { + err = c.wrapError("%v and failed to remove network %s: %v", err, c.GetNetwork(), removeNetworkErr) + } + + return err + } + + c.logInfo("Container fully started %s", c.ShortContainerId()) + + return nil +} + +func (c *SSHContainer) Restart(waitSSHDStarted bool, sleepBeforeStart time.Duration) error { + if err := c.stopContainer(); err != nil { + return err + } + + Sleep(sleepBeforeStart) + + return c.startContainer(waitSSHDStarted) +} + +func (c *SSHContainer) Stop() error { + shortId := c.ShortContainerId() + c.logInfo("Stopping container '%s' fully...", shortId) + + resError := "" + if err := c.stopContainer(); err != nil { + resError = err.Error() + } + + if err := c.removeNetwork(); err != nil { + resError = fmt.Sprintf("%s/%s", resError, err.Error()) + } + + if resError != "" { + return c.wrapError("cannot fully stop container: %v", errors.New(resError)) + } + + c.logInfo("Container '%s' fully stopped", shortId) + + return nil +} + +func (c *SSHContainer) FailAndUpConnection(sleepBeforeConnect time.Duration) error { + if err := c.Disconnect(); err != nil { + return err + } + + Sleep(sleepBeforeConnect) + + return c.Connect() +} + +func (c *SSHContainer) Disconnect() error { + return c.runDockerNetworkConnect(false) +} + +func (c *SSHContainer) Connect() error { + return c.runDockerNetworkConnect(true) +} + +func (c *SSHContainer) ExecToContainer(description string, command ...string) error { + if err := c.isContainerStarted(description); err != nil { + return err + } + + args := append([]string{"exec", c.GetContainerId()}, command...) + + return c.runDocker(description, args...) +} + +func (c *SSHContainer) CreateDeckhouseDirs() error { + description := func(name string) string { + d := "node tmp dir" + if name == "" { + return d + } + + return fmt.Sprintf("%s %s", d, name) + } + + nodeTmpPath := c.ContainerSettings().NodeTmpPath + + if nodeTmpPath == "" { + return c.wrapError("cannot create %s. Path is empty", description("")) + } + + if err := c.ExecToContainer(description("create"), "mkdir", "-p", nodeTmpPath); err != nil { + return err + } + + return c.ExecToContainer(description("set mode"), "chmod", "-R", "777", nodeTmpPath) +} + +func (c *SSHContainer) WithExternalNetwork(network string) *SSHContainer { + c.network = network + c.externalNetwork = true + + return c +} + +func (c *SSHContainer) GetContainerId() string { + return c.id +} + +func (c *SSHContainer) ShortContainerId() string { + id := c.GetContainerId() + + if len(id) > 12 { + return fmt.Sprintf("%.12s", id) + } + return id +} + +func (c *SSHContainer) GetNetwork() string { + return c.network +} + +func (c *SSHContainer) GetContainerIP() string { + return c.ip +} + +func (c *SSHContainer) GetSSHDConfigPath() string { + return c.sshdConfigPath +} + +func (c *SSHContainer) ContainerSettings() *ContainerSettings { + return c.settings +} + +func (c *SSHContainer) RemotePortString() string { + return fmt.Sprintf("%d", c.RemotePort()) +} + +func (c *SSHContainer) RemotePort() int { + return 2222 +} + +func (c *SSHContainer) LocalPortString() string { + return c.settings.LocalPortString() +} + +func (c *SSHContainer) dockerName(id string) string { + return fmt.Sprintf("%s_%s", dockerNamePrefix, id) +} + +func (c *SSHContainer) generateDockerNetworkName() string { + return c.dockerName(c.settings.Test.GetID()) +} + +func (c *SSHContainer) generateDockerContainerName() string { + id := GenerateID( + c.settings.Test.Name(), + c.settings.ContainerName, + c.settings.String(), + ) + + containerName := c.settings.ContainerName + + if containerName == "" { + containerName = "target" + } + + id = fmt.Sprintf("%s_%s", id, containerName) + + return c.dockerName(id) +} + +func (c *SSHContainer) runDockerWithOut(description string, command ...string) (string, error) { + out, err := RunDockerWithOut(command...) + if err != nil { + return "", c.wrapError("%s: %v", description, err) + } + return out, nil +} + +func (c *SSHContainer) runDocker(description string, command ...string) error { + _, err := c.runDockerWithOut(description, command...) + return err +} + +func (c *SSHContainer) wrapError(format string, args ...any) error { + return c.settings.Test.WrapErrorWithAfterName(c.settings.String(), format, args...) +} + +func (c *SSHContainer) isContainerStarted(description string) error { + if c.GetContainerId() != "" { + return nil + } + + return c.wrapError("%s: container seems to be not started. Call Start() first", description) +} + +func (c *SSHContainer) runContainerArgs() []string { + settings := c.ContainerSettings() + + ports := fmt.Sprintf("%d:%s", settings.LocalPort, c.RemotePortString()) + name := c.generateDockerContainerName() + args := []string{ + "-d", + "-e", "USER_NAME=" + settings.Username, + "-p", ports, + "--name", name, + "--network", c.GetNetwork(), + } + + if settings.HasPublicKeyContent() { + args = append(args, "-e") + args = append(args, "PUBLIC_KEY="+settings.PublicKey.Key) + } + if settings.HasPublicKeyPath() { + args = append(args, "-e") + args = append(args, "PUBLIC_KEY_FILE="+settings.PublicKey.Path) + } + // set default password if no auth methods present + if !settings.HasPublicKeyContent() && !settings.HasPublicKeyPath() && !settings.HasPassword() { + c.settings.Password = "password" + } + + settings = c.ContainerSettings() + + if settings.HasPassword() { + args = append(args, "-e") + args = append(args, "PASSWORD_ACCESS=true") + args = append(args, "-e") + args = append(args, "USER_PASSWORD="+settings.Password) + } + args = append(args, "-e") + args = append(args, "SUDO_ACCESS="+fmt.Sprintf("%v", settings.SudoAccess)) + args = append(args, "--restart") + args = append(args, "unless-stopped") + + sshdConfigPath := c.GetSSHDConfigPath() + if sshdConfigPath != "" { + args = append(args, "-v") + args = append(args, sshdConfigPath+":/config/sshd/sshd_config") + } + + image := os.Getenv("DHCTL_TESTS_OPENSSH_IMAGE") + if image == "" { + image = "lscr.io/linuxserver/openssh-server:10.0_p1-r9-ls209" + } + + args = append(args, image) + + return args +} + +func (c *SSHContainer) startContainer(waitSSHDStarted bool) error { + cmd := append([]string{"run"}, c.runContainerArgs()...) + + c.logInfo("Starting container...") + + id, err := c.runDockerWithOut("start container", cmd...) + if err != nil { + return err + } + + c.id = strings.TrimSpace(id) + + c.ip, err = c.discoveryContainerIP() + if err != nil { + if stopErr := c.stopContainer(); stopErr != nil { + err = c.wrapError("%v and cannot stop container: %v", err, stopErr) + } + return err + } + + c.logInfo("Container started: ID: %s IP: %s", c.ShortContainerId(), c.GetContainerIP()) + + if !waitSSHDStarted { + return nil + } + + addr := fmt.Sprintf("127.0.0.1:%s", c.LocalPortString()) + + loopParams := c.defaultRetryParams(fmt.Sprintf("Wait SSHD started after start container %s", c.ShortContainerId())) + err = retry.NewLoopWithParams(loopParams).Run(func() error { + conn, err := net.DialTimeout("tcp", addr, 3*time.Second) + if err != nil { + return err + } + if err := conn.Close(); err != nil { + c.ContainerSettings().Logger.InfoF("Failed to close SSHD connection after restart container: %v", err) + } + + return nil + }) + + if err != nil { + if stopErr := c.stopContainer(); stopErr != nil { + err = c.wrapError("%v and cannot stop container: %v", err, stopErr) + } + return err + } + + return nil +} + +func (c *SSHContainer) stopContainer() error { + if err := c.isContainerStarted("stop container"); err != nil { + return nil + } + + description := func(name string) string { + return fmt.Sprintf("%s %s", name, c.GetContainerId()) + } + + id := c.GetContainerId() + shortID := c.ShortContainerId() + + c.logInfo("Stop container %s...", shortID) + + if err := c.runDocker(description("stop container"), "stop", id); err != nil { + return err + } + + c.logInfo("Remove container %s...", shortID) + + return c.runDocker(description("remove container"), "rm", id) +} + +func (c *SSHContainer) hasNetwork(description string) error { + if c.GetNetwork() != "" { + return nil + } + + return c.wrapError("%s: docker network is not created. Container seems to be not connected to named bridge", description) +} + +func (c *SSHContainer) isExternalNetwork() bool { + return c.externalNetwork +} + +func (c *SSHContainer) createNetwork() error { + if err := c.isContainerStarted("create network"); err == nil { + return c.wrapError("container %s is already running", c.GetContainerId()) + } + + hasNetwork := c.hasNetwork("create network") == nil + isExternal := c.isExternalNetwork() + + if hasNetwork { + if isExternal { + c.logInfo("Skip creating network '%s'. Has external network", c.GetNetwork()) + // do not need to create network + return nil + } + + return c.wrapError("network %s is already created", c.GetContainerId()) + } + + network := c.generateDockerNetworkName() + + c.logInfo("Creating network '%s'...", network) + + description := fmt.Sprintf("create network '%s'", network) + if err := c.runDocker(description, "network", "create", network); err != nil { + return err + } + + c.network = network + + return nil +} + +func (c *SSHContainer) removeNetwork() error { + hasNetwork := c.hasNetwork("remove network") == nil + isExternal := c.isExternalNetwork() + + network := c.GetNetwork() + + if !hasNetwork || isExternal { + c.logInfo("Skip deleting network '%s'. Has external network or empty", network) + return nil + } + + c.logInfo("Deleting network %s...", network) + + if err := c.runDocker(fmt.Sprintf("remove network %s", network), "network", "rm", network); err != nil { + return err + } + + c.network = "" + + return nil +} + +func (c *SSHContainer) logInfo(format string, args ...any) { + format += fmt.Sprintf(" (%s)", c.settings.String()) + c.settings.Logger.InfoF(format, args...) +} + +func (c *SSHContainer) runDockerNetworkConnect(isDisconnect bool) error { + cmdName := "connect" + if isDisconnect { + cmdName = "disconnect" + } + + description := fmt.Sprintf("network %s", cmdName) + + if err := c.isContainerStarted(description); err != nil { + return err + } + + if err := c.hasNetwork(description); err != nil { + return err + } + + network := c.GetNetwork() + + c.logInfo( + "%s network %s to container %s...", + strings.ToTitle(cmdName), + network, + c.ShortContainerId(), + ) + + return c.runDocker(cmdName, "network", cmdName, network, c.GetContainerId()) +} + +func (c *SSHContainer) discoveryContainerIP() (string, error) { + description := "Getting IP address of container" + if err := c.hasNetwork(description); err != nil { + return "", err + } + + if err := c.isContainerStarted(description); err != nil { + return "", err + } + + getIPLoopParams := c.defaultRetryParams(fmt.Sprintf("%s %s", description, c.ShortContainerId())) + getIPCmd := []string{ + "inspect", + "-f", "{{range.NetworkSettings.Networks}}{{.IPAddress}}{{end}}", + c.GetContainerId(), + } + + ip := "" + + err := retry.NewLoopWithParams(getIPLoopParams).Run(func() error { + ipFromRun, err := c.runDockerWithOut(description, getIPCmd...) + if err != nil { + return err + } + + ipFromRun = strings.TrimSpace(ipFromRun) + if ipFromRun == "" { + return errors.New("container IP is empty") + } + + ip = ipFromRun + + return nil + }) + + if err != nil { + return "", err + } + + return ip, nil +} + +func (c *SSHContainer) defaultRetryParams(name string) retry.Params { + logger := c.ContainerSettings().Test.Logger + + return retry.NewEmptyParams( + retry.WithName(name), + retry.WithAttempts(5), + retry.WithWait(3*time.Second), + retry.WithLogger(logger), + ) +} diff --git a/pkg/ssh/gossh/testing/test.go b/pkg/ssh/gossh/testing/test.go new file mode 100644 index 0000000..172f8d2 --- /dev/null +++ b/pkg/ssh/gossh/testing/test.go @@ -0,0 +1,298 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/name212/govalue" + "github.com/stretchr/testify/require" +) + +const ( + randomSuffixSeparator = "." +) + +type Test struct { + tmpDir string + id string + Logger *log.InMemoryLogger + + testName string + subTestName string +} + +func ShouldNewTest(t *testing.T, testName string) *Test { + CheckSkipSSHTest(t, testName) + + err := os.Setenv("SSH_AUTH_SOCK", "") + require.NoError(t, err, "cleanup SSH_AUTH_SOCK env") + + tst, err := NewTest(testName) + require.NoError(t, err, "failed to create Test '%s'", testName) + tst.RegisterCleanup(t) + return tst +} + +func NewTest(testName string) (*Test, error) { + if testName == "" { + return nil, fmt.Errorf("testName is empty") + } + + id := GenerateID(testName) + + resTest := &Test{ + testName: testName, + id: id, + } + + if govalue.Nil(resTest.Logger) { + resTest.Logger = TestLogger() + } + + localTmpDirStr := filepath.Join(os.TempDir(), tmpGlobalDirName, id) + + err := os.MkdirAll(localTmpDirStr, 0777) + if err != nil { + return nil, resTest.WrapError("failed to create local tmp dir %s: %v", localTmpDirStr, err) + } + + resTest.tmpDir = localTmpDirStr + + return resTest, nil +} + +func (s *Test) IsZero() bool { + return s.TmpDir() == "" || s.GetID() == "" || s.Name() == "" +} + +var forReplace = []string{" ", ",", ".", "-"} + +func (s *Test) SetSubTest(names ...string) *Test { + resName := "" + + l := len(names) + if l > 0 { + tests := make([]string, 0, l) + for _, name := range names { + for _, old := range forReplace { + name = strings.ReplaceAll(name, old, "_") + } + tests = append(tests, name) + } + resName = strings.Join(tests, "/") + } + + s.subTestName = strings.TrimPrefix(resName, s.Name()+"/") + return s +} + +func (s *Test) WrapError(format string, args ...any) error { + f := s.FullName() + ": " + format + return fmt.Errorf(f, args...) +} + +func (s *Test) WrapErrorWithAfterName(aftername, format string, args ...any) error { + f := fmt.Sprintf("%s (%s): ", s.FullName(), aftername) + format + return fmt.Errorf(f, args...) +} + +func (s *Test) FullName() string { + res := s.Name() + if s.subTestName != "" { + res = fmt.Sprintf("%s/%s", res, s.subTestName) + } + + return res +} + +func (s *Test) Name() string { + return s.testName +} + +func (s *Test) GetID() string { + return s.id +} + +func (s *Test) TmpDir() string { + return s.tmpDir +} + +func (s *Test) MustMkSubDirs(t *testing.T, dirs ...string) string { + testDir, err := s.MkSubDirs(dirs...) + require.NoError(t, err, "MustMkSubDirs should create sub dirs") + + return testDir +} + +func (s *Test) MustCreateTmpFile(t *testing.T, content string, executable bool, pathInTestDir ...string) string { + result, err := s.CreateTmpFile(content, executable, pathInTestDir...) + require.NoError(t, err, "MustCreateTmpFile should create tmp file") + return result +} + +func (s *Test) MustCreateFile(t *testing.T, content string, executable bool, pathInTestDir ...string) string { + result, err := s.CreateFile(content, executable, pathInTestDir...) + require.NoError(t, err, "MustCreateFile should create file") + + return result +} + +func (s *Test) CreateTmpFile(content string, executable bool, pathInTestDir ...string) (string, error) { + if err := s.validateCreateDirsFilesArgs(pathInTestDir...); err != nil { + return "", err + } + + filePrefix, subDirs := s.fileNameAndSubDirs(pathInTestDir...) + + suffix := GenerateID(s.FullName(), filePrefix) + + fileName := addRandomSuffix(filePrefix, suffix) + + return s.CreateFile(content, executable, append(subDirs, fileName)...) +} + +func (s *Test) CreateFileWithSameSuffix(sourceFile string, content string, executable bool, pathInTestDir ...string) (string, error) { + if err := s.validateCreateDirsFilesArgs(pathInTestDir...); err != nil { + return "", err + } + + if sourceFile == "" { + return "", fmt.Errorf("source file is empty for file") + } + + sourceFileBase := filepath.Base(sourceFile) + + fileNameSeparated := strings.Split(sourceFileBase, randomSuffixSeparator) + if len(fileNameSeparated) < 2 { + return "", fmt.Errorf("suffix is empty for file %s", sourceFile) + } + + l := len(fileNameSeparated) + sourceName := fileNameSeparated[l-2] + suffix := fileNameSeparated[l-1] + + fileName, subDirs := s.fileNameAndSubDirs(pathInTestDir...) + + if sourceName == fileName { + return "", fmt.Errorf("source file name %s is same as destination for file %s", fileName, sourceFile) + } + + resFileName := addRandomSuffix(fileName, suffix) + + return s.CreateFile(content, executable, append(subDirs, resFileName)...) +} + +func (s *Test) CreateFile(content string, executable bool, pathInTestDir ...string) (string, error) { + if err := s.validateCreateDirsFilesArgs(pathInTestDir...); err != nil { + return "", err + } + + fileName, subDirs := s.fileNameAndSubDirs(pathInTestDir...) + + fullPathSlice := []string{s.TmpDir()} + if len(subDirs) > 0 { + if _, err := s.MkSubDirs(subDirs...); err != nil { + return "", fmt.Errorf("failed to create sub dirs: %v", err) + } + + fullPathSlice = append(fullPathSlice, subDirs...) + } + + fullPathSlice = append(fullPathSlice, fileName) + + fullPath := filepath.Join(fullPathSlice...) + + mode := os.FileMode(0666) + if executable { + mode = os.FileMode(0755) + } + + err := os.WriteFile(fullPath, []byte(content), mode) + if err != nil { + return "", s.WrapError("failed to create file %s: %v", fullPath, err) + } + + return fullPath, nil +} + +func (s *Test) MkSubDirs(dirs ...string) (string, error) { + if err := s.validateCreateDirsFilesArgs(dirs...); err != nil { + return "", err + } + + fullPathSlice := append([]string{s.TmpDir()}, dirs...) + testDir := filepath.Join(fullPathSlice...) + + if err := os.MkdirAll(testDir, 0755); err != nil { + return "", s.WrapError("failed to create test dir %s: %v", testDir, err) + } + + return testDir, nil +} + +func (s *Test) validateCreateDirsFilesArgs(paths ...string) error { + if len(paths) == 0 { + return s.WrapError("paths parts is empty") + } + + if s.TmpDir() == "" { + return s.WrapError("tmpDir is empty") + } + + return nil +} + +func (s *Test) RegisterCleanup(t *testing.T) { + t.Cleanup(func() { + s.Cleanup(t) + }) +} + +func (s *Test) Cleanup(t *testing.T) { + tmpDir := s.TmpDir() + if tmpDir == "" || tmpDir == "/" { + return + } + + err := os.RemoveAll(tmpDir) + if err != nil && !os.IsNotExist(err) { + LogErrorOrAssert(t, fmt.Sprintf("Remove local tmp dir: %s", tmpDir), err, s.Logger) + return + } + + if !govalue.Nil(s.Logger) { + s.Logger.InfoF("Temp dir '%s' removed for test %s", tmpDir, s.FullName()) + } +} + +func (s *Test) fileNameAndSubDirs(pathInTestDir ...string) (string, []string) { + l := len(pathInTestDir) + + if l == 1 { + return pathInTestDir[0], nil + } + + return pathInTestDir[l-1], pathInTestDir[:l-1] +} + +func addRandomSuffix(name string, suffix string) string { + return fmt.Sprintf("%s%s%s", name, randomSuffixSeparator, suffix) +} diff --git a/pkg/ssh/gossh/testing/test_container_wrapper.go b/pkg/ssh/gossh/testing/test_container_wrapper.go new file mode 100644 index 0000000..350076e --- /dev/null +++ b/pkg/ssh/gossh/testing/test_container_wrapper.go @@ -0,0 +1,221 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "testing" + + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/name212/govalue" + "github.com/stretchr/testify/require" +) + +type TestContainerWrapperSettingsOpts func(container *TestContainerWrapperSettings) +type TestContainerWrapperSettings struct { + *ContainerSettings + + PrivateKeyPassword string + ExternalNetwork string + + NoStartContainerDuringCreate bool + NoWaitStartingSSHDAfterStartContainer bool + NoGeneratePrivateKey bool + NoWriteSSHDConfig bool +} + +type TestContainerWrapper struct { + Container *SSHContainer + Settings *TestContainerWrapperSettings + PrivateKeyPath string +} + +func NewTestContainerWrapper(t *testing.T, test *Test, opts ...TestContainerWrapperSettingsOpts) *TestContainerWrapper { + require.False(t, govalue.Nil(test), "test must not be nil") + + testSettings := &TestContainerWrapperSettings{ + ContainerSettings: &ContainerSettings{ + Test: test, + Password: RandPassword(12), + Username: "user", + SudoAccess: true, + NodeTmpPath: "/opt/deckhouse/tmp", + }, + } + + for _, opt := range opts { + opt(testSettings) + } + + logger := testSettings.Logger + if govalue.Nil(logger) { + logger = TestLogger() + } + + testContainer := &TestContainerWrapper{ + Settings: testSettings, + } + + if !testSettings.HasPublicKeyContent() || !testSettings.HasPublicKeyPath() { + privateKeyPath, publicKey, err := GenerateKeys(test, testSettings.PrivateKeyPassword) + if err != nil { + testContainer.Cleanup(t) + require.NoError(t, err) + } + + publicKeyPath, err := WritePubKeyFileForPrivate(test, privateKeyPath, publicKey) + if err != nil { + testContainer.Cleanup(t) + require.NoError(t, err) + } + + test.Logger.InfoF("Private key created: path '%s' pub key path: %s", privateKeyPath, publicKeyPath) + + testSettings.PublicKey = &PublicKey{ + Path: publicKeyPath, + Key: publicKey, + } + testContainer.PrivateKeyPath = privateKeyPath + } + + container, err := NewSSHContainer(testSettings.ContainerSettings) + require.NoError(t, err) + + testContainer.Container = container + t.Cleanup(func() { + testContainer.Cleanup(t) + }) + + if testSettings.ExternalNetwork != "" { + container.WithExternalNetwork(testSettings.ExternalNetwork) + } + + if !testSettings.NoWriteSSHDConfig { + err := container.WriteConfig() + require.NoError(t, err) + } + + if !testSettings.NoStartContainerDuringCreate { + wait := !testSettings.NoWaitStartingSSHDAfterStartContainer + err = container.Start(wait) + require.NoError(t, err) + } + + return testContainer +} + +func (c *TestContainerWrapper) LocalPort() int { + return c.Container.ContainerSettings().LocalPort +} + +func (c *TestContainerWrapper) ContainerIP() string { + return c.Container.GetContainerIP() +} + +func (c *TestContainerWrapper) AgentPrivateKeys() []session.AgentPrivateKey { + if !c.Settings.HasPublicKey() || !c.Settings.HasPublicKeyPath() { + return make([]session.AgentPrivateKey, 0, 1) + } + + return []session.AgentPrivateKey{{Key: c.PrivateKeyPath, Passphrase: c.Settings.PrivateKeyPassword}} +} + +func (c *TestContainerWrapper) PublicKeyPath() string { + if c.Settings.HasPublicKeyPath() { + return c.Settings.PublicKey.Path + } + + return "" +} + +func (c *TestContainerWrapper) Cleanup(t *testing.T) { + containerStopError := c.Container.Stop() + removeErrors := removeFiles(c.PrivateKeyPath, c.PublicKeyPath(), c.Container.GetSSHDConfigPath()) + + logger := c.Settings.Test.Logger + + LogErrorOrAssert(t, "failed to stop container", containerStopError, logger) + for _, removeError := range removeErrors { + LogErrorOrAssert(t, "failed to remove private key", removeError, logger) + } + + c.Settings.Cleanup(t) +} + +func WithNoStartContainer() TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.NoStartContainerDuringCreate = true + } +} + +func WithNoWriteSSHDConfig() TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.NoWriteSSHDConfig = true + } +} + +func WithUserName(name string) TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.Username = name + } +} + +func WithConnectToContainerNetwork(testContainer *TestContainerWrapper) TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ExternalNetwork = testContainer.Container.GetNetwork() + } +} + +func WithAuthSettings(testContainer *TestContainerWrapper) TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ContainerSettings.Password = testContainer.Settings.Password + s.ContainerSettings.PublicKey = testContainer.Settings.PublicKey + } +} + +func WithNoWaitStartingSSHDAfterStartContainer() TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.NoWaitStartingSSHDAfterStartContainer = true + } +} + +func WithNoPassword() TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ContainerSettings.Password = "" + } +} + +func WithPassword(password string) TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ContainerSettings.Password = password + } +} + +func WithContainerName(name string) TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ContainerSettings.ContainerName = name + } +} + +func WithNoSudo() TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ContainerSettings.SudoAccess = false + } +} + +func WithLocalPort(port int) TestContainerWrapperSettingsOpts { + return func(s *TestContainerWrapperSettings) { + s.ContainerSettings.LocalPort = port + } +} diff --git a/pkg/ssh/gossh/testing/web_server.go b/pkg/ssh/gossh/testing/web_server.go new file mode 100644 index 0000000..b5cffb7 --- /dev/null +++ b/pkg/ssh/gossh/testing/web_server.go @@ -0,0 +1,263 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_testing + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/name212/govalue" + "github.com/stretchr/testify/require" +) + +const HealthzPath = "/healthz" + +type HTTPHandler struct { + Path string + Handle func(w http.ResponseWriter, r *http.Request, logger *PrefixLogger) +} + +func NewSimpleHTTPHandler(path string, response string) *HTTPHandler { + return &HTTPHandler{ + Path: path, + Handle: func(w http.ResponseWriter, r *http.Request, logger *PrefixLogger) { + _, err := fmt.Fprintf(w, "%s", response) + status := http.StatusOK + if err != nil { + logger.Error("Error writing %s response: %v", r.URL.Path, err) + status = http.StatusInternalServerError + } + w.WriteHeader(status) + }, + } +} + +func (h *HTTPHandler) IsValid() error { + if h.Path == "" { + return errors.New("missing path for handler") + } + + if !strings.HasPrefix(h.Path, "/") { + return fmt.Errorf("path '%s' must start with a slash", h.Path) + } + + if govalue.Nil(h.Handle) { + return fmt.Errorf("handle is nil for path %s", h.Path) + } + + return nil +} + +type HTTPServer struct { + mux *http.ServeMux + server *http.Server + logger *PrefixLogger + address string + stopped bool +} + +func MustStartHTTPServer(t *testing.T, test *Test, port int, handlers ...*HTTPHandler) *HTTPServer { + server := NewHTTPServer(port, test.Logger, handlers...).WithLogPrefix(test.FullName()) + err := server.Start(true) + require.NoError(t, err) + server.RegisterCleanup(t) + + return server +} + +func NewHTTPServer(port int, logger *log.InMemoryLogger, handlers ...*HTTPHandler) *HTTPServer { + mux := http.NewServeMux() + + address := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", port)) + server := &http.Server{ + Addr: address, + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + res := &HTTPServer{ + mux: mux, + server: server, + logger: newPrefixLoggerWithAddress(logger, address).WithPrefix(""), + address: address, + } + + healthz := NewSimpleHTTPHandler(HealthzPath, "OK\n") + + fullHandlers := append([]*HTTPHandler{healthz}, handlers...) + for _, h := range fullHandlers { + res.AddHandler(h) + } + + return res +} + +func (s *HTTPServer) WithLogPrefix(p string) *HTTPServer { + s.logger.WithPrefix(p) + return s +} + +func (s *HTTPServer) AddHandler(handler *HTTPHandler) { + if err := handler.IsValid(); err != nil { + s.logger.Error("Handler %s is not valid: %v", handler.Path, err) + return + } + + if s.stopped { + s.logger.Error("AddHandler %s: server already stopped", handler.Path) + return + } + + s.mux.HandleFunc(handler.Path, func(writer http.ResponseWriter, request *http.Request) { + handler.Handle(writer, request, s.logger) + }) +} + +func (s *HTTPServer) Start(waitStart bool) error { + go func() { + s.logger.Info("Starting HTTP server") + if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("Error starting HTTP server: %v", err) + } + }() + + if !waitStart { + return nil + } + + url := fmt.Sprintf("http://%s%s", s.address, HealthzPath) + + loop := retry.NewEmptyParams( + retry.WithName(fmt.Sprintf("Check HTTP server %s started", s.logger.prefix)), + retry.WithAttempts(10), + retry.WithWait(500*time.Millisecond), + retry.WithLogger(s.logger.Logger), + ) + + _, err := DoGetRequest(url, loop, s.logger) + if err != nil { + err = fmt.Errorf("error starting HTTP server: %w", err) + + errStop := s.Stop() + if errStop != nil { + err = fmt.Errorf("%w and Stop error %w", err, errStop) + } + + s.logger.Error("%v", err) + return err + } + + return nil +} + +func (s *HTTPServer) Stop() error { + if s.stopped { + return nil + } + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err := s.server.Shutdown(ctx) + if err != nil { + s.logger.Error("Error shutting down server: %v", err) + return err + } + + s.stopped = true + s.logger.Info("Server stopped") + return nil +} + +func (s *HTTPServer) RegisterCleanup(t *testing.T) { + t.Cleanup(func() { + if err := s.Stop(); err != nil { + s.logger.Error("Error cleanup server: %v", err) + } + }) +} + +func DoGetRequest(url string, loop retry.Params, logger *PrefixLogger) (string, error) { + if url == "" { + return "", errors.New("missing url for GET request") + } + + if govalue.Nil(loop) { + return "", errors.New("loop params is nil for GET request") + } + + if govalue.Nil(logger) { + return "", errors.New("logger is nil for GET request") + } + + logError := func(msg string, err error) error { + logger.Error("Error GET %s request. %s: %v", url, msg, err) + return err + } + + response := "" + + if loop.Name() == retry.NotSetName { + loop.Clone().WithName(fmt.Sprintf("Do GET request %s", url)) + } + + err := retry.NewLoopWithParams(loop).Run(func() error { + client := &http.Client{} + + ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) + defer cancel() // Ensure the context is canceled to release resources + + // Create a new HTTP GET request with the context + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return logError("creating", err) + } + + resp, err := client.Do(req) + if err != nil { + return logError("do", err) + } + + defer func() { + if err := resp.Body.Close(); err != nil { + _ = logError("closing body", err) + } + }() + + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + return logError("reading response body", err) + } + + response = string(responseBytes) + return nil + }) + + if err != nil { + return "", err + } + + return response, nil +} diff --git a/pkg/ssh/gossh/tunnel.go b/pkg/ssh/gossh/tunnel.go new file mode 100644 index 0000000..5f62da7 --- /dev/null +++ b/pkg/ssh/gossh/tunnel.go @@ -0,0 +1,187 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "fmt" + "io" + "math/rand/v2" + "net" + "strings" + "sync" + + "github.com/pkg/errors" +) + +type Tunnel struct { + sshClient *Client + address string + + tunMutex sync.Mutex + + started bool + stopCh chan struct{} + remoteListener net.Listener + + errorCh chan error +} + +func NewTunnel(sshClient *Client, address string) *Tunnel { + return &Tunnel{ + sshClient: sshClient, + address: address, + errorCh: make(chan error, 10), + } +} + +func (t *Tunnel) Up() error { + _, err := t.upNewTunnel(-1) + return err +} + +func (t *Tunnel) upNewTunnel(oldId int) (int, error) { + logger := t.sshClient.settings.Logger() + + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + + if t.started { + logger.DebugF("[%d] Tunnel already up\n", oldId) + return -1, fmt.Errorf("already up") + } + + id := rand.Int() + + parts := strings.Split(t.address, ":") + if len(parts) != 4 { + return -1, fmt.Errorf("invalid address must be 'remote_bind:remote_port:local_bind:local_port': %s", t.address) + } + + remoteBind, remotePort, localBind, localPort := parts[0], parts[1], parts[2], parts[3] + + logger.DebugF("[%d] Remote bind: %s remote port: %s local bind: %s local port: %s\n", id, remoteBind, remotePort, localBind, localPort) + + logger.DebugF("[%d] Start tunnel\n", id) + + remoteAddress := net.JoinHostPort(remoteBind, remotePort) + localAddress := net.JoinHostPort(localBind, localPort) + + listener, err := net.Listen("tcp", localAddress) + if err != nil { + return -1, errors.Wrap(err, fmt.Sprintf("failed to listen local on %s", localAddress)) + } + + logger.DebugF("[%d] Listen remote %s successful\n", id, localAddress) + + go t.acceptTunnelConnection(id, remoteAddress, listener) + + t.remoteListener = listener + t.started = true + + return id, nil +} + +func (t *Tunnel) acceptTunnelConnection(id int, remoteAddress string, listener net.Listener) { + for { + localConn, err := listener.Accept() + if err != nil { + e := fmt.Errorf("[%d] Accept(): %s", id, err.Error()) + t.errorCh <- e + continue + } + + remoteConn, err := t.sshClient.GetClient().Dial("tcp", remoteAddress) + if err != nil { + e := fmt.Errorf("[%d] Cannot dial to %s: %s", id, remoteAddress, err.Error()) + t.errorCh <- e + continue + } + + go func() { + defer localConn.Close() + defer remoteConn.Close() + go func() { + _, err := io.Copy(remoteConn, localConn) + if err != nil { + t.errorCh <- err + } + + }() + + _, err := io.Copy(localConn, remoteConn) + if err != nil { + t.errorCh <- err + } + + }() + } +} + +func (t *Tunnel) HealthMonitor(errorOutCh chan<- error) { + logger := t.sshClient.settings.Logger() + + defer logger.DebugF("Tunnel health monitor stopped\n") + logger.DebugF("Tunnel health monitor started\n") + + t.stopCh = make(chan struct{}, 1) + + for { + select { + case err := <-t.errorCh: + errorOutCh <- err + case <-t.stopCh: + if t.remoteListener != nil { + _ = t.remoteListener.Close() + } + return + } + } +} + +func (t *Tunnel) Stop() { + t.stop(-1, true) +} + +func (t *Tunnel) stop(id int, full bool) { + logger := t.sshClient.settings.Logger() + + t.tunMutex.Lock() + defer t.tunMutex.Unlock() + + if !t.started { + logger.DebugF("[%d] Tunnel already stopped\n", id) + return + } + + logger.DebugF("[%d] Stop tunnel\n", id) + defer logger.DebugF("[%d] End stop tunnel\n", id) + + if full && t.stopCh != nil { + logger.DebugF("[%d] Stop tunnel health monitor\n", id) + t.stopCh <- struct{}{} + } + + err := t.remoteListener.Close() + if err != nil { + logger.WarnF("[%d] Cannot close listener: %s\n", id, err.Error()) + } + + t.remoteListener = nil + t.started = false +} + +func (t *Tunnel) String() string { + return fmt.Sprintf("%s:%s", "L", t.address) +} diff --git a/pkg/ssh/gossh/tunnel_test.go b/pkg/ssh/gossh/tunnel_test.go new file mode 100644 index 0000000..f1143df --- /dev/null +++ b/pkg/ssh/gossh/tunnel_test.go @@ -0,0 +1,235 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/retry" + ssh "github.com/deckhouse/lib-gossh" + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" +) + +func TestTunnel(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestTunnel") + + sshClient, container := startContainerAndClientWithContainer(t, test, sshtesting.WithNoWriteSSHDConfig()) + sshClient.WithLoopsParams(ClientLoopsParams{ + NewSession: retry.NewEmptyParams( + retry.WithAttempts(5), + retry.WithWait(250*time.Millisecond), + ), + }) + + // we don't have /opt/deckhouse in the container, so we should create it before start any UploadScript with sudo + err := container.Container.CreateDeckhouseDirs() + require.NoError(t, err, "could not create deckhouse dirs") + + remoteServerPort := sshtesting.RandPortExclude([]int{container.Container.RemotePort()}) + remoteServerScript := fmt.Sprintf(`#!/bin/bash +while true ; do { + echo -ne "HTTP/1.0 200 OK\r\nContent-Length: 2\r\n\r\n" ; + echo -n "OK"; +} | nc -l -p %d ; +done`, remoteServerPort) + + const remoteServerFile = "/tmp/server.sh" + localServerFile := test.MustCreateTmpFile(t, remoteServerScript, true, "remote_server", "server.sh") + + err = sshClient.File().Upload(context.TODO(), localServerFile, remoteServerFile) + require.NoError(t, err) + + runRemoteServerSession, err := sshClient.NewSSHSession() + require.NoError(t, err) + + t.Cleanup(func() { + err := runRemoteServerSession.Signal(ssh.SIGKILL) + if err != nil { + test.Logger.ErrorF("error killing remote server: %v", err) + } + err = runRemoteServerSession.Close() + if err != nil { + test.Logger.ErrorF("error closing remote server session: %v", err) + } + }) + + err = runRemoteServerSession.Start(remoteServerFile) + require.NoError(t, err, "error starting remote server") + + localsReservedPorts := []int{container.LocalPort()} + + t.Run("Tunnel to container", func(t *testing.T) { + localServerPort := sshtesting.RandPortExclude(localsReservedPorts) + localsReservedPorts = append(localsReservedPorts, localServerPort) + + localServerInvalidPort := sshtesting.RandInvalidPortExclude(localsReservedPorts) + remoteServerInvalidPort := sshtesting.RandPortExclude([]int{remoteServerPort, container.Container.RemotePort()}) + + cases := []struct { + title string + + localPort int + remotePort int + + wantErr bool + err string + }{ + { + title: "Tunnel, success", + localPort: localServerPort, + remotePort: remoteServerPort, + wantErr: false, + }, + { + title: "Invalid address", + localPort: localServerPort, + remotePort: remoteServerInvalidPort, + wantErr: true, + err: "invalid address must be 'remote_bind:remote_port:local_bind:local_port'", + }, + { + title: "Invalid local bind", + localPort: localServerInvalidPort, + remotePort: remoteServerPort, + wantErr: true, + err: fmt.Sprintf("failed to listen local on 127.0.0.1:%d", localServerInvalidPort), + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + address := tunnelAddressString(c.localPort, c.remotePort) + tun := NewTunnel(sshClient, address) + err = tun.Up() + registerStopTunnel(t, tun) + + if c.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + } + + checkLocalTunnel(t, test, localServerPort, false) + + // try to up again: expectiong error + err = tun.Up() + require.Error(t, err) + require.Equal(t, err.Error(), "already up") + }) + } + }) + + t.Run("Health monitor", func(t *testing.T) { + upTunnelWithMonitor := func(t *testing.T, address string) chan error { + tun := NewTunnel(sshClient, address) + err = tun.Up() + registerStopTunnel(t, tun) + + // starting HealthMonitor + errChan := make(chan error, 10) + go tun.HealthMonitor(errChan) + + t.Cleanup(func() { + close(errChan) + }) + + return errChan + } + + t.Run("Dial to unreacheble host", func(t *testing.T) { + incorrectHost := sshtesting.IncorrectHost() + incorrectPort := sshtesting.RandPort() + localServerPort := sshtesting.RandPortExclude(localsReservedPorts) + localsReservedPorts = append(localsReservedPorts, localServerPort) + + remoteStr := fmt.Sprintf("%s:%d", incorrectHost, incorrectPort) + address := fmt.Sprintf("%s:127.0.0.1:%d", remoteStr, localServerPort) + + errChan := upTunnelWithMonitor(t, address) + + checkLocalTunnel(t, test, localServerPort, true) + + msg := "" + select { + case m, ok := <-errChan: + if !ok { + msg = "monitor channel closed" + } else { + if m != nil { + msg = m.Error() + } + } + + default: + msg = "" + } + + require.Contains(t, msg, fmt.Sprintf("Cannot dial to %s", remoteStr), "got: '%s'", msg) + }) + + t.Run("Restart connection", func(t *testing.T) { + localServerPort := sshtesting.RandPortExclude(localsReservedPorts) + localsReservedPorts = append(localsReservedPorts, localServerPort) + + upTunnelWithMonitor(t, tunnelAddressString(localServerPort, remoteServerPort)) + + checkLocalTunnel(t, test, localServerPort, false) + + restartSleep := 5 * time.Second + upMonitorSleep := 2 * time.Second + + test.Logger.InfoF( + "Disconnect (fail connection between server and client) case. Wait %s before connect. Wait %s before check", + restartSleep.String(), + upMonitorSleep.String(), + ) + + err = container.Container.FailAndUpConnection(restartSleep) + require.NoError(t, err) + + time.Sleep(upMonitorSleep) + + checkLocalTunnel(t, test, localServerPort, false) + }) + }) +} + +func checkLocalTunnel(t *testing.T, test *sshtesting.Test, localServerPort int, wantError bool) { + url := fmt.Sprintf("http://127.0.0.1:%d", localServerPort) + + requestLoop := retry.NewEmptyParams( + retry.WithName(fmt.Sprintf("Check local tunnel available by %s", url)), + retry.WithAttempts(10), + retry.WithWait(500*time.Millisecond), + retry.WithLogger(test.Logger), + ) + + _, err := sshtesting.DoGetRequest( + url, + requestLoop, + sshtesting.NewPrefixLogger(test.Logger).WithPrefix(test.FullName()), + ) + + assert := require.NoError + if wantError { + assert = require.Error + } + + assert(t, err, "check local tunnel. Want error %v", wantError) +} diff --git a/pkg/ssh/gossh/upload-script.go b/pkg/ssh/gossh/upload-script.go new file mode 100644 index 0000000..6d81b24 --- /dev/null +++ b/pkg/ssh/gossh/upload-script.go @@ -0,0 +1,328 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "errors" + "fmt" + "os/exec" + "path/filepath" + "regexp" + "strings" + "time" + + "al.essio.dev/pkg/shellescape" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-dhctl/pkg/log" + gossh "github.com/deckhouse/lib-gossh" + + "github.com/deckhouse/lib-connection/pkg/ssh/utils" + "github.com/deckhouse/lib-connection/pkg/ssh/utils/tar" +) + +type SSHUploadScript struct { + sshClient *Client + + uploadDir string + + ScriptPath string + Args []string + envs map[string]string + + sudo bool + + cleanupAfterExec bool + + stdoutHandler func(string) + + timeout time.Duration + + commanderMode bool +} + +func NewSSHUploadScript(sshClient *Client, scriptPath string, args ...string) *SSHUploadScript { + return &SSHUploadScript{ + sshClient: sshClient, + ScriptPath: scriptPath, + Args: args, + + cleanupAfterExec: true, + } +} + +func (u *SSHUploadScript) Sudo() { + u.sudo = true +} + +func (u *SSHUploadScript) WithStdoutHandler(handler func(string)) { + u.stdoutHandler = handler +} + +func (u *SSHUploadScript) WithTimeout(timeout time.Duration) { + u.timeout = timeout +} + +func (u *SSHUploadScript) WithEnvs(envs map[string]string) { + u.envs = envs +} + +func (u *SSHUploadScript) WithCommanderMode(enabled bool) { + u.commanderMode = enabled +} + +func (u *SSHUploadScript) IsSudo() bool { + return u.sudo +} + +func (u *SSHUploadScript) UploadDir() string { + return u.uploadDir +} + +func (u *SSHUploadScript) Settings() settings.Settings { + return u.sshClient.settings +} + +// WithCleanupAfterExec option tells if ssh executor should delete uploaded script after execution was attempted or not. +// It does not care if script was executed successfully of failed. +func (u *SSHUploadScript) WithCleanupAfterExec(doCleanup bool) { + u.cleanupAfterExec = doCleanup +} + +func (u *SSHUploadScript) WithExecuteUploadDir(dir string) { + u.uploadDir = dir +} + +func (u *SSHUploadScript) Execute(ctx context.Context) (stdout []byte, err error) { + logger := u.sshClient.settings.Logger() + + scriptName := filepath.Base(u.ScriptPath) + + remotePath := utils.ExecuteRemoteScriptPath(u, scriptName, false) + logger.DebugF("Uploading script %s to %s\n", u.ScriptPath, remotePath) + err = NewSSHFile(u.sshClient.settings, u.sshClient.sshClient).Upload(ctx, u.ScriptPath, remotePath) + if err != nil { + return nil, fmt.Errorf("upload: %v", err) + } + + var cmd *SSHCommand + scriptFullPath := u.pathWithEnv(utils.ExecuteRemoteScriptPath(u, scriptName, true)) + if u.sudo { + cmd = NewSSHCommand(u.sshClient, scriptFullPath, u.Args...) + cmd.Sudo(ctx) + } else { + cmd = NewSSHCommand(u.sshClient, scriptFullPath, u.Args...) + cmd.Cmd(ctx) + } + + if u.stdoutHandler != nil { + cmd.WithStdoutHandler(u.stdoutHandler) + } + + if u.timeout > 0 { + cmd.WithTimeout(u.timeout) + } + + err = cmd.Run(ctx) + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + // exitErr.Stderr is set in the "os/exec".Cmd.Output method from the Golang standard library. + // But we call the "os/exec".Cmd.Wait method, which does not set the Stderr field. + // We can reuse the exec.ExitError type when handling errors. + exitErr.Stderr = cmd.StderrBytes() + } + + err = fmt.Errorf("execute on remote: %w", err) + } + + if u.cleanupAfterExec { + defer func() { + err := NewSSHCommand(u.sshClient, "rm", "-f", scriptFullPath).Run(ctx) + if err != nil { + logger.DebugF("Failed to delete uploaded script %s: %v", scriptFullPath, err) + } + }() + } + + return cmd.StdoutBytes(), err +} + +func (u *SSHUploadScript) pathWithEnv(path string) string { + if len(u.envs) == 0 { + return path + } + + arrayToJoin := make([]string, 0, len(u.envs)*2) + + for k, v := range u.envs { + vEscaped := shellescape.Quote(v) + kvStr := fmt.Sprintf("%s=%s", k, vEscaped) + arrayToJoin = append(arrayToJoin, kvStr) + } + + envs := strings.Join(arrayToJoin, " ") + + return fmt.Sprintf("%s %s", envs, path) +} + +var ErrBashibleTimeout = errors.New("Timeout bashible step running") + +func (u *SSHUploadScript) ExecuteBundle(ctx context.Context, parentDir, bundleDir string) (stdout []byte, err error) { + logger := u.sshClient.settings.Logger() + + bundleName := fmt.Sprintf("bundle-%s.tar", time.Now().Format("20060102-150405")) + bundleLocalFilepath := filepath.Join(u.sshClient.settings.TmpDir(), bundleName) + + // tar cpf bundle.tar -C /tmp/dhctl.1231qd23/var/lib bashible + err = tar.CreateTar(bundleLocalFilepath, parentDir, bundleDir) + if err != nil { + return nil, fmt.Errorf("tar bundle: %v", err) + } + + // todo + //tomb.RegisterOnShutdown( + // "Delete bashible bundle folder", + // func() { _ = os.Remove(bundleLocalFilepath) }, + //) + + // upload to node's deckhouse tmp directory + err = NewSSHFile(u.sshClient.settings, u.sshClient.sshClient). + Upload(ctx, bundleLocalFilepath, u.sshClient.settings.NodeTmpDir()) + if err != nil { + return nil, fmt.Errorf("upload: %v", err) + } + + // sudo: + // tar xpof ${app.DeckhouseNodeTmpPath}/bundle.tar -C /var/lib && /var/lib/bashible/bashible.sh args... + tarCmdline := fmt.Sprintf( + "tar xpof %s/%s -C /var/lib && /var/lib/%s/%s %s", + u.sshClient.settings.NodeTmpDir(), + bundleName, + bundleDir, + u.ScriptPath, + strings.Join(u.Args, " "), + ) + bundleCmd := NewSSHCommand(u.sshClient, tarCmdline) + bundleCmd.Sudo(ctx) + + // Buffers to implement output handler logic + lastStep := "" + failsCounter := 0 + isBashibleTimeout := false + + processLogger := logger.ProcessLogger() + + handler := bundleSSHOutputHandler( + bundleCmd, + processLogger, + &lastStep, + &failsCounter, + &isBashibleTimeout, + u.commanderMode, + logger, + ) + bundleCmd.WithStdoutHandler(handler) + bundleCmd.CaptureStdout(nil) + bundleCmd.CaptureStderr(nil) + err = bundleCmd.Run(ctx) + if err != nil { + if lastStep != "" { + processLogger.ProcessFail() + } + + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + // exitErr.Stderr is set in the "os/exec".Cmd.Output method from the Golang standard library. + // But we call the "os/exec".Cmd.Wait method, which does not set the Stderr field. + // We can reuse the exec.ExitError type when handling errors. + exitErr.Stderr = bundleCmd.StderrBytes() + } + + err = fmt.Errorf("execute bundle: %w", err) + } else { + processLogger.ProcessEnd() + } + + if isBashibleTimeout { + return bundleCmd.StdoutBytes(), ErrBashibleTimeout + } + + return bundleCmd.StdoutBytes(), err +} + +var stepHeaderRegexp = regexp.MustCompile("^=== Step: /var/lib/bashible/bundle_steps/(.*)$") + +func bundleSSHOutputHandler( + cmd *SSHCommand, + processLogger log.ProcessLogger, + lastStep *string, + failsCounter *int, + isBashibleTimeout *bool, + commanderMode bool, + logger log.Logger, +) func(string) { + stepLogs := make([]string, 0) + return func(l string) { + if l == "===" { + return + } + if stepHeaderRegexp.Match([]byte(l)) { + match := stepHeaderRegexp.FindStringSubmatch(l) + stepName := match[1] + + if *lastStep == stepName { + logMessage := strings.Join(stepLogs, "\n") + switch { + case commanderMode && *failsCounter == 0: + logger.ErrorF("%s", logMessage) + case commanderMode && *failsCounter > 0: + logger.ErrorF("Run step %s finished with error^^^\n", stepName) + logger.DebugF("%s", logMessage) + default: + logger.ErrorF("%s", logMessage) + } + *failsCounter++ + stepLogs = stepLogs[:0] + if *failsCounter > 10 { + *isBashibleTimeout = true + if cmd != nil { + // Force kill bashible and close session/streams to unblock Wait/readers + _ = cmd.session.Signal(gossh.SIGABRT) + if cmd.Stdin != nil { + _ = cmd.Stdin.Close() + } + _ = cmd.session.Close() + } + return + } + + processLogger.ProcessFail() + stepName = fmt.Sprintf("%s, retry attempt #%d of 10", stepName, *failsCounter) + } else if *lastStep != "" { + stepLogs = make([]string, 0) + processLogger.ProcessEnd() + *failsCounter = 0 + } + + processLogger.ProcessStart("Run step " + stepName) + *lastStep = match[1] + return + } + + stepLogs = append(stepLogs, l) + logger.DebugLn(l) + } +} diff --git a/pkg/ssh/gossh/upload-script_test.go b/pkg/ssh/gossh/upload-script_test.go new file mode 100644 index 0000000..7148dce --- /dev/null +++ b/pkg/ssh/gossh/upload-script_test.go @@ -0,0 +1,317 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gossh + +import ( + "context" + "testing" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/retry" + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" +) + +func TestUploadScriptExecute(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestUploadScriptExecute") + + sshClient, container := startContainerAndClientWithContainer(t, test, sshtesting.WithNoWriteSSHDConfig()) + sshClient.WithLoopsParams(ClientLoopsParams{ + NewSession: retry.NewEmptyParams( + retry.WithAttempts(5), + retry.WithWait(250*time.Millisecond), + ), + }) + + // we don't have /opt/deckhouse in the container, so we should create it before start any UploadScript with sudo + err := container.Container.CreateDeckhouseDirs() + require.NoError(t, err, "could not create deckhouse dirs") + + script := `#!/bin/bash +if [[ $# -eq 0 ]]; then + echo "Error: No arguments provided." + exit 1 +elif [[ $# -gt 1 ]]; then + echo "Usage: $0 " + exit 1 +else + echo "provided: $1" +fi +` + scriptFile := test.MustCreateTmpFile(t, script, true, "execute_script", "script.sh") + + // evns test + envs := map[string]string{ + "TEST_ENV": "test", + } + + t.Run("Upload and execute script to container via existing ssh client", func(t *testing.T) { + cases := []struct { + title string + scriptPath string + scriptArgs []string + expected string + wantSudo bool + envs map[string]string + wantErr bool + err string + }{ + { + title: "Happy case", + scriptPath: scriptFile, + scriptArgs: []string{"one"}, + expected: "provided: one\n", + wantSudo: false, + wantErr: false, + }, + { + title: "Happy case with sudo", + scriptPath: scriptFile, + scriptArgs: []string{"one"}, + expected: "SUDO-SUCCESS\nprovided: one\n", + wantSudo: true, + wantErr: false, + }, + { + title: "Error by remote script execution", + scriptPath: scriptFile, + scriptArgs: []string{"one", "two"}, + wantSudo: false, + wantErr: true, + err: "execute on remote", + }, + { + title: "With envs", + scriptPath: scriptFile, + scriptArgs: []string{"one"}, + expected: "provided: one\n", + wantSudo: false, + envs: envs, + wantErr: false, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + s := sshClient.UploadScript(c.scriptPath, c.scriptArgs...) + s.WithCleanupAfterExec(true) + + if c.wantSudo { + s.Sudo() + } + if len(c.envs) > 0 { + s.WithEnvs(c.envs) + } + + out, err := s.Execute(context.Background()) + if !c.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + return + } + + require.NoError(t, err) + require.Equal(t, c.expected, string(out)) + }) + } + }) + +} + +func TestUploadScriptExecuteBundle(t *testing.T) { + test := sshtesting.ShouldNewTest(t, "TestUploadScriptExecuteBundle") + + sshClient, container := startContainerAndClientWithContainer(t, test, sshtesting.WithNoWriteSSHDConfig()) + sshClient.WithLoopsParams(ClientLoopsParams{ + NewSession: retry.NewEmptyParams( + retry.WithAttempts(5), + retry.WithWait(250*time.Millisecond), + ), + }) + + // we don't have /opt/deckhouse in the container, so we should create it before start any UploadScript with sudo + err := container.Container.CreateDeckhouseDirs() + require.NoError(t, err, "could not create deckhouse dirs") + + const entrypoint = "test.sh" + + testDir := prepareFakeBashibleBundle(t, test, entrypoint, "bashible") + + t.Run("Upload and execute bundle to container via existing ssh client", func(t *testing.T) { + cases := []struct { + title string + scriptArgs []string + parentDir string + bundleDir string + prepareFunc func() error + wantErr bool + err string + }{ + { + title: "Happy case", + scriptArgs: []string{}, + parentDir: testDir, + bundleDir: "bashible", + wantErr: false, + }, + { + title: "Bundle error", + scriptArgs: []string{"--add-failure"}, + parentDir: testDir, + bundleDir: "bashible", + wantErr: true, + }, + { + title: "Wrong bundle directory", + scriptArgs: []string{}, + parentDir: "/path/to/nonexistent/dir", + bundleDir: "wrong_bundle", + wantErr: true, + err: "tar bundle: failed to walk path", + }, + { + title: "Upload error", + scriptArgs: []string{""}, + parentDir: testDir, + bundleDir: "bashible", + prepareFunc: func() error { + cmd := sshClient.Command("chmod", "700", container.Container.ContainerSettings().NodeTmpPath) + cmd.Sudo(context.Background()) + return cmd.Run(context.Background()) + }, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + s := sshClient.UploadScript(entrypoint, c.scriptArgs...) + parentDir := c.parentDir + bundleDir := c.bundleDir + if c.prepareFunc != nil { + err = c.prepareFunc() + require.NoError(t, err) + } + + _, err := s.ExecuteBundle(context.Background(), parentDir, bundleDir) + if c.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.err) + return + } + + require.NoError(t, err) + }) + } + }) +} + +func prepareFakeBashibleBundle(t *testing.T, test *sshtesting.Test, entrypoint, bundleDir string) string { + bundleDirPath := func() []string { + return []string{"bundle_test", bundleDir} + } + + parentDir := test.MustMkSubDirs(t, bundleDirPath()...) + + entrypointScript := `#!/bin/bash + +echo "starting execute steps..." + +BUNDLE_STEPS_DIR=/var/lib/bashible/bundle_steps +BOOTSTRAP_DIR=/var/lib/bashible +MAX_RETRIES=5 + +for arg in "$@"; do + if [[ "$arg" == "--add-failure" ]] + then + echo "failures included" + export INCLUDE_FAILURE=true + fi +done + +# Execute bashible steps +for step in $BUNDLE_STEPS_DIR/*; do + echo === + echo === Step: $step + echo === + attempt=0 + sx="" + until /bin/bash --noprofile --norc -"$sx"eEo pipefail -c "export TERM=xterm-256color; unset CDPATH; cd $BOOTSTRAP_DIR; source $step" 2> >(tee /var/lib/bashible/step.log >&2) + do + attempt=$(( attempt + 1 )) + if [ -n "${MAX_RETRIES-}" ] && [ "$attempt" -gt "${MAX_RETRIES}" ]; then + >&2 echo "ERROR: Failed to execute step $step. Retry limit is over." + exit 1 + fi + >&2 echo "Failed to execute step "$step" ... retry in 10 seconds." + sleep 10 + echo === + echo === Step: $step + echo === + if [ "$attempt" -gt 2 ]; then + sx=x + fi + done +done + +` + + entrypointPath := append(bundleDirPath(), entrypoint) + test.MustCreateFile(t, entrypointScript, true, entrypointPath...) + + scrips := []struct { + name string + content string + }{ + { + name: "01-step.sh", + content: `#!/bin/bash +echo "just a step" + +for i in {0..3} +do + sleep $(( $RANDOM % 2 )) + echo $i +done +`, + }, + { + name: "02-step.sh", + content: `#!/bin/bash + +echo "second step" + +for i in {0..4} +do + sleep $(( $RANDOM % 2 )) + echo $i + if [[ $i -gt 2 && $INCLUDE_FAILURE == "true" ]] + then + echo "oops! failure!" + exit 1 + fi +done +`, + }, + } + + for _, c := range scrips { + scriptPath := append(bundleDirPath(), "bundle_steps", c.name) + test.MustCreateFile(t, c.content, true, scriptPath...) + } + + return parentDir +} diff --git a/pkg/ssh/local/command.go b/pkg/ssh/local/command.go new file mode 100644 index 0000000..a4c714f --- /dev/null +++ b/pkg/ssh/local/command.go @@ -0,0 +1,235 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/deckhouse/lib-connection/pkg/settings" +) + +type Command struct { + settings settings.Settings + + used atomic.Bool + + program string + args []string + sudo bool + env map[string]string + timeout time.Duration + + onStart func() + stdoutLineHandler func(line string) + stderrLineHandler func(line string) + + stdout []byte + stderr []byte +} + +func NewCommand(sett settings.Settings, program string, args ...string) *Command { + return &Command{ + program: program, + args: args, + settings: sett, + } +} + +func (c *Command) Run(ctx context.Context) error { + if !c.used.CompareAndSwap(false, true) { + return fmt.Errorf("command instance reused") + } + + cmd, cancel := c.prepareCmd(ctx) + defer cancel() + + wg := &sync.WaitGroup{} + stdoutBuf := &bytes.Buffer{} + stderrBuf := &bytes.Buffer{} + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("stdout pipe failed: %v", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("stderr pipe failed: %v", err) + } + wg.Add(2) + go c.scanLines(stdout, stdoutBuf, wg, c.stdoutLineHandler) + go c.scanLines(stderr, stderrBuf, wg, c.stderrLineHandler) + + if err = cmd.Start(); err != nil { + return fmt.Errorf("cmd start failed: %v", err) + } + if c.onStart != nil { + c.onStart() + } + + wg.Wait() // Wait for stdout/stderr reads to complete first + c.stdout = stdoutBuf.Bytes() + c.stderr = stderrBuf.Bytes() + return cmd.Wait() +} + +func (c *Command) scanLines( + stream io.Reader, + buf *bytes.Buffer, + wg *sync.WaitGroup, + handler func(string), +) { + defer wg.Done() + + scan := bufio.NewScanner(stream) + for scan.Scan() { + line := scan.Text() + buf.WriteString(line) + if handler != nil { + handler(line) + } + } + if err := scan.Err(); err != nil { + c.settings.Logger().ErrorF("scan cmd output failed: %v\n", err) + } +} + +func (c *Command) OnCommandStart(fn func()) { + c.onStart = fn +} + +func (c *Command) Output(ctx context.Context) ([]byte, []byte, error) { + if !c.used.CompareAndSwap(false, true) { + return nil, nil, fmt.Errorf("command instance reused") + } + + cmd, cancel := c.prepareCmd(ctx) + defer cancel() + + var stdout bytes.Buffer + cmd.Stdout = &stdout + + if err := cmd.Start(); err != nil { + return nil, nil, fmt.Errorf("start %q: %w", c.program, err) + } + if c.onStart != nil { + c.onStart() + } + + if err := cmd.Wait(); err != nil { + return nil, nil, err + } + return stdout.Bytes(), nil, nil // stderr is ignored to preserve compatibility with ssh frontend +} + +func (c *Command) CombinedOutput(ctx context.Context) ([]byte, error) { + if !c.used.CompareAndSwap(false, true) { + return nil, fmt.Errorf("command instance reused") + } + + cmd, cancel := c.prepareCmd(ctx) + defer cancel() + + var output bytes.Buffer + cmd.Stdout = &output + cmd.Stderr = &output + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start %q: %w", c.program, err) + } + if c.onStart != nil { + c.onStart() + } + + if err := cmd.Wait(); err != nil { + return nil, err + } + + return output.Bytes(), nil +} + +func (c *Command) prepareCmd(ctx context.Context) (*exec.Cmd, context.CancelFunc) { + bashBuiltins := []string{"bind", "type", "command", "let", "mapfile", "printf", "readarray", "ulimit"} + + program := c.program + args := c.args + if c.sudo { + program = "sudo" + args = append([]string{c.program}, c.args...) + } else if slices.Contains(bashBuiltins, program) { // For shell built-in things we need to run bash + program = "bash" + args = []string{"-c", strings.Join(append([]string{c.program}, c.args...), " ")} + } + + ctx, cancel := context.WithCancel(ctx) + if c.timeout > 0 { + cancel() + ctx, cancel = context.WithTimeout(ctx, c.timeout) + } + + cmd := exec.CommandContext(ctx, program, args...) + if len(c.env) > 0 { + cmd.Env = os.Environ() + for k, v := range c.env { + cmd.Env = append(cmd.Env, k+"="+v) + } + } + + c.settings.Logger().DebugF("Command prepared: %#v\n", cmd) + + return cmd, cancel +} + +func (c *Command) Sudo(_ context.Context) { + c.sudo = true +} + +func (c *Command) WithTimeout(t time.Duration) { + c.timeout = t +} + +func (c *Command) WithEnv(env map[string]string) { + c.env = env +} + +func (c *Command) WithStdoutHandler(h func(line string)) { + c.stdoutLineHandler = h +} + +func (c *Command) WithStderrHandler(h func(line string)) { + c.stderrLineHandler = h +} + +func (c *Command) StdoutBytes() []byte { + return c.stdout +} + +func (c *Command) StderrBytes() []byte { + return c.stderr +} + +// The rest are no-ops for local execution + +func (c *Command) Cmd(_ context.Context) {} +func (c *Command) WithSSHArgs(_ ...string) {} diff --git a/pkg/ssh/local/command_test.go b/pkg/ssh/local/command_test.go new file mode 100644 index 0000000..38d48a4 --- /dev/null +++ b/pkg/ssh/local/command_test.go @@ -0,0 +1,90 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/testssh" +) + +func TestCommandOutput(t *testing.T) { + s := require.New(t) + testFilePath := filepath.Join(os.TempDir(), "test") + tmpFile, err := os.Create(testFilePath) + s.NoError(err) + t.Cleanup(func() { + _ = os.Remove(testFilePath) + }) + + _, err = tmpFile.WriteString("Hello world") + s.NoError(err) + + cmd := NewCommand(sshtesting.CreateSettings(), "cat", testFilePath) + stdout, _, err := cmd.Output(context.Background()) + s.NoError(err) + s.Equal("Hello world", string(stdout)) +} + +func TestCommandCombinedOutput(t *testing.T) { + s := require.New(t) + testFilePath := filepath.Join(os.TempDir(), "test") + tmpFile, err := os.Create(testFilePath) + s.NoError(err) + t.Cleanup(func() { + _ = os.Remove(testFilePath) + }) + + _, err = tmpFile.WriteString("Hello world") + s.NoError(err) + + cmd := NewCommand(sshtesting.CreateSettings(), "cat", testFilePath) + stdout, err := cmd.CombinedOutput(context.Background()) + s.NoError(err) + s.Equal("Hello world", string(stdout)) +} + +func TestCommandRun(t *testing.T) { + s := require.New(t) + testFilePath := filepath.Join(os.TempDir(), "test") + tmpFile, err := os.Create(testFilePath) + s.NoError(err) + t.Cleanup(func() { + _ = os.Remove(testFilePath) + }) + + _, err = tmpFile.WriteString("Hello world") + s.NoError(err) + + cmd := NewCommand(sshtesting.CreateSettings(), "cat", testFilePath) + err = cmd.Run(context.Background()) + s.NoError(err) + s.Equal("Hello world", string(cmd.StdoutBytes())) + s.Nil(cmd.StderrBytes()) +} + +func TestCommandPipe(t *testing.T) { + s := require.New(t) + + cmd := NewCommand(sshtesting.CreateSettings(), "bash", "-c", `echo "Goodbye world" | sed "s/Goodbye/Hello/g"`) + s.NoError(cmd.Run(context.Background())) + s.Equal("Hello world", string(cmd.StdoutBytes())) + s.Nil(cmd.StderrBytes()) +} diff --git a/pkg/ssh/local/file.go b/pkg/ssh/local/file.go new file mode 100644 index 0000000..9b62414 --- /dev/null +++ b/pkg/ssh/local/file.go @@ -0,0 +1,122 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "context" + "io" + "os" + "path/filepath" +) + +type File struct{} + +func NewFile() *File { + return &File{} +} + +func (File) Upload(_ context.Context, srcPath, dstPath string) error { + if err := copyRecursively(srcPath, dstPath); err != nil { + return err + } + return nil +} + +func (File) Download(_ context.Context, srcPath, dstPath string) error { + if err := copyRecursively(srcPath, dstPath); err != nil { + return err + } + return nil +} + +func (File) UploadBytes(_ context.Context, data []byte, dstPath string) error { + if err := os.WriteFile(dstPath, data, 0666); err != nil { + return err + } + return nil +} + +func (File) DownloadBytes(_ context.Context, srcPath string) ([]byte, error) { + file, err := os.ReadFile(srcPath) + if err != nil { + return nil, err + } + return file, nil +} + +func copyRecursively(src string, dst string) error { + srcStat, err := os.Stat(src) + if err != nil { + return err + } + if !srcStat.IsDir() { + return copyFile(src, filepath.Join(dst, filepath.Base(src))) + } + + if err = os.MkdirAll(dst, srcStat.Mode()); err != nil { + return err + } + + srcEntries, err := os.ReadDir(src) + if err != nil { + return err + } + + for _, entry := range srcEntries { + srcEntryPath := filepath.Join(src, entry.Name()) + destEntryPath := filepath.Join(dst, entry.Name()) + + if entry.IsDir() { + if err = copyRecursively(srcEntryPath, destEntryPath); err != nil { + return err + } + } else { + if err = copyFile(srcEntryPath, destEntryPath); err != nil { + return err + } + } + } + + return nil +} + +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + srcStat, err := srcFile.Stat() + if err != nil { + return err + } + + destFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcStat.Mode()) + if err != nil { + return err + } + defer destFile.Close() + + if _, err = io.Copy(destFile, srcFile); err != nil { + return err + } + + if err = destFile.Sync(); err != nil { + return err + } + + return nil +} diff --git a/pkg/ssh/local/node.go b/pkg/ssh/local/node.go new file mode 100644 index 0000000..0d4deb0 --- /dev/null +++ b/pkg/ssh/local/node.go @@ -0,0 +1,59 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" +) + +type NodeInterface struct { + settings settings.Settings +} + +// NewDefaultNodeInterface +// Deprecated: +// use NewNodeInterface +func NewDefaultNodeInterface() *NodeInterface { + return NewNodeInterface(nil) +} + +func NewNodeInterface(sett settings.Settings) *NodeInterface { + return &NodeInterface{ + settings: sett, + } +} + +func (n *NodeInterface) Command(name string, args ...string) connection.Command { + logger := n.settings.Logger() + + logger.DebugLn("Starting NodeInterface.Command") + defer logger.DebugLn("Stop NodeInterface.Command") + + return NewCommand(n.settings, name, args...) +} + +func (n *NodeInterface) File() connection.File { + return NewFile() +} + +func (n *NodeInterface) UploadScript(scriptPath string, args ...string) connection.Script { + logger := n.settings.Logger() + + logger.DebugLn("Starting NodeInterface.UploadScript") + defer logger.DebugLn("Stop NodeInterface.UploadScript") + + return NewScript(n.settings, scriptPath, args...) +} diff --git a/pkg/ssh/local/script.go b/pkg/ssh/local/script.go new file mode 100644 index 0000000..70575c0 --- /dev/null +++ b/pkg/ssh/local/script.go @@ -0,0 +1,134 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/deckhouse/lib-connection/pkg/settings" +) + +type Script struct { + settings settings.Settings + + scriptPath string + args []string + env map[string]string + sudo bool + stdoutLineHandler func(line string) + timeout time.Duration + cleanupAfterRun bool +} + +func NewScript(sett settings.Settings, path string, args ...string) *Script { + return &Script{ + scriptPath: path, + args: args, + settings: sett, + } +} + +func (s *Script) Execute(ctx context.Context) (stdout []byte, err error) { + cmd := NewCommand(s.settings, s.scriptPath, s.args...) + if s.sudo { + cmd.Sudo(ctx) + } + + if s.timeout > 0 { + cmd.WithTimeout(s.timeout) + } + if s.env != nil { + cmd.WithEnv(s.env) + } + if s.stdoutLineHandler != nil { + cmd.WithStdoutHandler(s.stdoutLineHandler) + } + + if s.cleanupAfterRun { + defer os.Remove(cmd.program) + } + + err = cmd.Run(ctx) + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitErr.Stderr = cmd.StderrBytes() + } + + err = fmt.Errorf("Execute locally failed: %w", err) + } + + return cmd.StdoutBytes(), nil +} + +func (s *Script) ExecuteBundle(ctx context.Context, parentDir, bundleDir string) (stdout []byte, err error) { + srcPath := filepath.Join(parentDir, bundleDir) + dstPath := filepath.Join("/var/lib/", bundleDir) + _ = os.RemoveAll(dstPath) // Cleanup from previous runs + if err = copyRecursively(srcPath, dstPath); err != nil { + return nil, fmt.Errorf("copy bundle to /var/lib/%s: %w", bundleDir, err) + } + + cmd := NewCommand(s.settings, filepath.Join("/var/lib", bundleDir, s.scriptPath), s.args...) + if s.timeout > 0 { + cmd.WithTimeout(s.timeout) + } + if s.env != nil { + cmd.WithEnv(s.env) + } + if s.stdoutLineHandler != nil { + cmd.WithStdoutHandler(s.stdoutLineHandler) + } + if s.sudo { + cmd.Sudo(ctx) + } + + if err = cmd.Run(ctx); err != nil { + s.settings.Logger().DebugF("Execute bundle failed: stdout: %s\n\nstderr: %s\n", cmd.StdoutBytes(), cmd.StderrBytes()) + return nil, fmt.Errorf("Execute bundle failed: %w", err) + } + + return cmd.StdoutBytes(), nil +} + +func (s *Script) Sudo() { + s.sudo = true +} + +func (s *Script) WithStdoutHandler(handler func(string)) { + s.stdoutLineHandler = handler +} + +func (s *Script) WithTimeout(timeout time.Duration) { + s.timeout = timeout +} + +func (s *Script) WithEnvs(envs map[string]string) { + s.env = envs +} + +func (s *Script) WithCleanupAfterExec(doCleanup bool) { + s.cleanupAfterRun = doCleanup +} + +func (s *Script) WithCommanderMode(bool) {} + +func (s *Script) WithExecuteUploadDir(string) {} diff --git a/pkg/ssh/local/script_test.go b/pkg/ssh/local/script_test.go new file mode 100644 index 0000000..02022da --- /dev/null +++ b/pkg/ssh/local/script_test.go @@ -0,0 +1,47 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/testssh" +) + +const testRunScript = `#! /bin/bash +echo $@ +exit 0` + +func TestScriptExecute(t *testing.T) { + t.SkipNow() + + s := require.New(t) + scriptPath := filepath.Join(os.TempDir(), "test_run.sh") + err := os.WriteFile(scriptPath, []byte(testRunScript), 0774) + s.NoError(err) + t.Cleanup(func() { + _ = os.Remove(scriptPath) + }) + + script := NewScript(sshtesting.CreateSettings(), "arg 1", "arg 2") + stdout, err := script.Execute(context.Background()) + s.NoError(err) + s.Equal(string(stdout), "arg 1 arg 2") +} diff --git a/pkg/ssh/session/session.go b/pkg/ssh/session/session.go new file mode 100644 index 0000000..5afb157 --- /dev/null +++ b/pkg/ssh/session/session.go @@ -0,0 +1,324 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "sort" + "strings" + "sync" +) + +type Input struct { + User string + Port string + BastionHost string + BastionPort string + BastionUser string + BastionPassword string + ExtraArgs string + AvailableHosts []Host + BecomePass string +} + +type AgentSettings struct { + PrivateKeys []AgentPrivateKey + + // runtime + AuthSock string +} + +type AgentPrivateKey struct { + Key string + Passphrase string +} + +func (s *AgentSettings) AuthSockEnv() string { + if s.AuthSock != "" { + return fmt.Sprintf("SSH_AUTH_SOCK=%s", s.AuthSock) + } + return "" +} + +func (s *AgentSettings) Clone() *AgentSettings { + return &AgentSettings{ + AuthSock: s.AuthSock, + PrivateKeys: append(make([]AgentPrivateKey, 0), s.PrivateKeys...), + } +} + +// TODO rename to Settings +// Session is used to store ssh settings +type Session struct { + // input + User string + Port string + BastionHost string + BastionPort string + BastionUser string + BastionPassword string + ExtraArgs string + BecomePass string + + AgentSettings *AgentSettings + + lock sync.RWMutex + host string + availableHosts []Host + remainingHosts []Host +} + +type Host struct { + Host string + Name string +} + +func (h *Host) String() string { + name := h.Name + if name != "" { + name = fmt.Sprintf("%s: ", name) + } + return fmt.Sprintf("%s%s", name, h.Host) +} + +type SortByName []Host + +func (h SortByName) Len() int { return len(h) } +func (h SortByName) Less(i, j int) bool { + if h[i].Name == h[j].Name { + return h[i].Host < h[j].Host + } else { + return h[i].Name < h[j].Name + } +} +func (h SortByName) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func NewSession(input Input) *Session { + s := &Session{ + User: input.User, + Port: input.Port, + BastionHost: input.BastionHost, + BastionPort: input.BastionPort, + BastionUser: input.BastionUser, + ExtraArgs: input.ExtraArgs, + BecomePass: input.BecomePass, + BastionPassword: input.BastionPassword, + } + + s.SetAvailableHosts(input.AvailableHosts) + + return s +} + +func (s *Session) Host() string { + defer s.lock.RUnlock() + s.lock.RLock() + return s.host +} + +// ChoiceNewHost choice new host for connection +func (s *Session) ChoiceNewHost() { + defer s.lock.Unlock() + s.lock.Lock() + + s.selectNewHost() +} + +func (s *Session) AddAvailableHosts(hosts ...Host) { + defer s.lock.Unlock() + s.lock.Lock() + + availableHostsMap := make(map[string]string, len(s.availableHosts)) + + for _, host := range s.availableHosts { + availableHostsMap[host.Host] = host.Name + } + + for _, host := range hosts { + availableHostsMap[host.Host] = host.Name + } + + availableHosts := make([]Host, 0, len(availableHostsMap)) + + for key, value := range availableHostsMap { + availableHosts = append(availableHosts, Host{Host: key, Name: value}) + } + + sort.Sort(SortByName(availableHosts)) + s.availableHosts = availableHosts + + s.resetUsedHosts() + s.selectNewHost() +} + +func (s *Session) RemoveAvailableHosts(hosts ...Host) { + defer s.lock.Unlock() + s.lock.Lock() + + availableHostsMap := make(map[string]string, len(s.availableHosts)) + + for _, host := range s.availableHosts { + availableHostsMap[host.Host] = host.Name + } + + for _, host := range hosts { + delete(availableHostsMap, host.Host) + } + + availableHosts := make([]Host, 0, len(availableHostsMap)) + + for key, value := range availableHostsMap { + availableHosts = append(availableHosts, Host{Host: key, Name: value}) + } + + sort.Sort(SortByName(availableHosts)) + s.availableHosts = availableHosts + + s.resetUsedHosts() + s.selectNewHost() +} + +// SetAvailableHosts +// Set Available hosts. Current host can choice +func (s *Session) SetAvailableHosts(hosts []Host) { + defer s.lock.Unlock() + s.lock.Lock() + + s.availableHosts = make([]Host, len(hosts)) + copy(s.availableHosts, hosts) + + s.resetUsedHosts() + s.selectNewHost() +} + +func (s *Session) AvailableHosts() []Host { + s.lock.RLock() + defer s.lock.RUnlock() + + return append(make([]Host, 0), s.availableHosts...) +} + +func (s *Session) CountHosts() int { + defer s.lock.RUnlock() + s.lock.RLock() + + return len(s.availableHosts) +} + +// RemoteAddress returns host or username@host +func (s *Session) RemoteAddress() string { + defer s.lock.RUnlock() + s.lock.RLock() + + addr := s.host + if s.User != "" { + addr = s.User + "@" + addr + } + return addr +} + +func (s *Session) String() string { + defer s.lock.RUnlock() + s.lock.RLock() + + builder := strings.Builder{} + builder.WriteString("ssh ") + + if s.BastionHost != "" { + builder.WriteString("-J ") + if s.BastionUser != "" { + builder.WriteString(fmt.Sprintf("%s@%s", s.BastionUser, s.BastionHost)) + } else { + builder.WriteString(s.BastionHost) + } + if s.BastionPort != "" { + builder.WriteString(fmt.Sprintf(":%s", s.BastionPort)) + } + builder.WriteString(" ") + } + + if s.User != "" { + builder.WriteString(fmt.Sprintf("%s@%s", s.User, s.host)) + } else { + builder.WriteString(s.host) + } + + if s.Port != "" && s.Port != "22" { + builder.WriteString(fmt.Sprintf(" -p %s", s.Port)) + } + + return builder.String() +} + +func (s *Session) Copy() *Session { + defer s.lock.RUnlock() + s.lock.RLock() + + ses := &Session{} + + ses.Port = s.Port + ses.User = s.User + ses.BastionHost = s.BastionHost + ses.BastionPort = s.BastionPort + ses.BastionUser = s.BastionUser + ses.BastionPassword = s.BastionPassword + ses.ExtraArgs = s.ExtraArgs + ses.host = s.host + + if s.AgentSettings != nil { + ses.AgentSettings = s.AgentSettings.Clone() + } + + ses.availableHosts = make([]Host, len(s.availableHosts)) + copy(ses.availableHosts, s.availableHosts) + + ses.resetUsedHosts() + + return ses +} + +// resetUsedHosts if all available host is used this function reset +func (s *Session) resetUsedHosts() { + s.remainingHosts = make([]Host, len(s.availableHosts)) + copy(s.remainingHosts, s.availableHosts) + s.host = "" +} + +// selectNewHost selects new host from available and updates remaining hosts +func (s *Session) selectNewHost() { + if len(s.availableHosts) == 0 { + s.host = "" + return + } + + hosts := make([]Host, len(s.availableHosts)) + copy(hosts, s.availableHosts) + hostIndx := 0 + if s.host != "" { + for i, host := range hosts { + if host.Host == s.host { + if i != len(hosts)-1 { + hostIndx = i + 1 + } + break + } + } + } + + host := hosts[hostIndx] + s.remainingHosts = append(hosts[:hostIndx], hosts[hostIndx+1:]...) + + s.host = host.Host +} diff --git a/pkg/ssh/session/session_test.go b/pkg/ssh/session/session_test.go new file mode 100644 index 0000000..3d9c7fe --- /dev/null +++ b/pkg/ssh/session/session_test.go @@ -0,0 +1,246 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCreatingNewSShSession(t *testing.T) { + host := Host{Host: "a", Name: "master-0"} + + ses := NewSession(Input{ + AvailableHosts: []Host{host}, + }) + + t.Run("Create settings with not empty AvailableHosts returns session struct without errors", func(t *testing.T) { + require.NotNil(t, ses) + }) + + t.Run("Create settings with not empty AvailableHosts sets hosts field", func(t *testing.T) { + require.Equal(t, ses.host, host.Host) + }) + + t.Run("Create settings with not empty AvailableHosts choices host from remainingHosts (not contains host in remainingHosts)", func(t *testing.T) { + require.NotContains(t, ses.remainingHosts, host) + }) +} + +func TestSession_SetNewAvailableHosts(t *testing.T) { + oldHost := Host{Host: "a", Name: "master-0"} + newHost := Host{Host: "b", Name: "master-1"} + + oldHostsList := []Host{oldHost} + newHostsList := []Host{newHost} + + tests := []struct { + name string + assert func(t *testing.T, s *Session) + }{ + { + name: "Set new available hosts sets new host", + assert: func(t *testing.T, s *Session) { + require.Equal(t, s.host, newHost.Host) + }, + }, + + { + name: "Set new available sets new available list", + assert: func(t *testing.T, s *Session) { + require.Equal(t, s.availableHosts, newHostsList) + }, + }, + + { + name: "Set new available hosts choices host from remainingHosts (not contains host in remainingHosts)", + assert: func(t *testing.T, s *Session) { + require.NotContains(t, s.remainingHosts, oldHost) + require.NotContains(t, s.remainingHosts, newHost) + }, + }, + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + s := NewSession(Input{ + AvailableHosts: oldHostsList, + }) + + s.SetAvailableHosts(newHostsList) + + tst.assert(t, s) + }) + } +} + +func TestSession_ChoiceNewHost(t *testing.T) { + t.Run("ChoiceNewHost should always return one host when setting contains one host", func(t *testing.T) { + host := Host{Host: "a", Name: "master-0"} + ses := NewSession(Input{ + AvailableHosts: []Host{host}, + }) + + for i := 0; i < 3; i++ { + ses.ChoiceNewHost() + require.Equal(t, ses.host, host.Host) + } + }) + + t.Run("With multiple hosts ChoiceNewHost does not repeat hosts for calling count - 1 times", func(t *testing.T) { + availableHosts := []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}} + ses := NewSession(Input{ + AvailableHosts: availableHosts, + }) + + choicedHosts := make(map[string]bool) + choicedHosts[ses.host] = true + + for i := 0; i < len(availableHosts)-1; i++ { + ses.ChoiceNewHost() + + require.NotContains(t, choicedHosts, ses.host) + + choicedHosts[ses.host] = true + } + }) + + t.Run("With multiple hosts ChoiceNewHost should resets remainingHosts", func(t *testing.T) { + availableHosts := []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}} + ses := NewSession(Input{ + AvailableHosts: availableHosts, + }) + + for i := 0; i < len(availableHosts); i++ { + ses.ChoiceNewHost() + } + + var remainedHosts []Host + for i, host := range availableHosts { + if host.Host == ses.host { + remainedHosts = append(availableHosts[:i], availableHosts[i+1:]...) + break + } + } + var expectedRemainedHosts []Host + expectedRemainedHosts = append(expectedRemainedHosts, remainedHosts...) + + require.Len(t, ses.remainingHosts, len(availableHosts)-1) + require.NotContains(t, ses.remainingHosts, ses.host) + + for _, h := range expectedRemainedHosts { + require.Contains(t, ses.remainingHosts, h) + } + }) +} + +func TestSession_AddAvailableHosts(t *testing.T) { + tests := []struct { + hosts []Host + newHosts []Host + expected []Host + }{ + { + newHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + expected: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + }, + { + newHosts: []Host{{Host: "b", Name: ""}, {Host: "a", Name: ""}, {Host: "c", Name: ""}}, + expected: []Host{{Host: "a", Name: ""}, {Host: "b", Name: ""}, {Host: "c", Name: ""}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + newHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + expected: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + }, + { + hosts: []Host{{Host: "c", Name: ""}, {Host: "b", Name: ""}, {Host: "a", Name: ""}}, + newHosts: []Host{{Host: "b", Name: ""}, {Host: "a", Name: ""}, {Host: "c", Name: ""}}, + expected: []Host{{Host: "a", Name: ""}, {Host: "b", Name: ""}, {Host: "c", Name: ""}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + newHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}, {Host: "d", Name: "master-3"}}, + expected: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}, {Host: "d", Name: "master-3"}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + newHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}}, + expected: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + newHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "d", Name: "master-3"}}, + expected: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}, {Host: "d", Name: "master-3"}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}}, + expected: []Host{{Host: "a", Name: "master-0"}}, + }, + } + + for _, test := range tests { + s := NewSession(Input{ + AvailableHosts: test.hosts, + }) + + s.AddAvailableHosts(test.newHosts...) + + availableHosts := s.AvailableHosts() + + require.Equal(t, test.expected, availableHosts) + } +} + +func TestSession_RemoveAvailableHosts(t *testing.T) { + tests := []struct { + hosts []Host + removeHosts []Host + expected []Host + }{ + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + expected: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + removeHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + expected: []Host{}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + removeHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}}, + expected: []Host{{Host: "c", Name: "master-2"}}, + }, + { + hosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "c", Name: "master-2"}}, + removeHosts: []Host{{Host: "a", Name: "master-0"}, {Host: "b", Name: "master-1"}, {Host: "d", Name: "master-3"}}, + expected: []Host{{Host: "c", Name: "master-2"}}, + }, + } + + for _, test := range tests { + s := NewSession(Input{ + AvailableHosts: test.hosts, + }) + + s.RemoveAvailableHosts(test.removeHosts...) + + availableHosts := s.AvailableHosts() + + require.Equal(t, test.expected, availableHosts) + } +} diff --git a/pkg/ssh/testssh/client.go b/pkg/ssh/testssh/client.go new file mode 100644 index 0000000..31ac5a2 --- /dev/null +++ b/pkg/ssh/testssh/client.go @@ -0,0 +1,797 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "math/rand" + "os" + "reflect" + "strings" + "sync" + "time" + + "github.com/deckhouse/lib-dhctl/pkg/log" + "github.com/name212/govalue" + + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-connection/pkg/ssh/utils" +) + +type ( + Bastion struct { + Host string + Port string + User string + + NoSession bool + } + UploadScriptProvider func(withBastion Bastion, scriptPath string, args ...string) *Script + CommandProvider func(withBastion Bastion, scriptPath string, args ...string) *Command + FileProvider func(bastion Bastion) *File + SwitchHandler func(s Switch) + + Switch struct { + Bastion Bastion + Session *session.Session + PrivateKeys []session.AgentPrivateKey + } +) + +type SSHProvider struct { + once bool + initSession *session.Session + initPrivateKeys []session.AgentPrivateKey + client *Client + + scriptProviders *providersMap[UploadScriptProvider] + commandProviders *providersMap[CommandProvider] + fileProviders *providersMap[FileProvider] + + switchHandler SwitchHandler + switches []Switch +} + +func NewSSHProvider(initSession *session.Session, once bool) *SSHProvider { + return &SSHProvider{ + initSession: initSession, + initPrivateKeys: make([]session.AgentPrivateKey, 0), + once: once, + + scriptProviders: newProvidersMap[UploadScriptProvider](), + commandProviders: newProvidersMap[CommandProvider](), + fileProviders: newOneProviderMap[FileProvider](), + + switches: make([]Switch, 0), + } +} + +func (p *SSHProvider) AddScriptProvider(host string, f UploadScriptProvider) *SSHProvider { + p.scriptProviders.add(host, f) + return p +} + +func (p *SSHProvider) AddCommandProvider(host string, f CommandProvider) *SSHProvider { + p.commandProviders.add(host, f) + return p +} + +func (p *SSHProvider) SetFileProvider(host string, f FileProvider) *SSHProvider { + p.fileProviders.add(host, f) + return p +} + +func (p *SSHProvider) WithInitPrivateKeys(k []session.AgentPrivateKey) *SSHProvider { + p.initPrivateKeys = k + return p +} + +func (p *SSHProvider) WithSwitchHandler(f SwitchHandler) *SSHProvider { + p.switchHandler = f + return p +} + +func (p *SSHProvider) Switches() []Switch { + return p.switches +} + +func (p *SSHProvider) Client() (connection.SSHClient, error) { + if p.initSession == nil { + return nil, fmt.Errorf("Init session is nil") + } + + if p.once { + if p.client == nil { + client, err := p.newClient(p.initSession, p.initPrivateKeys) + if err != nil { + return nil, err + } + p.client = client + } + + return p.client, nil + } + + return p.newClient(p.initSession, p.initPrivateKeys) +} +func (p *SSHProvider) SwitchClient(_ context.Context, sess *session.Session, privateKeys []session.AgentPrivateKey, _ connection.SSHClient) (connection.SSHClient, error) { + privateKeysCpy := make([]session.AgentPrivateKey, len(privateKeys)) + copy(privateKeysCpy, privateKeys) + + sessCopy := sess.Copy() + // copy reset current host + sessCopy.ChoiceNewHost() + + bastion := getBastion(sessCopy) + + s := Switch{ + Bastion: bastion, + Session: sessCopy, + PrivateKeys: privateKeysCpy, + } + + if !govalue.Nil(p.switchHandler) { + p.switchHandler(s) + } + + p.switches = append(p.switches, s) + + return p.newClient(sess, privateKeys) +} + +func (p *SSHProvider) InitSession() *session.Session { + c := p.initSession.Copy() + // copy reset current host + c.ChoiceNewHost() + + return c +} + +func (p *SSHProvider) newClient(session *session.Session, k []session.AgentPrivateKey) (*Client, error) { + c := NewClient(session, k) + + p.scriptProviders.copyTo(p.scriptProviders) + p.commandProviders.copyTo(c.commandProviders) + p.fileProviders.copyTo(c.fileProviders) + + err := c.Start() + return c, err +} + +func NewClient(session *session.Session, privKeys []session.AgentPrivateKey) *Client { + return &Client{ + Settings: settings.NewBaseProviders(settings.ProviderParams{ + LoggerProvider: log.SimpleLoggerProvider(log.NewSimpleLogger(log.LoggerOptions{ + IsDebug: true, + })), + IsDebug: true, + }), + SessionSettings: session, + privateKeys: privKeys, + + scriptProviders: newProvidersMap[UploadScriptProvider](), + commandProviders: newProvidersMap[CommandProvider](), + fileProviders: newOneProviderMap[FileProvider](), + } +} + +type Client struct { + Settings settings.Settings + + SessionSettings *session.Session + + commandProviders *providersMap[CommandProvider] + scriptProviders *providersMap[UploadScriptProvider] + fileProviders *providersMap[FileProvider] + + privateKeys []session.AgentPrivateKey + + mu sync.Mutex + + kubeProxies []*kubeProxy + started bool + stopped bool +} + +func (c *Client) AddScriptProvider(host string, f UploadScriptProvider) *Client { + c.scriptProviders.add(host, f) + return c +} + +func (c *Client) AddCommandProvider(host string, f CommandProvider) *Client { + c.commandProviders.add(host, f) + return c +} + +func (c *Client) SetFileProvider(host string, f FileProvider) *Client { + c.fileProviders.add(host, f) + return c +} + +func (c *Client) WithSettings(sett settings.Settings) *Client { + c.Settings = sett + return c +} + +func (c *Client) OnlyPreparePrivateKeys() error { + // Double start is safe here because for initializing private keys we are using sync.Once + return c.Start() +} + +func (c *Client) Start() error { + if c.SessionSettings == nil { + return fmt.Errorf("Possible bug in ssh client: session should be created before start") + } + + if c.isStopped() { + return fmt.Errorf("Possible bug in ssh client: client stopped") + } + + c.setStarted() + + return nil +} + +// Easy access to frontends + +// Tunnel is used to open local (L) and remote (R) tunnels +func (c *Client) Tunnel(address string) connection.Tunnel { + return &tunnel{address: address} +} + +// ReverseTunnel is used to open remote (R) tunnel +func (c *Client) ReverseTunnel(address string) connection.ReverseTunnel { + return &reverseTunnel{address: address} +} + +func errorCommand(name, errStr string) connection.Command { + return NewCommand(nil).WithErr(fmt.Errorf("%s: '%s'", errStr, name)) +} + +// Command is used to run commands on remote server +func (c *Client) Command(name string, arg ...string) connection.Command { + if err := c.checkClient(); err != nil { + return errorCommand(name, err.Error()) + } + + host := c.SessionSettings.Host() + providers, err := c.commandProviders.get(host) + if err != nil { + return errorCommand(name, err.Error()) + } + + bastion := getBastion(c.SessionSettings) + + for _, provider := range providers { + cmd := provider(bastion, name, arg...) + if !govalue.Nil(cmd) { + return cmd + } + } + + return errorCommand(name, fmt.Sprintf("All commands providers (%d) returns nil command for host: %s", len(providers), host)) +} + +// KubeProxy is used to start kubectl proxy and create a tunnel from local port to proxy port +func (c *Client) KubeProxy() connection.KubeProxy { + p := &kubeProxy{} + c.kubeProxies = append(c.kubeProxies, p) + return p +} + +func errorFile(errStr string) connection.File { + err := errors.New(errStr) + + upload := func(data []byte, dstPath string) error { + return err + } + + download := func(srcPath string) ([]byte, error) { + return nil, err + } + + return NewFile(upload, download) +} + +// File is used to upload and download files and directories +func (c *Client) File() connection.File { + if err := c.checkClient(); err != nil { + return errorFile(err.Error()) + } + + host := c.SessionSettings.Host() + provider, err := c.fileProviders.get(host) + if err != nil { + return errorFile(err.Error()) + } + + bastion := getBastion(c.SessionSettings) + + // get returns error if not found + file := provider[0](bastion) + if govalue.Nil(file) { + return errorFile(fmt.Sprintf("File provider returns nil File for host: %s", host)) + } + + return file +} + +func errorScript(path, errStr string) connection.Script { + return NewScript(nil).WithError(fmt.Errorf("%s: %s", errStr, path)) +} + +func (c *Client) checkClient() error { + var errs []string + if c.SessionSettings == nil { + errs = append(errs, "Settings is nil") + } + if c.isStopped() { + errs = append(errs, "Already stopped") + } + if !c.isStarted() { + errs = append(errs, "Client not started") + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, ",")) + } + + return nil +} + +// UploadScript is used to upload script and execute it on remote server +func (c *Client) UploadScript(scriptPath string, args ...string) connection.Script { + if err := c.checkClient(); err != nil { + return errorScript(scriptPath, err.Error()) + } + + host := c.SessionSettings.Host() + + providers, err := c.scriptProviders.get(host) + if err != nil { + return errorScript(scriptPath, err.Error()) + } + + bastion := getBastion(c.SessionSettings) + + for _, provider := range providers { + s := provider(bastion, scriptPath, args...) + if !govalue.Nil(s) { + return s + } + } + + return errorScript(scriptPath, fmt.Sprintf("All script providers (%d) returns nil command for host: %s", len(providers), host)) +} + +// UploadScript is used to upload script and execute it on remote server +func (c *Client) Check() connection.Check { + f := func(sess *session.Session, cmd string) connection.Command { + return NewCommand([]byte("ok")) + } + return utils.NewCheck(f, c.SessionSettings, c.Settings) +} + +// Stop the client +func (c *Client) Stop() { + if c.isStopped() { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + for _, p := range c.kubeProxies { + p.StopAll() + } + + c.kubeProxies = nil + c.stopped = true +} + +func (c *Client) Session() *session.Session { + return c.SessionSettings +} + +func (c *Client) PrivateKeys() []session.AgentPrivateKey { + return c.privateKeys +} + +func (c *Client) RefreshPrivateKeys() error { + return nil +} + +// Loop Looping all available hosts +func (c *Client) Loop(fn connection.SSHLoopHandler) error { + var err error + + resetSession := func() { + c.SessionSettings = c.SessionSettings.Copy() + c.SessionSettings.ChoiceNewHost() + } + defer resetSession() + resetSession() + + for range c.SessionSettings.AvailableHosts() { + err = fn(c) + if err != nil { + return err + } + c.SessionSettings.ChoiceNewHost() + } + + return nil +} + +func (c *Client) setStarted() { + c.mu.Lock() + defer c.mu.Unlock() + + c.started = true +} + +func (c *Client) isStarted() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.started +} + +func (c *Client) isStopped() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.stopped +} + +func (c *Client) appendProxy(p *kubeProxy) { + c.mu.Lock() + defer c.mu.Unlock() + + c.kubeProxies = append(c.kubeProxies, p) +} + +type kubeProxy struct{} + +func (k *kubeProxy) Start(useLocalPort int) (port string, err error) { + i := rand.New(rand.NewSource(time.Now().UnixNano())).Int() + return fmt.Sprintf("%d", i), nil +} + +func (k *kubeProxy) StopAll() {} + +func (k *kubeProxy) Stop(startID int) {} + +type tunnel struct { + address string +} + +func (t *tunnel) Up() error { + return nil +} + +func (t *tunnel) HealthMonitor(errorOutCh chan<- error) {} + +func (t *tunnel) Stop() {} + +func (t *tunnel) String() string { + return "tunnel: " + t.address +} + +type reverseTunnel struct { + address string +} + +func (t *reverseTunnel) Up() error { + return nil +} + +func (t *reverseTunnel) StartHealthMonitor(ctx context.Context, checker connection.ReverseTunnelChecker, killer connection.ReverseTunnelKiller) { +} + +func (t *reverseTunnel) Stop() {} + +func (t *reverseTunnel) String() string { + return "reverseTunnel: " + t.address +} + +var newLine = []byte("\n") + +type Script struct { + stdOut []byte + err error + + handler func(string) + run func() +} + +func NewScript(stdOut []byte) *Script { + return &Script{ + stdOut: stdOut, + } +} + +func (t *Script) WithError(err error) *Script { + t.err = err + return t +} + +func (t *Script) WithRun(f func()) *Script { + t.run = f + return t +} + +func (t *Script) Execute(context.Context) (stdout []byte, err error) { + return t.execute() +} + +func (t *Script) ExecuteBundle(ctx context.Context, parentDir, bundleDir string) (stdout []byte, err error) { + return t.execute() +} + +func (t *Script) Sudo() {} +func (t *Script) WithStdoutHandler(handler func(string)) { + t.handler = handler +} +func (t *Script) WithTimeout(timeout time.Duration) {} +func (t *Script) WithEnvs(envs map[string]string) {} +func (t *Script) WithCleanupAfterExec(doCleanup bool) {} +func (t *Script) WithCommanderMode(enabled bool) {} +func (t *Script) WithExecuteUploadDir(dir string) {} +func (t *Script) execute() (stdout []byte, err error) { + if t.handler != nil { + t.handler(string(t.stdOut)) + } + if t.run != nil { + t.run() + } + + return t.stdOut, t.err +} + +type Command struct { + stdOut, stdErr []byte + err error + + onStart func() + run func() + + stdOutFunc func(line string) + stdErrFunc func(line string) +} + +func NewCommand(stdOut []byte) *Command { + return &Command{ + stdOut: stdOut, + } +} + +func (t *Command) WithStdErr(s []byte) *Command { + t.stdErr = s + return t +} + +func (t *Command) WithErr(err error) *Command { + t.err = err + return t +} + +func (t *Command) WithRun(f func()) *Command { + t.run = f + return t +} + +func (t *Command) Run(ctx context.Context) error { + return t.doRun() +} + +func (t *Command) Cmd(ctx context.Context) {} +func (t *Command) Sudo(ctx context.Context) {} + +func (t *Command) StdoutBytes() []byte { + return t.stdOut +} + +func (t *Command) StderrBytes() []byte { + return t.stdErr +} + +func (t *Command) Output(context.Context) ([]byte, []byte, error) { + return t.stdOut, t.stdErr, t.err +} + +func (t *Command) CombinedOutput(context.Context) ([]byte, error) { + return bytes.Join([][]byte{t.stdOut, t.stdErr}, newLine), t.err +} + +func (t *Command) OnCommandStart(fn func()) { + t.onStart = fn +} + +func (t *Command) WithEnv(env map[string]string) {} +func (t *Command) WithTimeout(timeout time.Duration) {} +func (t *Command) WithStdoutHandler(h func(line string)) { + t.stdOutFunc = h +} +func (t *Command) WithStderrHandler(h func(line string)) { + t.stdErrFunc = h +} +func (t *Command) WithSSHArgs(args ...string) {} + +func (t *Command) doRun() error { + if t.onStart != nil { + t.onStart() + } + + if t.stdOutFunc != nil { + for _, line := range bytes.Split(t.stdOut, newLine) { + t.stdOutFunc(string(line)) + } + } + if t.stdErrFunc != nil { + for _, line := range bytes.Split(t.stdErr, newLine) { + t.stdErrFunc(string(line)) + } + } + + if t.run != nil { + t.run() + } + + return t.err +} + +type ( + UploadFn func(data []byte, dstPath string) error + DownloadFn func(srcPath string) ([]byte, error) +) + +type File struct { + uploadFn UploadFn + downloadFn DownloadFn +} + +func NewFile(upload UploadFn, download DownloadFn) *File { + return &File{ + uploadFn: upload, + downloadFn: download, + } +} + +func (f *File) Upload(ctx context.Context, srcPath, dstPath string) error { + if govalue.Nil(f.uploadFn) { + return fmt.Errorf("uploadFn is nil for path '%s'", dstPath) + } + + data, err := os.ReadFile(srcPath) + if err != nil { + return err + } + + return f.uploadFn(data, dstPath) +} +func (f *File) Download(ctx context.Context, srcPath, dstPath string) error { + data, err := f.DownloadBytes(ctx, srcPath) + if err != nil { + return err + } + + return os.WriteFile(dstPath, data, os.ModePerm) +} + +func (f *File) UploadBytes(ctx context.Context, data []byte, remotePath string) error { + if govalue.Nil(f.uploadFn) { + return fmt.Errorf("uploadFn is nil for path '%s'", remotePath) + } + return f.uploadFn(data, remotePath) +} + +func (f *File) DownloadBytes(ctx context.Context, remotePath string) ([]byte, error) { + if govalue.Nil(f.downloadFn) { + return nil, fmt.Errorf("downloadFn is nil for path '%s'", remotePath) + } + return f.downloadFn(remotePath) +} + +type providersMap[T any] struct { + hostToProviders map[string][]T + hasOne bool +} + +func newProvidersMap[T any]() *providersMap[T] { + return &providersMap[T]{ + hostToProviders: make(map[string][]T), + hasOne: false, + } +} + +func newOneProviderMap[T any]() *providersMap[T] { + return &providersMap[T]{ + hostToProviders: make(map[string][]T), + } +} + +func (m *providersMap[T]) add(host string, provider T) { + mp := m.hostToProviders + if len(mp) == 0 { + mp = make(map[string][]T) + } + + providers, ok := mp[host] + if !ok || len(providers) == 0 { + providers = make([]T, 0, 1) + } + + if m.hasOne { + providers = []T{provider} + } else { + providers = append(providers, provider) + } + + mp[host] = providers + + m.hostToProviders = mp +} + +func (m *providersMap[T]) copyTo(dst *providersMap[T]) { + for host, providers := range m.hostToProviders { + for _, provider := range providers { + dst.add(host, provider) + } + } +} + +func (m *providersMap[T]) createErr(host, err string) error { + tp := reflect.TypeFor[T]() + return fmt.Errorf("Providers for %s %s for host '%s'", tp.String(), err, host) +} + +func (m *providersMap[T]) get(host string) ([]T, error) { + if len(m.hostToProviders) == 0 { + return nil, m.createErr(host, "not initialized") + } + + providers, ok := m.hostToProviders[host] + if !ok || len(providers) == 0 { + return nil, m.createErr(host, "no providers found") + } + + return providers, nil +} + +func getBastion(s *session.Session) Bastion { + if s == nil { + return Bastion{NoSession: true} + } + return Bastion{ + Host: s.BastionHost, + Port: s.BastionPort, + User: s.BastionUser, + + NoSession: false, + } +} + +func CreateSettings() settings.Settings { + return settings.NewBaseProviders(settings.ProviderParams{ + LoggerProvider: log.SimpleLoggerProvider(log.NewSimpleLogger(log.LoggerOptions{IsDebug: true})), + IsDebug: true, + }) +} diff --git a/pkg/ssh/utils/checks.go b/pkg/ssh/utils/checks.go new file mode 100644 index 0000000..e0a74a1 --- /dev/null +++ b/pkg/ssh/utils/checks.go @@ -0,0 +1,104 @@ +// Copyright 2022 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "sort" + "strings" + + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +const ( + // skip G101: Potential hardcoded credentials + //nolint:gosec + notPassedWarn = "SSH-hosts was not passed. Maybe you run converge in pod?" + notEnthoughtWarn = "Not enough master SSH-hosts." + tooManyWarn = "Too many master SSH-hosts. Maybe you want to delete nodes, but pass hosts for delete via --ssh-host?" + + checkHostsMsg = "Please check, is correct mapping node name to host?" + checkWarnMsg = `Warning! %s +If you lose connection to node, converge may not be finished. +Also, SSH connectivity to another nodes will not check before converge node. + +And be attentive when you create new control-plane nodes and change another control-plane instances both. +dhctl can not add new master IP's for connection. + +%s +Do you want to continue? +` +) + +func CheckSSHHosts(userPassedHosts []session.Host, nodesNames []string, phase string, runConfirm func(string) bool) (map[string]string, error) { + userPassedHostsLen := len(userPassedHosts) + replicas := len(nodesNames) + + nodeToHost := make(map[string]string) + for _, nodeName := range nodesNames { + nodeToHost[nodeName] = "" + } + + warnMsg := "" + + switch { + case userPassedHostsLen == 0: + warnMsg = notPassedWarn + case userPassedHostsLen < replicas: + warnMsg = notEnthoughtWarn + // Happens only when we make destructive changes to the only master in the cluster and + // to avoid reporting a warning that the number of replicas does not match the number + // of servers accessed by ssh when the number of masters is reduced. + // 1 -> 3 -> update(0) -> (message) -> 1 + case userPassedHostsLen == 3 && replicas == 1 && phase == "scale-to-single-master": + warnMsg = "" + case userPassedHostsLen > replicas: + warnMsg = tooManyWarn + } + + var nodesSorted []string + nodesSorted = append(nodesSorted, nodesNames...) + sort.Strings(nodesSorted) + + forConfirmation := make([]string, userPassedHostsLen) + + for i, host := range userPassedHosts { + nodeNameTrue := false + for _, nodeName := range nodesSorted { + if nodeName == host.Name { + forConfirmation[i] = fmt.Sprintf("%s -> %s", nodeName, host.Host) + nodeToHost[nodeName] = host.Host + nodeNameTrue = true + break + } + } + if !nodeNameTrue { + forConfirmation[i] = fmt.Sprintf("%s -> %s (ignored)", host.Name, host.Host) + } + } + + if warnMsg != "" { + msg := fmt.Sprintf(checkWarnMsg, warnMsg, strings.Join(forConfirmation, "\n")) + if !runConfirm(msg) { + return nil, fmt.Errorf("Hosts warning was not confirmed.") + } + } else { + msg := fmt.Sprintf("%s\n%s\n", checkHostsMsg, strings.Join(forConfirmation, "\n")) + if !runConfirm(msg) { + return nil, fmt.Errorf("Node name to host mapping was not confirmed. Please pass hosts in order.") + } + } + return nodeToHost, nil +} diff --git a/pkg/ssh/utils/checks_test.go b/pkg/ssh/utils/checks_test.go new file mode 100644 index 0000000..b862018 --- /dev/null +++ b/pkg/ssh/utils/checks_test.go @@ -0,0 +1,151 @@ +// Copyright 2022 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/deckhouse/lib-connection/pkg/ssh/session" +) + +func TestSSHHostChecks(t *testing.T) { + t.Run("User passed incorrect count of hosts", func(t *testing.T) { + nodes := []string{ + "pref-master-0", + "pref-master-1", + } + cases := []struct { + title string + passedHost []session.Host + warnMsg string + }{ + { + title: "User passed zero hosts", + passedHost: nil, + warnMsg: notPassedWarn, + }, + + { + title: "User passed less hosts than nodes", + passedHost: []session.Host{{Host: "127.0.0.1", Name: "master-0"}}, + warnMsg: notEnthoughtWarn, + }, + + { + title: "User passed more hosts than nodes", + passedHost: []session.Host{{Host: "127.0.0.1", Name: "master-0"}, {Host: "127.0.0.2"}, {Host: "127.0.0.3"}}, + warnMsg: tooManyWarn, + }, + } + + for _, c := range cases { + t.Run(c.title, func(t *testing.T) { + t.Run("Does not confirm incorrect host warning", func(t *testing.T) { + nodesToHosts, err := CheckSSHHosts(c.passedHost, nodes, "all-nodes", func(msg string) bool { + require.Contains(t, msg, c.warnMsg, "Incorrect warning") + return false + }) + + require.Error(t, err, "should return error") + require.Nil(t, nodesToHosts) + }) + + t.Run("Confirm incorrect host warning", func(t *testing.T) { + nodesToHosts, err := CheckSSHHosts(c.passedHost, nodes, "all-nodes", func(msg string) bool { + require.Contains(t, msg, c.warnMsg, "Incorrect warning") + return true + }) + + require.NoError(t, err) + require.Equal(t, nodesToHosts, map[string]string{ + "pref-master-0": "", + "pref-master-1": "", + }, "should return empty host for every node") + }) + }) + } + }) + + t.Run("User passed correct count of hosts", func(t *testing.T) { + assertNotShowIncorrectCountWarn := func(t *testing.T, msg string) { + for _, w := range []string{notPassedWarn, notEnthoughtWarn, tooManyWarn} { + require.NotContains(t, msg, w, "should not show incorrect count warning") + } + } + + passedHosts := []session.Host{{Host: "127.0.0.1", Name: "master-0"}, {Host: "127.0.0.2", Name: "master-1"}, {Host: "127.0.0.3", Name: "master-2"}} + + t.Run("Does not confirm nodes to hosts mapping", func(t *testing.T) { + nodes := []string{"master-0", "master-1", "master-2"} + nodesToHosts, err := CheckSSHHosts(passedHosts, nodes, "all-nodes", func(msg string) bool { + assertNotShowIncorrectCountWarn(t, msg) + + require.Contains(t, msg, checkHostsMsg, "Incorrect message") + + return false + }) + + require.Error(t, err, "should return error") + require.Nil(t, nodesToHosts) + }) + + t.Run("Confirms nodes to hosts mapping", func(t *testing.T) { + confirmFunc := func(msg string) bool { + assertNotShowIncorrectCountWarn(t, msg) + + require.Contains(t, msg, checkHostsMsg, "Incorrect message") + + return true + } + + t.Run("Nodes passed in incorrect order", func(t *testing.T) { + nodes := []string{"master-1", "master-2", "master-0"} + nodesToHosts, err := CheckSSHHosts(passedHosts, nodes, "all-nodes", confirmFunc) + + require.NoError(t, err, "should not return error") + require.Equal(t, nodesToHosts, map[string]string{ + "master-0": passedHosts[0].Host, + "master-1": passedHosts[1].Host, + "master-2": passedHosts[2].Host, + }, "nodes names should sorted") + }) + + t.Run("Nodes passed in correct order", func(t *testing.T) { + nodes := []string{"master-0", "master-1", "master-2"} + nodesToHosts, err := CheckSSHHosts(passedHosts, nodes, "all-nodes", confirmFunc) + + require.NoError(t, err, "should not return error") + require.Equal(t, nodesToHosts, map[string]string{ + "master-0": passedHosts[0].Host, + "master-1": passedHosts[1].Host, + "master-2": passedHosts[2].Host, + }, "nodes names should sorted") + }) + + t.Run("Reducing a cluster to a single master", func(t *testing.T) { + nodes := []string{"master-0"} + nodesToHosts, err := CheckSSHHosts(passedHosts, nodes, "scale-to-single-master", confirmFunc) + + require.NoError(t, err, "should not return error") + require.Equal(t, nodesToHosts, map[string]string{ + "master-0": passedHosts[0].Host, + }, "nodes name must be the same") + }) + + }) + }) +} diff --git a/pkg/ssh/utils/closer.go b/pkg/ssh/utils/closer.go new file mode 100644 index 0000000..dec61ba --- /dev/null +++ b/pkg/ssh/utils/closer.go @@ -0,0 +1,46 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "errors" + "io" + "net" + + "github.com/name212/govalue" +) + +type PresentCloseHandler func(isPresent bool) + +func callPresentHandlers(isPresent bool, presentHandlers ...PresentCloseHandler) { + for _, p := range presentHandlers { + p(isPresent) + } +} + +func SafeClose(conn io.Closer, presentHandlers ...PresentCloseHandler) error { + if govalue.Nil(conn) { + callPresentHandlers(false, presentHandlers...) + return nil + } + + callPresentHandlers(true, presentHandlers...) + + if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + + return nil +} diff --git a/pkg/ssh/utils/execute_path.go b/pkg/ssh/utils/execute_path.go new file mode 100644 index 0000000..f447d09 --- /dev/null +++ b/pkg/ssh/utils/execute_path.go @@ -0,0 +1,51 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "path/filepath" + + "github.com/deckhouse/lib-connection/pkg/settings" +) + +type ScriptPath interface { + IsSudo() bool + UploadDir() string + Settings() settings.Settings +} + +// ExecuteRemoteScriptPath +// deprecated - ugly solution +func ExecuteRemoteScriptPath(u ScriptPath, scriptName string, full bool) string { + root := "" + if u.IsSudo() { + root = u.Settings().NodeTmpDir() + } + + if uploadDir := u.UploadDir(); uploadDir != "" { + root = uploadDir + } + + if root == "" { + res := "." + if full { + res = res + "/" + scriptName + } + + return res + } + + return filepath.Join(root, scriptName) +} diff --git a/pkg/ssh/utils/execute_path_test.go b/pkg/ssh/utils/execute_path_test.go new file mode 100644 index 0000000..3cc4049 --- /dev/null +++ b/pkg/ssh/utils/execute_path_test.go @@ -0,0 +1,123 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" + + "github.com/deckhouse/lib-connection/pkg/settings" + sshtesting "github.com/deckhouse/lib-connection/pkg/ssh/gossh/testing" + "github.com/stretchr/testify/require" +) + +type testScriptPath struct { + sudo bool + uploadDir string + settings settings.Settings +} + +func (s *testScriptPath) IsSudo() bool { + return s.sudo +} +func (s *testScriptPath) UploadDir() string { + return s.uploadDir +} +func (s *testScriptPath) Settings() settings.Settings { + return s.settings +} + +func TestExecuteRemoteScriptPath(t *testing.T) { + type test struct { + name string + sudo bool + uploadDir string + expectedPath string + full bool + } + + const ( + testWithUploadDir = "/tmp" + testWithUploadDirExpected = "/tmp/script" + testWithSudoExpected = "/opt/deckhouse/tmp/script" + ) + + tests := []test{ + { + name: "with upload dir no sudo no full", + sudo: false, + uploadDir: testWithUploadDir, + expectedPath: testWithUploadDirExpected, + full: false, + }, + + { + name: "with upload dir no sudo with full", + sudo: false, + uploadDir: testWithUploadDir, + expectedPath: testWithUploadDirExpected, + full: true, + }, + + { + name: "with upload dir with sudo with full", + sudo: true, + uploadDir: testWithUploadDir, + expectedPath: testWithUploadDirExpected, + full: true, + }, + + { + name: "without upload dir no sudo no full", + sudo: false, + uploadDir: "", + expectedPath: ".", + full: false, + }, + { + name: "without upload dir with sudo no full", + sudo: true, + uploadDir: "", + expectedPath: testWithSudoExpected, + full: false, + }, + { + name: "without upload dir with sudo with full", + sudo: true, + uploadDir: "", + expectedPath: testWithSudoExpected, + full: true, + }, + { + name: "without upload dir without sudo with full", + sudo: false, + uploadDir: "", + expectedPath: "./script", + full: true, + }, + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + sshSettings := sshtesting.CreateDefaultTestSettings(sshtesting.ShouldNewTest(t, "")) + script := &testScriptPath{ + sudo: tst.sudo, + uploadDir: tst.uploadDir, + settings: sshSettings, + } + res := ExecuteRemoteScriptPath(script, "script", tst.full) + require.Equal(t, tst.expectedPath, res) + }) + } +} diff --git a/pkg/ssh/utils/matcher.go b/pkg/ssh/utils/matcher.go new file mode 100644 index 0000000..0503857 --- /dev/null +++ b/pkg/ssh/utils/matcher.go @@ -0,0 +1,96 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +// ByteSequenceMatcher can be used to match byte stream against a string +// byte by byte. +type ByteSequenceMatcher struct { + // settings + Pattern string + waitNonMatched bool + + patternBytes []byte + patternLen int + + // state + // index of a byte that should be matched + state int + patternFound bool + matchFound bool +} + +func NewByteSequenceMatcher(pattern string) *ByteSequenceMatcher { + b := []byte(pattern) + return &ByteSequenceMatcher{ + Pattern: pattern, + patternBytes: b, + patternLen: len(b), + state: 0, // need to check first byte + } +} + +func (m *ByteSequenceMatcher) WaitNonMatched() *ByteSequenceMatcher { + m.waitNonMatched = true + return m +} + +// Analyze matches Pattern from byte stream and ignores \r and \n after it. +// when match is not found, return n +// when match is found, return 0 +// return index (0 or more) of a first byte after pattern and \r, \n +// This behaviour is used to write bytes to Reader only after match is found. +func (m *ByteSequenceMatcher) Analyze(buf []byte) (n int) { + for i, b := range buf { + // ignore \r and \n + if b == '\r' || b == '\n' { + // reset pattern state + m.state = 0 + continue + } + + if m.matchFound { + return i + } + + if m.patternFound { + m.matchFound = true + return i + } + if b == m.patternBytes[m.state] { + m.state++ + } else { + m.state = 0 + } + if m.state == m.patternLen { + m.patternFound = true + if !m.waitNonMatched { + m.matchFound = true + return i + 1 + } + } + } + + return len(buf) +} + +func (m *ByteSequenceMatcher) Reset() { + m.matchFound = false + m.patternFound = false + m.state = 0 +} + +func (m *ByteSequenceMatcher) IsMatched() bool { + return m.matchFound +} diff --git a/pkg/ssh/utils/matcher_test.go b/pkg/ssh/utils/matcher_test.go new file mode 100644 index 0000000..8c5be0b --- /dev/null +++ b/pkg/ssh/utils/matcher_test.go @@ -0,0 +1,135 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestMatchSeveralBuffers(t *testing.T) { + m := NewByteSequenceMatcher("SUCCESS").WaitNonMatched() + + r := m.Analyze([]byte("SUCC")) + + if r != 4 { + t.Errorf("Should return len buf when match is not triggered: expect 4, got %d", r) + } + + r = m.Analyze([]byte("ESS")) + + if r != 3 { + t.Errorf("Should return len buf when match is not triggered: expect 3, got %d", r) + } + + r = m.Analyze([]byte("\r\n\r\n\n\nOutput")) + if r != 6 { + t.Errorf("Should return first non \r \n byte after match is triggered: expect 6, got %d", r) + } +} + +func TestMatchSeveralBuffersNoRN(t *testing.T) { + m := NewByteSequenceMatcher("SUCCESS").WaitNonMatched() + + r := m.Analyze([]byte("SUCC")) + + if r != 4 { + t.Errorf("Should return len buf when match is not triggered: expect 4, got %d", r) + } + + r = m.Analyze([]byte("ESS")) + + if r != 3 { + t.Errorf("Should return len buf when match is not triggered: expect 3, got %d", r) + } + + r = m.Analyze([]byte("Output")) + if r != 0 { + t.Errorf("Should return first non \r \n byte after match is triggered: expect 6, got %d", r) + } +} + +func TestMatchSeveralBuffersAlmostMatch(t *testing.T) { + m := NewByteSequenceMatcher("SUCCESS").WaitNonMatched() + + r := m.Analyze([]byte("SUCC")) + + if r != 4 { + t.Errorf("Should return len buf when match is not triggered: expect 4, got %d", r) + } + + r = m.Analyze([]byte("ES")) + + if r != 2 { + t.Errorf("Should return len buf when match is not triggered: expect 2, got %d", r) + } + + r = m.Analyze([]byte("-SUCC")) + + if r != 5 { + t.Errorf("Should return len buf when match is not triggered: expect 5, got %d", r) + } + + r = m.Analyze([]byte("ESS")) + + if r != 3 { + t.Errorf("Should return len buf when match is not triggered: expect 3, got %d", r) + } + + r = m.Analyze([]byte("Output")) + if r != 0 { + t.Errorf("Should return first non \r \n byte after match is triggered: expect 6, got %d", r) + } +} + +func TestMatchOneBuffer(t *testing.T) { + m := NewByteSequenceMatcher("SUCCESS").WaitNonMatched() + + before := []byte("Sometext\r\n\r\r\n") + pattern := []byte("SUCCESS\r") + after := []byte("More text") + + var buf []byte + + buf = append(buf, before...) + buf = append(buf, pattern...) + buf = append(buf, after...) + + r := m.Analyze(buf) + + expect := len(before) + len(pattern) + if r != expect { + t.Errorf("Should return first non \r \n byte after match is triggered: expect %d, got %d", expect, r) + } +} + +func TestMatchNoWait(t *testing.T) { + m := NewByteSequenceMatcher("SUCCESS") + + r := m.Analyze([]byte("SUCC")) + + if r != 4 { + t.Errorf("Should return len buf when match is not triggered: expect 4, got %d", r) + } + + r = m.Analyze([]byte("ESS")) + if r != 3 { + t.Errorf("Should return len buf when match is not triggered: expect 3, got %d", r) + } + + r = m.Analyze([]byte("\r\n\r\n\n\nOutput")) + if r != 6 { + t.Errorf("Should return first non \r \n byte after match is triggered: expect 6, got %d", r) + } +} diff --git a/pkg/ssh/utils/privatekeys.go b/pkg/ssh/utils/privatekeys.go new file mode 100644 index 0000000..bdff75f --- /dev/null +++ b/pkg/ssh/utils/privatekeys.go @@ -0,0 +1,58 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "errors" + "fmt" + "os" + + gossh "github.com/deckhouse/lib-gossh" + "golang.org/x/crypto/ssh" +) + +func GetSSHPrivateKey(keyPath string, passphrase string) (any, error) { + keyData, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("Reading key file %q got error: %w", keyPath, err) + } + + keyData = append(bytes.TrimSpace(keyData), '\n') + + var sshKey any + + if len(passphrase) > 0 { + sshKey, err = gossh.ParseRawPrivateKeyWithPassphrase(keyData, []byte(passphrase)) + } else { + sshKey, err = gossh.ParseRawPrivateKey(keyData) + } + + if err != nil { + var passphraseMissingError *gossh.PassphraseMissingError + switch { + case errors.As(err, &passphraseMissingError): + var err error + sshKey, err = ssh.ParseRawPrivateKeyWithPassphrase(keyData, []byte(passphrase)) + if err != nil { + return nil, fmt.Errorf("Wrong passphrase for ssh key") + } + default: + return nil, fmt.Errorf("Parsing private key %q got error: %w", keyPath, err) + } + } + + return sshKey, nil +} diff --git a/pkg/ssh/utils/reverse-tunnel-routines.go b/pkg/ssh/utils/reverse-tunnel-routines.go new file mode 100644 index 0000000..b922564 --- /dev/null +++ b/pkg/ssh/utils/reverse-tunnel-routines.go @@ -0,0 +1,127 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + + ssh "github.com/deckhouse/lib-connection/pkg" +) + +type BaseReverseTunnelRoutines[T any] struct { + uploadDir string + cleanup bool + + impl *T +} + +func newBaseReverseTunnel[T any](impl *T) *BaseReverseTunnelRoutines[T] { + return &BaseReverseTunnelRoutines[T]{ + impl: impl, + cleanup: false, + uploadDir: "", + } +} + +func (b *BaseReverseTunnelRoutines[T]) WithUploadDir(dir string) *T { + b.uploadDir = dir + return b.impl +} + +func (b *BaseReverseTunnelRoutines[T]) WithCleanup() *T { + b.cleanup = true + return b.impl +} + +func (b *BaseReverseTunnelRoutines[T]) SetUploadDirAndCleanup(dir string) *T { + b.WithUploadDir(dir) + b.WithCleanup() + + return b.impl +} + +func (b *BaseReverseTunnelRoutines[T]) prepareScript(script ssh.Script) { + if b.uploadDir != "" { + script.WithExecuteUploadDir(b.uploadDir) + } + + if b.cleanup { + script.WithCleanupAfterExec(b.cleanup) + } +} + +type RunScriptReverseTunnelChecker struct { + *BaseReverseTunnelRoutines[RunScriptReverseTunnelChecker] + + client ssh.SSHClient + scriptPath string +} + +func NewRunScriptReverseTunnelChecker(c ssh.SSHClient, scriptPath string) *RunScriptReverseTunnelChecker { + checker := &RunScriptReverseTunnelChecker{ + client: c, + scriptPath: scriptPath, + } + + checker.BaseReverseTunnelRoutines = newBaseReverseTunnel(checker) + + return checker +} + +func (s *RunScriptReverseTunnelChecker) CheckTunnel(ctx context.Context) (string, error) { + script := s.client.UploadScript(s.scriptPath) + + script.Sudo() + + s.prepareScript(script) + + out, err := script.Execute(ctx) + return string(out), err +} + +type RunScriptReverseTunnelKiller struct { + *BaseReverseTunnelRoutines[RunScriptReverseTunnelKiller] + + client ssh.SSHClient + scriptPath string +} + +func NewRunScriptReverseTunnelKiller(c ssh.SSHClient, scriptPath string) *RunScriptReverseTunnelKiller { + killer := &RunScriptReverseTunnelKiller{ + client: c, + scriptPath: scriptPath, + } + + killer.BaseReverseTunnelRoutines = newBaseReverseTunnel(killer) + + return killer +} + +func (s *RunScriptReverseTunnelKiller) KillTunnel(ctx context.Context) (string, error) { + script := s.client.UploadScript(s.scriptPath) + + script.Sudo() + + s.prepareScript(script) + + out, err := script.Execute(ctx) + return string(out), err +} + +type EmptyReverseTunnelKiller struct{} + +func (k EmptyReverseTunnelKiller) KillTunnel(context.Context) (string, error) { + return "", nil +} diff --git a/pkg/ssh/utils/tar/tar.go b/pkg/ssh/utils/tar/tar.go new file mode 100644 index 0000000..e141596 --- /dev/null +++ b/pkg/ssh/utils/tar/tar.go @@ -0,0 +1,72 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tar + +import ( + "archive/tar" + "fmt" + "io" + "os" + "path/filepath" +) + +func CreateTar(tarFilePath, baseDir, targetDir string) error { + tarFile, err := os.Create(tarFilePath) + if err != nil { + return fmt.Errorf("failed to create tar file: %w", err) + } + defer tarFile.Close() + + tarWriter := tar.NewWriter(tarFile) + defer tarWriter.Close() + + err = filepath.Walk(filepath.Join(baseDir, targetDir), func(filePath string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("failed to walk path: %w", err) + } + + relPath, err := filepath.Rel(baseDir, filePath) + if err != nil { + return fmt.Errorf("failed to get relative path: %w", err) + } + + header, err := tar.FileInfoHeader(info, relPath) + if err != nil { + return fmt.Errorf("failed to create tar header: %w", err) + } + + header.Name = relPath + + if err := tarWriter.WriteHeader(header); err != nil { + return fmt.Errorf("failed to write header to tar file: %w", err) + } + + if info.Mode().IsRegular() { + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + if _, err := io.Copy(tarWriter, file); err != nil { + return fmt.Errorf("failed to copy file to tar: %w", err) + } + } + + return nil + }) + + return err +} diff --git a/pkg/ssh/utils/tar/tar_test.go b/pkg/ssh/utils/tar/tar_test.go new file mode 100644 index 0000000..dc5f269 --- /dev/null +++ b/pkg/ssh/utils/tar/tar_test.go @@ -0,0 +1,94 @@ +// Copyright 2025 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tar + +import ( + "archive/tar" + "io" + "os" + "path/filepath" + "testing" +) + +func createTestDir(baseDir string) error { + testDir := filepath.Join(baseDir, "testdir") + err := os.MkdirAll(testDir, 0755) + if err != nil { + return err + } + + files := []string{"file1.txt", "file2.txt"} + for _, file := range files { + err := os.WriteFile(filepath.Join(testDir, file), []byte("This is a test file."), 0644) + if err != nil { + return err + } + } + + return nil +} + +func TestCreateTar(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "tar_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + err = createTestDir(tmpDir) + if err != nil { + t.Fatalf("Failed to create test directory: %v", err) + } + + tarFilePath := filepath.Join(tmpDir, "output.tar") + err = CreateTar(tarFilePath, tmpDir, "testdir") + if err != nil { + t.Fatalf("Failed to create tar file: %v", err) + } + + tarFile, err := os.Open(tarFilePath) + if err != nil { + t.Fatalf("Failed to open tar file: %v", err) + } + defer tarFile.Close() + + tarReader := tar.NewReader(tarFile) + + _, err = tarReader.Next() + if err != nil { + t.Fatalf("Failed to read from tar file: %v", err) + } + + expectedFiles := []string{"testdir/file1.txt", "testdir/file2.txt"} + for _, expectedFile := range expectedFiles { + header, err := tarReader.Next() + if err != nil { + t.Fatalf("Failed to read from tar file: %v", err) + } + + if header.Name != expectedFile { + t.Errorf("Expected %s, got %s", expectedFile, header.Name) + } + + content, err := io.ReadAll(tarReader) + if err != nil { + t.Fatalf("Failed to read file content from tar: %v", err) + } + + if string(content) != "This is a test file." { + t.Errorf("Content mismatch for %s", expectedFile) + } + } +} diff --git a/pkg/ssh/utils/waiting.go b/pkg/ssh/utils/waiting.go new file mode 100644 index 0000000..0fd260b --- /dev/null +++ b/pkg/ssh/utils/waiting.go @@ -0,0 +1,140 @@ +// Copyright 2021 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "time" + + ssh "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" + "github.com/deckhouse/lib-connection/pkg/ssh/session" + "github.com/deckhouse/lib-dhctl/pkg/retry" +) + +var defaultAvailabilityOpts = []retry.ParamsBuilderOpt{ + retry.WithAttempts(50), + retry.WithWait(5 * time.Second), +} + +type CommandConsumer func(*session.Session, string) ssh.Command + +type Check struct { + settings settings.Settings + + Session *session.Session + createCommand CommandConsumer + delay time.Duration +} + +func NewCheck(createCommand CommandConsumer, sess *session.Session, sett settings.Settings) *Check { + return &Check{ + Session: sess, + createCommand: createCommand, + settings: sett, + } +} + +func (c *Check) WithDelaySeconds(seconds int) ssh.Check { + c.delay = time.Duration(seconds) * time.Second + return c +} + +func (c *Check) AwaitAvailability(ctx context.Context, loopParams retry.Params) error { + if c.Session.Host() == "" { + return fmt.Errorf("Empty host for connection received") + } + + select { + case <-time.After(c.delay): + case <-ctx.Done(): + return ctx.Err() + } + + logger := c.settings.Logger() + retryParams := retry.SafeCloneOrNewParams(loopParams, defaultAvailabilityOpts...). + WithLogger(logger). + WithName("Waiting for SSH connection") + + return retry.NewLoopWithParams(retryParams).RunContext(ctx, func() error { + host := c.Session.Host() + logger.InfoF("Try to connect to host: %v", host) + + output, err := c.ExpectAvailable(ctx) + if err == nil { + logger.InfoF("Successfully connected to host: %v", host) + return nil + } + + target := c.Session.Host() + + logger.InfoF("Connection attempt failed to host: %v", target) + + c.Session.ChoiceNewHost() + + return fmt.Errorf("SSH error: %s\nSSH connect failed to %s: %s", err.Error(), target, string(output)) + }) +} + +func (c *Check) CheckAvailability(ctx context.Context) error { + if c.Session.Host() == "" { + return fmt.Errorf("Empty host for connection received") + } + + logger := c.settings.Logger() + + logger.InfoF("Try to connect to %v host", c.Session.Host()) + output, err := c.ExpectAvailable(ctx) + if err != nil { + logger.InfoF(string(output)) + return err + } + return nil +} + +func (c *Check) ExpectAvailable(ctx context.Context) ([]byte, error) { + cmd := c.createCommand(c.Session, "echo SUCCESS") + cmd.Cmd(ctx) + + output, _, err := cmd.Output(ctx) + if err != nil { + var stderr []byte + if ee := errors.Unwrap(err); ee != nil { + var exitErr *exec.ExitError + if errors.As(ee, &exitErr) && len(exitErr.Stderr) > 0 { + stderr = exitErr.Stderr + } + } + if len(stderr) == 0 { + stderr = []byte(err.Error()) + } + + return stderr, err + } + + if strings.Contains(string(output), "SUCCESS") { + return nil, nil + } + + return output, fmt.Errorf("SSH command output should contain \"SUCCESS\", error: %w", err) +} + +func (c *Check) String() string { + return c.Session.String() +} diff --git a/pkg/ssh/wrapper.go b/pkg/ssh/wrapper.go new file mode 100644 index 0000000..de29669 --- /dev/null +++ b/pkg/ssh/wrapper.go @@ -0,0 +1,65 @@ +// Copyright 2024 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "github.com/name212/govalue" + + connection "github.com/deckhouse/lib-connection/pkg" + "github.com/deckhouse/lib-connection/pkg/settings" +) + +type NodeInterfaceWrapper struct { + settings settings.Settings + + sshClient connection.SSHClient +} + +func NewNodeInterfaceWrapper(sshClient connection.SSHClient, sett settings.Settings) *NodeInterfaceWrapper { + if govalue.Nil(sshClient) { + return nil + } + + return &NodeInterfaceWrapper{ + sshClient: sshClient, + settings: sett, + } +} + +func (n *NodeInterfaceWrapper) Command(name string, args ...string) connection.Command { + logger := n.settings.Logger() + + logger.DebugLn("Starting NodeInterfaceWrapper.command") + defer logger.DebugLn("Stop NodeInterfaceWrapper.command") + + return n.sshClient.Command(name, args...) +} + +func (n *NodeInterfaceWrapper) File() connection.File { + return n.sshClient.File() +} + +func (n *NodeInterfaceWrapper) UploadScript(scriptPath string, args ...string) connection.Script { + logger := n.settings.Logger() + + logger.DebugLn("Starting NodeInterfaceWrapper.UploadScript") + defer logger.DebugLn("Stop NodeInterfaceWrapper.UploadScript") + + return n.sshClient.UploadScript(scriptPath, args...) +} + +func (n *NodeInterfaceWrapper) Client() connection.SSHClient { + return n.sshClient +} diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..4cd37af --- /dev/null +++ b/tests/README.md @@ -0,0 +1,3 @@ +# lib-connection/tests + +Module for testing purposes like SSHConfig loading inside another module, because we use embed openapi spec inside package \ No newline at end of file diff --git a/tests/go.mod b/tests/go.mod new file mode 100644 index 0000000..5ee6068 --- /dev/null +++ b/tests/go.mod @@ -0,0 +1,14 @@ +module github.com/deckhouse/lib-connection/tests + +go 1.25.5 + +require github.com/deckhouse/lib-connection v0.0.0-00010101000000-000000000000 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/deckhouse/lib-connection => ../ diff --git a/tests/go.sum b/tests/go.sum new file mode 100644 index 0000000..c4c1710 --- /dev/null +++ b/tests/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/ssh/embed_load_test.go b/tests/ssh/embed_load_test.go new file mode 100644 index 0000000..221999a --- /dev/null +++ b/tests/ssh/embed_load_test.go @@ -0,0 +1,42 @@ +// Copyright 2026 Flant JSC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "os" + "path/filepath" + "testing" + + "github.com/deckhouse/lib-connection/pkg/ssh/config" + "github.com/stretchr/testify/require" +) + +func TestEmbedLoadOpenAPISpec(t *testing.T) { + assertSpec(t, "ssh_host_configuration.yaml", config.HostOpenAPISpec) + assertSpec(t, "ssh_configuration.yaml", config.ConfigurationOpenAPISpec) +} + +func assertSpec(t *testing.T, fileName string, specProvider func() string) { + actualContent := specProvider() + require.NotEmpty(t, actualContent, "spec provider should not be empty for file %s", fileName) + + path := filepath.Join("../../pkg/ssh/config/openapi/", fileName) + fullPath, err := filepath.Abs(path) + + expectedContent, err := os.ReadFile(fullPath) + require.NoError(t, err, "failed to read host spec file %s; full: %s", fileName, fullPath) + + require.Equal(t, string(expectedContent), actualContent, "expected content does not match for file %s", fileName) +}