-
Notifications
You must be signed in to change notification settings - Fork 964
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(cli): add hashing group with journal-entry
Hopefully there's enough operator toggles to tune this run, as there's a large amount of records in the database. This should hopefully allow for balancing load vs locking. Signed-off-by: Mike Fiedler <miketheman@gmail.com>
- Loading branch information
1 parent
a091728
commit ae36a83
Showing
2 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import hashlib | ||
|
||
import pretend | ||
|
||
from warehouse import db | ||
from warehouse.cli import hashing | ||
|
||
from ...common.db.packaging import JournalEntry, JournalEntryFactory | ||
|
||
|
||
def remote_addr_salty_hash(remote_addr, salt): | ||
return hashlib.sha256(f"{remote_addr}{salt}".encode()).hexdigest() | ||
|
||
|
||
class TestHashingJournalEntry: | ||
def test_no_records_to_hash(self, cli, db_request, monkeypatch): | ||
engine = pretend.stub() | ||
config = pretend.stub(registry={"sqlalchemy.engine": engine}) | ||
session_cls = pretend.call_recorder(lambda bind: db_request.db) | ||
monkeypatch.setattr(db, "Session", session_cls) | ||
|
||
assert db_request.db.query(JournalEntry).count() == 0 | ||
|
||
args = ["--salt", "test"] | ||
|
||
result = cli.invoke(hashing.journal_entry, args, obj=config) | ||
|
||
assert result.exit_code == 0 | ||
assert result.output.strip() == "No rows to hash. Done!" | ||
|
||
def tests_hashes_records(self, cli, db_request, remote_addr, monkeypatch): | ||
engine = pretend.stub() | ||
config = pretend.stub(registry={"sqlalchemy.engine": engine}) | ||
session_cls = pretend.call_recorder(lambda bind: db_request.db) | ||
monkeypatch.setattr(db, "Session", session_cls) | ||
|
||
# create some JournalEntry records with unhashed ip addresses | ||
JournalEntryFactory.create_batch(3, submitted_from=remote_addr) | ||
assert db_request.db.query(JournalEntry).count() == 3 | ||
|
||
salt = "NaCl" | ||
salted_hash = remote_addr_salty_hash(remote_addr, salt) | ||
|
||
args = [ | ||
"--salt", | ||
salt, | ||
"--batch-size", | ||
"2", | ||
] | ||
|
||
result = cli.invoke(hashing.journal_entry, args, obj=config) | ||
|
||
assert result.exit_code == 0 | ||
assert result.output.strip() == "Hashing 2 rows...\nHashed 2 rows" | ||
# check that two of the ip addresses have been hashed | ||
assert ( | ||
db_request.db.query(JournalEntry) | ||
.filter_by(submitted_from=remote_addr) | ||
.one() | ||
) | ||
assert ( | ||
db_request.db.query(JournalEntry) | ||
.filter_by(submitted_from=salted_hash) | ||
.count() | ||
== 2 | ||
) | ||
|
||
def test_continue_until_done(self, cli, db_request, remote_addr, monkeypatch): | ||
engine = pretend.stub() | ||
config = pretend.stub(registry={"sqlalchemy.engine": engine}) | ||
session_cls = pretend.call_recorder(lambda bind: db_request.db) | ||
monkeypatch.setattr(db, "Session", session_cls) | ||
|
||
# create some JournalEntry records with unhashed ip addresses | ||
JournalEntryFactory.create_batch(3, submitted_from=remote_addr) | ||
|
||
salt = "NaCl" | ||
salted_hash = remote_addr_salty_hash(remote_addr, salt) | ||
|
||
args = [ | ||
"--salt", | ||
salt, | ||
"--batch-size", | ||
"1", | ||
"--sleep-time", | ||
"0", | ||
"--continue-until-done", | ||
] | ||
|
||
result = cli.invoke(hashing.journal_entry, args, obj=config) | ||
|
||
assert result.exit_code == 0 | ||
# check that all the ip addresses have been hashed | ||
assert ( | ||
db_request.db.query(JournalEntry) | ||
.filter_by(submitted_from=salted_hash) | ||
.count() | ||
== 3 | ||
) | ||
assert ( | ||
db_request.db.query(JournalEntry) | ||
.filter_by(submitted_from=remote_addr) | ||
.count() | ||
== 0 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import hashlib | ||
import time | ||
|
||
import click | ||
|
||
from warehouse.cli import warehouse | ||
|
||
|
||
@warehouse.group() | ||
def hashing(): | ||
""" | ||
Run Hashing operations for Warehouse data | ||
""" | ||
|
||
|
||
@hashing.command() | ||
@click.option( | ||
"-s", | ||
"--salt", | ||
prompt=True, | ||
hide_input=True, | ||
help="Pass value instead of prompting for salt", | ||
) | ||
@click.option( | ||
"-b", | ||
"--batch-size", | ||
default=10_000, | ||
show_default=True, | ||
help="Number of rows to hash at a time", | ||
) | ||
@click.option( | ||
"-st", | ||
"--sleep-time", | ||
default=1, | ||
show_default=True, | ||
help="Number of seconds to sleep between batches", | ||
) | ||
@click.option( | ||
"--continue-until-done", | ||
is_flag=True, | ||
default=False, | ||
help="Continue hashing until all rows are hashed", | ||
) | ||
@click.pass_obj | ||
def journal_entry( | ||
config, | ||
salt: str, | ||
batch_size: int, | ||
sleep_time: int, | ||
continue_until_done: bool, | ||
): | ||
""" | ||
Hash `journals.submitted_from` column with salt | ||
""" | ||
# Imported here because we don't want to trigger an import from anything | ||
# but warehouse.cli at the module scope. | ||
from warehouse.db import Session | ||
|
||
# This lives in the outer function so we only create a single session per | ||
# invocation of the CLI command. | ||
session = Session(bind=config.registry["sqlalchemy.engine"]) | ||
|
||
_hash_journal_entries_submitted_from( | ||
session, salt, batch_size, sleep_time, continue_until_done | ||
) | ||
|
||
|
||
def _hash_journal_entries_submitted_from( | ||
session, | ||
salt: str, | ||
batch_size: int, | ||
sleep_time: int, | ||
continue_until_done: bool, | ||
) -> None: | ||
""" | ||
Perform hashing of the `journals.submitted_from` column | ||
Broken out from the CLI command so that it can be called recursively. | ||
""" | ||
from sqlalchemy import func, select | ||
|
||
from warehouse.packaging.models import JournalEntry | ||
|
||
# Get rows a batch at a time, only if the row hasn't already been hashed | ||
# (i.e. the value is shorter than 64 characters) | ||
unhashed_rows = session.scalars( | ||
select(JournalEntry) | ||
.where(func.length(JournalEntry.submitted_from) < 63) | ||
.order_by(JournalEntry.submitted_date) | ||
.limit(batch_size) | ||
).all() | ||
|
||
# If there are no rows to hash, we're done | ||
if not unhashed_rows: | ||
click.echo("No rows to hash. Done!") | ||
return | ||
|
||
how_many = len(unhashed_rows) | ||
|
||
# Hash the value rows | ||
click.echo(f"Hashing {how_many} rows...") | ||
for row in unhashed_rows: | ||
row.submitted_from = hashlib.sha256( | ||
(row.submitted_from + salt).encode("utf8") | ||
).hexdigest() | ||
|
||
# Update the rows | ||
session.add_all(unhashed_rows) | ||
session.commit() | ||
|
||
# If there are more rows to hash, recurse until done | ||
if continue_until_done and how_many == batch_size: | ||
click.echo(f"Hashed {batch_size} rows. Sleeping for {sleep_time} second(s)...") | ||
time.sleep(sleep_time) | ||
_hash_journal_entries_submitted_from( | ||
session, | ||
salt, | ||
batch_size, | ||
sleep_time, | ||
continue_until_done, | ||
) | ||
else: | ||
click.echo(f"Hashed {how_many} rows") | ||
return |