Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #18

Merged
merged 4 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,32 @@ After checking out the repo, run `bin/setup` to install dependencies. Then, run

To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).

### update submodule to try a new feature on Stability-AI/api-interfaces

- update submodule

```sh
git submodule update --init # if you haven't fetched the content of the submodule yet
cd api-interfaces

# checkout some branch/commit you need
git fetch
git reset --hard origin/some_branch

cd ..
```

- build

```sh
bundle exec rake protoc

git diff
# now you may be able to confirm that the diff is created in lib/generation_pb.rb
```

- modify the `lib/stability_sdk/client.rb` to try some new features

## Contributing

Bug reports and pull requests are welcome on GitHub at https://github.com/cou929/stability-sdk-ruby.
26 changes: 13 additions & 13 deletions exe/stability-client
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@ opt.banner = "Usage: stability-client [options] YOUR_PROMPT_TEXT"
opt.separator ""
opt.separator "Options:"
opt.on("--api_key=VAL", "api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable") {|v| options[:api_key] = v }
opt.on("-H", "--height=VAL", "height of image in pixel. default 512") {|v| options[:height] = v }
opt.on("-W", "--width=VAL", "width of image in pixel. default 512") {|v| options[:width] = v }
opt.on("-C", "--cfg_scale=VAL", "CFG scale factor. default 7.0") {|v| options[:cfg_scale] = v }
opt.on("-A", "--sampler=VAL", "ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms") {|v| options[:sampler] = v }
opt.on("-s", "--steps=VAL", "number of steps. default 50") {|v| options[:steps] = v }
opt.on("-S", "--seed=VAL", "random seed to use in integer") {|v| options[:seed] = v }
opt.on("-p", "--prefix=VAL", "output prefixes for artifacts. default `generation`") {|v| options[:prefix] = v }
opt.on("-H", "--height=VAL", Integer, "height of image in pixel. default 512") {|v| options[:height] = v }
opt.on("-W", "--width=VAL", Integer, "width of image in pixel. default 512") {|v| options[:width] = v }
opt.on("-C", "--cfg_scale=VAL", Float, "CFG scale factor. default 7.0") {|v| options[:cfg_scale] = v }
opt.on("-A", "--sampler=VAL", String, "ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms") {|v| options[:sampler] = v }
opt.on("-s", "--steps=VAL", Integer, "number of steps. default 50") {|v| options[:steps] = v }
opt.on("-S", "--seed=VAL", Integer, "random seed to use in integer") {|v| options[:seed] = v }
opt.on("-p", "--prefix=VAL", String, "output prefixes for artifacts. default `generation`") {|v| options[:prefix] = v }
opt.on("--no-store", "do not write out artifacts") {|v| options[:no_store] = v }
opt.on("-n", "--num_samples=VAL", "number of samples to generate") {|v| options[:num_samples] = v }
opt.on("-e", "--engine=VAL", "engine to use for inference. default `stable-diffusion-v1`") {|v| options[:engine_id] = v }
opt.on("-i", "--init_image=VAL", "path to init image") {|v| options[:init_image] = v }
opt.on("-m", "--mask_image=VAL", "path to mask image") {|v| options[:mask_image] = v }
opt.on("--start_schedule=VAL", "start schedule for init image (must be greater than 0, 1 is full strength text prompt, no trace of image). default 1.0") {|v| options[:start_schedule] = v }
opt.on("--end_schedule=VAL", "end schedule for init image. default 0.01") {|v| options[:end_schedule] = v }
opt.on("-n", "--num_samples=VAL", Integer, "number of samples to generate. default 1") {|v| options[:num_samples] = v }
opt.on("-e", "--engine=VAL", String, "engine to use for inference. default `stable-diffusion-v1-5`") {|v| options[:engine_id] = v }
opt.on("-i", "--init_image=VAL", String, "path to init image") {|v| options[:init_image] = v }
opt.on("-m", "--mask_image=VAL", String, "path to mask image") {|v| options[:mask_image] = v }
opt.on("--start_schedule=VAL", Float, "start schedule for init image (must be greater than 0, 1 is full strength text prompt, no trace of image). default 1.0") {|v| options[:start_schedule] = v }
opt.on("--end_schedule=VAL", Float, "end schedule for init image. default 0.01") {|v| options[:end_schedule] = v }
opt.on("--guidance_preset=VAL", String,"Guidance preset to use. See generation.GuidancePreset for supported values. default `GUIDANCE_PRESET_NONE`") {|v| options[:guidance_preset] = v }
opt.on("--guidance_cuts=VAL", Integer, "Number of cuts to use for guidance. default 0") {|v| options[:guidance_cuts] = v }
opt.on("--guidance_strength=VAL", Float, "Strength of the guidance. We recommend values in range [0.0,1.0]. A good default is 0.25. default nil") {|v| options[:guidance_strength] = v }
Expand Down
5 changes: 5 additions & 0 deletions lib/stability_sdk/cli.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ def self.save_answer(answer, options, logger)
logger.warn "not implemented for ArtifactType #{artifact.type}"
end

if artifact.finish_reason == :FILTER
logger.debug "the generated image is filtered"
end

next if filename == "" || contents == ""

File.open(filename, "wb") do |f|
f.write(contents)
logger.debug "wrote #{artifact.type} to #{filename}"
end
end
end
Expand Down
30 changes: 21 additions & 9 deletions lib/stability_sdk/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ def initialize(options={})
end

def generate(prompt, options, &block)
width = options.has_key?(:width) ? options[:width].to_i : DEFAULT_IMAGE_WIDTH
height = options.has_key?(:height) ? options[:height].to_i : DEFAULT_IMAGE_HEIGHT
width = options.has_key?(:width) ? options[:width] : DEFAULT_IMAGE_WIDTH
height = options.has_key?(:height) ? options[:height] : DEFAULT_IMAGE_HEIGHT

if width % 64 != 0 || height % 64 != 0
raise InvalidParameter, "width and height must be a multiple of 64"
end

samples = options.has_key?(:num_samples) ? options[:num_samples].to_i : DEFAULT_SAMPLE_SIZE
steps = options.has_key?(:steps) ? options[:steps].to_i : DEFAULT_STEPS
seed = options.has_key?(:seed) ? [options[:seed].to_i] : [rand(4294967295)]
samples = options.has_key?(:num_samples) ? options[:num_samples] : DEFAULT_SAMPLE_SIZE
steps = options.has_key?(:steps) ? options[:steps] : DEFAULT_STEPS
seed = options.has_key?(:seed) ? [options[:seed]] : [rand(4294967295)]
transform = Gooseai::TransformType.new(
diffusion: options.has_key?(:sampler) ? SAMPLER_ALGORITHMS[options[:sampler].to_sym] : DEFAULT_SAMPLER_ALGORITHM,
)
step_parameter = Gooseai::StepParameter.new(
scaled_step: 0,
sampler: Gooseai::SamplerParameters.new(
cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale].to_f : DEFAULT_CFG_SCALE,
cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale] : DEFAULT_CFG_SCALE,
),
)

Expand All @@ -77,11 +77,11 @@ def generate(prompt, options, &block)
prompt_param << init_image_to_prompt(options[:init_image])
step_parameter.scaled_step = 0
step_parameter.sampler = Gooseai::SamplerParameters.new(
cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale].to_f : DEFAULT_CFG_SCALE,
cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale] : DEFAULT_CFG_SCALE,
)
step_parameter.schedule = Gooseai::ScheduleParameters.new(
start: options.has_key?(:start_schedule) ? options[:start_schedule].to_f : DEFAULT_START_SCHEDULE,
end: options.has_key?(:end_schedule) ? options[:end_schedule].to_f : DEFAULT_END_SCHEDULE,
start: options.has_key?(:start_schedule) ? options[:start_schedule] : DEFAULT_START_SCHEDULE,
end: options.has_key?(:end_schedule) ? options[:end_schedule] : DEFAULT_END_SCHEDULE,
)
end
if options.has_key?(:mask_image)
Expand Down Expand Up @@ -150,8 +150,20 @@ def generate(prompt, options, &block)
image: image_param
)

@logger.debug "sending request."
start = Time.now
@stub.generate(req).each do |answer|
duration = Time.now - start
if answer.artifacts.size > 0
artifact_types = answer.artifacts.map { |a| a.type }
@logger.debug "got #{answer.answer_id} with #{artifact_types} in #{duration.round(2)}s"
else
@logger.debug "got keepalive #{answer.answer_id} in #{duration.round(2)}s"
end

block.call(answer)

start = Time.now
end
end

Expand Down
2 changes: 1 addition & 1 deletion lib/stability_sdk/version.rb
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module StabilitySDK
VERSION = "0.2.11"
VERSION = "0.2.12"
end