More tests, all pass. Ruff-ing everything up as well
This commit is contained in:
@@ -9,8 +9,8 @@ Don't know enough Python? Watch this space.
|
||||
## Screenshots
|
||||
|
||||
<p float="left">
|
||||
<img src="/assets/images/Screenshot_1.jpg" width=15%>
|
||||
<img src="/assets/images/Screenshot_2.jpg" width=15%>
|
||||
<img src="/assets/images/Screenshot_1.jpg" width=14%>
|
||||
<img src="/assets/images/Screenshot_2.jpg" width=12.5%>
|
||||
<img src="/assets/images/Screenshot_3.jpg" width=15%>
|
||||
</p>
|
||||
|
||||
@@ -18,7 +18,9 @@ Don't know enough Python? Watch this space.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
You will need [`ollama`](https://ollama.com/) installed and running with a model of your choice available. The default is `qwen2.5:7b`, while this is easily configurable you don't really need a more powerful model.
|
||||
You will need [`ollama`](https://ollama.com/) installed and running with a model of your choice available.
|
||||
|
||||
The default is `tulu3:8b`, while this is easily configurable you don't really need a more powerful model. If you have the compute, then use `gemma2:27b` for higher accuracy. If you would like to choose, run `pytest` and choose the ones that pass enough of the extraction tests.
|
||||
|
||||
### Python virtual environment
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.app.utils import css_code, head_html, spinner_html
|
||||
|
||||
ComponentType = TypeVar("ComponentType")
|
||||
|
||||
|
||||
def calculate_splits(
|
||||
item_names: list[str],
|
||||
item_people: list[list[str]],
|
||||
@@ -22,7 +23,7 @@ def calculate_splits(
|
||||
tip_split_proportionally: bool,
|
||||
tax_split_proportionally: bool,
|
||||
cashback_discount: float,
|
||||
return_detailed_table: bool = False
|
||||
return_detailed_table: bool = False,
|
||||
) -> gr.DataFrame:
|
||||
"""
|
||||
A simple, but long function to calculate splits for a provided receipt.
|
||||
@@ -84,13 +85,17 @@ def calculate_splits(
|
||||
split_subtotals, split_tips, split_taxes
|
||||
)
|
||||
]
|
||||
split_cashback = [
|
||||
- x * cashback_discount for x in split_totals_pre_cashback
|
||||
]
|
||||
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
|
||||
split_totals_post_cashback = [
|
||||
x * (1 - cashback_discount) for x in split_totals_pre_cashback
|
||||
]
|
||||
first_col_names = list(item_names) + ["Subtotal", "Tip", "Tax", "Cashback", "Total"]
|
||||
first_col_names = list(item_names) + [
|
||||
"Subtotal",
|
||||
"Tip",
|
||||
"Tax",
|
||||
"Cashback",
|
||||
"Total",
|
||||
]
|
||||
splits = split_amounts + [
|
||||
split_subtotals,
|
||||
split_tips,
|
||||
@@ -126,10 +131,11 @@ def calculate_splits(
|
||||
return gr.DataFrame(full_calculation_df, visible=True)
|
||||
else:
|
||||
simple_calculation = (
|
||||
full_calculation_df
|
||||
.filter(pl.col("Item").eq("Total"))
|
||||
full_calculation_df.filter(pl.col("Item").eq("Total"))
|
||||
.select(pl.exclude("Total"))
|
||||
.transpose(include_header=True, header_name="Person", column_names=["Split"])
|
||||
.transpose(
|
||||
include_header=True, header_name="Person", column_names=["Split"]
|
||||
)
|
||||
.filter(pl.col("Person").ne("Item"))
|
||||
)
|
||||
return gr.DataFrame(simple_calculation, visible=True)
|
||||
@@ -264,7 +270,9 @@ class SplitAIApp:
|
||||
return calculate_splits(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def update_component_attributes(component: ComponentType, **kwargs) -> ComponentType:
|
||||
def update_component_attributes(
|
||||
component: ComponentType, **kwargs
|
||||
) -> ComponentType:
|
||||
"""
|
||||
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
|
||||
the component that you want to update with its attributes changed. It seems like it
|
||||
@@ -401,7 +409,10 @@ class SplitAIApp:
|
||||
interactive=True,
|
||||
)
|
||||
self.add_cashback_discount = gr.Number(
|
||||
minimum=0, maximum=100, value=0, step=0.5,
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
value=0,
|
||||
step=0.5,
|
||||
label="Cashback discount to apply on total",
|
||||
info="Choose a number between 0% and 100%.",
|
||||
interactive=True,
|
||||
@@ -444,7 +455,9 @@ class SplitAIApp:
|
||||
outputs=self.display_result,
|
||||
)
|
||||
|
||||
add_item_button = gr.Button("➕", variant="secondary", scale=1, min_width=10)
|
||||
add_item_button = gr.Button(
|
||||
"➕", variant="secondary", scale=1, min_width=10
|
||||
)
|
||||
|
||||
def add_item(
|
||||
items: list[Item],
|
||||
@@ -593,9 +606,16 @@ class SplitAIApp:
|
||||
else:
|
||||
self.demo.queue().launch()
|
||||
|
||||
|
||||
def arg_parser() -> agp.ArgumentParser:
|
||||
ag = agp.ArgumentParser()
|
||||
ag.add_argument("-m", "--model", type=str, default="qwen2.5:7b", help="Choose the LLM model used.")
|
||||
ag.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
default="qwen2.5:7b",
|
||||
help="Choose the LLM model used.",
|
||||
)
|
||||
return ag
|
||||
|
||||
|
||||
|
||||
@@ -49,4 +49,4 @@ spinner_html = """
|
||||
">
|
||||
<span class='loader'></span>
|
||||
</div>
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, Optional, TypedDict
|
||||
|
||||
@@ -18,7 +18,7 @@ class ReceiptMerchant(BaseModel):
|
||||
|
||||
|
||||
class ReceiptDate(BaseModel):
|
||||
date: datetime
|
||||
date: date
|
||||
|
||||
|
||||
class ReceiptAmount(BaseModel):
|
||||
@@ -48,7 +48,7 @@ class ItemizedAmounts(TypedDict):
|
||||
|
||||
class ReceiptExtracted(TypedDict):
|
||||
merchant: str
|
||||
receipt_date: datetime
|
||||
receipt_date: date
|
||||
total: Amount
|
||||
tip: Amount
|
||||
tax: Amount
|
||||
|
||||
@@ -37,13 +37,17 @@ class SuryaOCR:
|
||||
self.rec_processor = load_rec_processor()
|
||||
|
||||
@overload
|
||||
def ocr_image(self, image: Image, return_format: Literal["polygon"]) -> list[Polygon]: ...
|
||||
def ocr_image(
|
||||
self, image: Image, return_format: Literal["polygon"]
|
||||
) -> list[Polygon]: ...
|
||||
@overload
|
||||
def ocr_image(self, image: Image, return_format: Literal["text"]) -> list[str]: ...
|
||||
@overload
|
||||
def ocr_image(self, image: Image, return_format: Literal["bbox"]) -> list[BBox]: ...
|
||||
@overload
|
||||
def ocr_image(self, image: Image, return_format: Literal["confidence"]) -> list[BBox]: ...
|
||||
def ocr_image(
|
||||
self, image: Image, return_format: Literal["confidence"]
|
||||
) -> list[BBox]: ...
|
||||
@overload
|
||||
def ocr_image(self, image: Image, return_format: None) -> SplitAIOCRResult: ...
|
||||
def ocr_image(
|
||||
|
||||
@@ -2,6 +2,6 @@ def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--ollama-model-name",
|
||||
action="store",
|
||||
default="qwen2.5:7b",
|
||||
default="tulu3:8b",
|
||||
help="Specify the ollama model name to use. It should exist on the system you are running this from.",
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import date
|
||||
import ollama
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Any
|
||||
|
||||
from src.llm.extractor import LLMExtractor
|
||||
|
||||
@@ -17,7 +18,6 @@ from src.llm.models import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_model_name(request):
|
||||
"""Must be the name of an ollama model that has already been downloaded to the system."""
|
||||
@@ -54,33 +54,58 @@ def test_ollama_structured_response(ollama_model_name):
|
||||
assert ModelData.model_validate(output)
|
||||
|
||||
|
||||
def test_llm_extractor(ollama_model_name):
|
||||
@pytest.fixture
|
||||
def llm_extractor(ollama_model_name) -> LLMExtractor:
|
||||
return LLMExtractor(model=ollama_model_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def receipt_string() -> str:
|
||||
with open("assets/pytest/receipt_ocr_test.txt", "r") as f:
|
||||
receipt_string = "".join(f.readlines())
|
||||
return receipt_string
|
||||
|
||||
|
||||
def receipt_test_cases():
|
||||
return [
|
||||
(ReceiptMerchant, "extract_merchant", "name", "walmart"),
|
||||
(ReceiptDate, "extract_receipt_date", "date", "2017-07-28"),
|
||||
(ReceiptAmount, "extract_total", "amount", 98.21),
|
||||
(ReceiptAmount, "extract_tip", "amount", 0),
|
||||
(ReceiptAmount, "extract_tax", "amount", 4.59),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls,prompt,field,expected", receipt_test_cases())
|
||||
def test_extract_fields(
|
||||
cls: Type[BaseModel],
|
||||
prompt: str,
|
||||
field: str,
|
||||
expected: Any,
|
||||
receipt_string,
|
||||
llm_extractor,
|
||||
ollama_model_name,
|
||||
):
|
||||
"""
|
||||
Provides a slightly complex receipt to the LLM used to check if it is able to extract
|
||||
basic information in it. If it doesn't pass this, then you may need a different model.
|
||||
"""
|
||||
llm_extractor = LLMExtractor(model=ollama_model_name)
|
||||
with open("data/pytest/receipt_ocr_test.txt", "r") as f:
|
||||
receipt_ocr_text = "".join(f.readlines())
|
||||
assert llm_extractor.extract_fields("merchant", ReceiptMerchant, "name", None) == "Walmart"
|
||||
# assert llm_extractor.extract_merchant_name(receipt_ocr_text) == "Walmart"
|
||||
# assert llm_extractor.extract_receipt_date(receipt_ocr_text) == date.fromisoformat(
|
||||
# "2017-07-28"
|
||||
# )
|
||||
# assert llm_extractor.extract_receipt_total_amount(receipt_ocr_text) == {
|
||||
# "currency": "USD",
|
||||
# "amount": 98.21,
|
||||
# }
|
||||
# assert llm_extractor.extract_receipt_tip_amount(receipt_ocr_text) == {
|
||||
# "currency": "USD",
|
||||
# "amount": 0,
|
||||
# }
|
||||
# assert llm_extractor.extract_receipt_tax_amount(receipt_ocr_text) == {
|
||||
# "currency": "USD",
|
||||
# "amount": 4.59,
|
||||
# }
|
||||
# items = llm_extractor.extract_receipt_items(receipt_ocr_text)
|
||||
# # The exact items in these cannot be always identical as there will be some inherent variance in LLM output.
|
||||
# # Thus, we check if a few control totals match, at least approximately.
|
||||
# assert len(items) == 26
|
||||
# assert round(sum(x["amount"] for x in items), 0) == 94
|
||||
result = llm_extractor.extract_fields(receipt_string, cls, prompt, field, None)
|
||||
if isinstance(expected, str):
|
||||
result = result.lower()
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_extract_items(llm_extractor, receipt_string):
|
||||
"""
|
||||
Provides a slightly complex receipt to the LLM used to check if it is able to extract
|
||||
the list of items in it. If it doesn't pass this, then you may need a different model.
|
||||
|
||||
The exact items cannot be always identical as there will be some inherent variance in LLM output.
|
||||
Therefore, check if a few control totals match, at least approximately.
|
||||
"""
|
||||
items = llm_extractor.extract_fields(
|
||||
receipt_string, ReceiptItemized, "extract_receipt_items", "ItemizedReceipt", []
|
||||
)
|
||||
assert 22 <= len(items) <= 26
|
||||
assert 85 <= round(sum(x["amount"] for x in items), 0) <= 95
|
||||
|
||||
Reference in New Issue
Block a user