More tests, all pass. Ruff-ing everything up as well

This commit is contained in:
Avinash Mallya
2025-01-09 19:06:20 -06:00
parent 169c4f0c05
commit de6c341226
7 changed files with 100 additions and 49 deletions

View File

@@ -9,8 +9,8 @@ Don't know enough Python? Watch this space.
## Screenshots
<p float="left">
<img src="/assets/images/Screenshot_1.jpg" width=15%>
<img src="/assets/images/Screenshot_2.jpg" width=15%>
<img src="/assets/images/Screenshot_1.jpg" width=14%>
<img src="/assets/images/Screenshot_2.jpg" width=12.5%>
<img src="/assets/images/Screenshot_3.jpg" width=15%>
</p>
@@ -18,7 +18,9 @@ Don't know enough Python? Watch this space.
### Prerequisites
You will need [`ollama`](https://ollama.com/) installed and running with a model of your choice available. The default is `qwen2.5:7b`, while this is easily configurable you don't really need a more powerful model.
You will need [`ollama`](https://ollama.com/) installed and running with a model of your choice available.
The default is `tulu3:8b`, while this is easily configurable you don't really need a more powerful model. If you have the compute, then use `gemma2:27b` for higher accuracy. If you would like to choose, run `pytest` and choose the ones that pass enough of the extraction tests.
### Python virtual environment

View File

@@ -11,6 +11,7 @@ from src.app.utils import css_code, head_html, spinner_html
ComponentType = TypeVar("ComponentType")
def calculate_splits(
item_names: list[str],
item_people: list[list[str]],
@@ -22,7 +23,7 @@ def calculate_splits(
tip_split_proportionally: bool,
tax_split_proportionally: bool,
cashback_discount: float,
return_detailed_table: bool = False
return_detailed_table: bool = False,
) -> gr.DataFrame:
"""
A simple, but long function to calculate splits for a provided receipt.
@@ -84,13 +85,17 @@ def calculate_splits(
split_subtotals, split_tips, split_taxes
)
]
split_cashback = [
- x * cashback_discount for x in split_totals_pre_cashback
]
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
split_totals_post_cashback = [
x * (1 - cashback_discount) for x in split_totals_pre_cashback
]
first_col_names = list(item_names) + ["Subtotal", "Tip", "Tax", "Cashback", "Total"]
first_col_names = list(item_names) + [
"Subtotal",
"Tip",
"Tax",
"Cashback",
"Total",
]
splits = split_amounts + [
split_subtotals,
split_tips,
@@ -126,10 +131,11 @@ def calculate_splits(
return gr.DataFrame(full_calculation_df, visible=True)
else:
simple_calculation = (
full_calculation_df
.filter(pl.col("Item").eq("Total"))
full_calculation_df.filter(pl.col("Item").eq("Total"))
.select(pl.exclude("Total"))
.transpose(include_header=True, header_name="Person", column_names=["Split"])
.transpose(
include_header=True, header_name="Person", column_names=["Split"]
)
.filter(pl.col("Person").ne("Item"))
)
return gr.DataFrame(simple_calculation, visible=True)
@@ -264,7 +270,9 @@ class SplitAIApp:
return calculate_splits(**kwargs)
@staticmethod
def update_component_attributes(component: ComponentType, **kwargs) -> ComponentType:
def update_component_attributes(
component: ComponentType, **kwargs
) -> ComponentType:
"""
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
the component that you want to update with its attributes changed. It seems like it
@@ -401,7 +409,10 @@ class SplitAIApp:
interactive=True,
)
self.add_cashback_discount = gr.Number(
minimum=0, maximum=100, value=0, step=0.5,
minimum=0,
maximum=100,
value=0,
step=0.5,
label="Cashback discount to apply on total",
info="Choose a number between 0% and 100%.",
interactive=True,
@@ -444,7 +455,9 @@ class SplitAIApp:
outputs=self.display_result,
)
add_item_button = gr.Button("", variant="secondary", scale=1, min_width=10)
add_item_button = gr.Button(
"", variant="secondary", scale=1, min_width=10
)
def add_item(
items: list[Item],
@@ -593,9 +606,16 @@ class SplitAIApp:
else:
self.demo.queue().launch()
def arg_parser() -> agp.ArgumentParser:
ag = agp.ArgumentParser()
ag.add_argument("-m", "--model", type=str, default="qwen2.5:7b", help="Choose the LLM model used.")
ag.add_argument(
"-m",
"--model",
type=str,
default="qwen2.5:7b",
help="Choose the LLM model used.",
)
return ag

View File

@@ -49,4 +49,4 @@ spinner_html = """
">
<span class='loader'></span>
</div>
"""
"""

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import date
from pathlib import Path
from typing import Annotated, Literal, Optional, TypedDict
@@ -18,7 +18,7 @@ class ReceiptMerchant(BaseModel):
class ReceiptDate(BaseModel):
date: datetime
date: date
class ReceiptAmount(BaseModel):
@@ -48,7 +48,7 @@ class ItemizedAmounts(TypedDict):
class ReceiptExtracted(TypedDict):
merchant: str
receipt_date: datetime
receipt_date: date
total: Amount
tip: Amount
tax: Amount

View File

@@ -37,13 +37,17 @@ class SuryaOCR:
self.rec_processor = load_rec_processor()
@overload
def ocr_image(self, image: Image, return_format: Literal["polygon"]) -> list[Polygon]: ...
def ocr_image(
self, image: Image, return_format: Literal["polygon"]
) -> list[Polygon]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["text"]) -> list[str]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["bbox"]) -> list[BBox]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["confidence"]) -> list[BBox]: ...
def ocr_image(
self, image: Image, return_format: Literal["confidence"]
) -> list[BBox]: ...
@overload
def ocr_image(self, image: Image, return_format: None) -> SplitAIOCRResult: ...
def ocr_image(

View File

@@ -2,6 +2,6 @@ def pytest_addoption(parser):
parser.addoption(
"--ollama-model-name",
action="store",
default="qwen2.5:7b",
default="tulu3:8b",
help="Specify the ollama model name to use. It should exist on the system you are running this from.",
)

View File

@@ -3,6 +3,7 @@ from datetime import date
import ollama
import pytest
from pydantic import BaseModel
from typing import Type, Any
from src.llm.extractor import LLMExtractor
@@ -17,7 +18,6 @@ from src.llm.models import (
)
@pytest.fixture
def ollama_model_name(request):
"""Must be the name of an ollama model that has already been downloaded to the system."""
@@ -54,33 +54,58 @@ def test_ollama_structured_response(ollama_model_name):
assert ModelData.model_validate(output)
def test_llm_extractor(ollama_model_name):
@pytest.fixture
def llm_extractor(ollama_model_name) -> LLMExtractor:
return LLMExtractor(model=ollama_model_name)
@pytest.fixture
def receipt_string() -> str:
with open("assets/pytest/receipt_ocr_test.txt", "r") as f:
receipt_string = "".join(f.readlines())
return receipt_string
def receipt_test_cases():
return [
(ReceiptMerchant, "extract_merchant", "name", "walmart"),
(ReceiptDate, "extract_receipt_date", "date", "2017-07-28"),
(ReceiptAmount, "extract_total", "amount", 98.21),
(ReceiptAmount, "extract_tip", "amount", 0),
(ReceiptAmount, "extract_tax", "amount", 4.59),
]
@pytest.mark.parametrize("cls,prompt,field,expected", receipt_test_cases())
def test_extract_fields(
cls: Type[BaseModel],
prompt: str,
field: str,
expected: Any,
receipt_string,
llm_extractor,
ollama_model_name,
):
"""
Provides a slightly complex receipt to the LLM used to check if it is able to extract
basic information in it. If it doesn't pass this, then you may need a different model.
"""
llm_extractor = LLMExtractor(model=ollama_model_name)
with open("data/pytest/receipt_ocr_test.txt", "r") as f:
receipt_ocr_text = "".join(f.readlines())
assert llm_extractor.extract_fields("merchant", ReceiptMerchant, "name", None) == "Walmart"
# assert llm_extractor.extract_merchant_name(receipt_ocr_text) == "Walmart"
# assert llm_extractor.extract_receipt_date(receipt_ocr_text) == date.fromisoformat(
# "2017-07-28"
# )
# assert llm_extractor.extract_receipt_total_amount(receipt_ocr_text) == {
# "currency": "USD",
# "amount": 98.21,
# }
# assert llm_extractor.extract_receipt_tip_amount(receipt_ocr_text) == {
# "currency": "USD",
# "amount": 0,
# }
# assert llm_extractor.extract_receipt_tax_amount(receipt_ocr_text) == {
# "currency": "USD",
# "amount": 4.59,
# }
# items = llm_extractor.extract_receipt_items(receipt_ocr_text)
# # The exact items in these cannot be always identical as there will be some inherent variance in LLM output.
# # Thus, we check if a few control totals match, at least approximately.
# assert len(items) == 26
# assert round(sum(x["amount"] for x in items), 0) == 94
result = llm_extractor.extract_fields(receipt_string, cls, prompt, field, None)
if isinstance(expected, str):
result = result.lower()
assert result == expected
def test_extract_items(llm_extractor, receipt_string):
"""
Provides a slightly complex receipt to the LLM used to check if it is able to extract
the list of items in it. If it doesn't pass this, then you may need a different model.
The exact items cannot be always identical as there will be some inherent variance in LLM output.
Therefore, check if a few control totals match, at least approximately.
"""
items = llm_extractor.extract_fields(
receipt_string, ReceiptItemized, "extract_receipt_items", "ItemizedReceipt", []
)
assert 22 <= len(items) <= 26
assert 85 <= round(sum(x["amount"] for x in items), 0) <= 95