Initial stylized working version

This commit is contained in:
avimallu
2025-01-04 12:17:14 -06:00
commit 39a3fbe0aa
22 changed files with 3162 additions and 0 deletions

13
.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# Data Files
data/*

15
.idea/SplitAI.iml generated Normal file
View File

@@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.11 (SplitAI)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">
<option name="requirementsPath" value="" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="GOOGLE" />
<option name="myDocStringFormat" value="Google" />
</component>
</module>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.11 (SplitAI)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (SplitAI)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/../SplitAI/.idea/SplitAI.iml" filepath="$PROJECT_DIR$/../SplitAI/.idea/SplitAI.iml" />
</modules>
</component>
</project>

8
.idea/ruff.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="RuffConfigService">
<option name="runRuffOnSave" value="true" />
<option name="useRuffFormat" value="true" />
<option name="useRuffServer" value="true" />
</component>
</project>

11
.idea/ryecharm-overrides.xml generated Normal file
View File

@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="insyncwithfoo.ryecharm.configurations.ruff.Override">
<option name="names">
<map>
<entry key="crossPlatformExecutableResolution" value="true" />
<entry key="executable" value="true" />
</map>
</option>
</component>
</project>

6
.idea/ryecharm.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="insyncwithfoo.ryecharm.configurations.ruff.Local">
<option name="executable" value="ruff" />
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

70
.idea/workspace.xml generated Normal file
View File

@@ -0,0 +1,70 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="b6145d7e-32b4-4917-b6ec-19187eb432a7" name="Changes" comment="">
<change afterPath="$PROJECT_DIR$/.gitignore" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/SplitAI.iml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/ruff.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/ryecharm-overrides.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/ryecharm.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/.python-version" afterDir="false" />
<change afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
<change afterPath="$PROJECT_DIR$/pyproject.toml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/app/gradio_ui.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/app/split_ai.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/app/utils.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/llm/ollama.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/llm/prompts.yaml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/ocr/surya.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/tests/test_ollama.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/tests/test_surya.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/uv.lock" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProblemsViewState">
<option name="selectedTabId" value="CurrentFile" />
</component>
<component name="ProjectColorInfo">{
&quot;associatedIndex&quot;: 6
}</component>
<component name="ProjectId" id="2qtfzFBMMaI540YQLSoIQUrWHCr" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">{
&quot;keyToString&quot;: {
&quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
&quot;RunOnceActivity.git.unshallow&quot;: &quot;true&quot;,
&quot;SHARE_PROJECT_CONFIGURATION_FILES&quot;: &quot;true&quot;,
&quot;git-widget-placeholder&quot;: &quot;main&quot;,
&quot;last_opened_file_path&quot;: &quot;/Users/avimallu/Documents/SplitAI&quot;,
&quot;settings.editor.selected.configurable&quot;: &quot;preferences.pluginManager&quot;
}
}</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="b6145d7e-32b4-4917-b6ec-19187eb432a7" name="Changes" comment="" />
<created>1735487052741</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1735487052741</updated>
</task>
<servers />
</component>
</project>

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

0
README.md Normal file
View File

21
pyproject.toml Normal file
View File

@@ -0,0 +1,21 @@
[project]
name = "SplitAI"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"gradio>=5.9.1",
"great-tables>=0.15.0",
"more-itertools>=10.5.0",
"ollama>=0.4.4",
"polars>=1.18.0",
"selenium>=4.27.1",
"surya-ocr>=0.8.1",
]
[dependency-groups]
dev = [
"pytest>=8.3.4",
"ruff>=0.8.4",
]

596
src/app/gradio_ui.py Normal file
View File

@@ -0,0 +1,596 @@
from datetime import datetime
from typing import Literal, TypedDict, TypeVar
import gradio as gr
import polars as pl
from PIL.Image import Image
from src.app.split_ai import ReceiptReader
from src.app.utils import css_code, head_html, spinner_html
ComponentType = TypeVar("ComponentType")
def calculate_splits(
item_names: list[str],
item_people: list[list[str]],
item_amounts: list[float],
total: float,
tip: float,
tax: float,
people_list: list[str],
tip_split_proportionally: bool,
tax_split_proportionally: bool,
cashback_discount: float,
return_detailed_table: bool = False
) -> gr.DataFrame:
"""
A simple, but long function to calculate splits for a provided receipt.
Args:
item_names: Names of the items being split.
item_people: A list of people for each item who are splitting its cost.
item_amounts: Amounts of the items being split.
total: The total amount in the receipt
tip: The tip in the receipt
tax: The tax in the receipt
people_list: The total number of people splitting the receipt.
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
cashback_discount: The total will be reduced by this percentage value.
return_detailed_table: Indicator to return full calculation table or a simplified one.
Returns:
A DataFrame form of the provided values along with their calculated splits or a simplified version.
"""
split_count = 0
unsplit_names = []
checkbox_count = len(item_people)
for name, split in zip(item_names, item_people):
if len(split) > 0:
split_count += 1
else:
unsplit_names.append(name)
if split_count != checkbox_count:
gr.Warning(
f"Pending splits: {','.join(unsplit_names)}",
title="Can't show splits yet",
)
return gr.DataFrame(pl.DataFrame(), visible=False)
else:
# Deliberately avoiding going the numpy route here since the data is very small anyway.
split_arrays: list[list[float]] = []
for split in item_people:
split_array = [1 / len(split) if x in split else 0.0 for x in people_list]
split_arrays.append(split_array)
split_amounts: list[list[float]] = []
for split_array, amount in zip(split_arrays, item_amounts):
split_amount = [amount * split for split in split_array]
split_amounts.append(split_amount)
split_subtotals = [sum(x) for x in zip(*split_amounts)]
subtotal = total - tip - tax
split_tips = [
x / subtotal * tip if tip_split_proportionally else tax / len(people_list)
for x in split_subtotals
]
split_taxes = [
x / subtotal * tax if tax_split_proportionally else tax / len(people_list)
for x in split_subtotals
]
split_totals_pre_cashback = [
split_subtotal + split_tip + split_tax
for split_subtotal, split_tip, split_tax in zip(
split_subtotals, split_tips, split_taxes
)
]
split_cashback = [
- x * cashback_discount for x in split_totals_pre_cashback
]
split_totals_post_cashback = [
x * (1 - cashback_discount) for x in split_totals_pre_cashback
]
first_col_names = list(item_names) + ["Subtotal", "Tip", "Tax", "Cashback", "Total"]
splits = split_amounts + [
split_subtotals,
split_tips,
split_taxes,
split_cashback,
split_totals_post_cashback,
]
horizontal_totals = list(item_amounts) + [
subtotal,
tip,
tax,
sum(split_cashback),
sum(split_totals_post_cashback),
]
full_calculation_df = (
pl.DataFrame(
{
"Item": first_col_names,
"splits": splits,
"Total": horizontal_totals,
},
schema={
"Item": pl.String,
"splits": pl.List(pl.Float64),
"Total": pl.Float64,
},
)
.with_columns(pl.col("splits").list.to_struct(fields=people_list))
.unnest("splits")
.with_columns(pl.col(pl.Float64).round(2))
)
if return_detailed_table:
return gr.DataFrame(full_calculation_df, visible=True)
else:
simple_calculation = (
full_calculation_df
.filter(pl.col("Item").eq("Total"))
.select(pl.exclude("Total"))
.transpose(include_header=True, header_name="Person", column_names=["Split"])
.filter(pl.col("Person").ne("Item"))
)
return gr.DataFrame(simple_calculation, visible=True)
class Item(TypedDict):
name: str
amount: float
class ItemSplitter:
def __init__(
self,
item: Item,
people_list: list[str],
) -> None:
self.people_list_state = people_list
self.item = item
self.no_interaction_kwargs = {"interactive": False, "min_width": 10}
self.interaction_kwargs = self.no_interaction_kwargs | {
"container": False,
"interactive": True,
}
def name_textbox(self, item_name: str) -> gr.Textbox:
return gr.Textbox(
item_name, show_label=False, scale=8, **self.interaction_kwargs
)
def amount_number(self, item_amount: float) -> gr.Number:
return gr.Number(
value=item_amount, precision=2, scale=3, **self.interaction_kwargs
)
def split_status_button(
self, choices: list[str] | None = None, status: Literal["⚠️", "🆗"] = "⚠️"
) -> gr.Button:
if choices is not None:
if len(choices) == 0:
status = "⚠️"
else:
status = "🆗"
else:
choices = []
variant: Literal["huggingface", "primary"] = (
"huggingface" if (status == "⚠️") | (len(choices) == 0) else "primary"
)
return gr.Button(
value=status, variant=variant, scale=1, **self.no_interaction_kwargs
)
def delete_item(self, item_list: list[Item]) -> list[Item]:
item_list.remove(self.item)
return item_list
def delete_item_button(self) -> gr.Button:
kwargs = self.no_interaction_kwargs | {"interactive": True}
return gr.Button(value="", variant="stop", **kwargs)
def people_list_checkbox(self, people_list: list[str]) -> gr.CheckboxGroup:
return gr.CheckboxGroup(choices=people_list, **self.interaction_kwargs)
def generate(self) -> tuple[gr.Textbox, gr.CheckboxGroup, gr.Number, gr.Button]:
return self.generate_mobile()
def generate_mobile(self):
with gr.Row(variant="default", equal_height=True):
item_name_textbox = self.name_textbox(self.item["name"])
item_amount_number = self.amount_number(self.item["amount"])
split_status_button = self.split_status_button(status="⚠️")
delete_item_button = self.delete_item_button()
people_list_checkbox = self.people_list_checkbox(self.people_list_state)
people_list_checkbox.change(
lambda x: self.split_status_button(choices=x),
people_list_checkbox,
split_status_button,
)
return (
item_name_textbox,
people_list_checkbox,
item_amount_number,
delete_item_button,
)
class SplitAIApp:
valid_split_variant: Literal["primary"] = "primary"
invalid_split_variant: Literal["huggingface"] = "huggingface"
def __init__(self, llm_model: str):
self.receipt_reader = ReceiptReader(llm_model)
self.demo = self.create_app()
@staticmethod
def prepare_calculate_splits_kwargs(num_records: int, *all_values) -> gr.DataFrame:
"""
This method is necessary because only a list[gr.Component] or similar can be sent as
`inputs` to an event listener. Therefore, it is unpacked here and prepared into a
dictionary based on how it is sent by the event. This method is specifically for
the `get_split_button.click` event listener.
Args:
num_records: The number of items present to split.
*all_values: A list of components to forward to `calculate_splits`.
Returns:
gr.DataFrame
"""
kwargs = {
"item_names": all_values[:num_records],
"item_people": all_values[num_records : num_records * 2],
"item_amounts": all_values[num_records * 2 : num_records * 3],
}
additional_kwargs = {
k: v
for k, v in zip(
[
"total",
"tip",
"tax",
"people_list",
"tip_split_proportionally",
"tax_split_proportionally",
"cashback_discount",
"return_detailed_table",
],
tuple(all_values[num_records * 3 :]),
)
}
additional_kwargs["cashback_discount"] /= 100
kwargs.update(additional_kwargs)
return calculate_splits(**kwargs)
@staticmethod
def update_component_attributes(component: ComponentType, **kwargs) -> ComponentType:
"""
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
the component that you want to update with its attributes changed. It seems like it
doesn't replace the component, but updates it this way. Very weird behavior.
Args:
component: The gradio component to update attributes for.
**kwargs: (attribute, value) pairs to update in child.
Returns:
A new instance of child's class with the updated attributes.
"""
gradio_class = type(component)
try:
return gradio_class(**kwargs)
except Exception as err:
print(
f"The Gradio component {gradio_class} does not have one of the provided attribute keys."
)
raise err
@staticmethod
def validate_people_list(people_textbox) -> tuple[gr.Image, list]:
if "," in people_textbox and people_textbox[-1] != ",":
people_list = [x.strip() for x in people_textbox.split(",")]
return gr.Image(interactive=True), people_list
else:
gr.Warning("You need to enter a list of names separated by commas.")
return gr.Image(interactive=False), []
def create_app(self) -> gr.Blocks():
# `head_html` required to prevent iOS from scaling the UI when clicking on a textbox.
with gr.Blocks(
css=css_code,
head=head_html,
theme="JohnSmith9982/small_and_pretty",
fill_width=True,
) as split_app:
with gr.Column():
self.people_textbox = gr.Textbox(
placeholder="Split names with a comma",
label="Who all are splitting this receipt?",
lines=1,
autofocus=True,
submit_btn="Submit",
)
self.people_list = gr.State([])
self.image_uploader = gr.Image(
show_label=False, scale=1, type="pil", interactive=False
)
self.people_textbox.submit(
SplitAIApp.validate_people_list,
[self.people_textbox],
[self.image_uploader, self.people_list],
)
with gr.Column():
with gr.Column():
with gr.Row():
self.merchant = gr.Textbox(
interactive=True,
label="Merchant Name",
min_width=20,
visible=False,
scale=2,
)
self.receipt_date = gr.DateTime(
interactive=True,
include_time=False,
type="datetime",
label="Date",
min_width=20,
visible=False,
scale=1,
)
with gr.Row():
self.total_amount = gr.Number(
interactive=True,
label="Total",
min_width=20,
visible=False,
precision=2,
)
self.tip_amount = gr.Number(
interactive=True,
label="Tip",
min_width=20,
visible=False,
precision=2,
)
self.tax_amount = gr.Number(
interactive=True,
label="Tax",
min_width=20,
visible=False,
precision=2,
)
self.items = gr.State([])
@gr.render(inputs=[self.items, self.people_list])
def render_items(items: list[Item], people_list: list[str]):
item_names = []
item_peoples = []
item_amounts = []
for key, item in enumerate(items):
with gr.Column(variant="compact"):
splitter = ItemSplitter(item, people_list)
item_name, item_people, item_amount, delete_item_button = (
splitter.generate()
)
# This event needs to be defined outside the ItemSplitter class
# because it references a gr.State variable. All Gradio components
# can be properly pass ONLY via event listeners, as their state is
# managed by Gradio outside the flow of the Python app.
delete_item_button.click(
splitter.delete_item, self.items, self.items
)
item_names.append(item_name)
item_peoples.append(item_people)
item_amounts.append(item_amount)
self.split_tip_proportionally = gr.Checkbox(
value=True,
label="Split tip proportional to other costs",
info="If unchecked, will split equally.",
interactive=True,
)
self.split_tax_proportionally = gr.Checkbox(
value=True,
label="Split tax proportional to other costs",
info="If unchecked, will split equally.",
interactive=True,
)
self.add_cashback_discount = gr.Slider(
minimum=0, maximum=25, value=0, step=0.5,
label="Cashback discount to apply on total",
info="Choose a number between 0% and 100%.",
interactive=True,
)
self.show_detailed_table = gr.Checkbox(
value=False,
label="Show a detailed calculation table",
info="If unchecked, will just show the splits.",
interactive=True,
)
with gr.Row():
self.integrity_markdown = gr.Markdown(
show_label=False, value="", visible=False
)
with gr.Row():
get_splits_button = gr.Button(
"Get Splits", variant="primary", scale=5, min_width=10
)
get_splits_button.click(
lambda *x: SplitAIApp.prepare_calculate_splits_kwargs(
len(item_names), *x
),
inputs=(
item_names
+ item_peoples
+ item_amounts
+ [
self.total_amount,
self.tip_amount,
self.tax_amount,
self.people_list,
self.split_tip_proportionally,
self.split_tax_proportionally,
self.add_cashback_discount,
self.show_detailed_table,
]
),
outputs=self.display_result,
)
add_item_button = gr.Button("", variant="secondary", scale=1, min_width=10)
def add_item(
items: list[Item],
):
new_item_name = f"Item {len(items)+1}"
return items + [
{
"name": new_item_name,
"amount": 0.0,
}
]
add_item_button.click(
add_item,
inputs=[self.items],
outputs=[self.items],
)
# Constantly keep track of whether totals match or not.
def integrity_checker(*args) -> gr.Markdown:
items = args[: len(args) - 3]
tip_amount, tax_amount, total_amount = args[len(args) - 3 :]
subtotal = sum(items)
if subtotal + tip_amount + tax_amount != total_amount:
return gr.Markdown(
f"⚠️ Looks like the total ({total_amount}) doesn't match the value of subtotal ({subtotal}) + tip ({tip_amount}) + tax ({tax_amount}) ⚠️",
show_label=False,
visible=True,
)
else:
return gr.Markdown(visible=False)
gr.on(
triggers=[x.change for x in item_amounts]
+ [
self.tip_amount.change,
self.tax_amount.change,
self.total_amount.change,
],
fn=integrity_checker,
inputs=item_amounts
+ [
self.tip_amount,
self.tax_amount,
self.total_amount,
],
outputs=[self.integrity_markdown],
)
self.display_result = gr.DataFrame(value=None, visible=False)
self.spinner_html = gr.HTML(
spinner_html,
visible=False,
padding=False,
)
self.image_uploader.upload(
lambda: gr.HTML(visible=True), inputs=None, outputs=self.spinner_html
).then(
self.process_image,
inputs=[self.image_uploader, self.items],
outputs=[
self.merchant,
self.receipt_date,
self.total_amount,
self.tip_amount,
self.tax_amount,
self.items,
],
show_progress="hidden",
).then(
lambda: gr.HTML(visible=False), inputs=None, outputs=self.spinner_html
)
return split_app
def process_image(
self,
image: Image,
items: gr.State,
): # -> gr.State:
# gr.Info("Running OCR on image of receipt.")
# receipt_string = self.receipt_reader.get_ordered_text(image)
# gr.Info("Extracting components. Please be patient.")
# receipt_extracted = self.receipt_reader.extract_components(receipt_string)
receipt_extracted = {
"merchant": "FUBAR",
"receipt_date": datetime.now(),
"total": {"amount": 15},
"tip": {"amount": 0},
"tax": {"amount": 3},
"item_amounts": [
{"name": "PET TOY", "currency": "$", "amount": 2},
{"name": "FLOPPY PUPPY", "currency": "$", "amount": 4},
{"name": "SSSUPREME S", "currency": "$", "amount": 6},
],
}
key_value_updates = [
{
"component": self.merchant,
"kwargs": {"value": receipt_extracted["merchant"], "visible": True},
},
{
"component": self.receipt_date,
"kwargs": {
"value": receipt_extracted["receipt_date"],
"visible": True,
},
},
{
"component": self.total_amount,
"kwargs": {
"value": receipt_extracted["total"]["amount"],
"visible": True,
},
},
{
"component": self.tip_amount,
"kwargs": {
"value": receipt_extracted["tip"]["amount"],
"visible": True,
},
},
{
"component": self.tax_amount,
"kwargs": {
"value": receipt_extracted["tax"]["amount"],
"visible": True,
},
},
]
out = [
self.update_component_attributes(x["component"], **x["kwargs"])
for x in key_value_updates
]
items += [
{"name": x["name"], "amount": x["amount"]}
for x in receipt_extracted["item_amounts"]
]
out += [items]
return out
def launch(self, expose_to_local_network: bool = False):
if expose_to_local_network:
self.demo.queue().launch(server_name="0.0.0.0", server_port=7860)
else:
self.demo.queue().launch()
if __name__ == "__main__":
demo = SplitAIApp("qwen2.5")
demo.launch(True)

22
src/app/split_ai.py Normal file
View File

@@ -0,0 +1,22 @@
from src.llm.ollama import LLMExtractor, ReceiptExtracted
from src.ocr.surya import SuryaOCR
from PIL.Image import Image
from PIL.Image import open as PIL_open
class ReceiptReader:
def __init__(self, llm_model: str):
self.ocr_model = SuryaOCR()
self.llm_model = LLMExtractor(model=llm_model)
def get_ordered_text(self, image: Image) -> str:
return self.ocr_model.ordered_ocr_text(image)
def extract_components(self, receipt_string: str) -> ReceiptExtracted:
return self.llm_model.forward(receipt_string, enable_alerts=False)
if __name__ == "__main__":
image = PIL_open("data/pytest/image_upload_test.png")
sr = ReceiptReader()
print(sr.extract_components(image))

52
src/app/utils.py Normal file
View File

@@ -0,0 +1,52 @@
head_html = """
<meta name="viewport"
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
"""
css_code = """
input {
font-size: 16px !important;
}
input[type="number"] {
inputmode: numeric;
pattern: "[0-9]*";
}
#fullwidth-checkgroup .gr-checkgroup {
width: 100%;
}
footer {
display: none !important;
}
.loader {
width: 48px;
height: 48px;
border: 5px solid #FFF;
border-bottom-color: transparent;
border-radius: 50%;
display: inline-block;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
"""
spinner_html = """
<div style="
display: flex;
justify-content: center;
align-items: center;
height: 10vh;
width: 100%;
">
<span class='loader'></span>
</div>
"""

196
src/llm/ollama.py Normal file
View File

@@ -0,0 +1,196 @@
import json
import yaml
import gradio as gr
from ollama import chat
from pathlib import Path
from datetime import datetime
from pydantic import BaseModel
from typing import TypedDict, Optional, Literal, Any, Type, Annotated
class Message(TypedDict):
role: Literal["user", "assistant", "system", "tool"]
content: Optional[str]
images: Optional[list[str] | bytes | list[Path]]
class ReceiptMerchant(BaseModel):
name: str
class ReceiptDate(BaseModel):
date: datetime
def is_valid_currency_code(value: str) -> bool:
return len(value) == 3 and value.isupper() and value.isalpha()
class ReceiptAmount(BaseModel):
currency: Annotated[str, is_valid_currency_code]
amount: float
class Amount(TypedDict):
currency: Annotated[str, is_valid_currency_code]
amount: float
class ReceiptItemized(BaseModel):
class ReceiptLineItemAmount(BaseModel):
name: str
currency: Annotated[str, is_valid_currency_code]
amount: float
ItemizedReceipt: list[ReceiptLineItemAmount]
class ItemizedAmounts(TypedDict):
name: str
currency: Annotated[str, is_valid_currency_code]
amount: float
class ReceiptExtracted(TypedDict):
merchant: str
receipt_date: datetime
total_amount: Amount
tip_amount: Amount
tax_amount: Amount
item_amounts: list[ItemizedAmounts]
class LLMExtractor:
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml"):
"""
Args:
model: A model name that has been downloaded by ``ollama``.
prompt_path: A str or Path to a YAML file with various prompts.
"""
self.model = model
self.prompt_path = Path(prompt_path)
with open(self.prompt_path, "r") as f:
self.prompts = yaml.safe_load(f)
def get_chat_response(
self,
messages: list[Message],
structured_output_format: Type[BaseModel] | None = None,
) -> Any:
"""
Args:
messages: The conversation so far.
structured_output_format: The dictionary format of the pydantic model schema. Can be ``None``.
Returns:
Depends on the model chosen, and if `structured_output_format` is provided.
"""
response = chat(
model=self.model, messages=messages, format=structured_output_format.model_json_schema(),
options={"temperature": 0}
)
message = response.message.content
if structured_output_format is not None:
return json.loads(message)
else:
return message
def load_substituted_prompt(self, prompt_name: str, **kwargs) -> str:
prompt = self.prompts[prompt_name]
for patt, repl in kwargs.items():
prompt = prompt.replace(f"[[ {patt} ]]", repl)
return prompt
def extract_merchant_name(self, receipt_string: str) -> str:
messages: list[Message] = [
{
"role": "system",
"content": self.load_substituted_prompt(
"extract_merchant_name", **{"receipt_string": receipt_string}
),
},
]
return self.get_chat_response(messages, ReceiptMerchant)['name']
def extract_receipt_date(self, receipt_string: str) -> datetime:
messages: list[Message] = [
{
"role": "system",
"content": self.load_substituted_prompt(
"extract_receipt_date", **{"receipt_string": receipt_string}
),
},
]
receipt_date = self.get_chat_response(messages, ReceiptDate)['date']
return datetime.fromisoformat(receipt_date)
def extract_receipt_total_amount(self, receipt_string: str) -> Amount:
messages: list[Message] = [
{
"role": "system",
"content": self.load_substituted_prompt(
"extract_receipt_total_amount", **{"receipt_string": receipt_string}
),
},
]
currency_amount = self.get_chat_response(messages, ReceiptAmount)
return currency_amount
def extract_receipt_tax_amount(self, receipt_string: str) -> Amount:
messages: list[Message] = [
{
"role": "system",
"content": self.load_substituted_prompt(
"extract_receipt_tax_amount", **{"receipt_string": receipt_string}
),
},
]
currency_amount = self.get_chat_response(messages, ReceiptAmount)
return currency_amount
def extract_receipt_tip_amount(self, receipt_string: str) -> Amount:
messages: list[Message] = [
{
"role": "system",
"content": self.load_substituted_prompt(
"extract_receipt_tip_amount", **{"receipt_string": receipt_string}
),
},
]
currency_amount = self.get_chat_response(messages, ReceiptAmount)
return currency_amount
def extract_receipt_items(self, receipt_string: str) -> list[ItemizedAmounts]:
messages: list[Message] = [
{
"role": "system",
"content": self.load_substituted_prompt(
"extract_receipt_items", **{"receipt_string": receipt_string}
),
},
]
itemized_amounts = self.get_chat_response(messages, ReceiptItemized)['ItemizedReceipt']
return itemized_amounts
def forward(self, receipt_string, enable_alerts:bool=False) -> ReceiptExtracted:
merchant = self.extract_merchant_name(receipt_string)
# gr.Info("Extracted merchant name.")
receipt_date = self.extract_receipt_date(receipt_string)
# gr.Info("Extracted receipt date.")
total_amount = self.extract_receipt_total_amount(receipt_string)
# gr.Info("Extracted total amount.")
tip_amount = self.extract_receipt_tip_amount(receipt_string)
# gr.Info("Extracted tip amount.")
tax_amount = self.extract_receipt_tax_amount(receipt_string)
# gr.Info("Extracted tax amount. Extracting individual items - this will take time.")
item_amounts = self.extract_receipt_items(receipt_string)
# gr.Info("Extracted individual items. Populating items now!")
return {
"merchant": merchant,
"receipt_date": receipt_date,
"total": total_amount,
"tip": tip_amount,
"tax": tax_amount,
"item_amounts": item_amounts
}
# return {
# "merchant": self.extract_merchant_name(receipt_string),
# "receipt_date": self.extract_receipt_date(receipt_string),
# "total": self.extract_receipt_total_amount(receipt_string),
# "tip": self.extract_receipt_tip_amount(receipt_string),
# "tax": self.extract_receipt_tax_amount(receipt_string),
# "item_amounts": self.extract_receipt_items(receipt_string)
# }

89
src/llm/prompts.yaml Normal file
View File

@@ -0,0 +1,89 @@
extract_merchant_name: >
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
You are trying to identify the merchant name.
For restaurants, this will be the restaurant name.
For a grocery store or supermarket, this will be the supermarket's name.
For other retailers, it will be the retailer's name.
If you can't find a name, provide `null` as the response.
The receipt string is provided below:
```
[[ receipt_string ]]
```
Extract the name of the merchant from this receipt using your knowledge of the industry and the receipt itself.
extract_receipt_date: >
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
You are trying to identify the date (in YYYY-MM-DD format) on the receipt.
If you can't find a date, provide `null` as the response.
The receipt string is provided below:
```
[[ receipt_string ]]
```
Extract the date of this receipt using your knowledge of the industry and the receipt itself.
extract_receipt_total_amount: >
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
You are trying to identify the total amount on the receipt, as well as the three letter currency code (ISO 4217).
If you cannot find the total amount, provide '0' as the total amount.
The receipt string is provided below:
```
[[ receipt_string ]]
```
Extract the currency and amount of the total in this receipt using your knowledge of the industry
and the receipt itself.
extract_receipt_tax_amount: >
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
You are trying to identify the overall tax amount on the receipt, as well as the three letter currency code (ISO 4217).
If you cannot find the tax amount, provide '0' for the tax amount.
The receipt string is provided below:
```
[[ receipt_string ]]
```
Extract the currency and amount of the overall tax in this receipt using your knowledge of the industry
and the receipt itself.
extract_receipt_tip_amount: >
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
You are trying to identify the tip paid (including gratuity, but not tax) on the receipt, as well as the three letter
currency code (ISO 4217). If there is no tip, provide '0' as the tip.
The receipt string is provided below:
```
[[ receipt_string ]]
```
Extract the currency and amount of the overall tip in this receipt using your knowledge of the industry
and the receipt itself.
extract_receipt_items: >
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
You are trying to identify each of the items present, and the amount paid for those items.
The receipt string is provided below:
```
[[ receipt_string ]]
```
Extract a list of items by their name, currency (3 letter code) and amount paid for it using your knowledge of the industry
and the receipt itself.

185
src/ocr/surya.py Normal file
View File

@@ -0,0 +1,185 @@
from pathlib import Path
from PIL.Image import open as PIL_open
from PIL.Image import Image
from surya.ocr import run_ocr
from surya.model.detection.model import (
load_model as load_det_model,
load_processor as load_det_processor,
)
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from typing import Literal, TypedDict, overload, Optional
from collections import defaultdict
from more_itertools import bucket
Polygon = list[list[float]]
BBox = list[float]
class SuryaOCRResult(TypedDict):
polygon: Polygon
confidence: Optional[float]
text: str
bbox: BBox
class SplitAIOCRResult(TypedDict):
polygons: list[Polygon]
texts: list[str]
bboxes: list[BBox]
class SuryaOCR:
def __init__(self):
self.det_processor = load_det_processor()
self.det_model = load_det_model()
self.rec_model = load_rec_model()
self.rec_processor = load_rec_processor()
@overload
def ocr_image(self, image: Image, return_format: Literal["polygon"]) -> list[Polygon]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["text"]) -> list[str]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["bbox"]) -> list[BBox]: ...
@overload
def ocr_image(self, image: Image, return_format: Literal["confidence"]) -> list[BBox]: ...
@overload
def ocr_image(self, image: Image, return_format: None) -> SplitAIOCRResult: ...
def ocr_image(
self,
image: Image,
return_format: Literal["polygon", "text", "bbox", "confidence"] | None,
) -> SplitAIOCRResult | list[Polygon] | list[str] | list[BBox]:
"""
Specify either a path to an image or image file itself (as a PIL image) to run OCR on.
If both are provided, then the image file is prioritized.
Args:
image: the PIL image for Surya to process
return_format: Specify one of the allowed values to return only this key from SuryaOCR's output.
Returns:
OCRResult, in Surya's format. See `<https://github.com/VikParuchuri/surya>` for details.
"""
images = [image]
langs = [["en"]]
ocr_output = run_ocr(
images,
langs,
self.det_model,
self.det_processor,
self.rec_model,
self.rec_processor,
)
ocr_as_list: list[SuryaOCRResult] = ocr_output[0].model_dump()["text_lines"]
if isinstance(return_format, str):
return [x[return_format] for x in ocr_as_list]
else:
polygons = [x["polygon"] for x in ocr_as_list]
texts = [x["text"] for x in ocr_as_list]
bboxes = [x["bbox"] for x in ocr_as_list]
return {
"polygons": polygons,
"texts": texts,
"bboxes": bboxes,
}
def ordered_ocr_text(self, image: Image) -> str:
split_ocr_result = self.ocr_image(image, None)
ordered_text: list[str] = []
for group in SuryaOCR.order_polygons(split_ocr_result["polygons"]):
line = [split_ocr_result["texts"][x] for x in group]
ordered_text += [" ".join(line)]
return "\n".join(ordered_text)
@staticmethod
def get_centroid_of_bounding_polygon(
polygon: Polygon,
) -> tuple[float, float]:
"""
Function name is self-explanatory. Invariant of rotation. Any bounding polygon whose centroid differs
significantly to another one is unlikely to be in the same line.
Args:
polygon: `[[x_1, y_1], [x_2, y_2], [x_3, y_3], [x_4, y_4]]` coordinates of
the polygon, clockwise from top left.
Returns:
The centroid of the bounding polygon.
"""
(x_1, y_1), (x_2, y_2), (x_3, y_3), (x_4, y_4) = tuple(
tuple(x) for x in polygon
)
return (x_1 + x_4) / 2, (y_1 + y_2) / 2
@staticmethod
def get_height_of_bounding_polygon(polygon: Polygon) -> float:
"""
Args:
polygon: `[[x_1, y_1], [x_2, y_2], [x_3, y_3], [x_4, y_4]]` coordinates of
the polygon, clockwise from top left.
Returns:
The height of the bounding polygon.
"""
(x_1, y_1), (x_2, y_2), (x_3, y_3), (x_4, y_4) = tuple(
tuple(x) for x in polygon
)
return y_4 - y_1
@staticmethod
def order_polygons(polygons: list[Polygon]) -> list[list[int]]:
"""
Given a list of bounding polygons (from an OCR framework) for recognized text, attempt
to reorder them (left-to-right, top-to-bottom) that their text contents are expected to
be read.
This is intended to work independent of the orientation of the receipts, but currently
is at a PoC stage where it is assumed that the receipt is horizontally positioned.
Args:
polygons: Bounding polygons of recognized text.
Returns:
A ``list[int]`` with the estimated line numbers.
"""
x_midpoints, y_midpoints = zip(
*[SuryaOCR.get_centroid_of_bounding_polygon(x) for x in polygons]
)
heights = [SuryaOCR.get_height_of_bounding_polygon(x) for x in polygons]
threshold = 0.6
y_ranges = [
(
midpoint - height * threshold * 0.5,
midpoint,
midpoint + height * threshold * 0.5,
)
for midpoint, height in zip(y_midpoints, heights)
]
# Assign line groups
line_groups: dict[int, int | None] = defaultdict(lambda: None)
for idx, i_range in enumerate(y_ranges):
if idx not in line_groups:
line_groups[idx] = idx
for jdx, j_range in [
(jdx, j_range) for jdx, j_range in enumerate(y_ranges) if jdx > idx
]:
if (i_range[0] <= j_range[1] <= i_range[2]) and jdx not in line_groups:
line_groups[jdx] = idx
# Reorder by x_midpoints within a group:
line_groups_reversed = [(group, key) for key, group in line_groups.items()]
bucketed_line_groups = bucket(line_groups_reversed, key=lambda x: x[0])
ordered_text = []
for key in sorted(list(bucketed_line_groups)):
idx_group = list(bucketed_line_groups[key])
sorted_list = sorted(idx_group, key=lambda x: x_midpoints[x[1]])
sorted_idx = [x[1] for x in sorted_list]
ordered_text += [sorted_idx]
return ordered_text

54
tests/test_ollama.py Normal file
View File

@@ -0,0 +1,54 @@
from src.llm.ollama import LLMExtractor
from datetime import date
def test_ollama_response():
messages: list[Message] = [{"role": "user", "content": "Hello"}]
output = get_chat_response(model="gemma2:27b", messages=messages)
assert isinstance(output, str)
def test_ollama_structured_response():
class ModelData(BaseModel):
name: str
knowledge_cutoff: date
messages: list[Message] = [{"role": "user", "content": "Tell me about yourself"}]
output = get_chat_response(
model="gemma2:27b",
messages=messages,
structured_output_format=ModelData.model_json_schema(),
)
assert ModelData.model_validate(output)
def test_ollama_image_upload():
messages: list[Message] = [
{
"role": "user",
"content": "Run OCR on this image.",
"images": ["./data/pytest/image_upload_test.png"],
}
]
output = get_chat_response(
model="llava:34b",
messages=messages,
)
print(output)
assert False
def test_llm_extractor():
llm_extractor = LLMExtractor(model="gemma2:27b")
with open("data/pytest/receipt_ocr_test.txt", "r") as f:
receipt_ocr_text = "".join(f.readlines())
assert llm_extractor.extract_merchant_name(receipt_ocr_text) == 'Walmart'
assert llm_extractor.extract_receipt_date(receipt_ocr_text) == date.fromisoformat('2017-07-28')
assert llm_extractor.extract_receipt_total_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 98.21}
assert llm_extractor.extract_receipt_tip_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 0}
assert llm_extractor.extract_receipt_tax_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 4.59}
items = llm_extractor.extract_receipt_items(receipt_ocr_text)
# The exact items in these cannot be always identical as there will be some inherent variance in LLM output.
# Thus, we check if a few control totals match, at least approximately.
assert len(items) == 26
assert round(sum(x["amount"] for x in items), 0) == 94

7
tests/test_surya.py Normal file
View File

@@ -0,0 +1,7 @@
from src.ocr.surya import SuryaOCR
def test_ocr_sample_image():
image_path = "data/pytest/image_upload_test.png"
output = ocr_image(image_path=image_path)
assert output is not None

1789
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff