diff --git a/geonode/tests/test_utils.py b/geonode/tests/test_utils.py index da8ada03a2e..53ba59924b5 100644 --- a/geonode/tests/test_utils.py +++ b/geonode/tests/test_utils.py @@ -1,9 +1,15 @@ +import os +import shutil +import zipfile +import tempfile + +from osgeo import ogr from datetime import datetime, timedelta from unittest.mock import patch from geonode.br.management.commands.utils.utils import ignore_time from geonode.tests.base import GeoNodeBaseTestSupport -from geonode.utils import copy_tree +from geonode.utils import copy_tree, fixup_shp_columnnames, unzip_file class TestCopyTree(GeoNodeBaseTestSupport): @@ -71,3 +77,35 @@ def test_backup_of_child_directories( """ copy_tree('/src', '/dst', ignore=ignore_time('>=', datetime.now().isoformat())) self.assertTrue(patch_shutil_copytree.called) + + +class TestFixupShp(GeoNodeBaseTestSupport): + def test_fixup_shp_columnnames(self): + project_root = os.path.abspath(os.path.dirname(__file__)) + layer_zip = os.path.join(project_root, "data", "ming_female_1.zip") + + self.failUnless(zipfile.is_zipfile(layer_zip)) + + layer_shp = unzip_file(layer_zip) + + expected_fieldnames = [ + "ID", "_f", "__1", "__2", "m", "_", "_M2", "_M2_1", "l", "x", "y", "_WU", "_1", + ] + _, _, fieldnames = fixup_shp_columnnames(layer_shp, "windows-1258") + + inDriver = ogr.GetDriverByName("ESRI Shapefile") + inDataSource = inDriver.Open(layer_shp, 0) + inLayer = inDataSource.GetLayer() + inLayerDefn = inLayer.GetLayerDefn() + + self.assertEqual(inLayerDefn.GetFieldCount(), len(expected_fieldnames)) + + for i, fn in enumerate(expected_fieldnames): + self.assertEqual(inLayerDefn.GetFieldDefn(i).GetName(), fn) + + inDataSource.Destroy() + + # Cleanup temp dir + shp_parent = os.path.dirname(layer_shp) + if shp_parent.startswith(tempfile.gettempdir()): + shutil.rmtree(shp_parent) diff --git a/geonode/utils.py b/geonode/utils.py index f08574a812e..c0bb12643bb 100755 --- a/geonode/utils.py +++ b/geonode/utils.py @@ -1118,6 +1118,32 @@ def check_shp_columnnames(layer): return fixup_shp_columnnames(inShapefile, layer.charset) +def clone_shp_field_defn(srcFieldDefn, name): + """ + Clone an existing ogr.FieldDefn with a new name + """ + dstFieldDefn = ogr.FieldDefn(name, srcFieldDefn.GetType()) + dstFieldDefn.SetWidth(srcFieldDefn.GetWidth()) + dstFieldDefn.SetPrecision(srcFieldDefn.GetPrecision()) + + return dstFieldDefn + + +def rename_shp_columnnames(inLayer, fieldnames): + """ + Rename columns in a layer to those specified in the given mapping + """ + inLayerDefn = inLayer.GetLayerDefn() + + for i in range(inLayerDefn.GetFieldCount()): + srcFieldDefn = inLayerDefn.GetFieldDefn(i) + dstFieldName = fieldnames.get(srcFieldDefn.GetName()) + + if dstFieldName is not None: + dstFieldDefn = clone_shp_field_defn(srcFieldDefn, dstFieldName) + inLayer.AlterFieldDefn(i, dstFieldDefn, ogr.ALTER_NAME_FLAG) + + def fixup_shp_columnnames(inShapefile, charset, tempdir=None): """ Try to fix column names and warn the user """ @@ -1125,6 +1151,7 @@ def fixup_shp_columnnames(inShapefile, charset, tempdir=None): if not tempdir: tempdir = tempfile.mkdtemp() + if is_zipfile(inShapefile): inShapefile = unzip_file(inShapefile, '.shp', tempdir=tempdir) @@ -1135,8 +1162,9 @@ def fixup_shp_columnnames(inShapefile, charset, tempdir=None): tb = traceback.format_exc() logger.debug(tb) inDataSource = None + if inDataSource is None: - logger.debug('Could not open %s' % (inShapefile)) + logger.debug("Could not open {}".format(inShapefile)) return False, None, None else: inLayer = inDataSource.GetLayer() @@ -1153,7 +1181,7 @@ def fixup_shp_columnnames(inShapefile, charset, tempdir=None): list_col_original = [] list_col = {} - for i in range(0, inLayerDefn.GetFieldCount()): + for i in range(inLayerDefn.GetFieldCount()): try: field_name = inLayerDefn.GetFieldDefn(i).GetName() if a.match(field_name): @@ -1162,21 +1190,12 @@ def fixup_shp_columnnames(inShapefile, charset, tempdir=None): logger.exception(e) return True, None, None - for i in range(0, inLayerDefn.GetFieldCount()): + for i in range(inLayerDefn.GetFieldCount()): try: field_name = inLayerDefn.GetFieldDefn(i).GetName() if not a.match(field_name): # once the field_name contains Chinese, to use slugify_zh - has_ch = False - for ch in field_name: - try: - if '\u4e00' <= ch.decode("utf-8", "surrogateescape") <= '\u9fff': - has_ch = True - break - except Exception: - has_ch = True - break - if has_ch: + if any('\u4e00' <= ch <= '\u9fff' for ch in field_name): new_field_name = slugify_zh(field_name, separator='_') else: new_field_name = slugify(field_name) @@ -1189,7 +1208,8 @@ def fixup_shp_columnnames(inShapefile, charset, tempdir=None): if new_field_name.endswith('_' + str(j)): j += 1 new_field_name = new_field_name[:-2] + '_' + str(j) - list_col.update({field_name: new_field_name}) + if field_name != new_field_name: + list_col[field_name] = new_field_name except Exception as e: logger.exception(e) return True, None, None @@ -1198,9 +1218,9 @@ def fixup_shp_columnnames(inShapefile, charset, tempdir=None): return True, None, None else: try: - for key in list_col.keys(): - qry = "ALTER TABLE \"{}\" RENAME COLUMN \"{}\" TO \"{}\"".format(inLayer.GetName(), key, list_col[key]) - inDataSource.ExecuteSQL(qry) + rename_shp_columnnames(inLayer, list_col) + inDataSource.SyncToDisk() + inDataSource.Destroy() except Exception as e: logger.exception(e) raise GeoNodeException(