8000 Update `sem_map` by vincentzed · Pull Request #173 · lotus-data/lotus · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Update sem_map #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 111 additions & 9 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ def test_sem_extract(setup_models, model):

for idx, row in df.iterrows():
assert row["Name"] in row["Name_quote"], f"Name '{row['Name']}' not found in '{row['Name_quote']}'"
assert (
row["Sport"].lower() in row["Sport_quote"].lower()
), f"Sport '{row['Sport']}' not found in '{row['Sport_quote']}'"
assert (
str(row["Number of Championships"]) in row["Number of Championships_quote"]
), f"Number of Championships '{row['Number of Championships']}' not found in '{row['Number of Championships_quote']}'"
assert row["Sport"].lower() in row["Sport_quote"].lower(), (
f"Sport '{row['Sport']}' not found in '{row['Sport_quote']}'"
)
assert str(row["Number of Championships"]) in row["Number of Championships_quote"], (
f"Number of Championships '{row['Number of Championships']}' not found in '{row['Number of Championships_quote']}'"
)


################################################################################
Expand Down Expand Up @@ -453,9 +453,9 @@ def test_join_cascade(setup_models):
school, school_type = pair
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert (
stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"]
), stats # helper negative still can still meet the precision target
assert stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"], (
stats
) # helper negative still can still meet the precision target
assert stats["join_helper_positive"] == 0, stats


Expand Down Expand Up @@ -494,3 +494,105 @@ def test_custom_tokenizer():
tokens = custom_lm.count_tokens("Hello, world!")
assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens
assert tokens < 100


################################################################################
# sem_map nsample and temp tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1"))
def test_sem_map_nsample(setup_models, model):
"""Test that sem_map properly handles nsample > 1."""
lm = setup_models[model]
lotus.settings.configure(lm=lm)

# Test basic sem_map operation with multiple samples
data = {"Text": ["The sky is blue", "Water is wet"]}
df = pd.DataFrame(data)
user_instruction = "Describe {Text} in one sentence"

# Generate 3 descriptions per input
multi_df = df.sem_map(user_instruction, nsample=3)

# Check that we have the expected columns
assert "_map1" in multi_df.columns
assert "_map2" in multi_df.columns
assert "_map3" in multi_df.columns

# Check that each column contains non-empty strings
for i in range(1, 4):
col = f"_map{i}"
assert all(isinstance(val, str) and len(val) > 0 for val in multi_df[col])

# Check that we get different outputs for different samples (at least sometimes)
# We can't guarantee different outputs every time due to the probabilistic nature
# but they should differ at least once in our test data
different_outputs = False
for i in range(1, 3):
if any(multi_df[f"_map{i}"] != multi_df[f"_map{i + 1}"]):
#different_outputs = True
break
# assert different_outputs, "Expected different outputs for different samples"
# actually, this assertion isn't quite correct or always right, since even if temp > 1, it's possible that it's not correct.


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_sem_map_nsample_with_returns(setup_models, model):
"""Test sem_map with nsample > 1 and return_explanations/return_raw_outputs."""
lm = setup_models[model]
lotus.settings.configure(lm=lm)

data = {"Text": ["The sky is blue"]}
df = pd.DataFrame(data)
user_instruction = "Describe {Text} in one sentence"

# Generate 2 samples with explanations and raw outputs
multi_df = df.sem_map(user_instruction, nsample=2, return_explanations=True, return_raw_outputs=True)

# Check that we have the expected output columns
assert "_map1" in multi_df.columns
assert "_map2" in multi_df.columns

# Check that we have the expected explanation columns
assert "explanation_map1" in multi_df.columns
assert "explanation_map2" in multi_df.columns

# Check that we have the expected raw output columns
assert "raw_output_map1" in multi_df.columns
assert "raw_output_map2" in multi_df.columns


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_sem_map_temp(setup_models, model):
"""Test that sem_map properly handles the temp parameter."""
lm = setup_models[model]
lotus.settings.configure(lm=lm)

data = {"Text": ["Creativity prompt: describe a new animal"]}
df = pd.DataFrame(data)
user_instruction = "Respond to {Text} with a short description"

# Generate 5 samples with high temperature for more variation
high_temp_df = df.sem_map(
user_instruction,
nsample=5,
temp=1.0, # High temperature
)

# Generate 5 samples with low temperature for less variation
low_temp_df = df.sem_map(
user_instruction,
nsample=5,
temp=0.1, # Low temperature
)

# We can't make strong assertions about the specific content,
# but we can verify columns exist
for i in range(1, 6):
assert f"_map{i}" in high_temp_df.columns
assert f"_map{i}" in low_temp_df.columns

# All outputs should be non-empty strings
for i in range(1, 6):
col = f"_map{i}"
assert all(isinstance(val, str) and len(val) > 0 for val in high_temp_df[col])
assert all(isinstance(val, str) and len(val) > 0 for val in low_temp_df[col])
32 changes: 30 additions & 2 deletions docs/sem_map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ sem_map

Overview
----------
This operato performs a semantic projection over an input column. The langex parameter specifies this projection in natural language.
This operator performs a semantic projection over an input column. The langex parameter specifies this projection in natural language. It can generate a single output or multiple sample outputs for each input.

Motivation
-----------
The sem_map operator is useful for performing a row-wise operations over the data.
The sem_map operator is useful for performing row-wise operations over the data. The multi-sampling capability allows for generating diverse outputs for the same input, which can be useful for creative tasks, exploring multiple possibilities, or for understanding the variability in model outputs.

Example
----------
Expand All @@ -33,6 +33,10 @@ Example
user_instruction = "What is a similar course to {Course Name}. Be concise."
df = df.sem_map(user_instruction)
print(df)

# Example with multiple samples and temperature
df = df.sem_map(user_instruction, nsample=3, temp=0.7)
print(df)

Output:

Expand Down Expand Up @@ -60,3 +64,27 @@ Optional Parameters
- **suffix** : The suffix for the new columns. Defaults to "_map".
- **examples** : The examples dataframe. Defaults to None.
- **strategy** : The reasoning strategy. Defaults to None.
- **nsample** : Number of samples to generate per input. Defaults to 1.
- **temp** : Temperature for sampling. Higher values (e.g., 0.7, 1.0) increase randomness in the output, while lower values (e.g., 0.0, 0.1) make the output more deterministic. Defaults to None (using the model's default temperature).

Multiple Sample Output Structure
--------------------------------
When using ``nsample > 1``, the output dataframe will contain multiple columns, one for each sample:

- With ``nsample=3`` and ``suffix="_map"``, the output columns will be "_map1", "_map2", and "_map3"
- If ``return_explanations=True``, the explanation columns will be "explanation_map1", "explanation_map2", and "explanation_map3"
- If ``return_raw_outputs=True``, the raw output columns will be "raw_output_map1", "raw_output_map2", and "raw_output_map3"

Examples
--------

Basic example with multiple samples:

.. code-block:: python

# Generate 5 different responses per row, with increased randomness
df = df.sem_map("Summarize {article} in one sentence", nsample=5, temp=0.8)

# Access the different samples
print(df["_map1"]) # First sample for each row
print(df["_map2"]) # Second sample for each row
26 changes: 24 additions & 2 deletions examples/op_examples/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@
]
}
df = pd.DataFrame(data)

# Basic example - single output per row
print("\n===== Basic Example - Single Output =====")
user_instruction = "What is a similar course to {Course Name}. Be concise."
df = df.sem_map(user_instruction)
print(df)
df_basic = df.sem_map(user_instruction)
print(df_basic)

# Example with multiple samples - generate 3 alternatives per course
print("\n===== Multiple Samples Example =====")
user_instruction = "Suggest an alternative course to {Course Name}. Be creative."
df_multi = df.sem_map(
user_instruction,
nsample=3, # Generate 3 alternatives per course
temp=0.7, # Higher temperature for more varied outputs
)
print(df_multi)

# Example with temperature but single sample
print("\n===== Temperature Example (Higher Creativity) =====")
user_instruction = "If {Course Name} was a book title, what would it be called? Be creative."
df_temp = df.sem_map(
user_instruction,
temp=1.0, # High temperature for maximum creativity
)
print(df_temp)
A3E2 59 changes: 54 additions & 5 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,20 @@ def __call__(
if lotus.settings.enable_cache
else uncached_responses
)
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)
n = all_kwargs.get("n", 1)

if n <= 1:
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)
else:
outputs = [self._get_multiple_choices(resp, n) for resp in all_responses]
logprobs = (
[self._get_multiple_choices_logprobs(resp, n) for resp in all_responses]
if all_kwargs.get("logprobs")
else None
)

return LMOutput(outputs=outputs, logprobs=logprobs)

Expand Down Expand Up @@ -195,7 +205,9 @@ def _update_stats(self, response: ModelResponse, is_cached: bool = False):
except Exception as e:
# Handle any other unexpected errors when calculating cost
lotus.logger.debug(f"Unexpected error calculating completion cost: {e}")
warnings.warn("Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details.")
warnings.warn(
"Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details."
)

cost = None

Expand All @@ -215,12 +227,49 @@ def _get_top_choice(self, response: ModelResponse) -> str:
raise ValueError(f"No content in response: {response}")
return choice.message.content

def _get_multiple_choices(self, response: ModelResponse, n: int) -> list[str]:
"""Get multiple choices from a response, up to n choices."""
choices = []
available_choices = min(n, len(response.choices))

for i in range(available_choices):
choice = response.choices[i]
assert isinstance(choice, Choices)
if choice.message.content is None:
lotus.logger.warning(f"No content in choice {i} of response: {response}")
continue
choices.append(choice.message.content)

if available_choices < n:
lotus.logger.warning(f"Requested {n} samples but only got {available_choices}")

return choices

def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]:
choice = response.choices[0]
assert isinstance(choice, Choices)
logprobs = choice.logprobs["content"]
return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs]

def _get_multiple_choices_logprobs(self, response: ModelResponse, n: int) -> list[list[ChatCompletionTokenLogprob]]:
"""Get logprobs for multiple choices from a response, up to n choices."""
all_logprobs = []
available_choices = min(n, len(response.choices))

for i in range(available_choices):
choice = response.choices[i]
assert isinstance(choice, Choices)

if not hasattr(choice, "logprobs") or not choice.logprobs or "content" not in choice.logprobs:
lotus.logger.warning(f"No logprobs in choice {i} of response: {response}")
all_logprobs.append([])
continue

choice_logprobs = choice.logprobs["content"]
all_logprobs.append([ChatCompletionTokenLogprob(**logprob) for logprob in choice_logprobs])

return all_logprobs

def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade:
all_tokens = []
all_confidences = []
Expand Down
Loading
Loading
0