Skip to content

Commit

Permalink
Merge pull request #24 from Datawheel/demo
Browse files Browse the repository at this point in the history
merge changes in demo into main
  • Loading branch information
alebjanes authored Jul 19, 2024
2 parents 815e0cb + ea24060 commit 3a07305
Show file tree
Hide file tree
Showing 7 changed files with 826 additions and 73 deletions.
24 changes: 18 additions & 6 deletions api/setup/load_drilldowns_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
import requests
import urllib.parse

from config import POSTGRES_ENGINE, SCHEMA_DRILLDOWNS, DRILLDOWNS_TABLE_NAME, TESSERACT_API, TABLES_PATH
from config import POSTGRES_ENGINE, SCHEMA_DRILLDOWNS, TESSERACT_API, TABLES_PATH
from utils.similarity_search import embedding
from sqlalchemy import text as sql_text

def create_table(table_name=DRILLDOWNS_TABLE_NAME, schema_name=SCHEMA_DRILLDOWNS, embedding_size=384):
POSTGRES_ENGINE.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")
POSTGRES_ENGINE.execute(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (drilldown_id text, drilldown_name text, cube_name text, drilldown text, embedding vector({embedding_size}))")
embedding_model = "sfr-embedding-mistral:q8_0"
embedding_size = 4096
DRILLDOWNS_TABLE_NAME = "drilldowns_sfr"

def create_table(table_name=DRILLDOWNS_TABLE_NAME, schema_name=SCHEMA_DRILLDOWNS, embedding_size=embedding_size):
query_schema = f"CREATE SCHEMA IF NOT EXISTS {schema_name}"
query_table = f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (drilldown_id text, drilldown_name text, cube_name text, drilldown text, embedding vector({embedding_size}))"

with POSTGRES_ENGINE.connect() as conn:
conn.execute(sql_text(query_schema))
conn.execute(sql_text(query_table))
conn.commit()

def get_data_from_api(api_url):
try:
Expand All @@ -31,13 +41,15 @@ def prepare_dataframe(df, measure_name, cube_name, drilldown_name, drilldown_uni
df.dropna(subset=['drilldown_name', 'drilldown_id'], how='all', inplace=True)
df = df[['drilldown_id', 'drilldown_name', 'cube_name', 'drilldown']]
df['drilldown_name'] = df['drilldown_name'].astype(str)
df["embedding"] = ""
df['embedding'] = df['embedding'].astype(object)
print(df.head())
return df

def load_data_to_db(api_url, measure_name, cube_name, drilldown_name, drilldown_unique_name=None, schema_name=SCHEMA_DRILLDOWNS, db_table_name=DRILLDOWNS_TABLE_NAME):
df = get_data_from_api(api_url)
df = prepare_dataframe(df, measure_name, cube_name, drilldown_name, drilldown_unique_name)
df_embeddings = embedding(df, 'drilldown_name')
df_embeddings = embedding(df, 'drilldown_name', model = embedding_model)
df_embeddings.to_sql(db_table_name, con=POSTGRES_ENGINE, if_exists='append', index=False, schema=schema_name)

def main(include_cubes=False):
Expand Down Expand Up @@ -78,5 +90,5 @@ def main(include_cubes=False):
load_data_to_db(api_url, measure, cube_name, drilldown_name, drilldown_unique_name)

if __name__ == "__main__":
include_cubes = False # if set to False it will upload the drilldowns of all cubes in the schema.json
include_cubes = ['trade_i_baci_a_96'] # if set to False it will upload the drilldowns of all cubes in the schema.json
main(include_cubes)
10 changes: 6 additions & 4 deletions api/src/api_data_request/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(

if measures:
self.add_measure(measures)
else: self.measures = table.measures
else:
self.measures = table.measures

if cuts:
cuts_processing(cuts, table, self)
Expand Down Expand Up @@ -150,7 +151,8 @@ def build_api(self) -> str:
query_params.append(f"{key}={','.join(values)}")
if self.drilldowns:
query_params.append("drilldowns=" + ",".join(self.drilldowns))
else: query_params.append("drilldowns=Year")
else:
query_params.append("drilldowns=Year")

if self.measures:
query_params.append("measures=" + ",".join(self.measures))
Expand Down Expand Up @@ -269,5 +271,5 @@ def cuts_processing(cuts: List[str], table: Table, api: ApiBuilder):
api.add_drilldown(cut)
elif "HS" in cut:
api.add_drilldown(cut)

else: api.drilldowns.discard(cut)
else:
api.drilldowns.discard(cut)
29 changes: 14 additions & 15 deletions api/src/api_data_request/api_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,44 @@ def _get_api_components_messages(

response_part = """
{
"drilldowns": "",
"measures": "",
"filters": "",
"explanation": ""
"explanation": "Your explanation here",
"drilldowns": ["list", "of", "level", "names"],
"measures": ["list", "of", "measures"],
"filters": ["level = filtered_value"]
}
"""

if(model_author == "openai"):

message = f"""
You're a data scientist working with OLAP cubes. Given dimensions and measures in JSON format, identify the drilldowns, measures, and filters for querying the cube via API.
You're a data scientist working with OLAP cubes. Given dimensions and measures in JSON format, identify the appropriate drilldowns, measures, and filters for querying the cube via API.
**Dimensions:**
{table.prompt_get_dimensions()}
**Measures:**
{table.get_measures_description()}
Your response should be in JSON format with:
Your response should be in JSON format and include:
- "drilldowns": List of specific levels within each dimension for drilldowns (ONLY the level names).
- "measures": List of relevant measures.
- "filters": List of filters in 'level = filtered_value' format.
- "explanation": one to two sentence comment explaining why the chosen drilldowns and cuts are relevant goes here, double checking that the levels exist in the JSON given above.
- "explanation": A brief comment (one to two sentences) explaining why the chosen drilldowns and filters are relevant, ensuring the levels exist in the provided JSON.
- "drilldowns": A list of specific levels within each dimension for drilldowns (only the level names).
- "measures": A list of relevant measures.
- "filters": A list of filters in 'level = filtered_value' format.
Response format:
```
{response_part}
```
Provide only the required lists, and adhere to these rules:
Please adhere to these rules:
- Prioritize the HS4 level for products, but choose other levels if they are more appropriate.
- Apply filters only to the most relevant or granular level within the same parent dimension.
- For year or month ranges, specify each separately.
- Double check that the drilldowns and cuts contain ONLY the level names, and not the dimension.
- For filters, just write the general name, as it will be matched to its ID later on.
- Ensure that the drilldowns and filters contain only the level names, not the dimension names.
- For filters, use general names as they will be matched to their IDs later on.
"""

else:
Expand Down Expand Up @@ -211,9 +212,7 @@ def get_api_params_from_lm(
}

response = requests.post(url, json=payload)
print(response.text)
response = parse_response(response.text)
print(response)
params = extract_text_from_markdown_triple_backticks(response)
tokens = ""

Expand Down
21 changes: 11 additions & 10 deletions api/src/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import time

from os import getenv

from table_selection.table_selector import *
from table_selection.table import *
from api_data_request.api_generator import *
from data_analysis.data_analysis import *
from utils.logs import insert_logs
from utils.functions import set_to_string, clean_string
from typing import Dict, Generator, Tuple

from table_selection.table_selector import request_tables_to_lm_from_db
from table_selection.table import TableManager
from api_data_request.api_generator import get_api_params_from_lm
from api_data_request.api import ApiBuilder
from data_analysis.data_analysis import agent_answer
#from utils.logs import *
from utils.functions import clean_string, set_to_string
from config import TABLES_PATH


Expand All @@ -23,8 +24,8 @@ def get_api(
if token_tracker is None:
token_tracker = {}

if step == "request_tables_to_lm_from_db":
print("quest_tables_to_lm_from_db")
if step == 'request_tables_to_lm_from_db':
print("request_tables_to_lm_from_db")
start_time = time.time()
manager = TableManager(TABLES_PATH)
table, form_json, token_tracker = request_tables_to_lm_from_db(natural_language_query, manager, token_tracker)
Expand Down
45 changes: 24 additions & 21 deletions api/src/data_analysis/data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,36 @@ def agent_answer(
- An updated token_tracker dictionary with new token usage information.
"""
prompt = f"""
You are an expert data analyst working for the Observatory of Economic Complexity, whose goal is to give an answer, as accurate and complete as possible, to the following user's question using the given dataframe.
---------------\n
{natural_language_query}
\n---------------\n
Take into consideration the data type and formatting of the columns.
It's possible that any product/service or other variables the user is referring to appears with a different name in the dataframe. Explain this in your answer in a polite manner, but always trying to give an answer with the available data.
If you can't answer the question with the provided data, please answer with "I can't answer your question with the available data".
You are an expert data analyst working for the Observatory of Economic Complexity. Your goal is to provide an accurate and complete answer to the following user's question using the given dataframe.
You can complement your answer with any content found in the Observatory of Economic Complexity.
Notice that this dataframe was extracted with the following API (you can see the drilldowns, measures and cuts that have been applied to extract the data):
{api_url}
User's Question:
{natural_language_query}
Lets think it through step by step.
Avoid any further comments not related to the question itself.
Take into consideration the data type and formatting of the columns. If a product, service, or other variable referred to by the user appears under a different name in the dataframe, explain this politely and provide an answer using the available data.
If you cannot answer the question with the provided data, respond with "I can't answer your question with the available data."
You can complement your answer with any content found in the Observatory of Economic Complexity. Note that this dataframe was extracted using the following API (you can see the drilldowns, measures, and cuts applied to extract the data):
{api_url}
Guidelines:
1. Think through the answer step by step.
2. Avoid any comments unrelated to the question.
3. Always provide the corresponding trade value, and quantity if required.
4. All quantities are in metric tons, and trade value is in USD.
"""

simple_prompt = f"""
You are an expert data analyst working for the Observatory of Economic Complexity, whose goal is to
give an answer, as accurate and complete as possible, to the following user's question using the
given dataframe.
You are an expert data analyst working for the Observatory of Economic Complexity, whose goal is to
give an answer, as accurate and complete as possible, to the following user's question using the
given dataframe.
Here is the question:
{natural_language_query}
Here is the question:
{natural_language_query}
Take into consideration the data type and formatting of the columns.
It's possible that any product/service or other variables the user is referring to appears with a different name in the dataframe. Explain this in your answer in a polite manner, but always trying to give an answer with the available data.
If you can't answer the question with the provided data, please answer with "I can't answer your question with the available data".
Avoid any further comments not related to the question itself.
Take into consideration the data type and formatting of the columns.
It's possible that any product/service or other variables the user is referring to appears with a different name in the dataframe. Explain this in your answer in a polite manner, but always trying to give an answer with the available data.
If you can't answer the question with the provided data, please answer with "I can't answer your question with the available data".
Avoid any further comments not related to the question itself.
"""

llm = ChatOpenAI(model_name=model, temperature=0, openai_api_key=OPENAI_KEY, callbacks=[cb])
Expand Down
Loading

0 comments on commit 3a07305

Please sign in to comment.