Skip to content

Commit

Permalink
Demonstration of using pgvector to improve searching
Browse files Browse the repository at this point in the history
By using embeddings generated from a LLM we can hope to gain better results
when search for existing petitions and also when people are search generally.
  • Loading branch information
pixeltrix committed Feb 10, 2025
1 parent b9d9c8f commit 8c25186
Show file tree
Hide file tree
Showing 34 changed files with 445 additions and 38 deletions.
1 change: 1 addition & 0 deletions .env.test
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ EPETITIONS_PROTOCOL=https
MODERATE_HOST=moderate.petition.parliament.uk
SITE_TITLE="Petition parliament (Test)"
INLINE_UPDATES=true
EMBEDDING_BACKEND=AmazonBedrock
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

services:
postgres:
image: postgres:16
image: pgvector/pgvector:pg16
ports: ["5432:5432"]
env:
POSTGRES_PASSWORD: postgres
Expand Down
1 change: 1 addition & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ gem 'omniauth'
gem 'omniauth-rails_csrf_protection'
gem 'omniauth-saml'

gem 'aws-sdk-bedrockruntime'
gem 'aws-sdk-codedeploy'
gem 'aws-sdk-cloudwatchlogs'
gem 'aws-sdk-s3'
Expand Down
4 changes: 4 additions & 0 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ GEM
rack
aws-eventstream (1.3.0)
aws-partitions (1.1047.0)
aws-sdk-bedrockruntime (1.37.0)
aws-sdk-core (~> 3, >= 3.216.0)
aws-sigv4 (~> 1.5)
aws-sdk-cloudwatchlogs (1.108.0)
aws-sdk-core (~> 3, >= 3.216.0)
aws-sigv4 (~> 1.5)
Expand Down Expand Up @@ -484,6 +487,7 @@ PLATFORMS

DEPENDENCIES
appsignal
aws-sdk-bedrockruntime
aws-sdk-cloudwatchlogs
aws-sdk-codedeploy
aws-sdk-s3
Expand Down
7 changes: 7 additions & 0 deletions app/controllers/admin/archived/petition_details_controller.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
class Admin::Archived::PetitionDetailsController < Admin::AdminController
before_action :fetch_petition
after_action :enqueue_job_to_update_embedding, only: :update

def show
end
Expand Down Expand Up @@ -32,4 +33,10 @@ def petition_attributes
def petition_params
params.require(:archived_petition).permit(*petition_attributes)
end

def enqueue_job_to_update_embedding
if @petition.saved_changes?
UpdatePetitionEmbeddingJob.perform_later(@petition)
end
end
end
7 changes: 7 additions & 0 deletions app/controllers/admin/petition_details_controller.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
class Admin::PetitionDetailsController < Admin::AdminController
before_action :fetch_petition
after_action :enqueue_job_to_update_embedding, only: :update

def show
end
Expand All @@ -25,4 +26,10 @@ def petition_params
:creator_attributes => [:name, :email]
)
end

def enqueue_job_to_update_embedding
if @petition.saved_changes?
UpdatePetitionEmbeddingJob.perform_later(@petition)
end
end
end
2 changes: 1 addition & 1 deletion app/controllers/admin/sites_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def site_params
:signature_count_interval, :update_signature_counts,
:disable_trending_petitions, :threshold_for_moderation_delay,
:disable_invalid_signature_count_check, :disable_daily_update_statistics_job,
:disable_plus_address_check, :disable_feedback_sending,
:disable_plus_address_check, :disable_feedback_sending, :semantic_searching,
:show_feedback_page_message, :feedback_page_message, :feedback_page_message_colour,
:show_home_page_message, :home_page_message, :home_page_message_colour,
:show_petition_page_message, :petition_page_message, :petition_page_message_colour,
Expand Down
6 changes: 5 additions & 1 deletion app/helpers/petition_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def signatures_threshold_percentage(petition)
end

def petition_list_header
@_petition_list_header ||= t(:"#{@petitions.scope}_html", scope: :"petitions.list_headers", query: @petitions.url_safe_query, default: "")
@_petition_list_header ||= if @petitions.semantic_search?
t(:"semantic_html", scope: :"petitions.list_headers", query: @petitions.url_safe_query, default: "")
else
t(:"#{@petitions.scope}_html", scope: :"petitions.list_headers", query: @petitions.url_safe_query, default: "")
end
end

def petition_list_header?
Expand Down
1 change: 1 addition & 0 deletions app/jobs/archive_petition_job.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def perform(petition)
p.action = petition.action
p.background = petition.background
p.additional_details = petition.additional_details
p.embedding = petition.embedding
p.committee_note = petition.committee_note
p.departments = petition.departments
p.tags = petition.tags
Expand Down
7 changes: 7 additions & 0 deletions app/jobs/update_petition_embedding_job.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class UpdatePetitionEmbeddingJob < ApplicationJob
retry_on Embedding::GenerationError, wait: :polynomially_longer, attempts: 10

def perform(petition)
petition.update_columns(embedding: Embedding.generate(petition.content))
end
end
93 changes: 93 additions & 0 deletions app/lib/embedding.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
require 'faraday'
require 'aws-sdk-bedrockruntime'

module Embedding
class GenerationError < RuntimeError; end

module Backends
class Ollama
with_options instance_writer: false do
class_attribute :url, default: ENV.fetch('OLLAMA_URL', 'http://127.0.0.1:11434')
class_attribute :path, default: '/api/embed'
class_attribute :model, default: ENV.fetch('OLLAMA_MODEL', 'mxbai-embed-large')
class_attribute :headers, default: { content_type: 'application/json' }
class_attribute :open_timeout, default: 5
class_attribute :timeout, default: 5
end

def generate(input)
body = { input: input, model: model }.to_json

response = faraday.post(path, body, headers) do |request|
request.options[:timeout] = timeout
request.options[:open_timeout] = open_timeout
end

response.body.fetch('embeddings').first
rescue StandardError => e
raise Embedding::GenerationError, "Unable to generate an embedding using Ollama"
end

private

def faraday
@faraday ||= Faraday.new(url) do |f|
f.response :follow_redirects
f.response :json
f.response :raise_error
f.adapter :net_http_persistent
end
end
end

class AmazonBedrock
with_options instance_writer: false do
class_attribute :model_id, default: ENV.fetch('BEDROCK_MODEL_ID', 'amazon.titan-embed-text-v2:0')
end

def generate(input)
params = {
body: {
inputText: input,
dimensions: 1024,
embeddingTypes: ['float']
}.to_json,
content_type: 'application/json',
accept: 'application/json',
model_id: model_id
}

response = bedrock.invoke_model(**params)
json = JSON.parse(response.body.read)

json['embedding']
rescue StandardError => e
raise Embedding::GenerationError, "Unable to generate an embedding using Amazon Bedrock"
end

private

def bedrock
@bedrock ||= Aws::BedrockRuntime::Client.new
end
end
end

class << self
def generate(input)
client.generate(input)
end

def backend
Backends.const_get(ENV.fetch('EMBEDDING_BACKEND', 'Ollama'))
end

def client
Thread.current[:__embedding__] ||= backend.new
end

def reload
Thread.current[:__embedding__] = nil
end
end
end
6 changes: 5 additions & 1 deletion app/models/archived/petition.rb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Petition < ActiveRecord::Base
before_save :update_debate_state, if: :scheduled_debate_date_changed?

extend Searchable(:action, :background, :additional_details)
include Browseable, Taggable, Departments, Topics, Anonymization
include Browseable, NearestNeighbours, Taggable, Departments, Topics, Anonymization

facet :all, -> { by_most_signatures }
facet :awaiting_response, -> { awaiting_response.by_waiting_for_response_longest }
Expand Down Expand Up @@ -243,6 +243,10 @@ def scheduled_for_debate
end
end

def content
"#{action} - #{background}"
end

def notes?
note && note.details.present?
end
Expand Down
56 changes: 39 additions & 17 deletions app/models/concerns/browseable.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@ module Browseable
VALID_PAGE_SIZE = /\A(?:[1-9]|[1-5][0-9])\z/

included do
class_attribute :facet_definitions, instance_writer: false
self.facet_definitions = {}

class_attribute :filter_definitions, instance_writer: false
self.filter_definitions = {}

class_attribute :default_page_size, instance_writer: false
self.default_page_size = 50

class_attribute :max_page_size, instance_writer: false
self.max_page_size = 50
with_options instance_writer: false do
class_attribute :facet_definitions, default: {}
class_attribute :filter_definitions, default: {}
class_attribute :default_page_size, default: 50
class_attribute :max_page_size, default: 50
end
end

class Facets
Expand Down Expand Up @@ -160,6 +155,10 @@ def query
@query ||= params[:q].to_s
end

def embedding
@embedding ||= generate_embedding
end

def url_safe_query
Rack::Utils.escape(query)
end
Expand Down Expand Up @@ -200,6 +199,10 @@ def search?
query.present?
end

def semantic_search?
embedding.present?
end

def in_batches(&block)
execute_search.find_each do |obj|
block.call obj
Expand All @@ -223,6 +226,18 @@ def model

private

def embedding_column?
model.column_names.include?("embedding")
end

def generate_embedding
return unless Site.semantic_searching?
return unless query.present?
return unless embedding_column?

Embedding.generate(query)
end

def new_params(page)
{}.tap do |new_params|
new_params[:q] = query if query.present?
Expand All @@ -242,16 +257,23 @@ def execute_search_with_pagination
end

def execute_search
if search?
relation = klass.basic_search(query)
relation = klass

if semantic_search?
relation = filters.apply(relation)
relation = relation.instance_exec(&klass.facet_definitions[scope])
relation = relation.except(:order)
relation.nearest_neighbours(embedding)
elsif search?
relation = relation.basic_search(query)
relation = relation.except(:select).select(star)
relation = relation.except(:order)
relation = filters.apply(relation)
relation.instance_exec(&klass.facet_definitions[scope])
else
relation = klass
relation = filters.apply(relation)
relation.instance_exec(&klass.facet_definitions[scope])
end

relation = filters.apply(relation)
relation.instance_exec(&klass.facet_definitions[scope])
end

def star
Expand Down
13 changes: 13 additions & 0 deletions app/models/concerns/nearest_neighbours.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module NearestNeighbours
extend ActiveSupport::Concern

module ClassMethods
def nearest_neighbours(embedding, column: :embedding)
reorder(arel_table[column].nearest(embedding))
end
end

def nearest_neighbours(column: :embedding)
self.class.excluding(self).nearest_neighbours(read_attribute(column), column: column)
end
end
6 changes: 5 additions & 1 deletion app/models/petition.rb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Petition < ActiveRecord::Base
after_create :update_last_petition_created_at

extend Searchable(:action, :background, :additional_details)
include Browseable, Taggable, Departments, Topics, Anonymization
include Browseable, NearestNeighbours, Taggable, Departments, Topics, Anonymization

facet :all, -> { by_most_popular }
facet :open, -> { open_state.by_most_popular }
Expand Down Expand Up @@ -512,6 +512,10 @@ def scheduled_debate_state
end
end

def content
"#{action} - #{background}"
end

def statistics
super || create_statistics!
rescue ActiveRecord::RecordNotUnique => e
Expand Down
6 changes: 6 additions & 0 deletions app/models/petition_creator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def save

unless rate_limit.exceeded?(@petition.creator)
@petition.save!

send_email_to_gather_sponsors(@petition)
enqueue_job_to_update_embedding(@petition)
end

return true
Expand Down Expand Up @@ -255,6 +257,10 @@ def send_email_to_gather_sponsors(petition)
GatherSponsorsForPetitionEmailJob.perform_later(petition)
end

def enqueue_job_to_update_embedding(petition)
UpdatePetitionEmbeddingJob.perform_later(petition)
end

private

def rate_limit
Expand Down
Loading

0 comments on commit 8c25186

Please sign in to comment.