Skip to content
This repository has been archived by the owner on Sep 3, 2022. It is now read-only.

Commit

Permalink
CsvDataSet no longer globs files in init. (#187)
Browse files Browse the repository at this point in the history
* CsvDataSet no longer globs files in init.

* removed file_io, that fix will be done later

* removed junk lines

* sample uses .file

* fixed csv dataset def files()

* Update _dataset.py
  • Loading branch information
brandondutra authored and qimingj committed Feb 22, 2017
1 parent cef4eae commit b81244f
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions datalab/mlalpha/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
class CsvDataSet(object):
"""DataSet based on CSV files and schema."""

def __init__(self, files, schema=None, schema_file=None):
def __init__(self, file_pattern, schema=None, schema_file=None):
"""
Args:
files: A list of CSV files. Can contain wildcards in file names. Can be local or GCS path.
file_pattern: A list of CSV files. or a string. Can contain wildcards in
file names. Can be local or GCS path.
schema: A BigQuery schema object in the form of
[{'name': 'col1', 'type': 'STRING'},
{'name': 'col2', 'type': 'INTEGER'}]
Expand All @@ -59,21 +60,32 @@ def __init__(self, files, schema=None, schema_file=None):
else:
with ml.util._file.open_local_or_gcs(schema_file, 'r') as f:
self._schema = json.load(f)

if isinstance(files, basestring):
files = [files]
self._files = []
for file in files:
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
self._files += [str(x) for x in ml.util._file.glob_files(file)]
self._input_files = files

self._glob_files = []


@property
def _input_files(self):
"""Returns the file list that was given to this class without globing files."""
return self._input_files

@property
def files(self):
return self._files
if not self._glob_files:
for file in self._input_files:
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
self._glob_files += [str(x) for x in ml.util._file.glob_files(file)]

return self._glob_files

@property
def schema(self):
return self._schema
return self._schema

def sample(self, n):
""" Samples data into a Pandas DataFrame.
Args:
Expand All @@ -85,7 +97,7 @@ def sample(self, n):
"""
row_total_count = 0
row_counts = []
for file in self._files:
for file in self.files:
with ml.util._file.open_local_or_gcs(file, 'r') as f:
num_lines = sum(1 for line in f)
row_total_count += num_lines
Expand All @@ -108,7 +120,7 @@ def sample(self, n):
# Note that random.sample will raise Exception if skip_count is greater than rows count.
skip_all = sorted(random.sample(xrange(0, row_total_count), skip_count))
dfs = []
for file, row_count in zip(self._files, row_counts):
for file, row_count in zip(self.files, row_counts):
skip = [x for x in skip_all if x < row_count]
skip_all = [x - row_count for x in skip_all if x >= row_count]
with ml.util._file.open_local_or_gcs(file, 'r') as f:
Expand Down

0 comments on commit b81244f

Please sign in to comment.