8000 Clean up temporary archive directory at exit by vlthr · Pull Request #2184 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Clean up temporary archive directory at exit #2184

Merged
merged 2 commits into from
Dec 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions allennlp/models/archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import NamedTuple, Dict, Any
import atexit
import json
import logging
import os
Expand Down Expand Up @@ -109,7 +110,6 @@ def load_archive(archive_file: str,
else:
logger.info(f"loading archive file {archive_file} from cache at {resolved_archive_file}")

tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
Expand All @@ -118,6 +118,9 @@ def load_archive(archive_file: str,
logger.info(f"extracting archive file {resolved_archive_file} to temp dir {tempdir}")
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
# Postpone cleanup until exit in case the unarchived contents are needed outside
# this function.
atexit.register(_cleanup_archive_dir, tempdir)

serialization_dir = tempdir

Expand Down Expand Up @@ -152,8 +155,10 @@ def load_archive(archive_file: str,
serialization_dir=serialization_dir,
cuda_device=cuda_device)

if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)

return Archive(model=model, config=config)


def _cleanup_archive_dir(path: str):
if os.path.exists(path):
logger.info("removing temporary unarchived model dir at %s", path)
shutil.rmtree(path)
5 changes: 5 additions & 0 deletions allennlp/tests/models/archival_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=invalid-name
import copy
import os

import torch

Expand Down Expand Up @@ -93,5 +94,9 @@ def test_extra_files(self):
# (which we don't know, but we know what it ends with).
assert params.get('train_data_path').endswith('/fta/train_data_path')

# The temporary path should be accessible even after the load_archive
# function returns.
assert os.path.exists(params.get('train_data_path'))

# The validation data path should be the same though.
assert params.get('validation_data_path') == str(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
0