Simplify LLMExtractor class
This commit is contained in:
18
.idea/workspace.xml
generated
18
.idea/workspace.xml
generated
@@ -27,16 +27,16 @@
|
||||
<option name="hideEmptyMiddlePackages" value="true" />
|
||||
<option name="showLibraryContents" value="true" />
|
||||
</component>
|
||||
<component name="PropertiesComponent">{
|
||||
"keyToString": {
|
||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||
"RunOnceActivity.git.unshallow": "true",
|
||||
"SHARE_PROJECT_CONFIGURATION_FILES": "true",
|
||||
"git-widget-placeholder": "main",
|
||||
"last_opened_file_path": "/Users/avimallu/Documents/SplitAI",
|
||||
"settings.editor.selected.configurable": "preferences.pluginManager"
|
||||
<component name="PropertiesComponent"><![CDATA[{
|
||||
"keyToString": {
|
||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||
"RunOnceActivity.git.unshallow": "true",
|
||||
"SHARE_PROJECT_CONFIGURATION_FILES": "true",
|
||||
"git-widget-placeholder": "refactoring",
|
||||
"last_opened_file_path": "/Users/avimallu/Documents/SplitAI",
|
||||
"settings.editor.selected.configurable": "preferences.pluginManager"
|
||||
}
|
||||
}</component>
|
||||
}]]></component>
|
||||
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
||||
<component name="TaskManager">
|
||||
<task active="true" id="Default" summary="Default task">
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Type, Literal, overload
|
||||
from typing import Any, Type, Literal, overload, Callable
|
||||
|
||||
import yaml
|
||||
from ollama import chat
|
||||
@@ -20,15 +20,17 @@ from src.llm.models import (
|
||||
|
||||
|
||||
class LLMExtractor:
|
||||
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml"):
|
||||
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml", chat_function: Callable = chat):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: A model name that has been downloaded by ``ollama``.
|
||||
prompt_path: A str or Path to a YAML file with various prompts.
|
||||
chat_function: Typically ollama.chat, input not required.
|
||||
"""
|
||||
self.model = model
|
||||
self.prompt_path = Path(prompt_path)
|
||||
self.chat_function = chat_function
|
||||
with open(self.prompt_path, "r") as f:
|
||||
self.prompts = yaml.safe_load(f)
|
||||
|
||||
@@ -46,7 +48,7 @@ class LLMExtractor:
|
||||
Returns:
|
||||
Depends on the model chosen, and if `structured_output_format` is provided.
|
||||
"""
|
||||
response = chat(
|
||||
response = self.chat_function(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
format=structured_output_format.model_json_schema(),
|
||||
@@ -54,7 +56,10 @@ class LLMExtractor:
|
||||
)
|
||||
message = response.message.content
|
||||
if structured_output_format is not None:
|
||||
return json.loads(message)
|
||||
try:
|
||||
return json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
else:
|
||||
return message
|
||||
|
||||
@@ -86,6 +91,7 @@ class LLMExtractor:
|
||||
"extract_receipt_tax",
|
||||
],
|
||||
key_to_extract: str | None = None,
|
||||
default_value: Any = None
|
||||
) -> float: ...
|
||||
@overload
|
||||
def extract_fields(
|
||||
@@ -94,6 +100,7 @@ class LLMExtractor:
|
||||
structured_response_model: Type[BaseModel],
|
||||
prompt_name: Literal["extract_receipt_items"],
|
||||
key_to_extract: str | None = None,
|
||||
default_value: Any = None,
|
||||
) -> list[ItemizedAmounts]: ...
|
||||
def extract_fields(
|
||||
self,
|
||||
@@ -108,6 +115,7 @@ class LLMExtractor:
|
||||
"extract_receipt_items",
|
||||
],
|
||||
key_to_extract: str | None = None,
|
||||
default_value: Any = None,
|
||||
) -> str | float | list[ItemizedAmounts] | Any:
|
||||
messages: list[Message] = [
|
||||
{
|
||||
@@ -118,6 +126,8 @@ class LLMExtractor:
|
||||
}
|
||||
]
|
||||
output = self.get_chat_response(messages, structured_response_model)
|
||||
if output is None:
|
||||
return default_value
|
||||
if key_to_extract is not None:
|
||||
output = output[key_to_extract]
|
||||
return output
|
||||
@@ -129,13 +139,14 @@ class LLMExtractor:
|
||||
structured_output_model,
|
||||
f"extract_{field}",
|
||||
key_to_extract,
|
||||
default_value,
|
||||
)
|
||||
for field, structured_output_model, key_to_extract in [
|
||||
("merchant", ReceiptMerchant, "name"),
|
||||
("receipt_date", ReceiptDate, "date"),
|
||||
("total", ReceiptAmount, "amount"),
|
||||
("tip", ReceiptAmount, "amount"),
|
||||
("tax", ReceiptAmount, "amount"),
|
||||
("receipt_items", ReceiptItemized, "ItemizedReceipt"),
|
||||
for field, structured_output_model, key_to_extract, default_value in [
|
||||
("merchant", ReceiptMerchant, "name", None),
|
||||
("receipt_date", ReceiptDate, "date", None),
|
||||
("total", ReceiptAmount, "amount", None),
|
||||
("tip", ReceiptAmount, "amount", None),
|
||||
("tax", ReceiptAmount, "amount", None),
|
||||
("receipt_items", ReceiptItemized, "ItemizedReceipt", []),
|
||||
]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user