8000 Track dev loss in ATIS model by kl2806 · Pull Request #1907 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Track dev loss in ATIS model #1907

Merged
merged 3 commits into from
Oct 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions allennlp/data/dataset_readers/semantic_parsing/atis.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def text_to_instance(self, # type: ignore
try:
action_sequence = world.get_action_sequence(sql_query)
except ParseError:
action_sequence = []
logger.debug(f'Parsing error')

tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
Expand All @@ -159,12 +160,14 @@ def text_to_instance(self, # type: ignore

if sql_query_labels != None:
fields['sql_queries'] = MetadataField(sql_query_labels)
if action_sequence and not self._keep_if_unparseable:
if self._keep_if_unparseable or action_sequence:
for production_rule in action_sequence:
index_fields.append(IndexField(action_map[production_rule], action_field))
if not action_sequence:
index_fields = [IndexField(-1, action_field)]
action_sequence_field = ListField(index_fields)
fields['target_action_sequence'] = action_sequence_field
elif not self._keep_if_unparseable:
else:
# If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
# to keep it, then we will skip the it.
return None
Expand Down
11 changes: 11 additions & 0 deletions allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from allennlp.semparse.worlds import AtisWorld

class TestAtisReader(AllenNlpTestCase):
def test_atis_keep_unparseable(self):
database_file = cached_path("https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db")
reader = AtisDatasetReader(database_file=database_file, keep_if_unparseable=True)
instance = reader.text_to_instance(utterances=['show me the one way flights from detroit me to westchester county'],
sql_query_labels=['this is not a query that can be parsed'])

# If we have a query that can't be parsed, we check that it only has one element in the list of index fields and
# that index is the padding index, -1.
assert len(instance.fields['target_action_sequence'].field_list) == 1
assert instance.fields['target_action_sequence'].field_list[0].sequence_index == -1

def test_atis_read_from_file(self):
data_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "atis" / "sample.json"
database_file = cached_path("https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db")
Expand Down
0