diff --git a/tableaudocumentapi/xfile.py b/tableaudocumentapi/xfile.py index 8d2bb9f..5a8a818 100644 --- a/tableaudocumentapi/xfile.py +++ b/tableaudocumentapi/xfile.py @@ -61,9 +61,9 @@ def _register_all_namespaces(): def find_file_in_zip(zip_file): - '''Returns the twb/tds file from a Tableau packaged file format. Packaged - files can contain cache entries which are also valid XML, so only look for - files with a .tds or .twb extension. + '''Returns the twb/tds filename and parsed XML tree from a Tableau packaged + file format. Packaged files can contain cache entries which are also valid + XML, so only look for files with a .tds or .twb extension. ''' candidate_files = filter(lambda x: x.split('.')[-1] in ('twb', 'tds'), @@ -72,17 +72,18 @@ def find_file_in_zip(zip_file): for filename in candidate_files: with zip_file.open(filename) as xml_candidate: try: - ET.parse(xml_candidate) - return filename + xml_tree = ET.parse(xml_candidate) + return filename, xml_tree except ET.ParseError: # That's not an XML file by gosh pass + return None, None + def get_xml_from_archive(filename): with zipfile.ZipFile(filename, allowZip64=True) as zf: - with zf.open(find_file_in_zip(zf)) as xml_file: - xml_tree = ET.parse(xml_file) + _, xml_tree = find_file_in_zip(zf) return xml_tree @@ -113,7 +114,7 @@ def save_into_archive(xml_tree, filename, new_filename=None): # Extract to temp directory with temporary_directory() as temp_path: with zipfile.ZipFile(filename, allowZip64=True) as zf: - xml_file = find_file_in_zip(zf) + xml_file, _ = find_file_in_zip(zf) zf.extractall(temp_path) # Write the new version of the file to the temp directory xml_tree.write(os.path.join( diff --git a/test/test_xfile.py b/test/test_xfile.py index c9686f5..a10088e 100644 --- a/test/test_xfile.py +++ b/test/test_xfile.py @@ -24,11 +24,13 @@ class XFileEdgeTests(unittest.TestCase): def test_find_file_in_zip_no_xml_file(self): badzip = zipfile.ZipFile(BAD_ZIP_FILE) - self.assertIsNone(find_file_in_zip(badzip)) + xml_file, xml_tree = find_file_in_zip(badzip) + self.assertIsNone(xml_file) def test_only_find_twbs(self): twb_from_twbx_with_cache = zipfile.ZipFile(TWBX_WITH_CACHE_FILES) - self.assertEqual(find_file_in_zip(twb_from_twbx_with_cache), 'Superstore.twb') + xml_file, xml_tree = find_file_in_zip(twb_from_twbx_with_cache) + self.assertEqual(xml_file, 'Superstore.twb') class Namespacing(unittest.TestCase):