Skip to content

Commit

Permalink
Merge pull request #13 from akaihola/large-tables
Browse files Browse the repository at this point in the history
Support large tables which don't fit in RAM
  • Loading branch information
akaihola authored Apr 22, 2024
2 parents c5ba05b + 84d50b9 commit 230787a
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 27 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ Removed
Fixed
-----

- Very large tables are now sorted without crashing. This is done by merge sorting
in temporary files.


1.0.0_ / 2021-09-11
====================
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ strict_equality = True
disallow_any_decorated = False
disallow_untyped_defs = False

[mypy-pgtricks.mergesort]
disallow_any_explicit = False

[mypy-pytest.*]
ignore_missing_imports = True

Expand Down
76 changes: 76 additions & 0 deletions pgtricks/mergesort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Merge sort implementation to handle large files by sorting them in partitions."""

from __future__ import annotations

import sys
from heapq import merge
from tempfile import TemporaryFile
from typing import IO, Any, Callable, Iterable, Iterator, cast


class MergeSort(Iterable[str]):
"""Merge sort implementation to handle large files by sorting them in partitions."""

def __init__(
self,
key: Callable[[str], Any] = str,
directory: str = ".",
max_memory: int = 190,
) -> None:
"""Initialize the merge sort object."""
self._key = key
self._directory = directory
self._max_memory = max_memory
# Use binary mode to avoid newline conversion on Windows.
self._partitions: list[IO[bytes]] = []
self._iterating: Iterable[str] | None = None
self._buffer: list[str] = []
self._memory_counter: int = sys.getsizeof(self._buffer)
self._flush()

def append(self, line: str) -> None:
"""Append a line to the set of lines to be sorted."""
if self._iterating:
message = "Can't append lines after starting to sort"
raise ValueError(message)
self._memory_counter -= sys.getsizeof(self._buffer)
self._buffer.append(line)
self._memory_counter += sys.getsizeof(self._buffer)
self._memory_counter += sys.getsizeof(line)
if self._memory_counter >= self._max_memory:
self._flush()

def _flush(self) -> None:
if self._buffer:
# Use binary mode to avoid newline conversion on Windows.
self._partitions.append(TemporaryFile(mode="w+b", dir=self._directory))
self._partitions[-1].writelines(
line.encode("UTF-8") for line in sorted(self._buffer, key=self._key)
)
self._buffer = []
self._memory_counter = sys.getsizeof(self._buffer)

def __next__(self) -> str:
"""Return the next line in the sorted list of lines."""
if not self._iterating:
if self._partitions:
# At least one partition has already been flushed to disk.
# Iterate the merge sort for all partitions.
self._flush()
for partition in self._partitions:
partition.seek(0)
self._iterating = merge(
*[
(line.decode("UTF-8") for line in partition)
for partition in self._partitions
],
key=self._key,
)
else:
# All lines fit in memory. Iterate the list of lines directly.
self._iterating = iter(sorted(self._buffer, key=self._key))
return next(cast(Iterator[str], self._iterating))

def __iter__(self) -> Iterator[str]:
"""Return the iterator object for the sorted list of lines."""
return self
98 changes: 72 additions & 26 deletions pgtricks/pg_dump_splitsort.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
#!/usr/bin/env python

from __future__ import annotations

import functools
import io
import os
import re
import sys
from typing import IO, List, Match, Optional, Pattern, Tuple, Union, cast
from argparse import ArgumentParser
from typing import IO, Iterable, Match, Pattern, cast

from pgtricks.mergesort import MergeSort

COPY_RE = re.compile(r'COPY .*? \(.*?\) FROM stdin;\n$')
KIBIBYTE, MEBIBYTE, GIBIBYTE = 2**10, 2**20, 2**30
MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE}


def try_float(s1: str, s2: str) -> Union[Tuple[str, str], Tuple[float, float]]:
def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]:
"""Convert two strings to floats. Return original ones on conversion error."""
if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-':
# optimization
return s1, s2
Expand All @@ -22,7 +30,8 @@ def try_float(s1: str, s2: str) -> Union[Tuple[str, str], Tuple[float, float]]:
def linecomp(l1: str, l2: str) -> int:
p1 = l1.split('\t', 1)
p2 = l2.split('\t', 1)
v1, v2 = cast(Tuple[float, float], try_float(p1[0], p2[0]))
# TODO: unquote cast after support for Python 3.8 is dropped
v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0]))
result = (v1 > v2) - (v1 < v2)
# modifying a line to see whether Darker works:
if not result and len(p1) == len(p2) == 2:
Expand All @@ -37,9 +46,10 @@ def linecomp(l1: str, l2: str) -> int:

class Matcher(object):
def __init__(self) -> None:
self._match: Optional[Match[str]] = None
self._match: Match[str] | None = None

def match(self, pattern: Pattern[str], data: str) -> Optional[Match[str]]:
def match(self, pattern: Pattern[str], data: str) -> Match[str] | None:
"""Match the regular expression pattern against the data."""
self._match = pattern.match(data)
return self._match

Expand All @@ -49,34 +59,44 @@ def group(self, group1: str) -> str:
return self._match.group(group1)


def split_sql_file(sql_filepath: str) -> None:

def split_sql_file( # noqa: C901 too complex
sql_filepath: str,
max_memory: int = 100 * MEBIBYTE,
) -> None:
"""Split a SQL file so that each COPY statement is in its own file."""
directory = os.path.dirname(sql_filepath)

output: Optional[IO[str]] = None
buf: List[str] = []
# `output` needs to be instantiated before the inner functions are defined.
# Assign it a dummy string I/O object so type checking is happy.
# This will be replaced with the prologue SQL file object.
output: IO[str] = io.StringIO()
buf: list[str] = []

def flush() -> None:
cast(IO[str], output).writelines(buf)
output.writelines(buf)
buf[:] = []

def writelines(lines: Iterable[str]) -> None:
if buf:
flush()
output.writelines(lines)

def new_output(filename: str) -> IO[str]:
if output:
output.close()
return open(os.path.join(directory, filename), 'w')

copy_lines: Optional[List[str]] = None
sorted_data_lines: MergeSort | None = None
counter = 0
output = new_output('0000_prologue.sql')
matcher = Matcher()

for line in open(sql_filepath):
if copy_lines is None:
if sorted_data_lines is None:
if line in ('\n', '--\n'):
buf.append(line)
elif line.startswith('SET search_path = '):
flush()
buf.append(line)
writelines([line])
else:
if matcher.match(DATA_COMMENT_RE, line):
counter += 1
Expand All @@ -86,28 +106,54 @@ def new_output(filename: str) -> IO[str]:
schema=matcher.group('schema'),
table=matcher.group('table')))
elif COPY_RE.match(line):
copy_lines = []
sorted_data_lines = MergeSort(
key=functools.cmp_to_key(linecomp),
max_memory=max_memory,
)
elif SEQUENCE_SET_RE.match(line):
pass
elif 1 <= counter < 9999:
counter = 9999
output = new_output('%04d_epilogue.sql' % counter)
buf.append(line)
flush()
writelines([line])
else:
if line == '\\.\n':
copy_lines.sort(key=functools.cmp_to_key(linecomp))
buf.extend(copy_lines)
buf.append(line)
flush()
copy_lines = None
if line == "\\.\n":
writelines(sorted_data_lines)
writelines(line)
sorted_data_lines = None
else:
copy_lines.append(line)
sorted_data_lines.append(line)
flush()


def memory_size(size: str) -> int:
"""Parse a human-readable memory size.
:param size: The memory size to parse, e.g. "100MB".
:return: The memory size in bytes.
:raise ValueError: If the memory size is invalid.
"""
match = re.match(r"([\d._]+)\s*([kmg]?)b?", size.lower().strip())
if not match:
message = f"Invalid memory size: {size}"
raise ValueError(message)
return int(float(match.group(1)) * MEMORY_UNITS[match.group(2)])


def main() -> None:
split_sql_file(sys.argv[1])
parser = ArgumentParser(description="Split a SQL file into smaller files.")
parser.add_argument("sql_filepath", help="The SQL file to split.")
parser.add_argument(
"-m",
"--max-memory",
default=100 * MEBIBYTE,
type=memory_size,
help="Max memory to use, e.g. 50_000, 200000000, 100kb, 100MB (default), 2Gig.",
)
args = parser.parse_args()

split_sql_file(args.sql_filepath, args.max_memory)


if __name__ == '__main__':
Expand Down
110 changes: 110 additions & 0 deletions pgtricks/tests/test_mergesort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Tests for the `pgtricks.mergesort` module."""

import functools
from types import GeneratorType
from typing import Iterable, cast

import pytest

from pgtricks.mergesort import MergeSort
from pgtricks.pg_dump_splitsort import linecomp

# This is the biggest amount of memory which can't hold two one-character lines on any
# platform. On Windows it's slightly smaller than on Unix.
JUST_BELOW_TWO_SHORT_LINES = 174


@pytest.mark.parametrize("lf", ["\n", "\r\n"])
def test_mergesort_append(tmpdir, lf):
"""Test appending lines to the merge sort object."""
m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES)
m.append(f"1{lf}")
assert m._buffer == [f"1{lf}"]
m.append(f"2{lf}")
assert m._buffer == []
m.append(f"3{lf}")
assert m._buffer == [f"3{lf}"]
assert len(m._partitions) == 1
pos = m._partitions[0].tell()
m._partitions[0].seek(0)
assert m._partitions[0].read() == f"1{lf}2{lf}".encode()
assert pos == len(f"1{lf}2{lf}")


@pytest.mark.parametrize("lf", ["\n", "\r\n"])
def test_mergesort_flush(tmpdir, lf):
"""Test flushing the buffer to disk."""
m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES)
for value in [1, 2, 3]:
m.append(f"{value}{lf}")
m._flush()
assert len(m._partitions) == 2
assert m._partitions[0].tell() == len(f"1{lf}2{lf}")
m._partitions[0].seek(0)
assert m._partitions[0].read() == f"1{lf}2{lf}".encode()
pos = m._partitions[1].tell()
m._partitions[1].seek(0)
assert m._partitions[1].read() == f"3{lf}".encode()
assert pos == len(f"3{lf}")


@pytest.mark.parametrize("lf", ["\n", "\r\n"])
def test_mergesort_iterate_disk(tmpdir, lf):
"""Test iterating over the sorted lines on disk."""
m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES)
for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]:
m.append(f"{value}{lf}")
assert next(m) == f"1{lf}"
assert isinstance(m._iterating, GeneratorType)
assert next(m) == f"1{lf}"
assert next(m) == f"2{lf}"
assert next(m) == f"3{lf}"
assert next(m) == f"3{lf}"
assert next(m) == f"4{lf}"
assert next(m) == f"4{lf}"
assert next(m) == f"5{lf}"
assert next(m) == f"5{lf}"
assert next(m) == f"6{lf}"
assert next(m) == f"8{lf}"
assert next(m) == f"9{lf}"
with pytest.raises(StopIteration):
next(m)


@pytest.mark.parametrize("lf", ["\n", "\r\n"])
def test_mergesort_iterate_memory(tmpdir, lf):
"""Test iterating over the sorted lines when all lines fit in memory."""
m = MergeSort(
directory=tmpdir,
max_memory=1000000,
key=functools.cmp_to_key(linecomp),
)
for value in [3, 1, 4, 1, 5, 9, 2, 10, 6, 5, 3, 8, 4]:
m.append(f"{value}{lf}")
assert next(m) == f"1{lf}"
assert not isinstance(m._iterating, GeneratorType)
assert iter(cast(Iterable[str], m._iterating)) is m._iterating
assert next(m) == f"1{lf}"
assert next(m) == f"2{lf}"
assert next(m) == f"3{lf}"
assert next(m) == f"3{lf}"
assert next(m) == f"4{lf}"
assert next(m) == f"4{lf}"
assert next(m) == f"5{lf}"
assert next(m) == f"5{lf}"
assert next(m) == f"6{lf}"
assert next(m) == f"8{lf}"
assert next(m) == f"9{lf}"
assert next(m) == f"10{lf}"
with pytest.raises(StopIteration):
next(m)


@pytest.mark.parametrize("lf", ["\n", "\r\n"])
def test_mergesort_key(tmpdir, lf):
"""Test sorting lines based on a key function."""
m = MergeSort(directory=tmpdir, key=lambda line: -int(line[0]))
for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]:
m.append(f"{value}{lf}")
result = "".join(value[0] for value in m)
assert result == "986554433211"
Loading

0 comments on commit 230787a

Please sign in to comment.