8000 Use tempfile for unit test by tushuhei · Pull Request #114 · google/budoux · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Use tempfile for unit test #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import os
import sys
import tempfile
import typing
import unittest

Expand Down Expand Up @@ -73,11 +74,10 @@ def test_cmdargs_full(self) -> None:


< 10000 span class='blob-code-inner blob-code-marker ' data-code-marker=" ">class TestPreprocess(unittest.TestCase):
ENTRIES_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'entries_test.txt'))

def test_standard_setup(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
entries_file_path = tempfile.NamedTemporaryFile().name
with open(entries_file_path, 'w') as f:
f.write(('1\tfoo\tbar\n'
'-1\tfoo\n'
'1\tfoo\tbar\tbaz\n'
Expand All @@ -90,30 +90,29 @@ def test_standard_setup(self) -> None:
# 1 1 1 1
# 1 1 1 0
# -1 0 0 1
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1)
rows, cols, Y, features = train.preprocess(entries_file_path, 1)
self.assertEqual(features, ['foo', 'bar', 'baz'])
self.assertEqual(Y.tolist(), [True, False, True, True, False])
self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4])
self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2])
os.remove(entries_file_path)

def test_skip_invalid_rows(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
entries_file_path = tempfile.NamedTemporaryFile().name
with open(entries_file_path, 'w') as f:
f.write(('\n1\tfoo\tbar\n'
'-1\n\n'
'-1\tfoo\n\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar)
# 1 1 1
# -1 1 0
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0)
rows, cols, Y, features = train.preprocess(entries_file_path, 0)
self.assertEqual(features, ['foo', 'bar'])
self.assertEqual(Y.tolist(), [True, False])
self.assertEqual(rows.tolist(), [0, 0, 1])
self.assertEqual(cols.tolist(), [0, 1, 0])

def tearDown(self) -> None:
if (os.path.exists(self.ENTRIES_FILE_PATH)):
os.remove(self.ENTRIES_FILE_PATH)
os.remove(entries_file_path)


class TestSplitData(unittest.TestCase):
Expand Down Expand Up @@ -205,12 +204,10 @@ def test_standard_setup1(self) -> None:


class TestFit(unittest.TestCase):
WEIGHTS_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'weights_test.txt'))
LOG_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'train_test.log'))

def test_fit(self) -> None:
weights_file_path = tempfile.NamedTemporaryFile().name
log_file_path = tempfile.NamedTemporaryFile().name
# Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly
# correlates with Y in a negative way.
X = np.array([
Expand All @@ -225,8 +222,8 @@ def test_fit(self) -> None:
iters = 5
out_span = 2
scores = train.fit(rows, cols, rows, cols, Y, Y, features, iters,
self.WEIGHTS_FILE_PATH, self.LOG_FILE_PATH, out_span)
with open(self.WEIGHTS_FILE_PATH) as f:
weights_file_path, log_file_path, out_span)
with open(weights_file_path) as f:
weights = [
line.split('\t') for line in f.read().splitlines() if line.strip()
]
Expand All @@ -238,7 +235,7 @@ def test_fit(self) -> None:
iters,
msg='The number of lines should equal to the iteration count.')

with open(self.LOG_FILE_PATH) as f:
with open(log_file_path) as f:
log = [line.split('\t') for line in f.read().splitlines() if line.strip()]
self.assertEqual(
len(log),
Expand All @@ -257,10 +254,8 @@ def test_fit(self) -> None:
self.assertEqual(scores.shape[0], len(features))
loaded_scores = [model.get(feature, 0) for feature in features]
self.assertTrue(np.all(np.isclose(scores, loaded_scores)))

def tearDown(self) -> None:
os.remove(self.WEIGHTS_FILE_PATH)
os.remove(self.LOG_FILE_PATH)
os.remove(weights_file_path)
os.remove(log_file_path)


if __name__ == '__main__':
Expand Down
0