8000 Refactor image corruption toolbox by iamksuresh · Pull Request #319 · aiverify-foundation/aiverify · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Refactor image corruption toolbox #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stock-plugins/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ build
temp
__pycache__
.pytest_cache
output
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,52 @@

## Developers:
* AI Verify

## Develop plugin locally
#### Execute the below bash script in the project root
```
#!/bin/bash

# setup virtual environment
python -m venv .venv
source .venv/bin/activate

# execute plugin
cd aiverify/stock-plugins/aiverify.stock.image-corruption-toolbox/algorithms/blur_corruptions/

# install test-engine-core
pip install -e '.[dev]'

python -m aiverify_blur_corruptions --data_path <data_path> --model_path <model_path> --ground_truth_path <ground_truth_path> --ground_truth <str> --model_type CLASSIFICATION --run_pipeline --set_seed <int> --annotated_ground_truth_path <annotated_file_path> --file_name_label <str>

```
#### Example :
```
#!/bin/bash

root_path="<PATH_TO_FOLDER>/aiverify/stock-plugins/user_defined_files"

python -m aiverify_blur_corruptions \
--data_path $root_path/data/raw_fashion_image_10 \
--model_path $root_path/pipeline/multiclass_classification_image_mnist_fashion \
--ground_truth_path $root_path/data/pickle_pandas_fashion_mnist_annotated_labels_10.sav \
--ground_truth label \
--model_type CLASSIFICATION \
--run_pipeline \
--annotated_ground_truth_path $root_path/data/pickle_pandas_fashion_mnist_annotated_labels_10.sav \
--set_seed 10 \
--file_name_label file_name
```

## Build Plugin
```
cd aiverify/stock-plugins/aiverify.stock.image-corruption-toolbox/algorithms/blur_corruptions/
hatch build
```
## Tests
### Pytest is used as the testing framework.
Execute the below steps to execute unit and integration tests inside tests/ folder
```
cd aiverify/stock-plugins/aiverify.stock.image-corruption-toolbox/algorithms/blur_corruptions/
pytest .
```

This file was deleted.

A92E
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Main package for aiverify blur corruptions plugin.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Allow aiverify_blur_corruptions to be executable through
`python3 -m aiverify_blur_corruptions`
"""
import sys
from importlib.metadata import version
from pathlib import Path

from aiverify_blur_corruptions.plugin_init import run


def main() -> None:
"""
Print the version of test engine core
"""
print("*" * 20)
print(version_msg())
print("*" * 20)
# invoke algorithm
run()


def version_msg():
"""
Return the aiverify_blur_corruptions version, location and Python powering it.
"""
python_version = sys.version
location = Path(__file__).resolve().parent.parent

return f"Aiverify Blur Corruptions - {version('aiverify_blur_corruptions')} from {location} \
(Python {python_version})"


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"input.schema.json",
"LICENSE",
"output.schema.json",
"blur_corruptions.meta.json",
"blur_corruptions.py",
"algo.meta.json",
"algo.py",
"README.md",
"requirements.txt",
"syntax_checker.py",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pandas as pd
from aiverify_blur_corruptions.utils import blur
from PIL import Image
from sklearn.metrics import accuracy_score
from test_engine_core.interfaces.ialgorithm import IAlgorithm
Expand All @@ -19,7 +20,6 @@
from test_engine_core.plugins.metadata.plugin_metadata import PluginMetadata
from test_engine_core.utils.json_utils import load_schema_file, validate_json
from test_engine_core.utils.simple_progress import SimpleProgress
from utils import blur


# =====================================================================================
Expand Down Expand Up @@ -140,21 +140,25 @@ def __init__(
self._data = None
self._results = {"results": [0]}
self._ordered_ground_truth = None
self._tmp_path = self._base_path / "temp"
self._save_path = self._base_path.parents[1] / "widgets" / "blur_images"
# write all output to the output folder
output_folder = Path.cwd() / "output"
output_folder.mkdir(parents=True, exist_ok=True)
self._tmp_path = output_folder / "temp"
self._save_path = output_folder / "widgets" / "blur_images"

# Algorithm input schema defined in input.schema.json
# By defining the input schema, it allows the front-end to know what algorithm input params is
# required by this plugin. This allows this algorithm plug-in to receive the arguments values it requires.
current_file_dir = Path(__file__).parent
self._input_schema = load_schema_file(
str(self._base_path / "input.schema.json")
str(current_file_dir / "input.schema.json")
)

# Algorithm output schema defined in output.schema.json
# By defining the output schema, this plug-in validates the result with the output schema.
# This allows the result to be validated against the schema before passing it to the front-end for display.
self._output_schema = load_schema_file(
str(self._base_path / "output.schema.json")
str(current_file_dir / "output.schema.json")
)

# Retrieve the input parameters defined in the input schema and store them
Expand Down Expand Up @@ -497,9 +501,7 @@ def _get_rand_display(

image_name = str(severity) + ".png"
image_path = images["image_directory"].iloc[index]
image_relative_path = str(
Path(image_path).relative_to(Path(self._base_path.parents[1]))
)
image_relative_path = str(Path(image_path).relative_to(Path().absolute()))

Path(self._save_path / corruption).mkdir(parents=True, exist_ok=True)
shutil.copy(
Expand Down
39 changes: 22 additions & 17 deletions ...hms/blur_corruptions/tests/plugin_test.py → ...ns/aiverify_blur_corruptions/algo_init.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import importlib
import json
import logging
import sys
from pathlib import Path
from typing import Dict, Tuple, Union

from blur_corruptions import Plugin
from aiverify_blur_corruptions.algo import Plugin
from test_engine_core.interfaces.idata import IData
from test_engine_core.interfaces.imodel import IModel
from test_engine_core.interfaces.ipipeline import IPipeline
Expand All @@ -24,9 +25,9 @@
# =====================================================================================
# NOTE: Do not modify this file unless you know what you are doing.
# =====================================================================================
class PluginTest:
class AlgoInit:
"""
The PluginTest class specifies methods in supporting testing for the plugin.
The AlgoInit class specifies methods in supporting testing for the plugin.
"""

@staticmethod
Expand All @@ -37,7 +38,7 @@ def progress_callback(completion_value: int):
Args:
completion_value (int): Current progress completion
"""
print(f"[PluginTest] Progress Update: {completion_value}")
print(f"[AlgoInit] Progress Update: {completion_value}")

def __init__(
self,
Expand All @@ -57,7 +58,9 @@ def __init__(

# Store the input arguments as private vars
if core_modules_path == "":
core_modules_path = "../../../../test-engine-core-modules"
core_modules_path = Path(
importlib.util.find_spec("test_engine_core").origin
).parent
self._core_modules_path: str = core_modules_path
self._data_path: str = str(self._base_path / data_path)
self._model_path: str = str(self._base_path / model_path)
Expand Down Expand Up @@ -100,7 +103,7 @@ def run(self) -> None:
print(f"[DETECTED_PLUGINS]: {PluginManager.get_printable_plugins()}")

# Create logger
self._logger_instance = logging.getLogger("PluginTestLogger")
self._logger_instance = logging.getLogger("AlgoInitLogger")
self._logger_instance.setLevel(logging.DEBUG)
log_format = logging.Formatter(
fmt="%(levelname)s %(asctime)s \t %(pathname)s %(funcName)s L%(lineno)s - %(message)s",
Expand Down Expand Up @@ -190,7 +193,6 @@ def run(self) -> None:
)
print(f"[GROUND_TRUTH]: {self._ground_truth}")
print(f"[MODEL_TYPE]: {self._model_type}")
print(f"[DATA FEATURES]: {self._data_instance.read_labels()}")
print(
f"[GROUND_TRUTH FEATURES]: {self._ground_truth_instance.read_labels()}"
)
Expand All @@ -216,7 +218,6 @@ def run(self) -> None:
"truth feature exists in the data specified in ground truth path file.)"
)

print(f"[DATA FEATURES]: {self._data_instance.read_labels()}")
print(
f"[GROUND_TRUTH FEATURES]: {self._ground_truth_instance.read_labels()}"
)
Expand All @@ -230,9 +231,7 @@ def run(self) -> None:
self._input_arguments["ground_truth"] = self._ground_truth
self._input_arguments["model_type"] = self._model_type
self._input_arguments["logger"] = self._logger_instance
self._input_arguments[
"progress_callback"
] = PluginTest.progress_callback
self._input_arguments["progress_callback"] = AlgoInit.progress_callback
self._input_arguments["project_base_path"] = self._base_path

# Run the plugin with the arguments and instances
Expand All @@ -257,11 +256,17 @@ def run(self) -> None:
print("Verifying results with output schema...")
is_success, error_messages = self._verify_task_results(results)
if is_success:
# Print the output results
print(json.dumps(results))

# Exit successfully
sys.exit(0)
# Save the output results
output_folder = Path.cwd() / "output"
output_folder.mkdir(parents=True, exist_ok=True)
json_file_path = output_folder / "results.json"

# Write the data to the JSON file
with open(json_file_path, "w") as json_file:
json.dump(results, json_file, indent=4)
print("*" * 20)
print(f"check the results here : {json_file_path}")
print("*" * 20)
else:
raise RuntimeError(error_messages)
else:
Expand Down Expand Up @@ -298,7 +303,7 @@ def _verify_task_results(self, task_result: Dict) -> Tuple[bool, str]:
# Check that it meets the required format before sending out to the UI for display
if not validate_json(
task_result,
load_schema_file(str(self._base_path / "output.schema.json")),
load_schema_file(str(Path(__file__).parent / "output.schema.json")),
):
is_success = False
error_message = "Failed schema validation"
Expand Down
Loading
Loading
0