From 19a64c968749f163ec767d741d90826dd32983c4 Mon Sep 17 00:00:00 2001
From: Eugen Rochko <eugen@zeonfederated.com>
Date: Fri, 12 Mar 2021 07:00:05 +0100
Subject: [PATCH] Refactor raw SQL queries to use Arel

---
 Gemfile                                       |   1 +
 Gemfile.lock                                  |   3 +
 app/lib/account_search_query_builder.rb       | 188 ++++++++++++++++++
 app/models/account.rb                         |  80 +-------
 app/services/account_search_service.rb        |  20 +-
 spec/lib/account_search_query_builder_spec.rb | 101 ++++++++++
 spec/models/account_spec.rb                   | 119 -----------
 7 files changed, 301 insertions(+), 211 deletions(-)
 create mode 100644 app/lib/account_search_query_builder.rb
 create mode 100644 spec/lib/account_search_query_builder_spec.rb

diff --git a/Gemfile b/Gemfile
index a4a2cc91c3f..4e8e6eb9dfa 100644
--- a/Gemfile
+++ b/Gemfile
@@ -26,6 +26,7 @@ gem 'streamio-ffmpeg', '~> 3.0'
 gem 'blurhash', '~> 0.1'
 
 gem 'active_model_serializers', '~> 0.10'
+gem 'activerecord-cte', '~> 0.1'
 gem 'addressable', '~> 2.7'
 gem 'bootsnap', '~> 1.6.0', require: false
 gem 'browser'
diff --git a/Gemfile.lock b/Gemfile.lock
index b59cfb1f318..8ddf667731f 100644
--- a/Gemfile.lock
+++ b/Gemfile.lock
@@ -54,6 +54,8 @@ GEM
       activemodel (= 5.2.4.5)
       activesupport (= 5.2.4.5)
       arel (>= 9.0)
+    activerecord-cte (0.1.1)
+      activerecord
     activestorage (5.2.4.5)
       actionpack (= 5.2.4.5)
       activerecord (= 5.2.4.5)
@@ -695,6 +697,7 @@ PLATFORMS
 DEPENDENCIES
   active_model_serializers (~> 0.10)
   active_record_query_trace (~> 1.8)
+  activerecord-cte (~> 0.1)
   addressable (~> 2.7)
   annotate (~> 3.1)
   aws-sdk-s3 (~> 1.89)
diff --git a/app/lib/account_search_query_builder.rb b/app/lib/account_search_query_builder.rb
new file mode 100644
index 00000000000..c757da91481
--- /dev/null
+++ b/app/lib/account_search_query_builder.rb
@@ -0,0 +1,188 @@
+# frozen_string_literal: true
+
+class AccountSearchQueryBuilder
+  DISALLOWED_TSQUERY_CHARACTERS = /['?\\:‘’]/.freeze
+
+  LANGUAGE     = Arel::Nodes.build_quoted('simple').freeze
+  EMPTY_STRING = Arel::Nodes.build_quoted('').freeze
+  WEIGHT_A     = Arel::Nodes.build_quoted('A').freeze
+  WEIGHT_B     = Arel::Nodes.build_quoted('B').freeze
+  WEIGHT_C     = Arel::Nodes.build_quoted('C').freeze
+
+  FIELDS = {
+    display_name: { weight: WEIGHT_A }.freeze,
+    username:     { weight: WEIGHT_B }.freeze,
+    domain:       { weight: WEIGHT_C, nullable: true }.freeze,
+  }.freeze
+
+  RANK_NORMALIZATION = 32
+
+  DEFAULT_OPTIONS = {
+    limit: 10,
+    only_following: false,
+  }.freeze
+
+  # @param [String] terms
+  # @param [Hash] options
+  # @option [Account] :account
+  # @option [Boolean] :only_following
+  # @option [Integer] :limit
+  # @option [Integer] :offset
+  def initialize(terms, options = {})
+    @terms   = terms
+    @options = DEFAULT_OPTIONS.merge(options)
+  end
+
+  # @return [ActiveRecord::Relation]
+  def build
+    search_scope.tap do |scope|
+      scope.merge!(personalization_scope) if with_account?
+
+      if with_account? && only_following?
+        scope.merge!(only_following_scope)
+        scope.with!(first_degree_definition) # `merge!` does not handle `with`
+      end
+    end
+  end
+
+  # @return [Array<Account>]
+  def results
+    build.to_a
+  end
+
+  private
+
+  def search_scope
+    Account.select(projections)
+           .where(match_condition)
+           .searchable
+           .includes(:account_stat)
+           .order(rank: :desc)
+           .limit(limit)
+           .offset(offset)
+  end
+
+  def personalization_scope
+    join_condition = accounts_table.join(follows_table, Arel::Nodes::OuterJoin)
+                                   .on(accounts_table.grouping(accounts_table[:id].eq(follows_table[:account_id]).and(follows_table[:target_account_id].eq(account.id))).or(accounts_table.grouping(accounts_table[:id].eq(follows_table[:target_account_id]).and(follows_table[:account_id].eq(account.id)))))
+                                   .join_sources
+
+    Account.joins(join_condition)
+           .group(accounts_table[:id])
+  end
+
+  def only_following_scope
+    Account.where(accounts_table[:id].in(first_degree_table.project('*')))
+  end
+
+  def first_degree_definition
+    target_account_ids_query = follows_table.project(follows_table[:target_account_id]).where(follows_table[:account_id].eq(account.id))
+    account_id_query         = Arel::SelectManager.new.project(account.id)
+
+    Arel::Nodes::As.new(
+      first_degree_table,
+      target_account_ids_query.union(:all, account_id_query)
+    )
+  end
+
+  def projections
+    rank_column = begin
+      if with_account?
+        weighted_tsrank_template.as('rank')
+      else
+        tsrank_template.as('rank')
+      end
+    end
+
+    [all_columns, rank_column]
+  end
+
+  def all_columns
+    accounts_table[Arel.star]
+  end
+
+  def match_condition
+    Arel::Nodes::InfixOperation.new('@@', tsvector_template, tsquery_template)
+  end
+
+  def tsrank_template
+    @tsrank_template ||= Arel::Nodes::NamedFunction.new('ts_rank_cd', [tsvector_template, tsquery_template, RANK_NORMALIZATION])
+  end
+
+  def weighted_tsrank_template
+    @weighted_tsrank_template ||= Arel::Nodes::Multiplication.new(weight, tsrank_template)
+  end
+
+  def weight
+    Arel::Nodes::Addition.new(follows_table[:id].count, 1)
+  end
+
+  def tsvector_template
+    return @tsvector_template if defined?(@tsvector_template)
+
+    vectors = FIELDS.keys.map do |column|
+      options = FIELDS[column]
+
+      vector = accounts_table[column]
+      vector = Arel::Nodes::NamedFunction.new('coalesce', [vector, EMPTY_STRING]) if options[:nullable]
+      vector = Arel::Nodes::NamedFunction.new('to_tsvector', [LANGUAGE, vector])
+
+      Arel::Nodes::NamedFunction.new('setweight', [vector, options[:weight]])
+    end
+
+    @tsvector_template = Arel::Nodes::Grouping.new(vectors.reduce { |memo, vector| Arel::Nodes::Concat.new(memo, vector) })
+  end
+
+  def query_vector
+    @query_vector ||= Arel::Nodes::NamedFunction.new('to_tsquery', [LANGUAGE, tsquery_template])
+  end
+
+  def sanitized_terms
+    @sanitized_terms ||= @terms.gsub(DISALLOWED_TSQUERY_CHARACTERS, ' ')
+  end
+
+  def tsquery_template
+    return @tsquery_template if defined?(@tsquery_template)
+
+    terms = [
+      Arel::Nodes.build_quoted("' "),
+      Arel::Nodes.build_quoted(sanitized_terms),
+      Arel::Nodes.build_quoted(" '"),
+      Arel::Nodes.build_quoted(':*'),
+    ]
+
+    @tsquery_template = Arel::Nodes::NamedFunction.new('to_tsquery', [LANGUAGE, terms.reduce { |memo, term| Arel::Nodes::Concat.new(memo, term) }])
+  end
+
+  def account
+    @options[:account]
+  end
+
+  def with_account?
+    account.present?
+  end
+
+  def limit
+    @options[:limit]
+  end
+
+  def offset
+    @options[:offset]
+  end
+
+  def only_following?
+    @options[:only_following]
+  end
+
+  def accounts_table
+    Account.arel_table
+  end
+
+  def follows_table
+    Follow.arel_table
+  end
+
+  def first_degree_table
+    Arel::Table.new(:first_degree)
+  end
+end
diff --git a/app/models/account.rb b/app/models/account.rb
index d85fd1f6e9f..a31c2428dd1 100644
--- a/app/models/account.rb
+++ b/app/models/account.rb
@@ -432,75 +432,6 @@ class Account < ApplicationRecord
       DeliveryFailureTracker.without_unavailable(urls)
     end
 
-    def search_for(terms, limit = 10, offset = 0)
-      textsearch, query = generate_query_for_search(terms)
-
-      sql = <<-SQL.squish
-        SELECT
-          accounts.*,
-          ts_rank_cd(#{textsearch}, #{query}, 32) AS rank
-        FROM accounts
-        WHERE #{query} @@ #{textsearch}
-          AND accounts.suspended_at IS NULL
-          AND accounts.moved_to_account_id IS NULL
-        ORDER BY rank DESC
-        LIMIT ? OFFSET ?
-      SQL
-
-      records = find_by_sql([sql, limit, offset])
-      ActiveRecord::Associations::Preloader.new.preload(records, :account_stat)
-      records
-    end
-
-    def advanced_search_for(terms, account, limit = 10, following = false, offset = 0)
-      textsearch, query = generate_query_for_search(terms)
-
-      if following
-        sql = <<-SQL.squish
-          WITH first_degree AS (
-            SELECT target_account_id
-            FROM follows
-            WHERE account_id = ?
-            UNION ALL
-            SELECT ?
-          )
-          SELECT
-            accounts.*,
-            (count(f.id) + 1) * ts_rank_cd(#{textsearch}, #{query}, 32) AS rank
-          FROM accounts
-          LEFT OUTER JOIN follows AS f ON (accounts.id = f.account_id AND f.target_account_id = ?)
-          WHERE accounts.id IN (SELECT * FROM first_degree)
-            AND #{query} @@ #{textsearch}
-            AND accounts.suspended_at IS NULL
-            AND accounts.moved_to_account_id IS NULL
-          GROUP BY accounts.id
-          ORDER BY rank DESC
-          LIMIT ? OFFSET ?
-        SQL
-
-        records = find_by_sql([sql, account.id, account.id, account.id, limit, offset])
-      else
-        sql = <<-SQL.squish
-          SELECT
-            accounts.*,
-            (count(f.id) + 1) * ts_rank_cd(#{textsearch}, #{query}, 32) AS rank
-          FROM accounts
-          LEFT OUTER JOIN follows AS f ON (accounts.id = f.account_id AND f.target_account_id = ?) OR (accounts.id = f.target_account_id AND f.account_id = ?)
-          WHERE #{query} @@ #{textsearch}
-            AND accounts.suspended_at IS NULL
-            AND accounts.moved_to_account_id IS NULL
-          GROUP BY accounts.id
-          ORDER BY rank DESC
-          LIMIT ? OFFSET ?
-        SQL
-
-        records = find_by_sql([sql, account.id, account.id, limit, offset])
-      end
-
-      ActiveRecord::Associations::Preloader.new.preload(records, :account_stat)
-      records
-    end
-
     def from_text(text)
       return [] if text.blank?
 
@@ -512,19 +443,10 @@ class Account < ApplicationRecord
             TagManager.instance.normalize_domain(domain)
           end
         end
+
         EntityCache.instance.mention(username, domain)
       end
     end
-
-    private
-
-    def generate_query_for_search(terms)
-      terms      = Arel.sql(connection.quote(terms.gsub(/['?\\:]/, ' ')))
-      textsearch = "(setweight(to_tsvector('simple', accounts.display_name), 'A') || setweight(to_tsvector('simple', accounts.username), 'B') || setweight(to_tsvector('simple', coalesce(accounts.domain, '')), 'C'))"
-      query      = "to_tsquery('simple', ''' ' || #{terms} || ' ''' || ':*')"
-
-      [textsearch, query]
-    end
   end
 
   def emojis
diff --git a/app/services/account_search_service.rb b/app/services/account_search_service.rb
index 6fe4b6593af..3a4a937a399 100644
--- a/app/services/account_search_service.rb
+++ b/app/services/account_search_service.rb
@@ -53,19 +53,13 @@ class AccountSearchService < BaseService
   end
 
   def from_database
-    if account
-      advanced_search_results
-    else
-      simple_search_results
-    end
-  end
-
-  def advanced_search_results
-    Account.advanced_search_for(terms_for_query, account, limit_for_non_exact_results, options[:following], offset)
-  end
-
-  def simple_search_results
-    Account.search_for(terms_for_query, limit_for_non_exact_results, offset)
+    AccountSearchQueryBuilder.new(
+      terms_for_query,
+      account: account,
+      only_following: options[:following],
+      limit: limit_for_non_exact_results,
+      offset: offset
+    ).results
   end
 
   def from_elasticsearch
diff --git a/spec/lib/account_search_query_builder_spec.rb b/spec/lib/account_search_query_builder_spec.rb
new file mode 100644
index 00000000000..dc59e2784f3
--- /dev/null
+++ b/spec/lib/account_search_query_builder_spec.rb
@@ -0,0 +1,101 @@
+# frozen_string_literal: true
+
+require 'rails_helper'
+
+describe AccountSearchQueryBuilder do
+  before do
+    Fabricate(
+      :account,
+      display_name: "Missing",
+      username: "missing",
+      domain: "missing.com"
+    )
+  end
+
+  context 'without account' do
+    it 'accepts ?, \, : and space as delimiter' do
+      needle = Fabricate(
+        :account,
+        display_name: 'A & l & i & c & e',
+        username: 'username',
+        domain: 'example.com'
+      )
+
+      results = described_class.new('A?l\i:c e').build.to_a
+      expect(results).to eq [needle]
+    end
+
+    it 'finds accounts with matching display_name' do
+      needle = Fabricate(
+        :account,
+        display_name: "Display Name",
+        username: "username",
+        domain: "example.com"
+      )
+
+      results = described_class.new("display").build.to_a
+      expect(results).to eq [needle]
+    end
+
+    it 'finds accounts with matching username' do
+      needle = Fabricate(
+        :account,
+        display_name: "Display Name",
+        username: "username",
+        domain: "example.com"
+      )
+
+      results = described_class.new("username").build.to_a
+      expect(results).to eq [needle]
+    end
+
+    it 'finds accounts with matching domain' do
+      needle = Fabricate(
+        :account,
+        display_name: "Display Name",
+        username: "username",
+        domain: "example.com"
+      )
+
+      results = described_class.new("example").build.to_a
+      expect(results).to eq [needle]
+    end
+
+    it 'limits by 10 by default' do
+      11.times.each { Fabricate(:account, display_name: "Display Name") }
+      results = described_class.new("display").build.to_a
+      expect(results.size).to eq 10
+    end
+
+    it 'accepts arbitrary limits' do
+      2.times.each { Fabricate(:account, display_name: "Display Name") }
+      results = described_class.new("display", limit: 1).build.to_a
+      expect(results.size).to eq 1
+    end
+
+    it 'ranks multiple matches higher' do
+      needles = [
+        { username: "username", display_name: "username" },
+        { display_name: "Display Name", username: "username", domain: "example.com" },
+      ].map(&method(:Fabricate).curry(2).call(:account))
+
+      results = described_class.new("username").build.to_a
+      expect(results).to eq needles
+    end
+  end
+
+  context 'with account' do
+    let(:account) { Fabricate(:account) }
+
+    it 'ranks followed accounts higher' do
+      needle = Fabricate(:account, username: "Matching")
+      followed_needle = Fabricate(:account, username: "Matcher")
+      account.follow!(followed_needle)
+
+      results = described_class.new("match", account: account).build.to_a
+
+      expect(results).to eq [followed_needle, needle]
+      expect(results.first.rank).to be > results.last.rank
+    end
+  end
+end
diff --git a/spec/models/account_spec.rb b/spec/models/account_spec.rb
index 03d6f5fb0f5..a22640f1f44 100644
--- a/spec/models/account_spec.rb
+++ b/spec/models/account_spec.rb
@@ -309,125 +309,6 @@ RSpec.describe Account, type: :model do
     end
   end
 
-  describe '.search_for' do
-    before do
-      _missing = Fabricate(
-        :account,
-        display_name: "Missing",
-        username: "missing",
-        domain: "missing.com"
-      )
-    end
-
-    it 'accepts ?, \, : and space as delimiter' do
-      match = Fabricate(
-        :account,
-        display_name: 'A & l & i & c & e',
-        username: 'username',
-        domain: 'example.com'
-      )
-
-      results = Account.search_for('A?l\i:c e')
-      expect(results).to eq [match]
-    end
-
-    it 'finds accounts with matching display_name' do
-      match = Fabricate(
-        :account,
-        display_name: "Display Name",
-        username: "username",
-        domain: "example.com"
-      )
-
-      results = Account.search_for("display")
-      expect(results).to eq [match]
-    end
-
-    it 'finds accounts with matching username' do
-      match = Fabricate(
-        :account,
-        display_name: "Display Name",
-        username: "username",
-        domain: "example.com"
-      )
-
-      results = Account.search_for("username")
-      expect(results).to eq [match]
-    end
-
-    it 'finds accounts with matching domain' do
-      match = Fabricate(
-        :account,
-        display_name: "Display Name",
-        username: "username",
-        domain: "example.com"
-      )
-
-      results = Account.search_for("example")
-      expect(results).to eq [match]
-    end
-
-    it 'limits by 10 by default' do
-      11.times.each { Fabricate(:account, display_name: "Display Name") }
-      results = Account.search_for("display")
-      expect(results.size).to eq 10
-    end
-
-    it 'accepts arbitrary limits' do
-      2.times.each { Fabricate(:account, display_name: "Display Name") }
-      results = Account.search_for("display", 1)
-      expect(results.size).to eq 1
-    end
-
-    it 'ranks multiple matches higher' do
-      matches = [
-        { username: "username", display_name: "username" },
-        { display_name: "Display Name", username: "username", domain: "example.com" },
-      ].map(&method(:Fabricate).curry(2).call(:account))
-
-      results = Account.search_for("username")
-      expect(results).to eq matches
-    end
-  end
-
-  describe '.advanced_search_for' do
-    it 'accepts ?, \, : and space as delimiter' do
-      account = Fabricate(:account)
-      match = Fabricate(
-        :account,
-        display_name: 'A & l & i & c & e',
-        username: 'username',
-        domain: 'example.com'
-      )
-
-      results = Account.advanced_search_for('A?l\i:c e', account)
-      expect(results).to eq [match]
-    end
-
-    it 'limits by 10 by default' do
-      11.times { Fabricate(:account, display_name: "Display Name") }
-      results = Account.search_for("display")
-      expect(results.size).to eq 10
-    end
-
-    it 'accepts arbitrary limits' do
-      2.times { Fabricate(:account, display_name: "Display Name") }
-      results = Account.search_for("display", 1)
-      expect(results.size).to eq 1
-    end
-
-    it 'ranks followed accounts higher' do
-      account = Fabricate(:account)
-      match = Fabricate(:account, username: "Matching")
-      followed_match = Fabricate(:account, username: "Matcher")
-      Fabricate(:follow, account: account, target_account: followed_match)
-
-      results = Account.advanced_search_for("match", account)
-      expect(results).to eq [followed_match, match]
-      expect(results.first.rank).to be > results.last.rank
-    end
-  end
-
   describe '#statuses_count' do
     subject { Fabricate(:account) }
 
-- 
GitLab