Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to CSV encoding/line endings/dialect inference #432

Merged
merged 5 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
repos:
- repo: https://github.com/ambv/black
rev: stable
rev: '20.8b1'
hooks:
- id: black
language_version: python3.7
language_version: python3.8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'master'
rev: 'v0.812'
hooks:
- id: mypy
11 changes: 11 additions & 0 deletions splitgraph/ingestion/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ def dedupe_sg_schema(schema_spec: TableSchema, prefix_len: int = 59) -> TableSch
return result


def generate_column_names(schema_spec: TableSchema, prefix: str = "col_") -> TableSchema:
"""Replace empty column names with autogenerated ones"""
result = []
for i, column in enumerate(schema_spec):
if column.name:
result.append(column)
else:
result.append(column._replace(name=f"{prefix}{i+1}"))
return result


def _format_jsonschema(prop, schema, required):
if prop == "tables":
return """tables: Tables to mount (default all). If a list, will import only these tables.
Expand Down
25 changes: 17 additions & 8 deletions splitgraph/ingestion/csv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def copy_csv_buffer(
copy_command += SQL(")")

cur.copy_expert(
cur.mogrify(copy_command, extra_args), data,
cur.mogrify(copy_command, extra_args),
data,
)


Expand All @@ -72,12 +73,6 @@ class CSVDataSource(ForeignDataWrapperDataSource):
params_schema = {
"type": "object",
"properties": {
"tables": {
"type": "object",
"additionalProperties": {
"options": {"type": "object", "additionalProperties": {"type": "string"}},
},
},
"url": {"type": "string", "description": "HTTP URL to the CSV file"},
"s3_endpoint": {
"type": "string",
Expand All @@ -95,6 +90,15 @@ class CSVDataSource(ForeignDataWrapperDataSource):
"type": "boolean",
"description": "Detect the CSV file's dialect (separator, quoting characters etc) automatically",
},
"autodetect_encoding": {
"type": "boolean",
"description": "Detect the CSV file's encoding automatically",
},
"autodetect_sample_size": {
"type": "integer",
"description": "Sample size, in bytes, for encoding/dialect/header detection",
},
"encoding": {"type": "string", "description": "Encoding of the CSV file"},
"header": {
"type": "boolean",
"description": "First line of the CSV file is its header",
Expand Down Expand Up @@ -129,7 +133,8 @@ class CSVDataSource(ForeignDataWrapperDataSource):
"s3_bucket": "data",
"s3_object_prefix": "csv_files/current/",
"autodetect_header": true,
"autodetect_dialect": true
"autodetect_dialect": true,
"autodetect_encoding": true
}
EOF
```
Expand Down Expand Up @@ -180,9 +185,13 @@ def get_server_options(self):
"url",
"autodetect_dialect",
"autodetect_header",
"autodetect_encoding",
"autodetect_sample_size",
"encoding",
"header",
"separator",
"quotechar",
"dialect",
]:
if k in self.params:
options[k] = str(self.params[k])
Expand Down
98 changes: 98 additions & 0 deletions splitgraph/ingestion/csv/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import csv
import io
from typing import Optional, Dict, Tuple, NamedTuple, Union, Type, TYPE_CHECKING

if TYPE_CHECKING:
import _csv

import chardet

from splitgraph.commandline.common import ResettableStream


class CSVOptions(NamedTuple):
autodetect_header: bool = True
autodetect_dialect: bool = True
autodetect_encoding: bool = True
autodetect_sample_size: int = 65536
delimiter: str = ","
quotechar: str = '"'
dialect: Optional[Union[str, Type[csv.Dialect]]] = "excel"
header: bool = True
encoding: str = "utf-8"

@classmethod
def from_fdw_options(cls, fdw_options):
return cls(
autodetect_header=get_bool(fdw_options, "autodetect_header"),
autodetect_dialect=get_bool(fdw_options, "autodetect_dialect"),
autodetect_encoding=get_bool(fdw_options, "autodetect_encoding"),
autodetect_sample_size=int(fdw_options.get("autodetect_sample_size", 65536)),
header=get_bool(fdw_options, "header"),
delimiter=fdw_options.get("delimiter", ","),
quotechar=fdw_options.get("quotechar", '"'),
dialect=fdw_options.get("dialect"),
encoding=fdw_options.get("encoding", "utf-8"),
)

def to_csv_kwargs(self):
if self.dialect:
return {"dialect": self.dialect}
return {"delimiter": self.delimiter, "quotechar": self.quotechar}


def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions:
"""Autodetect the CSV dialect, encoding, header etc."""
if not (
csv_options.autodetect_encoding
or csv_options.autodetect_header
or csv_options.autodetect_dialect
):
return csv_options

data = stream.read(csv_options.autodetect_sample_size)
assert data

if csv_options.autodetect_encoding:
encoding = chardet.detect(data)["encoding"]
if encoding == "ascii":
# ASCII is a subset of UTF-8. For safety, if chardet detected
# the encoding as ASCII, use UTF-8 (a valid ASCII file is a valid UTF-8 file,
# but not vice versa)
encoding = "utf-8"
csv_options = csv_options._replace(encoding=encoding)

sample = data.decode(csv_options.encoding)
# Emulate universal newlines mode (convert \r, \r\n, \n into \n)
sample = "\n".join(sample.splitlines())

if csv_options.autodetect_dialect:
dialect = csv.Sniffer().sniff(sample)
csv_options = csv_options._replace(dialect=dialect)

if csv_options.autodetect_header:
has_header = csv.Sniffer().has_header(sample)
csv_options = csv_options._replace(header=has_header)

return csv_options


def get_bool(params: Dict[str, str], key: str, default: bool = True) -> bool:
if key not in params:
return default
return params[key].lower() == "true"


def make_csv_reader(
response: io.IOBase, csv_options: CSVOptions
) -> Tuple[CSVOptions, "_csv._reader"]:
stream = ResettableStream(response)
csv_options = autodetect_csv(stream, csv_options)

stream.reset()
# https://docs.python.org/3/library/csv.html#id3
# Open with newline="" for universal newlines
io_stream = io.TextIOWrapper(io.BufferedReader(stream), encoding=csv_options.encoding, newline="") # type: ignore

reader = csv.reader(io_stream, **csv_options.to_csv_kwargs())
return csv_options, reader
124 changes: 35 additions & 89 deletions splitgraph/ingestion/csv/fdw.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import csv
import gzip
import io
import logging
import os
from copy import deepcopy
from itertools import islice
from typing import Tuple, Dict, Any
from typing import Tuple

import requests
from minio import Minio

import splitgraph.config
from splitgraph.commandline import get_exception_name
from splitgraph.commandline.common import ResettableStream
from splitgraph.ingestion.common import generate_column_names
from splitgraph.ingestion.csv.common import CSVOptions, get_bool, make_csv_reader
from splitgraph.ingestion.inference import infer_sg_schema

try:
Expand All @@ -37,61 +36,18 @@
_PG_LOGLEVEL = logging.INFO


def get_bool(params: Dict[str, str], key: str, default: bool = True) -> bool:
if key not in params:
return default
return params[key].lower() == "true"


def make_csv_reader(
response: io.IOBase,
autodetect_header: bool = False,
autodetect_dialect: bool = False,
delimiter: str = ",",
quotechar: str = '"',
header: bool = True,
encoding: str = "utf-8",
):
stream = ResettableStream(response)
if autodetect_header or autodetect_dialect:
data = stream.read(2048)
assert data
sniffer_sample = data.decode(encoding)

dialect = csv.Sniffer().sniff(sniffer_sample)
has_header = csv.Sniffer().has_header(sniffer_sample)

stream.reset()
io_stream = io.TextIOWrapper(io.BufferedReader(stream)) # type: ignore
csv_kwargs: Dict[str, Any] = (
{"dialect": dialect}
if autodetect_dialect
else {"delimiter": delimiter, "quotechar": quotechar}
)

if not autodetect_header:
has_header = header

reader = csv.reader(io_stream, **csv_kwargs,)
return has_header, reader


def _get_table_definition(response, fdw_options, table_name, table_options, encoding="utf-8"):
has_header, reader = make_csv_reader(
response,
autodetect_header=get_bool(fdw_options, "autodetect_header"),
autodetect_dialect=get_bool(fdw_options, "autodetect_dialect"),
header=get_bool(fdw_options, "header"),
delimiter=fdw_options.get("delimiter", ","),
quotechar=fdw_options.get("quotechar", '"'),
encoding=encoding,
)
def _get_table_definition(response, fdw_options, table_name, table_options):
csv_options, reader = make_csv_reader(response, CSVOptions.from_fdw_options(fdw_options))
sample = list(islice(reader, 1000))

if not has_header:
sample = [[str(i) for i in range(len(sample))]] + sample
if not csv_options.header:
sample = [[""] * len(sample)] + sample

sg_schema = infer_sg_schema(sample, None, None)

# For nonexistent column names: replace with autogenerated ones (can't have empty column names)
sg_schema = generate_column_names(sg_schema)

# Build Multicorn TableDefinition. ColumnDefinition takes in type OIDs,
# typmods and other internal PG stuff but other FDWs seem to get by with just
# the textual type name.
Expand Down Expand Up @@ -125,10 +81,10 @@ def explain(self, quals, columns, sortkeys=None, verbose=False):
f"Object ID: {self.s3_object}",
]

def _read_csv(self, csv_reader, header=True):
def _read_csv(self, csv_reader, csv_options):
header_skipped = False
for row in csv_reader:
if not header_skipped and header:
if not header_skipped and csv_options.header:
header_skipped = True
continue
# CSVs don't really distinguish NULLs and empty strings well. We know
Expand All @@ -149,31 +105,24 @@ def execute(self, quals, columns, sortkeys=None):
stream = response.raw
if response.headers.get("Content-Encoding") == "gzip":
stream = gzip.GzipFile(fileobj=stream)
has_header, reader = make_csv_reader(
stream,
self.autodetect_header,
self.autodetect_dialect,
self.delimiter,
self.quotechar,
self.header,
encoding=response.encoding,
)
yield from self._read_csv(reader, header=has_header)

csv_options = self.csv_options
if csv_options.encoding is None and not csv_options.autodetect_encoding:
csv_options = csv_options._replace(encoding=response.encoding)

csv_options, reader = make_csv_reader(stream, csv_options)
yield from self._read_csv(reader, csv_options)
else:
response = None
try:
response = self.s3_client.get_object(
bucket_name=self.s3_bucket, object_name=self.s3_object
)
has_header, reader = make_csv_reader(
response,
self.autodetect_header,
self.autodetect_dialect,
self.delimiter,
self.quotechar,
self.header,
)
yield from self._read_csv(reader, header=has_header)
csv_options = self.csv_options
if csv_options.encoding is None and not csv_options.autodetect_encoding:
csv_options = csv_options._replace(autodetect_encoding=True)
csv_options, reader = make_csv_reader(response, csv_options)
yield from self._read_csv(reader, csv_options)
finally:
if response:
response.close()
Expand All @@ -198,11 +147,7 @@ def import_schema(cls, schema, srv_options, options, restriction_type, restricts
stream = response.raw
if response.headers.get("Content-Encoding") == "gzip":
stream = gzip.GzipFile(fileobj=stream)
return [
_get_table_definition(
stream, fdw_options, "data", None, encoding=response.encoding
)
]
return [_get_table_definition(stream, fdw_options, "data", None)]

# Get S3 options
client, bucket, prefix = cls._get_s3_params(fdw_options)
Expand Down Expand Up @@ -233,7 +178,14 @@ def import_schema(cls, schema, srv_options, options, restriction_type, restricts
response = None
try:
response = client.get_object(bucket, o)
result.append(_get_table_definition(response, fdw_options, o, {"s3_object": o},))
result.append(
_get_table_definition(
response,
fdw_options,
o,
{"s3_object": o},
)
)
except Exception as e:
logging.error(
"Error scanning object %s, ignoring: %s: %s", o, get_exception_name(e), e
Expand Down Expand Up @@ -280,13 +232,7 @@ def __init__(self, fdw_options, fdw_columns):
# The foreign datawrapper columns (name -> ColumnDefinition).
self.fdw_columns = fdw_columns

self.delimiter = fdw_options.get("delimiter", ",")
self.quotechar = fdw_options.get("quotechar", '"')

self.header = get_bool(fdw_options, "header")

self.autodetect_header = get_bool(fdw_options, "autodetect_header")
self.autodetect_dialect = get_bool(fdw_options, "autodetect_dialect")
self.csv_options = CSVOptions.from_fdw_options(fdw_options)

# For HTTP: use full URL
if fdw_options.get("url"):
Expand Down
Loading