8000 SIM-2799: Receive the model name in the ActionsRequest and override t… by fablechris · Pull Request #35 · fablestudio/fable-saga · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

SIM-2799: Receive the model name in the ActionsRequest and override t… #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demos/space_colony/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def tick(self, delta: timedelta):
for agent in self.agents.values():
self.actionsQueue.put_nowait(agent.tick(delta, self))

async def generate_actions(self, sim_agent: SimAgent, retries=0, verbose=False) -> [List[Dict[str, Any]]]:
async def generate_actions(self, sim_agent: SimAgent, retries=0, verbose=False, model_override: Optional[str] = None) -> [List[Dict[str, Any]]]:
"""Generate actions for this agent using the SAGA agent."""

print(f"Generating actions for {sim_agent.persona.id()} ...")
Expand All @@ -150,7 +150,7 @@ async def generate_actions(self, sim_agent: SimAgent, retries=0, verbose=False)
context += f"Your location is {sim_agent.location.id()}.\n"

return await self.saga_agent.generate_actions(context, sim_agent.skills,
max_tries=retries, verbose=verbose)
max_tries=retries, verbose=verbose, model_override=model_override)


class Format:
Expand Down
7 changes: 4 additions & 3 deletions fable_saga/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def __init__(self, llm: BaseChatModel = None):
path = pathlib.Path(__file__).parent.resolve()
self.prompt = load_prompt(path / "prompt_templates/generate_actions.yaml")

def chain(self) -> LLMChain:
def chain(self, model_override: Optional[str] = None) -> LLMChain:
self._llm.model_name = model_override if model_override else default_openai_model_name
return LLMChain(llm=self._llm, prompt=self.prompt)

async def generate_actions(self, context: str, skills: List[Skill], max_tries=0, verbose=False) -> GeneratedActions:
async def generate_actions(self, context: str, skills: List[Skill], max_tries=0, verbose=False, model_override: Optional[str] = None) -> GeneratedActions:
"""Generate actions for the given context and skills."""
assert context is not None and len(context) > 0, "Must provide a context."
assert skills is not None and len(skills) > 0, "Must provide at least one skill."
Expand All @@ -114,7 +115,7 @@ async def generate_actions(self, context: str, skills: List[Skill], max_tries=0,
assert skill.name is not None and len(skill.name) > 0, "Must provide a skill name."
assert skill.description is not None and len(skill.description) > 0, "Must provide a skill description."

chain = self.chain()
chain = self.chain(model_override)
chain.verbose = verbose

# Set up the callback handler.
Expand Down
3 changes: 2 additions & 1 deletion fable_saga/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ActionsRequest:
retries: int = 0
verbose: bool = False
reference: Optional[str] = None
model: Optional[str] = None


@define(slots=True)
Expand All @@ -66,7 +67,7 @@ async def generate_actions(self, req: ActionsRequest) -> ActionsResponse:
# Generate actions
try:
assert isinstance(req, ActionsRequest), f"Invalid request type: {type(req)}"
actions = await self.agent.generate_actions(req.context, req.skills, req.retries, req.verbose)
actions = await self.agent.generate_actions(req.context, req.skills, req.retries, req.verbose, req.model)
response = ActionsResponse(actions=actions, reference=req.reference)
if actions.error is not None:
response.error = f"Generation Error: {actions.error}"
Expand Down
270 changes: 260 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ packages = [{include = "fable_saga"}]

[tool.poetry.dependencies]
python = "^3.10"
python-socketio = "^5.9.0"
cattrs = "^23.1.2"
aiohttp = "^3.8.5"
langchain = {extras = ["llms"], version = "^0.0.293"}
python-dateutil = "^2.8.2"
tiktoken = "^0.5.1"
datetime = "^5.3"


[tool.poetry.group.test]
Expand Down
1 change: 1 addition & 0 deletions tests/examples/request.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"reference": "asdfasfasdf",
"context": "You are a mouse",
"model": "test_model",
"skills": [
{
"name": "goto",
Expand Down
25 changes: 22 additions & 3 deletions tests/test_saga.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from langchain.chat_models.fake import FakeListChatModel

import fable_saga
from fable_saga.server import SagaServer, ActionsRequest, ActionsResponse
from fable_saga.server import SagaServer, ActionsRequest


class FakeChatOpenAI(FakeListChatModel):
model_name: str = 'model_name_default'


@pytest.fixture
def fake_llm():
actions = json.load(open("examples/generated_actions.json"))
responses = [json.dumps(action) for action in actions]
llm = FakeListChatModel(responses=responses, sleep=0.1)
llm = FakeChatOpenAI(responses=responses, sleep=0.1)
return llm


Expand Down Expand Up @@ -48,7 +52,15 @@ async def test_generate_actions(self, fake_llm, fake_skills):

# fake_llm.callbacks = [callback_handler]
agent = fable_saga.Agent(fake_llm)
actions = await agent.generate_actions("context", fake_skills)

# Should be using the default model
test_model = 'test_model'
assert fake_llm.model_name != test_model

actions = await agent.generate_actions("context", fake_skills, model_override=test_model)

# Should be using the test model
assert fake_llm.model_name == test_model

# In our test data, we assume 2 actions are generated and are pre-sorted by score.
assert len(actions.options) == 2
Expand Down Expand Up @@ -90,9 +102,16 @@ def test_init(self, fake_llm):

@pytest.mark.asyncio
async def test_generate_actions(self, fake_llm, fake_request):

# Should be using the default model
assert fake_llm.model_name != fake_request.model

server = SagaServer(llm=fake_llm)
response = await server.generate_actions(fake_request)

# Should be using the test model
assert fake_llm.model_name == fake_request.model

# The response is a valid ActionsResponse
# Note: we don't throw exceptions in the server, but return the error in the response.
assert response.error is None
Expand Down
0