Merge pull request #2 from avimallu/fix_and_create_tests

Fix failing tests and create new ones
This commit is contained in:
avimallu
2025-01-09 19:07:21 -06:00
committed by GitHub
12 changed files with 234 additions and 78 deletions

View File

@@ -9,8 +9,8 @@ Don't know enough Python? Watch this space.
## Screenshots
<p float="left">
<img src="/assets/images/Screenshot_1.jpg" width=15%>
<img src="/assets/images/Screenshot_2.jpg" width=15%>
<img src="/assets/images/Screenshot_1.jpg" width=14%>
<img src="/assets/images/Screenshot_2.jpg" width=12.5%>
<img src="/assets/images/Screenshot_3.jpg" width=15%>
</p>
@@ -18,7 +18,9 @@ Don't know enough Python? Watch this space.
### Prerequisites
You will need [`ollama`](https://ollama.com/) installed and running with a model of your choice available. The default is `qwen2.5:7b`, while this is easily configurable you don't really need a more powerful model.
You will need [`ollama`](https://ollama.com/) installed and running with a model of your choice available.
The default is `tulu3:8b`, while this is easily configurable you don't really need a more powerful model. If you have the compute, then use `gemma2:27b` for higher accuracy. If you would like to choose, run `pytest` and choose the ones that pass enough of the extraction tests.
### Python virtual environment

BIN
assets/pytest/OCR_test.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

View File

@@ -0,0 +1,54 @@
ave money. Live better.
( 330 ) 339 - 3991
MANAGER DIANA EARNEST
231 BLUEBELL DR SW
NEW PHILADELPHIA OH 44663
ST# 02115 OP# 009044 TE# 44 TR# 01301
PET TOY 004747571658 1.97 X
FLOPPY PUPPY 004747514846 1.97 X
SSSUPREME S 070060332153 4.97 X
2 . 5 SQUEAK 084699803238 5.92 X
MUNCHY DMBEL 068113108796 3.77 X
DOG TREAT 007119013654 2.92 x
PED PCH 1 002310011802 0.50 X
PED PCH T 002310011802 0.50
COUPON 23100 052310037000 1.00-0
HNYMD SMORES 088491226837 F 3.98
FRENCH DRSNG 004132100655 F 1 . 98
3 ORANGES 001466835001 F 5.47
BABY CARROTS 003338366602 I 1 . 48
COLLARDS 0000000004614KI 1.24 N
CALZONE 005208362080 F 2 . 50 O
MM RVW MNT 003399105848 19.77 X
STKOBRLPLABL 001558679414 1.97 x
STKOBRLPLABL 001558679414 1.97
STKO SUNFLWR 001558679410 0.97
STKO SUNFLWR 001558679410 0.97
STKO SUNFLWR 001558679410 0.97
STKO SUNFLWR 001558679410 0.97 X
BLING BEADS 076594060699 0.97
GREAT VALUE 007874203191 F 9.97
LIPTON 001200011224 F 4 . 48
DRY DOG 002310011035 12.44
SUBTOTAL 93.62
TAX l 6.750 % 4.59
TOTAL 98.21
VISA TEND 98.21
US DEBIT * * * * * * * * * * * * a I I 0
APPROVAL # 572868
REF # 720900544961
TRANS ID - 387209239650894
VALIDATION - 87HS
PAYMENT SERVICE - E
AID A0000000980840
TC 51319CA81DC22BC7
TERMINAL # SC010764
*Signature Verified
07/28/17 02 : 39 : 48
CHANGE DUE 0.00
ITEMS SOLD 2 5
0223 1059 8001 e I 4 C
<EFBFBD><EFBFBD><EFBFBD><EFBFBD> <20><><EFBFBD><EFBFBD><EFBFBD> <20> <20> <20><><EFBFBD><EFBFBD><EFBFBD> <20><><EFBFBD><EFBFBD> <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD> <20><> <20><> <20> <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
Prices You Can Trust. Every
07/28/17 02 : 39 : 48
* CUSTOMER COPY ***

View File

@@ -11,6 +11,7 @@ from src.app.utils import css_code, head_html, spinner_html
ComponentType = TypeVar("ComponentType")
def calculate_splits(
item_names: list[str],
item_people: list[list[str]],
@@ -22,7 +23,7 @@ def calculate_splits(
tip_split_proportionally: bool,
tax_split_proportionally: bool,
cashback_discount: float,
return_detailed_table: bool = False
return_detailed_table: bool = False,
) -> gr.DataFrame:
"""
A simple, but long function to calculate splits for a provided receipt.
@@ -84,13 +85,17 @@ def calculate_splits(
split_subtotals, split_tips, split_taxes
)
]
split_cashback = [
- x * cashback_discount for x in split_totals_pre_cashback
]
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
split_totals_post_cashback = [
x * (1 - cashback_discount) for x in split_totals_pre_cashback
]
first_col_names = list(item_names) + ["Subtotal", "Tip", "Tax", "Cashback", "Total"]
first_col_names = list(item_names) + [
"Subtotal",
"Tip",
"Tax",
"Cashback",
"Total",
]
splits = split_amounts + [
split_subtotals,
split_tips,
@@ -126,10 +131,11 @@ def calculate_splits(
return gr.DataFrame(full_calculation_df, visible=True)
else:
simple_calculation = (
full_calculation_df
.filter(pl.col("Item").eq("Total"))
full_calculation_df.filter(pl.col("Item").eq("Total"))
.select(pl.exclude("Total"))
.transpose(include_header=True, header_name="Person", column_names=["Split"])
.transpose(
include_header=True, header_name="Person", column_names=["Split"]
)
.filter(pl.col("Person").ne("Item"))
)
return gr.DataFrame(simple_calculation, visible=True)
@@ -264,7 +270,9 @@ class SplitAIApp:
return calculate_splits(**kwargs)
@staticmethod
def update_component_attributes(component: ComponentType, **kwargs) -> ComponentType:
def update_component_attributes(
component: ComponentType, **kwargs
) -> ComponentType:
"""
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
the component that you want to update with its attributes changed. It seems like it
@@ -401,7 +409,10 @@ class SplitAIApp:
interactive=True,
)
self.add_cashback_discount = gr.Number(
minimum=0, maximum=100, value=0, step=0.5,
minimum=0,
maximum=100,
value=0,
step=0.5,
label="Cashback discount to apply on total",
info="Choose a number between 0% and 100%.",
interactive=True,
@@ -444,7 +455,9 @@ class SplitAIApp:
outputs=self.display_result,
)
add_item_button = gr.Button("", variant="secondary", scale=1, min_width=10)
add_item_button = gr.Button(
"", variant="secondary", scale=1, min_width=10
)
def add_item(
items: list[Item],
@@ -593,9 +606,16 @@ class SplitAIApp:
else:
self.demo.queue().launch()
def arg_parser() -> agp.ArgumentParser:
ag = agp.ArgumentParser()
ag.add_argument("-m", "--model", type=str, default="qwen2.5:7b", help="Choose the LLM model used.")
ag.add_argument(
"-m",
"--model",
type=str,
default="qwen2.5:7b",
help="Choose the LLM model used.",
)
return ag

View File

@@ -49,4 +49,4 @@ spinner_html = """
">
<span class='loader'></span>
</div>
"""
"""

View File

@@ -1,26 +1,29 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Type, Literal, overload, Callable
from typing import Any, Callable, Literal, Type, overload
import yaml
from ollama import chat
from pydantic import BaseModel
from src.llm.models import (
ItemizedAmounts,
Message,
ReceiptAmount,
Amount,
ReceiptDate,
ReceiptExtracted,
ReceiptItemized,
ItemizedAmounts,
ReceiptMerchant,
)
class LLMExtractor:
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml", chat_function: Callable = chat):
def __init__(
self,
model: str,
prompt_path: str | Path = "./src/llm/prompts.yaml",
chat_function: Callable = chat,
):
"""
Args:
@@ -51,7 +54,9 @@ class LLMExtractor:
response = self.chat_function(
model=self.model,
messages=messages,
format=structured_output_format.model_json_schema(),
format=structured_output_format.model_json_schema()
if structured_output_format is not None
else None,
options={"temperature": 0},
)
message = response.message.content
@@ -75,8 +80,8 @@ class LLMExtractor:
receipt_str: str,
structured_response_model: Type[BaseModel],
prompt_name: Literal[
"extract_merchant",
"extract_receipt_date",
"merchant",
"receipt_date",
],
key_to_extract: str | None = None,
) -> str: ...
@@ -86,19 +91,19 @@ class LLMExtractor:
receipt_str: str,
structured_response_model: Type[BaseModel],
prompt_name: Literal[
"extract_receipt_total",
"extract_receipt_tip",
"extract_receipt_tax",
"total",
"tip",
"tax",
],
key_to_extract: str | None = None,
default_value: Any = None
default_value: Any = None,
) -> float: ...
@overload
def extract_fields(
self,
receipt_str: str,
structured_response_model: Type[BaseModel],
prompt_name: Literal["extract_receipt_items"],
prompt_name: Literal["receipt_items"],
key_to_extract: str | None = None,
default_value: Any = None,
) -> list[ItemizedAmounts]: ...
@@ -107,12 +112,12 @@ class LLMExtractor:
receipt_str: str,
structured_response_model: Type[BaseModel],
prompt_name: Literal[
"extract_merchant",
"extract_receipt_date",
"extract_receipt_total",
"extract_receipt_tip",
"extract_receipt_tax",
"extract_receipt_items",
"merchant",
"receipt_date",
"total",
"tip",
"tax",
"receipt_items",
],
key_to_extract: str | None = None,
default_value: Any = None,

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import date
from pathlib import Path
from typing import Annotated, Literal, Optional, TypedDict
@@ -18,7 +18,7 @@ class ReceiptMerchant(BaseModel):
class ReceiptDate(BaseModel):
date: datetime
date: date
class ReceiptAmount(BaseModel):
@@ -48,7 +48,7 @@ class ItemizedAmounts(TypedDict):
class ReceiptExtracted(TypedDict):
merchant: str
receipt_date: datetime
receipt_date: date
total: Amount
tip: Amount
tax: Amount

View File

@@ -37,13 +37,17 @@ class SuryaOCR:
self.rec_processor = load_rec_processor()
@overload
def ocr_image(self, image: Image, return_format: Literal["polygon"]) -> list[Polygon]: ...
def ocr_image(
self, image: Image, return_format: Literal["polygon"]
) -> list[Polygon]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["text"]) -> list[str]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["bbox"]) -> list[BBox]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["confidence"]) -> list[BBox]: ...
def ocr_image(
self, image: Image, return_format: Literal["confidence"]
) -> list[BBox]: ...
@overload
def ocr_image(self, image: Image, return_format: None) -> SplitAIOCRResult: ...
def ocr_image(

7
tests/conftest.py Normal file
View File

@@ -0,0 +1,7 @@
def pytest_addoption(parser):
parser.addoption(
"--ollama-model-name",
action="store",
default="tulu3:8b",
help="Specify the ollama model name to use. It should exist on the system you are running this from.",
)

View File

@@ -1,54 +1,111 @@
from src.llm.extractor import LLMExtractor
from datetime import date
import ollama
import pytest
from pydantic import BaseModel
from typing import Type, Any
def test_ollama_response():
from src.llm.extractor import LLMExtractor
from src.llm.models import (
ItemizedAmounts,
Message,
ReceiptAmount,
ReceiptDate,
ReceiptExtracted,
ReceiptItemized,
ReceiptMerchant,
)
@pytest.fixture
def ollama_model_name(request):
"""Must be the name of an ollama model that has already been downloaded to the system."""
param = request.config.getoption("--ollama-model-name")
return param
def test_ollama_response(ollama_model_name: str):
"""Make sure that there is a response (and not an error from ollama)."""
messages: list[Message] = [{"role": "user", "content": "Hello"}]
output = get_chat_response(model="gemma2:27b", messages=messages)
llm_extractor = LLMExtractor(model=ollama_model_name)
output = llm_extractor.get_chat_response(messages=messages)
assert isinstance(output, str)
def test_ollama_structured_response():
def test_ollama_invalid_model_name():
messages: list[Message] = [{"role": "user", "content": "Hello"}]
llm_extractor = LLMExtractor(model="ba-ba-black-sheep")
with pytest.raises(ollama._types.ResponseError):
llm_extractor.get_chat_response(messages=messages)
def test_ollama_structured_response(ollama_model_name):
class ModelData(BaseModel):
name: str
knowledge_cutoff: date
messages: list[Message] = [{"role": "user", "content": "Tell me about yourself"}]
output = get_chat_response(
model="gemma2:27b",
llm_extractor = LLMExtractor(model=ollama_model_name)
output = llm_extractor.get_chat_response(
messages=messages,
structured_output_format=ModelData.model_json_schema(),
structured_output_format=ModelData,
)
assert ModelData.model_validate(output)
def test_ollama_image_upload():
messages: list[Message] = [
{
"role": "user",
"content": "Run OCR on this image.",
"images": ["./data/pytest/image_upload_test.png"],
}
@pytest.fixture
def llm_extractor(ollama_model_name) -> LLMExtractor:
return LLMExtractor(model=ollama_model_name)
@pytest.fixture
def receipt_string() -> str:
with open("assets/pytest/receipt_ocr_test.txt", "r") as f:
receipt_string = "".join(f.readlines())
return receipt_string
def receipt_test_cases():
return [
(ReceiptMerchant, "extract_merchant", "name", "walmart"),
(ReceiptDate, "extract_receipt_date", "date", "2017-07-28"),
(ReceiptAmount, "extract_total", "amount", 98.21),
(ReceiptAmount, "extract_tip", "amount", 0),
(ReceiptAmount, "extract_tax", "amount", 4.59),
]
output = get_chat_response(
model="llava:34b",
messages=messages,
@pytest.mark.parametrize("cls,prompt,field,expected", receipt_test_cases())
def test_extract_fields(
cls: Type[BaseModel],
prompt: str,
field: str,
expected: Any,
receipt_string,
llm_extractor,
ollama_model_name,
):
"""
Provides a slightly complex receipt to the LLM used to check if it is able to extract
basic information in it. If it doesn't pass this, then you may need a different model.
"""
result = llm_extractor.extract_fields(receipt_string, cls, prompt, field, None)
if isinstance(expected, str):
result = result.lower()
assert result == expected
def test_extract_items(llm_extractor, receipt_string):
"""
Provides a slightly complex receipt to the LLM used to check if it is able to extract
the list of items in it. If it doesn't pass this, then you may need a different model.
The exact items cannot be always identical as there will be some inherent variance in LLM output.
Therefore, check if a few control totals match, at least approximately.
"""
items = llm_extractor.extract_fields(
receipt_string, ReceiptItemized, "extract_receipt_items", "ItemizedReceipt", []
)
print(output)
assert False
def test_llm_extractor():
llm_extractor = LLMExtractor(model="gemma2:27b")
with open("data/pytest/receipt_ocr_test.txt", "r") as f:
receipt_ocr_text = "".join(f.readlines())
assert llm_extractor.extract_merchant_name(receipt_ocr_text) == 'Walmart'
assert llm_extractor.extract_receipt_date(receipt_ocr_text) == date.fromisoformat('2017-07-28')
assert llm_extractor.extract_receipt_total_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 98.21}
assert llm_extractor.extract_receipt_tip_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 0}
assert llm_extractor.extract_receipt_tax_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 4.59}
items = llm_extractor.extract_receipt_items(receipt_ocr_text)
# The exact items in these cannot be always identical as there will be some inherent variance in LLM output.
# Thus, we check if a few control totals match, at least approximately.
assert len(items) == 26
assert round(sum(x["amount"] for x in items), 0) == 94
assert 22 <= len(items) <= 26
assert 85 <= round(sum(x["amount"] for x in items), 0) <= 95

View File

@@ -1,7 +1,14 @@
from PIL.Image import open as pil_open
from src.ocr.surya import SuryaOCR
def test_ocr_sample_image():
image_path = "data/pytest/image_upload_test.png"
output = ocr_image(image_path=image_path)
assert output is not None
"""
Any OCR package must be able to read this image correctly.
This was created by an text-to-image service.
"""
image_path = pil_open("assets/pytest/OCR_test.png")
ocr = SuryaOCR()
output = ocr.ordered_ocr_text(image_path)
assert output == "This is Line 1\nThis is Line 2\nThis is Line 5"