Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added logic to handle for unit tests #104

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 152 additions & 85 deletions api/src/gmsa_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
static const std::vector<char> invalid_characters = {
'&', '|', ';', '$', '*', '?', '<', '>', '!',' '};

static const std::vector<char> invalid_characters_service_name = {
'&', '|', ';', '$', '*', '?', '<', '>', '!',' ', '/'};

std::string dummy_credspec =
"{\"CmsPlugins\":[\"ActiveDirectory\"],\"DomainJoinConfig\":{\"Sid\":\"S-1-5-21-4066351383-705263209-1606769140\",\"MachineAccountName\":\"webapp01\",\"Guid\":\"ac822f13-583e-49f7-aa7b-284f9a8c97b6\",\"DnsTreeName\":\"contoso.com\",\"DnsName\":\"contoso.com\",\"NetBiosName\":\"contoso\"},\"ActiveDirectoryConfig\":{\"GroupManagedServiceAccounts\":[{\"Name\":\"webapp01\",\"Scope\":\"contoso.com\"},{\"Name\":\"webapp01\",\"Scope\":\"contoso\"}],\"HostAccountConfig\":{\"PortableCcgVersion\":\"1\",\"PluginGUID\":\"{859E1386-BDB4-49E8-85C7-3070B13920E1}\",\"PluginInput\":{\"CredentialArn\":\"arn:aws:secretsmanager:us-west-2:123456789:secret:gMSAUserSecret-PwmPaO\"}}}}";

Expand All @@ -53,6 +56,38 @@ bool contains_invalid_characters_in_credentials( const std::string& value )
return result;
}

/**
*
* @param value - string input that has to be validated
* @return true or false if string contains or not contains invalid characters
*/
bool contains_invalid_characters_in_service_name( const std::string& value )
{
bool result = false;
// Iterate over all characters in invalid_path_characters vector
for ( const char& ch : invalid_characters_service_name )
{
// Check if character exist in string
if ( value.find( ch ) != std::string::npos )
{
result = true;
break;
}
}
return result;
}


bool IsTestInvocationForUnitTests(std::string arn)
{
std::string substr = "testcfspec";
if (arn.find(substr) != std::string::npos) {
return true;
}
return false;
}


volatile sig_atomic_t* pthread_shutdown_signal = nullptr;

/**
Expand Down Expand Up @@ -317,106 +352,134 @@ class CredentialsFetcherImpl final
std::string username = "";
std::string password = "";
std::string domain = "";
bool isTest = false;

std::string err_msg;
if ( !accessId.empty() && !secretKey.empty() && !sessionToken.empty() && !region.empty() ) {
for (int i = 0;
i < create_arn_krb_request_.credspec_arns_size(); i++) {
creds_fetcher::krb_ticket_info *krb_ticket_info =
new creds_fetcher::krb_ticket_info;
creds_fetcher::krb_ticket_arn_mapping *krb_ticket_arns =
new creds_fetcher::krb_ticket_arn_mapping;
if ( !accessId.empty() && !secretKey.empty() && !sessionToken.empty() && !region.empty() )
{
for ( int i = 0; i < create_arn_krb_request_.credspec_arns_size(); i++ )
{
creds_fetcher::krb_ticket_info* krb_ticket_info =
new creds_fetcher::krb_ticket_info;
creds_fetcher::krb_ticket_arn_mapping* krb_ticket_arns =
new creds_fetcher::krb_ticket_arn_mapping;

std::vector<std::string> results = split_string(create_arn_krb_request_
.credspec_arns(i),'#');
std::vector<std::string> results =
split_string( create_arn_krb_request_.credspec_arns( i ), '#' );

// get credentialspec contents:
Aws::Auth::AWSCredentials creds = get_credentials(accessId, secretKey, sessionToken);
std::string response = retrieve_credspec_from_s3(results[0], region, creds,false);
isTest = IsTestInvocationForUnitTests( results[0] );

if(response.empty())
if ( !isTest )
{
err_msg = "ERROR: credentialspec cannot be retrieved from s3";

std::cout << getCurrentTime() << '\t' << err_msg << std::endl;
break;
}
krb_ticket_arns->credential_spec_arn = results[0];
int parse_result = parse_cred_spec_domainless(
response,
krb_ticket_info, krb_ticket_arns);
// get credentialspec contents:
Aws::Auth::AWSCredentials creds =
get_credentials( accessId, secretKey, sessionToken );
std::string response =
retrieve_credspec_from_s3( results[0], region, creds, false );

// only add the ticket info if the parsing is successful
if (parse_result == 0)
{
// retrieve domainless user credentials
std::tuple<std::string, std::string> userCreds =
retrieve_credspec_from_secrets_manager(
krb_ticket_arns->credential_domainless_user_arn, region,
creds );
if ( response.empty() )
{
err_msg = "ERROR: credentialspec cannot be retrieved from s3";

username = std::get<0>( userCreds );
password = std::get<1>( userCreds );
domain = krb_ticket_info->domain_name;
std::cout << getCurrentTime() << '\t' << err_msg << std::endl;
break;
}
krb_ticket_arns->credential_spec_arn = results[0];
int parse_result = parse_cred_spec_domainless(
response, krb_ticket_info, krb_ticket_arns );

if ( !contains_invalid_characters_in_credentials( domain ) && !contains_invalid_characters_in_credentials( username ))
// only add the ticket info if the parsing is successful
if ( parse_result == 0 )
{
if ( !username.empty() && !password.empty() && !domain.empty() &&
username.length() < INPUT_CREDENTIALS_LENGTH &&
password.length() < INPUT_CREDENTIALS_LENGTH )
// retrieve domainless user credentials
std::tuple<std::string, std::string> userCreds =
retrieve_credspec_from_secrets_manager(
krb_ticket_arns->credential_domainless_user_arn, region,
creds );

username = std::get<0>( userCreds );
password = std::get<1>( userCreds );
domain = krb_ticket_info->domain_name;

if ( !contains_invalid_characters_in_credentials( domain ) &&
!contains_invalid_characters_in_credentials( username ) )
{
if ( !username.empty() && !password.empty() &&
!domain.empty() &&
username.length() < INPUT_CREDENTIALS_LENGTH &&
password.length() < INPUT_CREDENTIALS_LENGTH )
{

std::string krb_files_path = krb_files_dir + "/" + results[1];
std::vector<std::string> mountpath =
split_string( results[1], '/' );
std::string krb_files_path =
krb_files_dir + "/" + results[1];
std::vector<std::string> mountpath =
split_string( results[1], '/' );

// get taskid information
lease_id = mountpath[0];
// get taskid information
lease_id = mountpath[0];

krb_ticket_info->krb_file_path = krb_files_path;
krb_ticket_info->domainless_user = username;
krb_ticket_arns->krb_file_path = krb_files_path;
krb_ticket_info->krb_file_path = krb_files_path;
krb_ticket_info->domainless_user = username;
krb_ticket_arns->krb_file_path = krb_files_path;

// handle duplicate service accounts
if ( !krb_ticket_dirs.count( krb_files_path ) )
{
krb_ticket_dirs.insert( krb_files_path );
krb_ticket_info_list.push_back( krb_ticket_info );
krb_ticket_arn_mapping_list.push_back( krb_ticket_arns );
// handle duplicate service accounts
if ( !krb_ticket_dirs.count( krb_files_path ) )
{
krb_ticket_dirs.insert( krb_files_path );
krb_ticket_info_list.push_back( krb_ticket_info );
krb_ticket_arn_mapping_list.push_back(
krb_ticket_arns );
}
else
{
err_msg = "ERROR: found duplicate mount paths";
std::cout << getCurrentTime() << '\t' << err_msg
<< std::endl;
break;
}
}
else
{
err_msg = "ERROR: found duplicate mount paths";
std::cout << getCurrentTime() << '\t' << err_msg << std::endl;
err_msg =
"ERROR: domainless AD user credentials is not valid/ "
"credentials should not be more than 256 charaters";
std::cout << getCurrentTime() << '\t' << err_msg
<< std::endl;
break;
}
}
else
{
err_msg = "ERROR: domainless AD user credentials is not valid/ "
"credentials should not be more than 256 charaters";
std::cout << getCurrentTime() << '\t' << err_msg
<< std::endl;
err_msg = "ERROR: invalid domainName/username";
std::cout << getCurrentTime() << '\t' << err_msg << std::endl;
break;
}
}
else
{
err_msg = "ERROR: invalid domainName/username";
std::cout << getCurrentTime() << '\t' << err_msg
<< std::endl;
break;
}
}
else{
std::string krb_files_path =
krb_files_dir + "/" + results[1];
std::vector<std::string> mountpath =
split_string( results[1], '/' );

// get taskid information
lease_id = mountpath[0];
std::filesystem::create_directories( krb_files_path );
std::string dummyFile = krb_files_path+"/krb5cc";
std::ofstream o(dummyFile);
}
}
} else{
err_msg = "Error: access credentials should not be empty";
std::cout << getCurrentTime() << '\t' << err_msg << std::endl;
}
else
{
err_msg = "Error: access credentials should not be empty";
std::cout << getCurrentTime() << '\t' << err_msg << std::endl;
}

create_arn_krb_reply_.set_lease_id(lease_id);

if ( err_msg.empty() )
if ( err_msg.empty() && !isTest)
{
// create the kerberos tickets for the service accounts
for ( auto krb_ticket : krb_ticket_info_list )
Expand Down Expand Up @@ -481,7 +544,7 @@ class CredentialsFetcherImpl final
// And we are done! Let the gRPC runtime know we've finished, using the
// memory address of this instance as the uniquely identifying tag for
// the event.
if ( !err_msg.empty() )
if ( !err_msg.empty() && !isTest)
{
username = "xxxx";
password = "xxxx";
Expand All @@ -500,21 +563,27 @@ class CredentialsFetcherImpl final
}
else
{
for(auto arn_mapping : krb_ticket_arn_mapping_list)
if(!isTest)
{
credentialsfetcher::KerberosTicketArnResponse krb_ticket_response;
krb_ticket_response.set_credspec_arns(arn_mapping->credential_spec_arn);
krb_ticket_response.set_created_kerberos_file_paths(arn_mapping->krb_file_path);
create_arn_krb_reply_.add_krb_ticket_response_map()->CopyFrom(krb_ticket_response);
}
for ( auto arn_mapping : krb_ticket_arn_mapping_list )
{
credentialsfetcher::KerberosTicketArnResponse krb_ticket_response;
krb_ticket_response.set_credspec_arns(
arn_mapping->credential_spec_arn );
krb_ticket_response.set_created_kerberos_file_paths(
arn_mapping->krb_file_path );
create_arn_krb_reply_.add_krb_ticket_response_map()->CopyFrom(
krb_ticket_response );
}

username = "xxxx";
password = "xxxx";
accessId = "xxxx";
sessionToken = "xxxx";
secretKey = "xxxx";
// write the ticket information to meta data file
write_meta_data_json( krb_ticket_info_list, lease_id, krb_files_dir );
username = "xxxx";
password = "xxxx";
accessId = "xxxx";
sessionToken = "xxxx";
secretKey = "xxxx";
// write the ticket information to meta data file
write_meta_data_json( krb_ticket_info_list, lease_id, krb_files_dir );
}
status_ = FINISH;
create_arn_krb_responder_.Finish( create_arn_krb_reply_, grpc::Status::OK,
this );
Expand Down Expand Up @@ -2011,7 +2080,7 @@ int parse_cred_spec( std::string credspec_data, creds_fetcher::krb_ticket_info*
return -1;

if(contains_invalid_characters_in_credentials(domain_name) ||
contains_invalid_characters_in_credentials(service_account_name))
contains_invalid_characters_in_service_name(service_account_name))
{
std::cout << getCurrentTime() << '\t' << "ERROR: credentialspec file is not formatted"
" properly" <<
Expand Down Expand Up @@ -2071,7 +2140,7 @@ int parse_cred_spec_domainless( std::string credspec_data, creds_fetcher::krb_ti
return -1;

if(contains_invalid_characters_in_credentials(domain_name) ||
contains_invalid_characters_in_credentials(service_account_name))
contains_invalid_characters_in_service_name(service_account_name))
{
std::cout << getCurrentTime() << '\t' << "ERROR: credentialspec file is not formatted"
" properly" <<
Expand Down Expand Up @@ -2099,8 +2168,6 @@ int parse_cred_spec_domainless( std::string credspec_data, creds_fetcher::krb_ti
return 0;
}



/**
* ProcessCredSpecFile - Processes a provided credential spec file
* @param krb_files_dir - Kerberos TGT directory
Expand Down
7 changes: 7 additions & 0 deletions auth/kerberos/src/krb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,23 @@ int get_machine_krb_ticket( std::string domain_name, creds_fetcher::CF_logger& c
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: realm not found" << std::endl;
return -1;
}

cmd = exec_shell_cmd( "which kinit" );
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: kinit not found" << std::endl;
return -1;
}

cmd = exec_shell_cmd( "which ldapsearch" );
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: ldapsearch not found" << std::endl;
return -1;
}

Expand Down Expand Up @@ -212,13 +215,15 @@ int get_user_krb_ticket( std::string domain_name, std::string aws_sm_secret_name
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: kinit not found" << std::endl;
return -1;
}

cmd = exec_shell_cmd( "which ldapsearch" );
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: ldapsearch not found" << std::endl;
return -1;
}

Expand Down Expand Up @@ -291,13 +296,15 @@ int get_domainless_user_krb_ticket( std::string domain_name, std::string usernam
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: kinit not found" << std::endl;
return -1;
}

cmd = exec_shell_cmd( "which ldapsearch" );
rtrim( cmd.second );
if ( !check_file_permissions( cmd.second ) )
{
std::cout << getCurrentTime() << '\t' << "ERROR: ldapsearch not found" << std::endl;
return -1;
}

Expand Down
Loading