8000 [feature & bug fix] Update openai library and add support to gpt4 by pedMatias · Pull Request #3 · mljar/plotai · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[feature & bug fix] Update openai library and add support to gpt4 #3

8000 New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
8000
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ plot = PlotAI(df)
plot.make("make a scatter plot")
```

By default the library will use '*gpt-3.5-turbo*'. You can use different OpenAI models:

```python
# import PlotAI
from plotai import PlotAI

# create PlotAI object, pass pandas DataFrame as an argument
plot = PlotAI(df, model_version="gpt-4")

# make a plot, just tell what you want
plot.make("make a scatter plot")
```

## More examples

#### Analyze the GPD dataset
Expand Down
19 changes: 11 additions & 8 deletions plotai/llm/openai.py
< 10000 td class="blob-code blob-code-deletion js-file-line">
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import openai

from dotenv import load_dotenv

load_dotenv()


class ChatGPT():
class ChatGPT:

temperature = 0
max_tokens = 1000
Expand All @@ -14,14 +15,15 @@ class ChatGPT():
presence_penalty = 0.6
model = "gpt-3.5-turbo"

def __init__(self):
def __init__(self, model: str):
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
raise Exception("Please set OPENAI_API_KEY environment variable."
"You can obtain API key from https://platform.openai.com/account/api-keys")
raise Exception(
"Please set OPENAI_API_KEY environment variable."
"You can obtain API key from https://platform.openai.com/account/api-keys"
)
openai.api_key = api_key

self.model = model

@property
def _default_params(self):
Expand All @@ -35,6 +37,7 @@ def _default_params(self):
}

def chat(self, prompt):
client = openai.OpenAI()

params = {
**self._default_params,
Expand All @@ -45,5 +48,5 @@ def chat(self, prompt):
}
],
}
response = openai.ChatCompletion.create(**params)
return response["choices"][0]["message"]["content"]
response = client.chat.completions.create(**params)
return response.choices[0].message.content
15 changes: 8 additions & 7 deletions plotai/plotai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from plotai.code.executor import Executor
from plotai.code.logger import Logger


class PlotAI:

def __init__(self, *args, **kwargs):
def __init__(self, model_version: str = "gpt-3.5-turbo", *args, **kwargs):

# OpenAI Model Version
self.model_version = model_version
# DataFrame to plot
self.df, self.x, self.y, self.z = None, None, None, None
if len(args) > 1:
for i in range(len(args)):
Expand All @@ -34,13 +38,11 @@ def __init__(self, *args, **kwargs):
setattr(self, k, kwargs[k])

def make(self, prompt):

df, x, y, z = self.df, self.x, self.y, self.z
p = Prompt(prompt, self.df, self.x, self.y, self.z)
p = Prompt(prompt, self.df, self.x, self.y, self.z)

Logger().log({"title": "Prompt", "details": p.value})

response = ChatGPT().chat(p.value)
response = ChatGPT(model=self.model_version).chat(p.value)

Logger().log({"title": "Response", "details": response})

Expand All @@ -49,8 +51,7 @@ def make(self, prompt):
if error is not None:
Logger().log({"title": "Error in code execution", "details": error})


# p_again = Prompt(prompt, self.df, self.x, self.y, self.z, previous_code=response, previous_error=error)
# p_again = Prompt(prompt, self.df, self.x, self.y, self.z, previous_code=response, previous_error=error)

# Logger().log({"title": "Prompt with fix", "details": p_again.value})

Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
matplotlib
pandas
numpy
openai
python-dotenv
matplotlib~=3.8.3
pandas~=2.2.0
numpy~=1.26.4
openai~=1.12.0
python-dotenv~=1.0.1
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line nu 8000 mber Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="plotai",
version="0.0.2",
version="0.0.3",
description="Create plots in Python with AI",
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -21,7 +21,7 @@
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=open("requirements.txt").readlines(),
include_package_data=True,
python_requires='>=3.7.1',
python_requires=">=3.7.1",
classifiers=[
"Programming Language :: Python",
"Programming Language :: Python :: 3.7",
Expand All @@ -37,6 +37,6 @@
"matplotlib",
"llm",
"openai",
"mljar"
"mljar",
],
)
6 changes: 6 additions & 0 deletions tests/test_plotai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ def test_pass_data(self):
df2 = pd.DataFrame({"x":np.random.rand(rows), "y": np.random.rand(rows)})
plot = PlotAI(df=df2)
#plot.make("Plot a scatter plot")

def test_gpt4(self):
rows = 100
df2 = pd.DataFrame({"x":np.random.rand(rows), "y": np.random.rand(rows)})
plot = PlotAI(df=df2, model_version="gpt4")
#plot.make("Plot a scatter plot")
0