Simplify LLMExtractor class

This commit is contained in:
Avinash Mallya
2025-01-08 18:10:33 -06:00
parent 1647df3bd2
commit 3ef02a14bb
2 changed files with 31 additions and 20 deletions

18
.idea/workspace.xml generated
View File

@@ -27,16 +27,16 @@
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">{
&quot;keyToString&quot;: {
&quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
&quot;RunOnceActivity.git.unshallow&quot;: &quot;true&quot;,
&quot;SHARE_PROJECT_CONFIGURATION_FILES&quot;: &quot;true&quot;,
&quot;git-widget-placeholder&quot;: &quot;main&quot;,
&quot;last_opened_file_path&quot;: &quot;/Users/avimallu/Documents/SplitAI&quot;,
&quot;settings.editor.selected.configurable&quot;: &quot;preferences.pluginManager&quot;
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"RunOnceActivity.ShowReadmeOnStart": "true",
"RunOnceActivity.git.unshallow": "true",
"SHARE_PROJECT_CONFIGURATION_FILES": "true",
"git-widget-placeholder": "refactoring",
"last_opened_file_path": "/Users/avimallu/Documents/SplitAI",
"settings.editor.selected.configurable": "preferences.pluginManager"
}
}</component>
}]]></component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">

View File

@@ -1,7 +1,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Type, Literal, overload
from typing import Any, Type, Literal, overload, Callable
import yaml
from ollama import chat
@@ -20,15 +20,17 @@ from src.llm.models import (
class LLMExtractor:
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml"):
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml", chat_function: Callable = chat):
"""
Args:
model: A model name that has been downloaded by ``ollama``.
prompt_path: A str or Path to a YAML file with various prompts.
chat_function: Typically ollama.chat, input not required.
"""
self.model = model
self.prompt_path = Path(prompt_path)
self.chat_function = chat_function
with open(self.prompt_path, "r") as f:
self.prompts = yaml.safe_load(f)
@@ -46,7 +48,7 @@ class LLMExtractor:
Returns:
Depends on the model chosen, and if `structured_output_format` is provided.
"""
response = chat(
response = self.chat_function(
model=self.model,
messages=messages,
format=structured_output_format.model_json_schema(),
@@ -54,7 +56,10 @@ class LLMExtractor:
)
message = response.message.content
if structured_output_format is not None:
return json.loads(message)
try:
return json.loads(message)
except json.JSONDecodeError:
return None
else:
return message
@@ -86,6 +91,7 @@ class LLMExtractor:
"extract_receipt_tax",
],
key_to_extract: str | None = None,
default_value: Any = None
) -> float: ...
@overload
def extract_fields(
@@ -94,6 +100,7 @@ class LLMExtractor:
structured_response_model: Type[BaseModel],
prompt_name: Literal["extract_receipt_items"],
key_to_extract: str | None = None,
default_value: Any = None,
) -> list[ItemizedAmounts]: ...
def extract_fields(
self,
@@ -108,6 +115,7 @@ class LLMExtractor:
"extract_receipt_items",
],
key_to_extract: str | None = None,
default_value: Any = None,
) -> str | float | list[ItemizedAmounts] | Any:
messages: list[Message] = [
{
@@ -118,6 +126,8 @@ class LLMExtractor:
}
]
output = self.get_chat_response(messages, structured_response_model)
if output is None:
return default_value
if key_to_extract is not None:
output = output[key_to_extract]
return output
@@ -129,13 +139,14 @@ class LLMExtractor:
structured_output_model,
f"extract_{field}",
key_to_extract,
default_value,
)
for field, structured_output_model, key_to_extract in [
("merchant", ReceiptMerchant, "name"),
("receipt_date", ReceiptDate, "date"),
("total", ReceiptAmount, "amount"),
("tip", ReceiptAmount, "amount"),
("tax", ReceiptAmount, "amount"),
("receipt_items", ReceiptItemized, "ItemizedReceipt"),
for field, structured_output_model, key_to_extract, default_value in [
("merchant", ReceiptMerchant, "name", None),
("receipt_date", ReceiptDate, "date", None),
("total", ReceiptAmount, "amount", None),
("tip", ReceiptAmount, "amount", None),
("tax", ReceiptAmount, "amount", None),
("receipt_items", ReceiptItemized, "ItemizedReceipt", []),
]
}