diff --git a/internal/runner/options.go b/internal/runner/options.go index 3d278348..bf256223 100644 --- a/internal/runner/options.go +++ b/internal/runner/options.go @@ -6,6 +6,7 @@ import ( "os" "os/user" "path/filepath" + "strings" "github.com/projectdiscovery/cloudlist/pkg/inventory" "github.com/projectdiscovery/cloudlist/pkg/schema" @@ -28,9 +29,9 @@ type Options struct { Config string // Config is the location of the config file. Output string // Output is the file to write found results too. ExcludePrivate bool // ExcludePrivate excludes private IPs from results - Provider []string // Provider specifies what providers to fetch assets for. + Providers goflags.StringSlice // Providers specifies what providers to fetch assets for. Id goflags.StringSlice // Id specifies what id's to fetch assets for. - Services []string // Services specifies what services to fetch assets for a provider. + Services goflags.StringSlice // Services specifies what services to fetch assets for a provider. ProviderConfig string // ProviderConfig is the location of the provider config file. DisableUpdateCheck bool // DisableUpdateCheck disable automatic update check } @@ -38,19 +39,19 @@ type Options struct { var ( defaultConfigLocation = filepath.Join(userHomeDir(), ".config/cloudlist/config.yaml") defaultProviderConfigLocation = filepath.Join(userHomeDir(), ".config/cloudlist/provider-config.yaml") - defaultProviders, defaultServies = []goflags.EnumVariable{}, []goflags.EnumVariable{} - allowedProviders, allowedServices = goflags.AllowdTypes{}, goflags.AllowdTypes{} + defaultProviders, defaultServies = []string{}, []string{} + allowedProviders, allowedServices = []string{}, []string{} ) func init() { - for i, provider := range inventory.GetProviders() { - allowedProviders[provider] = goflags.EnumVariable(i) - defaultProviders = append(defaultProviders, goflags.EnumVariable(i)) + for _, provider := range inventory.GetProviders() { + allowedProviders = append(allowedProviders, provider) + defaultProviders = append(defaultProviders, provider) } - for i, service := range inventory.GetServices() { - defaultServies = append(defaultServies, goflags.EnumVariable(i)) - allowedServices[service] = goflags.EnumVariable(i) + for _, service := range inventory.GetServices() { + defaultServies = append(defaultServies, service) + allowedServices = append(allowedServices, service) } } @@ -77,11 +78,11 @@ func ParseOptions() *Options { flagSet.StringVarP(&options.ProviderConfig, "provider-config", "pc", defaultProviderConfigLocation, "provider config file"), ) flagSet.CreateGroup("filter", "Filters", - flagSet.EnumSliceVarP(&options.Provider, "provider", "p", defaultProviders, "display results for given providers (comma-separated)", allowedProviders), + flagSet.StringSliceVarP(&options.Providers, "provider", "p", nil, "display results for given providers (comma-separated) (default "+strings.Join(defaultProviders, ",")+")", goflags.CommaSeparatedStringSliceOptions), flagSet.StringSliceVar(&options.Id, "id", nil, "display results for given ids (comma-separated)", goflags.NormalizedStringSliceOptions), flagSet.BoolVar(&options.Hosts, "host", false, "display only hostnames in results"), flagSet.BoolVar(&options.IPAddress, "ip", false, "display only ips in results"), - flagSet.EnumSliceVarP(&options.Services, "service", "s", defaultServies, "query and display results from given service (comma-separated))", allowedServices), + flagSet.StringSliceVarP(&options.Services, "service", "s", nil, "query and display results from given service (comma-separated)) (default "+strings.Join(defaultServies, ",")+")", goflags.CommaSeparatedStringSliceOptions), flagSet.BoolVarP(&options.ExcludePrivate, "exclude-private", "ep", false, "exclude private ips in cli output"), ) flagSet.CreateGroup("update", "Update", diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 63bee2a0..26383876 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -31,6 +31,20 @@ func New(options *Options) (*Runner, error) { if err != nil { return nil, err } + + // CLI overrides config + if len(options.Services) == 0 { + options.Services = append(options.Services, config.GetServiceNames()...) + } + + // assign default services if not provided + if len(options.Services) == 0 { + options.Services = append(options.Services, defaultServies...) + } + if len(options.Providers) == 0 { + options.Providers = append(options.Providers, defaultProviders...) + } + return &Runner{config: config, options: options}, nil } @@ -53,8 +67,8 @@ func (r *Runner) Enumerate() { item["services"] = strings.Join(services, ",") } // Validate and only pass the correct items to input - if len(r.options.Provider) != 0 || len(r.options.Id) != 0 { - if len(r.options.Provider) != 0 && !Contains(r.options.Provider, item["provider"]) { + if len(r.options.Providers) != 0 || len(r.options.Id) != 0 { + if len(r.options.Providers) != 0 && !Contains(r.options.Providers, item["provider"]) { continue } if len(r.options.Id) != 0 && !Contains(r.options.Id, item["id"]) { diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index ae1120a8..2c980602 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -162,6 +162,22 @@ func (e *ErrNoSuchKey) Error() string { // Options contains configuration options for a provider type Options []OptionBlock +// GetServiceNames returns the services from the options +func (o Options) GetServiceNames() []string { + services := make([]string, 0) + for _, option := range o { + if serviceNameList, ok := option["services"]; ok { + for _, serviceName := range strings.Split(serviceNameList, ",") { + trimmedServiceName := strings.TrimSpace(serviceName) + if trimmedServiceName != "" { + services = append(services, trimmedServiceName) + } + } + } + } + return services +} + // OptionBlock is a single option on which operation is possible type OptionBlock map[string]string @@ -176,7 +192,7 @@ func (ob *OptionBlock) UnmarshalYAML(unmarshal func(interface{}) error) error { // Convert raw map to OptionBlock and handle special cases for key, value := range rawMap { switch key { - case "account_ids", "urls": + case "account_ids", "urls", "services": if valueArr, ok := value.([]interface{}); ok { var strArr []string for _, v := range valueArr {