diff --git a/CHANGES.rst b/CHANGES.rst index 3fc8314..c3796a7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -13,6 +13,9 @@ Removed Fixed ----- +- Very large tables are now sorted without crashing. This is done by merge sorting + in temporary files. + 1.0.0_ / 2021-09-11 ==================== diff --git a/mypy.ini b/mypy.ini index ed30dab..bbb2f00 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,6 +32,9 @@ strict_equality = True disallow_any_decorated = False disallow_untyped_defs = False +[mypy-pgtricks.mergesort] +disallow_any_explicit = False + [mypy-pytest.*] ignore_missing_imports = True diff --git a/pgtricks/mergesort.py b/pgtricks/mergesort.py new file mode 100644 index 0000000..f22f020 --- /dev/null +++ b/pgtricks/mergesort.py @@ -0,0 +1,76 @@ +"""Merge sort implementation to handle large files by sorting them in partitions.""" + +from __future__ import annotations + +import sys +from heapq import merge +from tempfile import TemporaryFile +from typing import IO, Any, Callable, Iterable, Iterator, cast + + +class MergeSort(Iterable[str]): + """Merge sort implementation to handle large files by sorting them in partitions.""" + + def __init__( + self, + key: Callable[[str], Any] = str, + directory: str = ".", + max_memory: int = 190, + ) -> None: + """Initialize the merge sort object.""" + self._key = key + self._directory = directory + self._max_memory = max_memory + # Use binary mode to avoid newline conversion on Windows. + self._partitions: list[IO[bytes]] = [] + self._iterating: Iterable[str] | None = None + self._buffer: list[str] = [] + self._memory_counter: int = sys.getsizeof(self._buffer) + self._flush() + + def append(self, line: str) -> None: + """Append a line to the set of lines to be sorted.""" + if self._iterating: + message = "Can't append lines after starting to sort" + raise ValueError(message) + self._memory_counter -= sys.getsizeof(self._buffer) + self._buffer.append(line) + self._memory_counter += sys.getsizeof(self._buffer) + self._memory_counter += sys.getsizeof(line) + if self._memory_counter >= self._max_memory: + self._flush() + + def _flush(self) -> None: + if self._buffer: + # Use binary mode to avoid newline conversion on Windows. + self._partitions.append(TemporaryFile(mode="w+b", dir=self._directory)) + self._partitions[-1].writelines( + line.encode("UTF-8") for line in sorted(self._buffer, key=self._key) + ) + self._buffer = [] + self._memory_counter = sys.getsizeof(self._buffer) + + def __next__(self) -> str: + """Return the next line in the sorted list of lines.""" + if not self._iterating: + if self._partitions: + # At least one partition has already been flushed to disk. + # Iterate the merge sort for all partitions. + self._flush() + for partition in self._partitions: + partition.seek(0) + self._iterating = merge( + *[ + (line.decode("UTF-8") for line in partition) + for partition in self._partitions + ], + key=self._key, + ) + else: + # All lines fit in memory. Iterate the list of lines directly. + self._iterating = iter(sorted(self._buffer, key=self._key)) + return next(cast(Iterator[str], self._iterating)) + + def __iter__(self) -> Iterator[str]: + """Return the iterator object for the sorted list of lines.""" + return self diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 56908ea..aab1258 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -1,15 +1,23 @@ #!/usr/bin/env python +from __future__ import annotations + import functools +import io import os import re -import sys -from typing import IO, List, Match, Optional, Pattern, Tuple, Union, cast +from argparse import ArgumentParser +from typing import IO, Iterable, Match, Pattern, cast + +from pgtricks.mergesort import MergeSort COPY_RE = re.compile(r'COPY .*? \(.*?\) FROM stdin;\n$') +KIBIBYTE, MEBIBYTE, GIBIBYTE = 2**10, 2**20, 2**30 +MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE} -def try_float(s1: str, s2: str) -> Union[Tuple[str, str], Tuple[float, float]]: +def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]: + """Convert two strings to floats. Return original ones on conversion error.""" if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-': # optimization return s1, s2 @@ -22,7 +30,8 @@ def try_float(s1: str, s2: str) -> Union[Tuple[str, str], Tuple[float, float]]: def linecomp(l1: str, l2: str) -> int: p1 = l1.split('\t', 1) p2 = l2.split('\t', 1) - v1, v2 = cast(Tuple[float, float], try_float(p1[0], p2[0])) + # TODO: unquote cast after support for Python 3.8 is dropped + v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0])) result = (v1 > v2) - (v1 < v2) # modifying a line to see whether Darker works: if not result and len(p1) == len(p2) == 2: @@ -37,9 +46,10 @@ def linecomp(l1: str, l2: str) -> int: class Matcher(object): def __init__(self) -> None: - self._match: Optional[Match[str]] = None + self._match: Match[str] | None = None - def match(self, pattern: Pattern[str], data: str) -> Optional[Match[str]]: + def match(self, pattern: Pattern[str], data: str) -> Match[str] | None: + """Match the regular expression pattern against the data.""" self._match = pattern.match(data) return self._match @@ -49,34 +59,44 @@ def group(self, group1: str) -> str: return self._match.group(group1) -def split_sql_file(sql_filepath: str) -> None: - +def split_sql_file( # noqa: C901 too complex + sql_filepath: str, + max_memory: int = 100 * MEBIBYTE, +) -> None: + """Split a SQL file so that each COPY statement is in its own file.""" directory = os.path.dirname(sql_filepath) - output: Optional[IO[str]] = None - buf: List[str] = [] + # `output` needs to be instantiated before the inner functions are defined. + # Assign it a dummy string I/O object so type checking is happy. + # This will be replaced with the prologue SQL file object. + output: IO[str] = io.StringIO() + buf: list[str] = [] def flush() -> None: - cast(IO[str], output).writelines(buf) + output.writelines(buf) buf[:] = [] + def writelines(lines: Iterable[str]) -> None: + if buf: + flush() + output.writelines(lines) + def new_output(filename: str) -> IO[str]: if output: output.close() return open(os.path.join(directory, filename), 'w') - copy_lines: Optional[List[str]] = None + sorted_data_lines: MergeSort | None = None counter = 0 output = new_output('0000_prologue.sql') matcher = Matcher() for line in open(sql_filepath): - if copy_lines is None: + if sorted_data_lines is None: if line in ('\n', '--\n'): buf.append(line) elif line.startswith('SET search_path = '): - flush() - buf.append(line) + writelines([line]) else: if matcher.match(DATA_COMMENT_RE, line): counter += 1 @@ -86,28 +106,54 @@ def new_output(filename: str) -> IO[str]: schema=matcher.group('schema'), table=matcher.group('table'))) elif COPY_RE.match(line): - copy_lines = [] + sorted_data_lines = MergeSort( + key=functools.cmp_to_key(linecomp), + max_memory=max_memory, + ) elif SEQUENCE_SET_RE.match(line): pass elif 1 <= counter < 9999: counter = 9999 output = new_output('%04d_epilogue.sql' % counter) - buf.append(line) - flush() + writelines([line]) else: - if line == '\\.\n': - copy_lines.sort(key=functools.cmp_to_key(linecomp)) - buf.extend(copy_lines) - buf.append(line) - flush() - copy_lines = None + if line == "\\.\n": + writelines(sorted_data_lines) + writelines(line) + sorted_data_lines = None else: - copy_lines.append(line) + sorted_data_lines.append(line) flush() +def memory_size(size: str) -> int: + """Parse a human-readable memory size. + + :param size: The memory size to parse, e.g. "100MB". + :return: The memory size in bytes. + :raise ValueError: If the memory size is invalid. + + """ + match = re.match(r"([\d._]+)\s*([kmg]?)b?", size.lower().strip()) + if not match: + message = f"Invalid memory size: {size}" + raise ValueError(message) + return int(float(match.group(1)) * MEMORY_UNITS[match.group(2)]) + + def main() -> None: - split_sql_file(sys.argv[1]) + parser = ArgumentParser(description="Split a SQL file into smaller files.") + parser.add_argument("sql_filepath", help="The SQL file to split.") + parser.add_argument( + "-m", + "--max-memory", + default=100 * MEBIBYTE, + type=memory_size, + help="Max memory to use, e.g. 50_000, 200000000, 100kb, 100MB (default), 2Gig.", + ) + args = parser.parse_args() + + split_sql_file(args.sql_filepath, args.max_memory) if __name__ == '__main__': diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py new file mode 100644 index 0000000..6f7c0b6 --- /dev/null +++ b/pgtricks/tests/test_mergesort.py @@ -0,0 +1,110 @@ +"""Tests for the `pgtricks.mergesort` module.""" + +import functools +from types import GeneratorType +from typing import Iterable, cast + +import pytest + +from pgtricks.mergesort import MergeSort +from pgtricks.pg_dump_splitsort import linecomp + +# This is the biggest amount of memory which can't hold two one-character lines on any +# platform. On Windows it's slightly smaller than on Unix. +JUST_BELOW_TWO_SHORT_LINES = 174 + + +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_append(tmpdir, lf): + """Test appending lines to the merge sort object.""" + m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES) + m.append(f"1{lf}") + assert m._buffer == [f"1{lf}"] + m.append(f"2{lf}") + assert m._buffer == [] + m.append(f"3{lf}") + assert m._buffer == [f"3{lf}"] + assert len(m._partitions) == 1 + pos = m._partitions[0].tell() + m._partitions[0].seek(0) + assert m._partitions[0].read() == f"1{lf}2{lf}".encode() + assert pos == len(f"1{lf}2{lf}") + + +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_flush(tmpdir, lf): + """Test flushing the buffer to disk.""" + m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES) + for value in [1, 2, 3]: + m.append(f"{value}{lf}") + m._flush() + assert len(m._partitions) == 2 + assert m._partitions[0].tell() == len(f"1{lf}2{lf}") + m._partitions[0].seek(0) + assert m._partitions[0].read() == f"1{lf}2{lf}".encode() + pos = m._partitions[1].tell() + m._partitions[1].seek(0) + assert m._partitions[1].read() == f"3{lf}".encode() + assert pos == len(f"3{lf}") + + +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_iterate_disk(tmpdir, lf): + """Test iterating over the sorted lines on disk.""" + m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES) + for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: + m.append(f"{value}{lf}") + assert next(m) == f"1{lf}" + assert isinstance(m._iterating, GeneratorType) + assert next(m) == f"1{lf}" + assert next(m) == f"2{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"6{lf}" + assert next(m) == f"8{lf}" + assert next(m) == f"9{lf}" + with pytest.raises(StopIteration): + next(m) + + +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_iterate_memory(tmpdir, lf): + """Test iterating over the sorted lines when all lines fit in memory.""" + m = MergeSort( + directory=tmpdir, + max_memory=1000000, + key=functools.cmp_to_key(linecomp), + ) + for value in [3, 1, 4, 1, 5, 9, 2, 10, 6, 5, 3, 8, 4]: + m.append(f"{value}{lf}") + assert next(m) == f"1{lf}" + assert not isinstance(m._iterating, GeneratorType) + assert iter(cast(Iterable[str], m._iterating)) is m._iterating + assert next(m) == f"1{lf}" + assert next(m) == f"2{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"6{lf}" + assert next(m) == f"8{lf}" + assert next(m) == f"9{lf}" + assert next(m) == f"10{lf}" + with pytest.raises(StopIteration): + next(m) + + +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_key(tmpdir, lf): + """Test sorting lines based on a key function.""" + m = MergeSort(directory=tmpdir, key=lambda line: -int(line[0])) + for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: + m.append(f"{value}{lf}") + result = "".join(value[0] for value in m) + assert result == "986554433211" diff --git a/pgtricks/tests/test_pg_dump_splitsort.py b/pgtricks/tests/test_pg_dump_splitsort.py index 3305c03..74e6b56 100644 --- a/pgtricks/tests/test_pg_dump_splitsort.py +++ b/pgtricks/tests/test_pg_dump_splitsort.py @@ -1,8 +1,9 @@ from functools import cmp_to_key +from textwrap import dedent import pytest -from pgtricks.pg_dump_splitsort import linecomp, try_float +from pgtricks.pg_dump_splitsort import linecomp, memory_size, split_sql_file, try_float @pytest.mark.parametrize( @@ -101,3 +102,111 @@ def test_linecomp_by_sorting(): [r'\N', r'\N', r'\N'], [r'\N', 'foo', '.42'], ] + + +PROLOGUE = dedent( + """ + + -- + -- Name: table1; Type: TABLE; Schema: public; Owner: + -- + + (information for table1 goes here) + """, +) + +TABLE1_COPY = dedent( + r""" + + -- Data for Name: table1; Type: TABLE DATA; Schema: public; + + COPY foo (id) FROM stdin; + 3 + 1 + 4 + 1 + 5 + 9 + 2 + 6 + 5 + 3 + 8 + 4 + \. + """, +) + +TABLE1_COPY_SORTED = dedent( + r""" + + -- Data for Name: table1; Type: TABLE DATA; Schema: public; + + COPY foo (id) FROM stdin; + 1 + 1 + 2 + 3 + 3 + 4 + 4 + 5 + 5 + 6 + 8 + 9 + \. + """, +) + +EPILOGUE = dedent( + """ + -- epilogue + """, +) + + +def test_split_sql_file(tmpdir): + """Test splitting a SQL file with COPY statements.""" + sql_file = tmpdir / "test.sql" + sql_file.write(PROLOGUE + TABLE1_COPY + EPILOGUE) + + split_sql_file(sql_file, max_memory=190) + + split_files = sorted(path.relto(tmpdir) for path in tmpdir.listdir()) + assert split_files == [ + "0000_prologue.sql", + "0001_public.table1.sql", + "9999_epilogue.sql", + "test.sql", + ] + assert (tmpdir / "0000_prologue.sql").read() == PROLOGUE + assert (tmpdir / "0001_public.table1.sql").read() == TABLE1_COPY_SORTED + assert (tmpdir / "9999_epilogue.sql").read() == EPILOGUE + + +@pytest.mark.parametrize( + ("size", "expect"), + [ + ("0", 0), + ("1", 1), + ("1k", 1024), + ("1m", 1024**2), + ("1g", 1024**3), + ("100_000K", 102400000), + ("1.5M", 1536 * 1024), + ("1.5G", 1536 * 1024**2), + ("1.5", 1), + ("1.5 kibibytes", 1536), + ("1.5 Megabytes", 1024 * 1536), + ("1.5 Gigs", 1024**2 * 1536), + ("1.5KB", 1536), + (".5MB", 512 * 1024), + ("20GB", 20 * 1024**3), + ], +) +def test_memory_size(size, expect): + """Test parsing human-readable memory sizes with `memory_size`.""" + result = memory_size(size) + + assert result == expect diff --git a/pyproject.toml b/pyproject.toml index 086fffd..6f3739b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,9 @@ ignore = [ "ANN201", # Missing return type annotation for public function #"ANN204", # Missing return type annotation for special method `__init__` #"C408", # Unnecessary `dict` call (rewrite as a literal) + "PLR2004", # Magic value used in comparison "S101", # Use of `assert` detected + "SLF001", # Private member accessed ] [tool.ruff.lint.isort]