Skip to content

Vision AutoML

FastAPI app init for the vision AutoML service.

AudioClassificationTask

Bases: BaseModel

Configuration for audio classification tasks.

audio_dir is the root directory containing audio files. labels_file is a CSV with audio_path and label columns.

Source code in app/vision_automl/models.py
87
88
89
90
91
92
93
94
95
96
97
98
99
class AudioClassificationTask(BaseModel):
    """Configuration for audio classification tasks.

    ``audio_dir`` is the root directory containing audio files.
    ``labels_file`` is a CSV with ``audio_path`` and ``label`` columns.
    """

    task_type: Literal["audio_classification"] = "audio_classification"
    audio_dir: Path
    labels_file: Path

    class Config:
        arbitrary_types_allowed = True

CausalLMTask

Bases: TextTask

Configuration for causal language modelling tasks.

CSV must have a text column.

Source code in app/vision_automl/models.py
126
127
128
129
130
131
132
class CausalLMTask(TextTask):
    """Configuration for causal language modelling tasks.

    CSV must have a ``text`` column.
    """

    task_type: Literal["causal_lm"] = "causal_lm"

ImageClassificationTask

Bases: ImageTask

Configuration for single-label image classification tasks.

Source code in app/vision_automl/models.py
38
39
40
41
class ImageClassificationTask(ImageTask):
    """Configuration for single-label image classification tasks."""

    task_type: Literal["image_classification"] = "image_classification"

ImageMultiLabelClassificationTask

Bases: ImageTask

Configuration for multi-label image classification tasks.

Source code in app/vision_automl/models.py
158
159
160
161
162
class ImageMultiLabelClassificationTask(ImageTask):
    """Configuration for multi-label image classification tasks."""

    task_type: str = "image_multilabel_classification"
    label_format: Literal["csv", "json"] = "csv"  # required

ImageRegressionTask

Bases: ImageTask

Configuration for image regression tasks (predict numeric values).

Source code in app/vision_automl/models.py
165
166
167
168
169
class ImageRegressionTask(ImageTask):
    """Configuration for image regression tasks (predict numeric values)."""

    task_type: str = "image_regression"
    label_format: Literal["csv"] = "csv"  # regression needs exact values

ImageSegmentationTask

Bases: ImageTask

Configuration for semantic/panoptic image segmentation tasks.

Source code in app/vision_automl/models.py
44
45
46
47
class ImageSegmentationTask(ImageTask):
    """Configuration for semantic/panoptic image segmentation tasks."""

    task_type: Literal["image_segmentation"] = "image_segmentation"

ImageTask

Bases: BaseModel

Base Pydantic model describing common image task inputs.

Source code in app/vision_automl/models.py
12
13
14
15
16
17
18
19
20
21
class ImageTask(BaseModel):
    """Base Pydantic model describing common image task inputs."""

    train_dir: Path
    test_dir: Path | None = None
    label_format: Literal["folder", "csv"] = "folder"
    labels_file: Path | None = None  # used if label_format != 'folder'

    class Config:
        arbitrary_types_allowed = True

KeypointDetectionTask

Bases: ImageTask

Configuration for keypoint detection tasks.

The labels CSV must include a keypoints column with JSON-encoded [x, y, visibility] lists.

Source code in app/vision_automl/models.py
72
73
74
75
76
77
78
79
class KeypointDetectionTask(ImageTask):
    """Configuration for keypoint detection tasks.

    The labels CSV must include a ``keypoints`` column with JSON-encoded
    ``[x, y, visibility]`` lists.
    """

    task_type: Literal["keypoint_detection"] = "keypoint_detection"

MaskedLMTask

Bases: TextTask

Configuration for masked language modelling tasks.

CSV must have a text column.

Source code in app/vision_automl/models.py
144
145
146
147
148
149
150
class MaskedLMTask(TextTask):
    """Configuration for masked language modelling tasks.

    CSV must have a ``text`` column.
    """

    task_type: Literal["masked_lm"] = "masked_lm"

ObjectDetectionTask

Bases: ImageTask

Configuration for object detection tasks.

The labels CSV must include boxes and class_labels columns (JSON-encoded lists per row).

Source code in app/vision_automl/models.py
50
51
52
53
54
55
56
57
58
class ObjectDetectionTask(ImageTask):
    """Configuration for object detection tasks.

    The labels CSV must include ``boxes`` and ``class_labels`` columns
    (JSON-encoded lists per row).
    """

    task_type: Literal["object_detection"] = "object_detection"
    label_format: Literal["csv"] = "csv"

QuestionAnsweringTask

Bases: TextTask

Configuration for extractive question answering tasks.

CSV must have question, context, answer_start, and answer_text columns.

Source code in app/vision_automl/models.py
116
117
118
119
120
121
122
123
class QuestionAnsweringTask(TextTask):
    """Configuration for extractive question answering tasks.

    CSV must have ``question``, ``context``, ``answer_start``, and
    ``answer_text`` columns.
    """

    task_type: Literal["question_answering"] = "question_answering"

Seq2SeqLMTask

Bases: TextTask

Configuration for sequence-to-sequence tasks.

CSV must have input_text and target_text columns.

Source code in app/vision_automl/models.py
135
136
137
138
139
140
141
class Seq2SeqLMTask(TextTask):
    """Configuration for sequence-to-sequence tasks.

    CSV must have ``input_text`` and ``target_text`` columns.
    """

    task_type: Literal["seq2seq_lm"] = "seq2seq_lm"

SequenceClassificationTask

Bases: TextTask

Configuration for text sequence classification tasks.

CSV must have text and label columns.

Source code in app/vision_automl/models.py
107
108
109
110
111
112
113
class SequenceClassificationTask(TextTask):
    """Configuration for text sequence classification tasks.

    CSV must have ``text`` and ``label`` columns.
    """

    task_type: Literal["text_classification"] = "text_classification"

TextTask

Bases: BaseModel

Base Pydantic model for text-based tasks.

Source code in app/vision_automl/models.py
24
25
26
27
28
29
30
class TextTask(BaseModel):
    """Base Pydantic model for text-based tasks."""

    data_file: Path  # CSV with the required columns for the task type

    class Config:
        arbitrary_types_allowed = True

VideoClassificationTask

Bases: ImageTask

Configuration for video classification tasks.

The labels CSV must include a video_path column pointing to video files relative to train_dir.

Source code in app/vision_automl/models.py
61
62
63
64
65
66
67
68
69
class VideoClassificationTask(ImageTask):
    """Configuration for video classification tasks.

    The labels CSV must include a ``video_path`` column pointing to video
    files relative to ``train_dir``.
    """

    task_type: Literal["video_classification"] = "video_classification"
    label_format: Literal["csv"] = "csv"

Route definitions for the vision AutoML service.

find_best_model_for_vision(request, user_id, dataset_id, dataset_version=None, filename_column='filename', label_column='label', task_type='image_classification', time_budget=60, model_size='small', dataset_split=None) async

Fetch a vision dataset from AutoDW, run AutoML training, and upload the best model.

Steps
  1. Fetch dataset metadata from AutoDW.
  2. Resolve the correct download URL (respecting splits if present).
  3. Download the dataset ZIP to a temporary directory and extract it.
  4. Validate CSV structure and image file presence.
  5. Train a vision AutoML model within the given time budget.
  6. Zip the model artifacts.
  7. Upload the model and leaderboard back to AutoDW.

Returns:

Type Description
JSONResponse

200 – success message and leaderboard summary.

JSONResponse

400 – validation error (bad inputs or unsupported dataset).

JSONResponse

502 – AutoDW communication failure.

JSONResponse

500 – unexpected runtime error.

Source code in app/vision_automl/router.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@router.post("/best_model/")
async def find_best_model_for_vision(
    request: Request,
    user_id: Annotated[str, Form(..., description="User id from AutoDW")],
    dataset_id: Annotated[str, Form(..., description="Dataset id from AutoDW")],
    dataset_version: Annotated[
        str | None, Form(description="Optional dataset version")
    ] = None,
    filename_column: Annotated[
        str, Form(..., description="Filename column in labels.csv")
    ] = "filename",
    label_column: Annotated[
        str, Form(..., description="Label column in labels.csv")
    ] = "label",
    task_type: Annotated[
        str,
        Form(
            description=(
                "Vision task type. One of: "
                + ", ".join(sorted(SUPPORTED_VISION_TASK_TYPES))
            )
        ),
    ] = "image_classification",
    time_budget: Annotated[int, Form(..., description="Time budget in seconds")] = 60,
    model_size: Annotated[
        str, Form(..., description="Model size: small / medium / large")
    ] = "small",
    dataset_split: Annotated[
        str | None,
        Form(description="Dataset split to use for training (e.g., 'train')."),
    ] = None,
) -> JSONResponse:
    """
    Fetch a vision dataset from AutoDW, run AutoML training, and upload the best model.

    Steps:
      1. Fetch dataset metadata from AutoDW.
      2. Resolve the correct download URL (respecting splits if present).
      3. Download the dataset ZIP to a temporary directory and extract it.
      4. Validate CSV structure and image file presence.
      5. Train a vision AutoML model within the given time budget.
      6. Zip the model artifacts.
      7. Upload the model and leaderboard back to AutoDW.

    Returns:
        200 – success message and leaderboard summary.
        400 – validation error (bad inputs or unsupported dataset).
        502 – AutoDW communication failure.
        500 – unexpected runtime error.
    """
    autodw_base = os.getenv("AUTODW_URL", "http://localhost:8000")
    upload_url = f"{autodw_base}/ai-models/upload/single/{user_id}"

    try:
        # 1. Metadata
        metadata = fetch_dataset_metadata(
            autodw_base, user_id, dataset_id, dataset_version
        )

        if metadata.get("file_type") != "zip":
            return JSONResponse(
                status_code=400,
                content={"error": "Vision AutoML requires a ZIP dataset."},
            )

        # 2. Download URL
        download_url = resolve_download_url(
            autodw_base, user_id, dataset_id, dataset_version, metadata, dataset_split
        )

        with dataset_workspace(f"automl_{dataset_id}") as workdir:
            # 3. Download & extract
            zip_path = download_dataset(
                download_url, workdir, metadata.get("original_filename", "dataset.zip")
            )
            csv_path, images_dir = extract_and_locate_dataset(zip_path, workdir)

            # 4. Validate
            if task_type not in SUPPORTED_VISION_TASK_TYPES:
                return JSONResponse(
                    status_code=400,
                    content={
                        "error": f"Unsupported task_type '{task_type}'. "
                        f"Supported: {sorted(SUPPORTED_VISION_TASK_TYPES)}"
                    },
                )

            validation_error = validate_vision_inputs(
                csv_path, images_dir, filename_column, label_column, task_type
            )
            if validation_error:
                return JSONResponse(
                    status_code=400, content={"error": validation_error}
                )

            # 5. Train
            optuna_result = await train_automl(
                csv_path,
                images_dir,
                filename_column,
                label_column,
                time_budget,
                model_size,
                workdir=workdir,
                task_type=task_type,
            )

            # 6. Serialize
            zip_path = serialize_and_zip_model(optuna_result, workdir)
            leaderboard_json, leaderboard_str = convert_leaderboard_safely(
                optuna_result
            )

            # 7. Upload
            _, payload = build_upload_payload(
                dataset_id, dataset_version, metadata, task_type, leaderboard_json
            )
            upload_resp = upload_model(
                upload_url, zip_path, payload, request.headers.get("X-Task-ID")
            )

            if upload_resp.status_code >= 400:
                logger.error("Model upload failed: %s", upload_resp.text)
                return JSONResponse(
                    status_code=upload_resp.status_code,
                    content={"error": f"Failed to upload model: {upload_resp.text}"},
                )

        logger.info("Vision AutoML training completed and model uploaded successfully.")
        return JSONResponse(
            status_code=200,
            content={
                "message": "Vision AutoML training completed successfully and model uploaded to AutoDW",
                "leaderboard": leaderboard_str,
            },
        )

    except DatasetValidationError as e:
        return JSONResponse(status_code=400, content={"error": str(e)})
    except AutodwError as e:
        return JSONResponse(status_code=502, content={"error": f"AutoDW error: {e}"})
    except Exception as e:
        logger.exception("Unexpected error during vision AutoML")
        return JSONResponse(status_code=500, content={"error": str(e)})

Service layer for vision AutoML workflows.

Mirrors the structure of tabular_automl/services.py so both pipelines share a consistent public API consumed by their respective main.py files.

AutodwError

Bases: Exception

Raised on AutoDW communication failures.

Source code in app/vision_automl/services.py
204
205
class AutodwError(Exception):
    """Raised on AutoDW communication failures."""

DatasetValidationError

Bases: ValueError

Raised when the uploaded dataset fails structural validation.

Source code in app/vision_automl/services.py
200
201
class DatasetValidationError(ValueError):
    """Raised when the uploaded dataset fails structural validation."""

build_upload_payload(dataset_id, dataset_version, metadata, task_type, leaderboard_json)

Return (model_id, form_data_dict) for the AutoDW upload request.

Mirrors tabular's build_upload_payload.

Source code in app/vision_automl/services.py
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def build_upload_payload(
    dataset_id: str,
    dataset_version: str | None,
    metadata: dict,
    task_type: str,
    leaderboard_json: dict,
) -> tuple[str, dict]:
    """
    Return (model_id, form_data_dict) for the AutoDW upload request.

    Mirrors tabular's ``build_upload_payload``.
    """
    model_id = (
        f"vision_automl_{dataset_id}_{int(datetime.datetime.utcnow().timestamp())}"
    )
    data = {
        "model_id": model_id,
        "name": f"Vision AutoML Model - {dataset_id}",
        "description": "AutoML trained vision model",
        "framework": "pytorch",
        "model_type": task_type,
        "training_dataset": str(dataset_id),
        "training_dataset_version": dataset_version or metadata.get("version", "v1"),
        "leaderboard": json.dumps(leaderboard_json),
    }
    return model_id, data

collect_missing_files(df, images_dir, filename_col, label_col)

Return a list of filenames referenced in the CSV but absent on disk.

Source code in app/vision_automl/services.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def collect_missing_files(
    df: pd.DataFrame, images_dir: Path, filename_col: str, label_col: str
) -> list[str]:
    """Return a list of filenames referenced in the CSV but absent on disk."""
    missing = []
    for _, row in df.iterrows():
        filename = row[filename_col]

        img_path = images_dir / filename
        if img_path.exists():
            continue

        matches = list(images_dir.rglob(filename))
        if len(matches) == 1:
            continue
        elif len(matches) > 1:
            logger.warning("Multiple matches for %s: %s", filename, matches)

        missing.append(filename)
    return missing

convert_leaderboard_safely(optuna_result)

Extract leaderboard information from an Optuna result dict.

Returns (leaderboard_json, leaderboard_str) — mirrors the tabular convert_leaderboard_safely signature so main.py can treat both pipelines identically.

Source code in app/vision_automl/services.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def convert_leaderboard_safely(optuna_result: dict) -> tuple[dict, str]:
    """
    Extract leaderboard information from an Optuna result dict.

    Returns (leaderboard_json, leaderboard_str) — mirrors the tabular
    ``convert_leaderboard_safely`` signature so main.py can treat both
    pipelines identically.
    """
    leaderboard_json = {
        "best_loss": optuna_result.get("best_value"),
        "best_params": optuna_result.get("best_params"),
        "trials": optuna_result.get("n_trials"),
    }
    leaderboard_str = json.dumps(leaderboard_json, indent=2)
    return leaderboard_json, leaderboard_str

download_dataset(download_url, workdir, original_filename)

Stream-download the ZIP dataset and return its local path.

Source code in app/vision_automl/services.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def download_dataset(download_url: str, workdir: Path, original_filename: str) -> Path:
    """Stream-download the ZIP dataset and return its local path."""
    zip_path = workdir / original_filename
    with requests.get(
        download_url,
        stream=True,
        timeout=60,
        headers={"Accept-Encoding": "gzip, deflate"},
    ) as resp:
        resp.raise_for_status()
        with open(zip_path, "wb") as f:
            for chunk in resp.iter_content(chunk_size=1024 * 1024):
                f.write(chunk)
    logger.info("Dataset ZIP saved to %s", zip_path)
    return zip_path

extract_and_locate_dataset(zip_path, workdir)

Extract a vision dataset ZIP and return (csv_path, images_dir).

Raises DatasetValidationError for structural problems.

Source code in app/vision_automl/services.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def extract_and_locate_dataset(zip_path: Path, workdir: Path) -> tuple[Path, Path]:
    """
    Extract a vision dataset ZIP and return (csv_path, images_dir).

    Raises DatasetValidationError for structural problems.
    """
    extract_dir = workdir / "dataset"
    extract_dir.mkdir(exist_ok=True)
    shutil.unpack_archive(zip_path, extract_dir)

    dataset_root = _find_valid_dataset_root(extract_dir)
    csv_path = _find_csv_file(dataset_root)
    images_dir = _find_or_resolve_images_dir(dataset_root, csv_path)
    return csv_path, images_dir

fetch_dataset_metadata(autodw_base, user_id, dataset_id, dataset_version)

Fetch and return dataset metadata from AutoDW.

Source code in app/vision_automl/services.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def fetch_dataset_metadata(
    autodw_base: str,
    user_id: str,
    dataset_id: str,
    dataset_version: str | None,
) -> dict:
    """Fetch and return dataset metadata from AutoDW."""
    metadata_url = _build_metadata_url(
        autodw_base, user_id, dataset_id, dataset_version
    )
    logger.debug("Fetching dataset metadata: %s", metadata_url)
    resp = requests.get(metadata_url, timeout=15)
    resp.raise_for_status()
    return resp.json()

get_num_params_if_available(repo_id, revision=None)

Try to retrieve number of parameters for a HF model, if available.

Source code in app/vision_automl/services.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def get_num_params_if_available(
    repo_id: str, revision: str | None = None
) -> int | None:
    """Try to retrieve number of parameters for a HF model, if available."""
    logger.debug("Fetching parameter count for model %s", repo_id)
    api = HfApi()
    try:
        info = api.model_info(repo_id, revision=revision, files_metadata=True)
        num_params = getattr(info, "safetensors", None)
        if num_params is not None:
            return num_params.total
    except Exception as e:
        logger.warning("Failed to retrieve num_params for %s: %s", repo_id, e)
    return None

normalize_dataframe_filenames(df, filename_column, csv_path)

Normalize filenames to basenames and persist CSV back to disk.

Source code in app/vision_automl/services.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def normalize_dataframe_filenames(
    df: pd.DataFrame, filename_column: str, csv_path: Path
) -> pd.DataFrame:
    """Normalize filenames to basenames and persist CSV back to disk."""
    logger.info("Normalizing filenames in column '%s'", filename_column)
    if filename_column in df.columns:
        df[filename_column] = (
            df[filename_column]
            .astype(str)
            .map(lambda s: os.path.basename(str(s).replace("\\", "/")))
        )
        df.to_csv(csv_path, index=False)
        logger.debug("Normalized filenames saved to %s", csv_path)
    else:
        logger.warning(
            "Filename column '%s' not found during normalization", filename_column
        )
    return df

resolve_download_url(autodw_base, user_id, dataset_id, dataset_version, metadata, split)

Determine the correct dataset download URL, accounting for splits.

Source code in app/vision_automl/services.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def resolve_download_url(
    autodw_base: str,
    user_id: str,
    dataset_id: str,
    dataset_version: str | None,
    metadata: dict,
    split: str | None,
) -> str:
    """Determine the correct dataset download URL, accounting for splits."""
    base_url = _build_metadata_url(autodw_base, user_id, dataset_id, dataset_version)
    download_url = f"{base_url}/download"

    has_split = bool(metadata.get("custom_metadata", {}).get("split"))
    if split and has_split:
        download_url = f"{download_url}?split={split}"
        logger.info(
            "Dataset has splits; downloading '%s' split from: %s", split, download_url
        )
    else:
        if split and not has_split:
            logger.warning(
                "split='%s' was requested but dataset has no splits; "
                "downloading full dataset.",
                split,
            )
        logger.debug("Downloading full dataset ZIP: %s", download_url)

    return download_url

resolve_images_root(images_dir)

Resolve common nested packaging patterns inside uploaded image zips.

Source code in app/vision_automl/services.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def resolve_images_root(images_dir: Path) -> Path:
    """Resolve common nested packaging patterns inside uploaded image zips."""
    logger.info("Resolving image directory structure at %s", images_dir)
    nested_images_dir = images_dir / "images"
    if nested_images_dir.exists() and nested_images_dir.is_dir():
        logger.debug("Detected nested 'images' folder, using it as root")
        images_dir = nested_images_dir

    try:
        top_level_entries = list(images_dir.iterdir())
        only_dirs = [p for p in top_level_entries if p.is_dir()]
        only_files = [p for p in top_level_entries if p.is_file()]
        if len(only_files) == 0 and len(only_dirs) == 1:
            logger.debug("Detected single top-level directory: %s", only_dirs[0])
            images_dir = only_dirs[0]
    except Exception as e:
        logger.warning("Error resolving image root: %s", e)

    return images_dir

search_hf_for_pytorch_models_with_estimated_parameters(filter='image-classification', limit=3, sort='downloads')

Search HF for PyTorch image-classification models annotated with param counts.

Source code in app/vision_automl/services.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def search_hf_for_pytorch_models_with_estimated_parameters(
    filter: str = "image-classification", limit: int = 3, sort: str = "downloads"
) -> list[dict[str, Any]]:
    """Search HF for PyTorch image-classification models annotated with param counts."""
    logger.info("Searching Hugging Face models for filter='%s'", filter)
    api = HfApi()
    models = api.list_models(
        filter=filter,
        library="pytorch",
        sort=sort,
        direction=-1,
        limit=limit,
    )

    results: list[dict[str, Any]] = []
    for m in models:
        num_params = get_num_params_if_available(m.id)
        if num_params:
            results.append(
                {
                    "model_id": m.id,
                    "downloads": getattr(m, "downloads", None),
                    "likes": getattr(m, "likes", None),
                    "last_modified": getattr(m, "lastModified", None),
                    "private": getattr(m, "private", None),
                    "num_params": num_params,
                }
            )

    logger.info("Found %d models with parameter info", len(results))
    return results

serialize_and_zip_model(result, workdir)

Package the trained model directory into a ZIP archive.

Returns the path to the ZIP file. Mirrors tabular's serialize_and_zip_predictor.

Source code in app/vision_automl/services.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def serialize_and_zip_model(result: dict, workdir: Path) -> Path:
    """
    Package the trained model directory into a ZIP archive.

    Returns the path to the ZIP file.
    Mirrors tabular's ``serialize_and_zip_predictor``.
    """
    model_dir = workdir / "model"
    model_dir.mkdir(exist_ok=True)

    try:
        with open(workdir / "vision_deployment_instructions.md") as f:
            f.write(deployment_instructions())
    except Exception as e:
        logger.debug(f"No deployment_instructions found, {e}")

    zip_base = workdir / "vision_model"
    shutil.make_archive(str(zip_base), "zip", model_dir)
    zip_path = zip_base.with_suffix(".zip")
    logger.debug("Model artifacts zipped to %s", zip_path)
    return zip_path

sort_models_by_size(models, size_tier)

Filter and sort models by size tier based on estimated parameter counts.

Source code in app/vision_automl/services.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def sort_models_by_size(
    models: list[dict[str, Any]], size_tier: str
) -> list[dict[str, Any]]:
    """Filter and sort models by size tier based on estimated parameter counts."""
    logger.info("Sorting models by size tier: %s", size_tier)
    tier = str(size_tier).strip().lower()

    SMALL_MAX: int = int(os.getenv("MODEL_SMALL_MAX_PARAM_SIZE", 50_000_000))
    MEDIUM_MIN: int = SMALL_MAX + 1
    MEDIUM_MAX: int = int(os.getenv("MODEL_MEDIUM_MAX_PARAM_SIZE", 200_000_000))
    LARGE_MIN: int = MEDIUM_MAX + 1

    def in_tier(m: dict[str, Any]) -> bool:
        n = m.get("num_params")
        if n is None:
            return False
        if tier == "small":
            return 0 <= n <= SMALL_MAX
        if tier == "medium":
            return MEDIUM_MIN <= n <= MEDIUM_MAX
        if tier == "large":
            return n >= LARGE_MIN
        return True

    filtered = [m for m in models if in_tier(m)]
    if not filtered:
        logger.warning("No models matched tier '%s'; falling back to all models", tier)
        filtered = models

    return sorted(
        filtered, key=lambda m: (m.get("num_params") is None, m.get("num_params", 0))
    )

train_automl(csv_path, images_dir, filename_column, label_column, time_budget, model_size, workdir, task_type='image_classification') async

Run Optuna-based vision AutoML and return the result dict.

Source code in app/vision_automl/services.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
async def train_automl(
    csv_path: Path,
    images_dir: Path,
    filename_column: str,
    label_column: str,
    time_budget: int,
    model_size: str,
    workdir: Path,
    task_type: str = "image_classification",
) -> dict:
    """Run Optuna-based vision AutoML and return the result dict."""
    return await run_in_threadpool(
        run_optuna_search,
        task_type=task_type,
        csv_path=csv_path,
        images_dir=images_dir,
        filename_column=filename_column,
        label_column=label_column,
        n_trials=max(1, min(25, time_budget // 60)),
        timeout=time_budget,
        model_size=model_size,
        workdir=workdir,
    )

upload_model(upload_url, zip_path, payload, task_id)

Upload the zipped model to AutoDW and return the raw response.

Source code in app/vision_automl/services.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def upload_model(
    upload_url: str,
    zip_path: Path,
    payload: dict,
    task_id: str | None,
) -> requests.Response:
    """Upload the zipped model to AutoDW and return the raw response."""
    headers = {"X-Task-ID": task_id} if task_id else {}
    with open(zip_path, "rb") as f:
        files = {"file": (zip_path.name, f, "application/octet-stream")}
        logger.debug("Uploading vision model to %s", upload_url)
        return requests.post(
            upload_url, headers=headers, files=files, data=payload, timeout=120
        )

validate_vision_inputs(csv_path, images_dir, filename_column, label_column, task_type='image_classification')

Validate dataset structure for the given task type.

Returns an error string on failure, or None if everything is valid. Mirrors the signature/contract of tabular's validate_tabular_inputs.

Parameters:

Name Type Description Default
csv_path Path

Path to the labels CSV.

required
images_dir Path

Root directory containing image/audio/video files. Unused for pure text tasks.

required
filename_column str

Column name containing file paths (image/audio tasks).

required
label_column str

Column name containing labels (classification tasks).

required
task_type str

One of the supported task type slugs.

'image_classification'
Source code in app/vision_automl/services.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def validate_vision_inputs(
    csv_path: Path,
    images_dir: Path,
    filename_column: str,
    label_column: str,
    task_type: str = "image_classification",
) -> str | None:
    """Validate dataset structure for the given task type.

    Returns an error string on failure, or None if everything is valid.
    Mirrors the signature/contract of tabular's ``validate_tabular_inputs``.

    Args:
        csv_path: Path to the labels CSV.
        images_dir: Root directory containing image/audio/video files.
            Unused for pure text tasks.
        filename_column: Column name containing file paths (image/audio tasks).
        label_column: Column name containing labels (classification tasks).
        task_type: One of the supported task type slugs.
    """
    # Audio task — validate audio dir + CSV
    if task_type == "audio_classification":
        if not csv_path.exists():
            return f"Labels CSV not found: {csv_path}"
        if not images_dir.exists():
            return f"Audio directory not found: {images_dir}"
        try:
            df = pd.read_csv(csv_path)
        except Exception as e:
            return f"Could not read labels CSV: {e}"
        for col, role in [(filename_column, "Filename"), (label_column, "Label")]:
            if col not in df.columns:
                return f"{role} column '{col}' not found in labels CSV"
        return None

    # Text tasks — validate CSV + required columns
    if task_type in _TEXT_REQUIRED_COLUMNS:
        try:
            df = pd.read_csv(csv_path)
        except Exception as e:
            return f"Could not read labels CSV: {e}"
        required = _TEXT_REQUIRED_COLUMNS[task_type]
        missing_cols = [c for c in required if c not in df.columns]
        if missing_cols:
            return f"Required column(s) missing for {task_type}: {missing_cols}"
        return None

    # Image tasks — existing CSV + image presence checks
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        return f"Could not read labels CSV: {e}"

    for col, role in [(filename_column, "Filename"), (label_column, "Label")]:
        if col not in df.columns:
            return f"{role} column '{col}' not found in labels CSV"

    # Detection/segmentation tasks — validate annotation columns
    if task_type in _DETECTION_EXTRA_COLUMNS:
        extra = _DETECTION_EXTRA_COLUMNS[task_type]
        missing_cols = [c for c in extra if c not in df.columns]
        if missing_cols:
            return (
                f"Required annotation column(s) missing for {task_type}: {missing_cols}"
            )

    df = normalize_dataframe_filenames(df, filename_column, csv_path)

    missing = collect_missing_files(df, images_dir, filename_column, label_column)
    if missing:
        preview = missing[:5]
        suffix = "..." if len(missing) > 5 else ""
        return f"Missing {len(missing)} image file(s): {preview}{suffix}"

    return None

ML engine

Per-task hyperparameter and model config loader.

load_task_config(task_type)

Load and return the JSON config for the given task type.

Parameters:

Name Type Description Default
task_type str

One of the supported task type slugs.

required

Returns:

Type Description
dict

Dict with keys: small_models, medium_models, large_models,

dict

lr_low, lr_high, batch_sizes, weight_decay_low, weight_decay_high,

dict

max_epochs, early_stopping_patience.

Raises:

Type Description
ValueError

If the task type is not supported.

Source code in app/vision_automl/ml_engine/configs/__init__.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def load_task_config(task_type: str) -> dict:
    """Load and return the JSON config for the given task type.

    Args:
        task_type: One of the supported task type slugs.

    Returns:
        Dict with keys: small_models, medium_models, large_models,
        lr_low, lr_high, batch_sizes, weight_decay_low, weight_decay_high,
        max_epochs, early_stopping_patience.

    Raises:
        ValueError: If the task type is not supported.
    """
    if task_type not in SUPPORTED_TASK_TYPES:
        raise ValueError(
            f"Unknown task type '{task_type}'. "
            f"Supported: {sorted(SUPPORTED_TASK_TYPES)}"
        )
    config_path = _CONFIGS_DIR / f"{task_type}.json"
    with open(config_path) as f:
        return json.load(f)

AudioClassificationDataModule

Datamodule for audio classification tasks.

CSV columns: audio_path (relative to root_dir) and label. Audio is loaded with torchaudio (must be installed separately).

Source code in app/vision_automl/ml_engine/datamodule.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
class AudioClassificationDataModule:
    """Datamodule for audio classification tasks.

    CSV columns: ``audio_path`` (relative to ``root_dir``) and ``label``.
    Audio is loaded with ``torchaudio`` (must be installed separately).
    """

    def __init__(
        self,
        csv_file: Path,
        root_dir: Path,
        audio_col: str = "audio_path",
        label_col: str = "label",
        sampling_rate: int = 16000,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "facebook/wav2vec2-base",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.root_dir = Path(root_dir)
        self.audio_col = audio_col
        self.label_col = label_col
        self.sampling_rate = sampling_rate
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.num_classes: int = 0
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.feature_extractor: AutoFeatureExtractor | None = None
        self.train_df: pd.DataFrame | None = None
        self.val_df: pd.DataFrame | None = None
        self.test_df: pd.DataFrame | None = None
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        classes = sorted(df[self.label_col].unique().tolist())
        self.num_classes = len(classes)
        self.id2label = {i: c for i, c in enumerate(classes)}
        self.label2id = {c: i for i, c in enumerate(classes)}
        df = df.copy()
        df[self.label_col] = df[self.label_col].map(self.label2id)

        train_df, temp_df = train_test_split(
            df,
            test_size=self.val_split + self.test_split,
            stratify=df[self.label_col],
            random_state=self.seed,
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df,
            test_size=1 - relative_val,
            stratify=temp_df[self.label_col],
            random_state=self.seed,
        )
        self.train_df = train_df.reset_index(drop=True)
        self.val_df = val_df.reset_index(drop=True)
        self.test_df = test_df.reset_index(drop=True)
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.hf_model_id)

    def _make_dataset(self, df: pd.DataFrame) -> Dataset:
        try:
            import torchaudio
        except ImportError as e:
            raise ImportError(
                "torchaudio is required for audio tasks. "
                "Install it with: pip install torchaudio"
            ) from e

        root = self.root_dir
        audio_col = self.audio_col
        label_col = self.label_col
        target_sr = self.sampling_rate

        class _AudioDataset(Dataset):
            def __init__(self, df):
                self.df = df

            def __len__(self):
                return len(self.df)

            def __getitem__(self, idx):
                row = self.df.iloc[idx]
                waveform, sr = torchaudio.load(str(root / str(row[audio_col])))
                if sr != target_sr:
                    waveform = torchaudio.functional.resample(waveform, sr, target_sr)
                waveform = waveform.mean(0)  # mono
                return waveform, torch.tensor(int(row[label_col]), dtype=torch.long)

        return _AudioDataset(df)

    def _collate_fn(self, batch):
        waveforms, labels = zip(*batch)
        if self.feature_extractor is None:
            raise RuntimeError("Feature extractor not initialized.")
        inputs = self.feature_extractor(
            [w.numpy() for w in waveforms],
            sampling_rate=self.sampling_rate,
            return_tensors="pt",
            padding=True,
        )
        return {
            "input_values": inputs.input_values,
            "labels": torch.tensor(labels, dtype=torch.long),
        }

    def _make_loader(self, df: pd.DataFrame, shuffle: bool) -> DataLoader:
        return DataLoader(
            self._make_dataset(df),
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_df, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_df, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_df, shuffle=False)

CausalLMDataModule

Datamodule for causal language modelling tasks.

CSV column: text. Labels are produced by shifting input_ids right by one position (handled by the model internally when labels equals input_ids).

Source code in app/vision_automl/ml_engine/datamodule.py
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
class CausalLMDataModule:
    """Datamodule for causal language modelling tasks.

    CSV column: ``text``.  Labels are produced by shifting ``input_ids``
    right by one position (handled by the model internally when ``labels``
    equals ``input_ids``).
    """

    def __init__(
        self,
        csv_file: Path,
        text_col: str = "text",
        max_length: int = 256,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "distilgpt2",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.text_col = text_col
        self.max_length = max_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.tokenizer: AutoTokenizer | None = None
        self.train_dataset: CausalLMFromCSVDataset | None = None
        self.val_dataset: CausalLMFromCSVDataset | None = None
        self.test_dataset: CausalLMFromCSVDataset | None = None
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        train_df, temp_df = train_test_split(
            df, test_size=self.val_split + self.test_split, random_state=self.seed
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df, test_size=1 - relative_val, random_state=self.seed
        )

        self.train_dataset = CausalLMFromCSVDataset(train_df, self.text_col)
        self.val_dataset = CausalLMFromCSVDataset(val_df, self.text_col)
        self.test_dataset = CausalLMFromCSVDataset(test_df, self.text_col)

        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def _collate_fn(self, batch: list[str]) -> dict[str, torch.Tensor]:
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer not initialized.")
        encoding = self.tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        # For causal LM, labels = input_ids (model shifts internally)
        return {
            "input_ids": encoding.input_ids,
            "attention_mask": encoding.attention_mask,
            "labels": encoding.input_ids.clone(),
        }

    def _make_loader(self, dataset, shuffle: bool) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_dataset, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_dataset, shuffle=False)

ImageClassificationDataModule

Handles dataset preparation and dataloaders for image classification tasks.

Source code in app/vision_automl/ml_engine/datamodule.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class ImageClassificationDataModule:
    """Handles dataset preparation and dataloaders for image classification tasks."""

    def __init__(
        self,
        csv_file: Path,
        root_dir: Path,
        img_col: str = "filename",
        label_col: str = "label",
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        transform: Callable | None = None,
        shuffle: bool = True,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = DEFAULT_IMAGE_CLASSIFIER_HF_ID,
    ) -> None:
        self.csv_file = Path(csv_file)
        self.root_dir = Path(root_dir)
        self.img_col: str = img_col
        self.label_col: str = label_col
        self.batch_size: int = batch_size
        self.num_workers: int = num_workers
        self.transform: Callable | None = transform
        self.shuffle: bool = shuffle
        self.val_split: float = val_split
        self.test_split: float = test_split
        self.seed: int = seed
        self.hf_model_id: str = hf_model_id

        self.num_classes: int = 0
        self.train_dataset: ImageClassificationFromCSVDataset | None = None
        self.val_dataset: ImageClassificationFromCSVDataset | None = None
        self.test_dataset: ImageClassificationFromCSVDataset | None = None
        self.processor: AutoImageProcessor | None = None
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}

        logger.info("Initializing ImageClassificationDataModule with CSV: %s", csv_file)
        self.setup()

    def setup(self) -> None:
        """Create train/val/test splits, datasets, label maps, and processor."""
        logger.info("Reading dataset from %s", self.csv_file)
        df: pd.DataFrame = pd.read_csv(self.csv_file)

        train_df, temp_df = train_test_split(
            df,
            test_size=self.val_split + self.test_split,
            stratify=df[self.label_col],
            random_state=self.seed,
        )

        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df,
            test_size=1 - relative_val,
            stratify=temp_df[self.label_col],
            random_state=self.seed,
        )

        logger.info(
            "Split completed: train=%d, val=%d, test=%d",
            len(train_df),
            len(val_df),
            len(test_df),
        )

        self.train_dataset = ImageClassificationFromCSVDataset(
            csv_file=train_df,
            root_dir=self.root_dir,
            img_col=self.img_col,
            label_col=self.label_col,
            transform=self.transform,
        )
        self.val_dataset = ImageClassificationFromCSVDataset(
            csv_file=val_df,
            root_dir=self.root_dir,
            img_col=self.img_col,
            label_col=self.label_col,
            transform=self.transform,
        )
        self.test_dataset = ImageClassificationFromCSVDataset(
            csv_file=test_df,
            root_dir=self.root_dir,
            img_col=self.img_col,
            label_col=self.label_col,
            transform=self.transform,
        )

        self.num_classes = len(self.train_dataset.classes)
        self.id2label = {i: c for i, c in enumerate(self.train_dataset.classes)}
        self.label2id = {c: i for i, c in enumerate(self.train_dataset.classes)}

        self.processor = AutoImageProcessor.from_pretrained(self.hf_model_id)
        logger.info("Loaded processor from: %s", self.hf_model_id)

    def _collate_fn(self, batch: list[tuple[Any, Any]]) -> dict[str, torch.Tensor]:
        images, labels = zip(*batch)
        if self.processor is None:
            raise RuntimeError("Processor not initialized. Call setup() first.")
        pixel_values = self.processor(
            images=list(images), return_tensors="pt"
        ).pixel_values
        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(labels, dtype=torch.long),
        }

    def train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise RuntimeError("Train dataset not initialized. Call setup() first.")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def val_dataloader(self) -> DataLoader:
        if self.val_dataset is None:
            raise RuntimeError(
                "Validation dataset not initialized. Call setup() first."
            )
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def test_dataloader(self) -> DataLoader:
        if self.test_dataset is None:
            raise RuntimeError("Test dataset not initialized. Call setup() first.")
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

setup()

Create train/val/test splits, datasets, label maps, and processor.

Source code in app/vision_automl/ml_engine/datamodule.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def setup(self) -> None:
    """Create train/val/test splits, datasets, label maps, and processor."""
    logger.info("Reading dataset from %s", self.csv_file)
    df: pd.DataFrame = pd.read_csv(self.csv_file)

    train_df, temp_df = train_test_split(
        df,
        test_size=self.val_split + self.test_split,
        stratify=df[self.label_col],
        random_state=self.seed,
    )

    relative_val = self.val_split / (self.val_split + self.test_split)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=1 - relative_val,
        stratify=temp_df[self.label_col],
        random_state=self.seed,
    )

    logger.info(
        "Split completed: train=%d, val=%d, test=%d",
        len(train_df),
        len(val_df),
        len(test_df),
    )

    self.train_dataset = ImageClassificationFromCSVDataset(
        csv_file=train_df,
        root_dir=self.root_dir,
        img_col=self.img_col,
        label_col=self.label_col,
        transform=self.transform,
    )
    self.val_dataset = ImageClassificationFromCSVDataset(
        csv_file=val_df,
        root_dir=self.root_dir,
        img_col=self.img_col,
        label_col=self.label_col,
        transform=self.transform,
    )
    self.test_dataset = ImageClassificationFromCSVDataset(
        csv_file=test_df,
        root_dir=self.root_dir,
        img_col=self.img_col,
        label_col=self.label_col,
        transform=self.transform,
    )

    self.num_classes = len(self.train_dataset.classes)
    self.id2label = {i: c for i, c in enumerate(self.train_dataset.classes)}
    self.label2id = {c: i for i, c in enumerate(self.train_dataset.classes)}

    self.processor = AutoImageProcessor.from_pretrained(self.hf_model_id)
    logger.info("Loaded processor from: %s", self.hf_model_id)

ImageSegmentationDataModule

Bases: ImageClassificationDataModule

Datamodule for image segmentation tasks.

Uses the same CSV + class-subdir image layout as image classification. The collate function passes labels (pixel-level segmentation maps) to the processor. The labels CSV must contain a mask_filename column pointing to the segmentation mask image (same class-subdir layout).

Source code in app/vision_automl/ml_engine/datamodule.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class ImageSegmentationDataModule(ImageClassificationDataModule):
    """Datamodule for image segmentation tasks.

    Uses the same CSV + class-subdir image layout as image classification.
    The collate function passes ``labels`` (pixel-level segmentation maps)
    to the processor.  The labels CSV must contain a ``mask_filename``
    column pointing to the segmentation mask image (same class-subdir layout).
    """

    def __init__(
        self,
        csv_file: Path,
        root_dir: Path,
        img_col: str = "filename",
        label_col: str = "label",
        mask_col: str = "mask_filename",
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = DEFAULT_IMAGE_CLASSIFIER_HF_ID,
    ) -> None:
        self.mask_col = mask_col
        super().__init__(
            csv_file=csv_file,
            root_dir=root_dir,
            img_col=img_col,
            label_col=label_col,
            batch_size=batch_size,
            num_workers=num_workers,
            val_split=val_split,
            test_split=test_split,
            seed=seed,
            hf_model_id=hf_model_id,
        )

    def _collate_fn(self, batch: list[tuple[Any, Any]]) -> dict[str, torch.Tensor]:
        images, labels = zip(*batch)
        if self.processor is None:
            raise RuntimeError("Processor not initialized.")
        encoding = self.processor(images=list(images), return_tensors="pt")
        return {
            "pixel_values": encoding.pixel_values,
            "labels": torch.stack(
                [l if isinstance(l, torch.Tensor) else torch.tensor(l) for l in labels]
            ),
        }

KeypointDetectionDataModule

Bases: ImageClassificationDataModule

Datamodule for keypoint detection tasks.

Uses the same CSV + image layout as image classification. The keypoints_col should contain a JSON list of [x, y, visibility] entries (one per keypoint).

Source code in app/vision_automl/ml_engine/datamodule.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
class KeypointDetectionDataModule(ImageClassificationDataModule):
    """Datamodule for keypoint detection tasks.

    Uses the same CSV + image layout as image classification.
    The ``keypoints_col`` should contain a JSON list of
    ``[x, y, visibility]`` entries (one per keypoint).
    """

    def __init__(
        self,
        csv_file: Path,
        root_dir: Path,
        img_col: str = "filename",
        label_col: str = "label",
        keypoints_col: str = "keypoints",
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = DEFAULT_IMAGE_CLASSIFIER_HF_ID,
    ) -> None:
        self.keypoints_col = keypoints_col
        super().__init__(
            csv_file=csv_file,
            root_dir=root_dir,
            img_col=img_col,
            label_col=label_col,
            batch_size=batch_size,
            num_workers=num_workers,
            val_split=val_split,
            test_split=test_split,
            seed=seed,
            hf_model_id=hf_model_id,
        )

    def _collate_fn(self, batch):
        images, labels = zip(*batch)
        if self.processor is None:
            raise RuntimeError("Processor not initialized.")
        encoding = self.processor(images=list(images), return_tensors="pt")
        return {
            "pixel_values": encoding.pixel_values,
            "labels": torch.stack(
                [l if isinstance(l, torch.Tensor) else torch.tensor(l) for l in labels]
            ),
        }

MaskedLMDataModule

Datamodule for masked language modelling tasks.

CSV column: text. Uses DataCollatorForLanguageModeling to randomly mask tokens at runtime.

Source code in app/vision_automl/ml_engine/datamodule.py
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
class MaskedLMDataModule:
    """Datamodule for masked language modelling tasks.

    CSV column: ``text``.  Uses ``DataCollatorForLanguageModeling`` to
    randomly mask tokens at runtime.
    """

    def __init__(
        self,
        csv_file: Path,
        text_col: str = "text",
        mlm_probability: float = 0.15,
        max_length: int = 256,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "bert-base-uncased",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.text_col = text_col
        self.mlm_probability = mlm_probability
        self.max_length = max_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.tokenizer: AutoTokenizer | None = None
        self.data_collator: DataCollatorForLanguageModeling | None = None
        self.train_dataset: CausalLMFromCSVDataset | None = None
        self.val_dataset: CausalLMFromCSVDataset | None = None
        self.test_dataset: CausalLMFromCSVDataset | None = None
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        train_df, temp_df = train_test_split(
            df, test_size=self.val_split + self.test_split, random_state=self.seed
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df, test_size=1 - relative_val, random_state=self.seed
        )

        # Reuse CausalLMFromCSVDataset as it just returns text strings
        self.train_dataset = CausalLMFromCSVDataset(train_df, self.text_col)
        self.val_dataset = CausalLMFromCSVDataset(val_df, self.text_col)
        self.test_dataset = CausalLMFromCSVDataset(test_df, self.text_col)

        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_id)
        self.data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=True,
            mlm_probability=self.mlm_probability,
        )

    def _tokenize(self, batch: list[str]) -> dict[str, torch.Tensor]:
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer not initialized.")
        return self.tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

    def _collate_fn(self, batch: list[str]) -> dict[str, torch.Tensor]:
        encoding = self._tokenize(batch)
        # data_collator applies random masking and returns input_ids + labels
        collated = self.data_collator(
            [{"input_ids": ids} for ids in encoding.input_ids]
        )
        return {
            "input_ids": collated["input_ids"],
            "attention_mask": encoding.attention_mask,
            "labels": collated["labels"],
        }

    def _make_loader(self, dataset, shuffle: bool) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_dataset, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_dataset, shuffle=False)

ObjectDetectionDataModule

Datamodule for object detection tasks.

CSV columns: filename (image file), boxes (JSON list of [x_min, y_min, x_max, y_max]), class_labels (JSON list of int class IDs). Images live in class-neutral flat layout under root_dir/images/.

Source code in app/vision_automl/ml_engine/datamodule.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
class ObjectDetectionDataModule:
    """Datamodule for object detection tasks.

    CSV columns: ``filename`` (image file), ``boxes`` (JSON list of
    ``[x_min, y_min, x_max, y_max]``), ``class_labels`` (JSON list of
    int class IDs).  Images live in class-neutral flat layout under
    ``root_dir/images/``.
    """

    def __init__(
        self,
        csv_file: Path,
        root_dir: Path,
        img_col: str = "filename",
        boxes_col: str = "boxes",
        class_labels_col: str = "class_labels",
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "facebook/detr-resnet-50",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.root_dir = Path(root_dir)
        self.img_col = img_col
        self.boxes_col = boxes_col
        self.class_labels_col = class_labels_col
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.processor: AutoImageProcessor | None = None
        self.num_classes: int = 0
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.train_df: pd.DataFrame | None = None
        self.val_df: pd.DataFrame | None = None
        self.test_df: pd.DataFrame | None = None
        self.setup()

    def setup(self) -> None:
        import json as _json

        df = pd.read_csv(self.csv_file)
        all_labels: set[int] = set()
        for row in df[self.class_labels_col]:
            all_labels.update(_json.loads(row))
        self.num_classes = len(all_labels)
        self.id2label = {i: str(i) for i in sorted(all_labels)}
        self.label2id = {v: k for k, v in self.id2label.items()}

        train_df, temp_df = train_test_split(
            df, test_size=self.val_split + self.test_split, random_state=self.seed
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df, test_size=1 - relative_val, random_state=self.seed
        )
        self.train_df = train_df.reset_index(drop=True)
        self.val_df = val_df.reset_index(drop=True)
        self.test_df = test_df.reset_index(drop=True)
        self.processor = AutoImageProcessor.from_pretrained(self.hf_model_id)

    def _make_dataset(self, df: pd.DataFrame) -> Dataset:
        import json as _json
        from PIL import Image as _Image

        root = self.root_dir

        class _DetectionDataset(Dataset):
            def __init__(self, df, root, img_col, boxes_col, class_labels_col):
                self.df = df
                self.root = root
                self.img_col = img_col
                self.boxes_col = boxes_col
                self.class_labels_col = class_labels_col

            def __len__(self):
                return len(self.df)

            def __getitem__(self, idx):
                row = self.df.iloc[idx]
                img = _Image.open(self.root / str(row[self.img_col])).convert("RGB")
                boxes = _json.loads(row[self.boxes_col])
                class_labels = _json.loads(row[self.class_labels_col])
                return img, {
                    "boxes": torch.tensor(boxes, dtype=torch.float32),
                    "class_labels": torch.tensor(class_labels, dtype=torch.long),
                }

        return _DetectionDataset(
            df, root, self.img_col, self.boxes_col, self.class_labels_col
        )

    def _collate_fn(self, batch):
        images, targets = zip(*batch)
        encoding = self.processor(images=list(images), return_tensors="pt")
        return {"pixel_values": encoding.pixel_values, "labels": list(targets)}

    def _make_loader(self, df: pd.DataFrame, shuffle: bool) -> DataLoader:
        return DataLoader(
            self._make_dataset(df),
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_df, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_df, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_df, shuffle=False)

QuestionAnsweringDataModule

Datamodule for extractive question answering tasks.

CSV columns: question, context, answer_start (char offset), answer_text.

Source code in app/vision_automl/ml_engine/datamodule.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
class QuestionAnsweringDataModule:
    """Datamodule for extractive question answering tasks.

    CSV columns: ``question``, ``context``, ``answer_start`` (char offset),
    ``answer_text``.
    """

    def __init__(
        self,
        csv_file: Path,
        question_col: str = "question",
        context_col: str = "context",
        answer_start_col: str = "answer_start",
        answer_text_col: str = "answer_text",
        max_length: int = 384,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "distilbert-base-uncased-distilled-squad",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.question_col = question_col
        self.context_col = context_col
        self.answer_start_col = answer_start_col
        self.answer_text_col = answer_text_col
        self.max_length = max_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.tokenizer: AutoTokenizer | None = None
        self.train_dataset: QuestionAnsweringFromCSVDataset | None = None
        self.val_dataset: QuestionAnsweringFromCSVDataset | None = None
        self.test_dataset: QuestionAnsweringFromCSVDataset | None = None
        # QA tasks do not use id2label
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        train_df, temp_df = train_test_split(
            df, test_size=self.val_split + self.test_split, random_state=self.seed
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df, test_size=1 - relative_val, random_state=self.seed
        )

        self.train_dataset = QuestionAnsweringFromCSVDataset(
            train_df,
            self.question_col,
            self.context_col,
            self.answer_start_col,
            self.answer_text_col,
        )
        self.val_dataset = QuestionAnsweringFromCSVDataset(
            val_df,
            self.question_col,
            self.context_col,
            self.answer_start_col,
            self.answer_text_col,
        )
        self.test_dataset = QuestionAnsweringFromCSVDataset(
            test_df,
            self.question_col,
            self.context_col,
            self.answer_start_col,
            self.answer_text_col,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_id)

    def _collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer not initialized.")
        questions = [b["question"] for b in batch]
        contexts = [b["context"] for b in batch]
        answer_starts = [b["answer_start"] for b in batch]
        answer_texts = [b["answer_text"] for b in batch]

        encoding = self.tokenizer(
            questions,
            contexts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            return_offsets_mapping=True,
        )
        offset_mapping = encoding.pop("offset_mapping")

        # Convert character-level answer positions to token positions
        start_positions = []
        end_positions = []
        for i, (start_char, answer) in enumerate(zip(answer_starts, answer_texts)):
            end_char = start_char + len(answer)
            offsets = offset_mapping[i].tolist()
            token_start = token_end = 0
            for j, (s, e) in enumerate(offsets):
                if s <= start_char < e:
                    token_start = j
                if s < end_char <= e:
                    token_end = j
                    break
            start_positions.append(token_start)
            end_positions.append(token_end)

        return {
            "input_ids": encoding.input_ids,
            "attention_mask": encoding.attention_mask,
            "start_positions": torch.tensor(start_positions, dtype=torch.long),
            "end_positions": torch.tensor(end_positions, dtype=torch.long),
        }

    def _make_loader(self, dataset, shuffle: bool) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_dataset, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_dataset, shuffle=False)

Seq2SeqLMDataModule

Datamodule for sequence-to-sequence tasks.

CSV columns: input_text and target_text.

Source code in app/vision_automl/ml_engine/datamodule.py
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
class Seq2SeqLMDataModule:
    """Datamodule for sequence-to-sequence tasks.

    CSV columns: ``input_text`` and ``target_text``.
    """

    def __init__(
        self,
        csv_file: Path,
        input_col: str = "input_text",
        target_col: str = "target_text",
        max_source_length: int = 256,
        max_target_length: int = 128,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "t5-small",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.input_col = input_col
        self.target_col = target_col
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.tokenizer: AutoTokenizer | None = None
        self.train_dataset: Seq2SeqFromCSVDataset | None = None
        self.val_dataset: Seq2SeqFromCSVDataset | None = None
        self.test_dataset: Seq2SeqFromCSVDataset | None = None
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        train_df, temp_df = train_test_split(
            df, test_size=self.val_split + self.test_split, random_state=self.seed
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df, test_size=1 - relative_val, random_state=self.seed
        )

        self.train_dataset = Seq2SeqFromCSVDataset(
            train_df, self.input_col, self.target_col
        )
        self.val_dataset = Seq2SeqFromCSVDataset(
            val_df, self.input_col, self.target_col
        )
        self.test_dataset = Seq2SeqFromCSVDataset(
            test_df, self.input_col, self.target_col
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_id)

    def _collate_fn(self, batch: list[tuple[str, str]]) -> dict[str, torch.Tensor]:
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer not initialized.")
        inputs, targets = zip(*batch)
        src = self.tokenizer(
            list(inputs),
            padding=True,
            truncation=True,
            max_length=self.max_source_length,
            return_tensors="pt",
        )
        tgt = self.tokenizer(
            list(targets),
            padding=True,
            truncation=True,
            max_length=self.max_target_length,
            return_tensors="pt",
        )
        labels = tgt.input_ids.clone()
        # Replace pad token id with -100 so it's ignored in loss
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            "input_ids": src.input_ids,
            "attention_mask": src.attention_mask,
            "labels": labels,
        }

    def _make_loader(self, dataset, shuffle: bool) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_dataset, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_dataset, shuffle=False)

SequenceClassificationDataModule

Datamodule for text sequence classification tasks.

CSV columns: text and label.

Source code in app/vision_automl/ml_engine/datamodule.py
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
class SequenceClassificationDataModule:
    """Datamodule for text sequence classification tasks.

    CSV columns: ``text`` and ``label``.
    """

    def __init__(
        self,
        csv_file: Path,
        text_col: str = "text",
        label_col: str = "label",
        max_length: int = 128,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "distilbert-base-uncased",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.text_col = text_col
        self.label_col = label_col
        self.max_length = max_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.num_classes: int = 0
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.tokenizer: AutoTokenizer | None = None
        self.train_dataset: TextClassificationFromCSVDataset | None = None
        self.val_dataset: TextClassificationFromCSVDataset | None = None
        self.test_dataset: TextClassificationFromCSVDataset | None = None
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        train_df, temp_df = train_test_split(
            df,
            test_size=self.val_split + self.test_split,
            stratify=df[self.label_col],
            random_state=self.seed,
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df,
            test_size=1 - relative_val,
            stratify=temp_df[self.label_col],
            random_state=self.seed,
        )

        self.train_dataset = TextClassificationFromCSVDataset(
            train_df, self.text_col, self.label_col
        )
        self.val_dataset = TextClassificationFromCSVDataset(
            val_df, self.text_col, self.label_col
        )
        self.test_dataset = TextClassificationFromCSVDataset(
            test_df, self.text_col, self.label_col
        )

        classes = self.train_dataset.classes
        self.num_classes = len(classes)
        self.id2label = {i: str(c) for i, c in enumerate(classes)}
        self.label2id = {str(c): i for i, c in enumerate(classes)}

        self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_id)

    def _collate_fn(self, batch: list[tuple[str, int]]) -> dict[str, torch.Tensor]:
        texts, labels = zip(*batch)
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer not initialized.")
        encoding = self.tokenizer(
            list(texts),
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding.input_ids,
            "attention_mask": encoding.attention_mask,
            "labels": torch.tensor(labels, dtype=torch.long),
        }

    def _make_loader(self, dataset, shuffle: bool) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_dataset, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_dataset, shuffle=False)

VideoClassificationDataModule

Datamodule for video classification tasks.

CSV columns: video_path (relative to root_dir) and label. Frames are decoded using torchvision.io.read_video.

Source code in app/vision_automl/ml_engine/datamodule.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
class VideoClassificationDataModule:
    """Datamodule for video classification tasks.

    CSV columns: ``video_path`` (relative to ``root_dir``) and ``label``.
    Frames are decoded using ``torchvision.io.read_video``.
    """

    def __init__(
        self,
        csv_file: Path,
        root_dir: Path,
        video_col: str = "video_path",
        label_col: str = "label",
        num_frames: int = 8,
        batch_size: int = DEFAULT_BATCH_SIZE,
        num_workers: int = DEFAULT_NUM_WORKERS,
        val_split: float = DEFAULT_VAL_SPLIT,
        test_split: float = DEFAULT_TEST_SPLIT,
        seed: int = 42,
        hf_model_id: str = "MCG-NJU/videomae-base",
    ) -> None:
        self.csv_file = Path(csv_file)
        self.root_dir = Path(root_dir)
        self.video_col = video_col
        self.label_col = label_col
        self.num_frames = num_frames
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.hf_model_id = hf_model_id
        self.num_classes: int = 0
        self.id2label: dict[int, str] = {}
        self.label2id: dict[str, int] = {}
        self.processor: AutoImageProcessor | None = None
        self.train_df: pd.DataFrame | None = None
        self.val_df: pd.DataFrame | None = None
        self.test_df: pd.DataFrame | None = None
        self.setup()

    def setup(self) -> None:
        df = pd.read_csv(self.csv_file)
        classes = sorted(df[self.label_col].unique().tolist())
        self.num_classes = len(classes)
        self.id2label = {i: c for i, c in enumerate(classes)}
        self.label2id = {c: i for i, c in enumerate(classes)}
        df = df.copy()
        df[self.label_col] = df[self.label_col].map(self.label2id)

        train_df, temp_df = train_test_split(
            df,
            test_size=self.val_split + self.test_split,
            stratify=df[self.label_col],
            random_state=self.seed,
        )
        relative_val = self.val_split / (self.val_split + self.test_split)
        val_df, test_df = train_test_split(
            temp_df,
            test_size=1 - relative_val,
            stratify=temp_df[self.label_col],
            random_state=self.seed,
        )
        self.train_df = train_df.reset_index(drop=True)
        self.val_df = val_df.reset_index(drop=True)
        self.test_df = test_df.reset_index(drop=True)
        self.processor = AutoImageProcessor.from_pretrained(self.hf_model_id)

    def _make_dataset(self, df: pd.DataFrame) -> Dataset:
        from torchvision.io import read_video

        root = self.root_dir
        num_frames = self.num_frames
        video_col = self.video_col
        label_col = self.label_col

        class _VideoDataset(Dataset):
            def __init__(self, df):
                self.df = df

            def __len__(self):
                return len(self.df)

            def __getitem__(self, idx):
                row = self.df.iloc[idx]
                video_path = str(root / str(row[video_col]))
                frames, _, _ = read_video(
                    video_path, output_format="TCHW", pts_unit="sec"
                )
                # Sample num_frames evenly
                total = frames.shape[0]
                indices = torch.linspace(0, total - 1, num_frames).long()
                frames = frames[indices]  # (T, C, H, W)
                return frames.float() / 255.0, torch.tensor(
                    int(row[label_col]), dtype=torch.long
                )

        return _VideoDataset(df)

    def _collate_fn(self, batch):
        clips, labels = zip(*batch)
        return {
            "pixel_values": torch.stack(clips),  # (B, T, C, H, W)
            "labels": torch.tensor(labels, dtype=torch.long),
        }

    def _make_loader(self, df: pd.DataFrame, shuffle: bool) -> DataLoader:
        return DataLoader(
            self._make_dataset(df),
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )

    def train_dataloader(self) -> DataLoader:
        return self._make_loader(self.train_df, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._make_loader(self.val_df, shuffle=False)

    def test_dataloader(self) -> DataLoader:
        return self._make_loader(self.test_df, shuffle=False)

CausalLMFromCSVDataset

Bases: Dataset

Dataset for causal language modelling tasks.

Expected CSV column: text. The datamodule tokenises and shifts labels automatically.

Source code in app/vision_automl/ml_engine/dataset.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class CausalLMFromCSVDataset(Dataset):
    """Dataset for causal language modelling tasks.

    Expected CSV column: ``text``.  The datamodule tokenises and shifts
    labels automatically.
    """

    def __init__(
        self,
        csv_file: Union[Path, pd.DataFrame],
        text_col: str = "text",
    ):
        if isinstance(csv_file, Path):
            self.df = pd.read_csv(csv_file)
        elif isinstance(csv_file, pd.DataFrame):
            self.df = csv_file.reset_index(drop=True)
        else:
            raise ValueError("csv_file must be a path or DataFrame")

        self.text_col = text_col

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> str:
        if torch.is_tensor(idx):
            idx = idx.item()
        return str(self.df.iloc[idx][self.text_col])

ImageClassificationFromCSVDataset

Bases: Dataset

Torch dataset that reads image paths and labels from a CSV/DataFrame.

Source code in app/vision_automl/ml_engine/dataset.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class ImageClassificationFromCSVDataset(Dataset):
    """Torch dataset that reads image paths and labels from a CSV/DataFrame."""

    def __init__(
        self,
        csv_file: Union[Path, pd.DataFrame],
        root_dir: Path,
        img_col: str = "image",
        label_col: str = "label",
        transform: Optional[T.Compose] = None,
    ):
        if isinstance(csv_file, Path):
            self.label_csv = pd.read_csv(csv_file)
        elif isinstance(csv_file, pd.DataFrame):
            self.label_csv = csv_file.reset_index(drop=True)
        else:
            raise ValueError("csv_file must be a path or DataFrame")

        self.root_dir = root_dir
        self.img_col = img_col
        self.label_col = label_col
        # By default, do not apply torchvision transforms so that a Hugging Face
        # AutoImageProcessor can handle preprocessing in a DataLoader collate_fn.
        self.transform = transform

        if self.label_csv[self.label_col].dtype not in [int, float]:
            self.classes = sorted(self.label_csv[self.label_col].unique().tolist())
            self.class_to_idx = {
                cls_name: idx for idx, cls_name in enumerate(self.classes)
            }
            self.idx_to_class = {
                idx: cls_name for cls_name, idx in self.class_to_idx.items()
            }
            self.label_csv[self.label_col] = self.label_csv[self.label_col].map(
                self.class_to_idx
            )
        else:
            self.classes = sorted(self.label_csv[self.label_col].unique().tolist())
            self.class_to_idx = {cls: cls for cls in self.classes}
            self.idx_to_class = {cls: cls for cls in self.classes}

    def __len__(self):
        """Return number of samples."""
        return len(self.label_csv)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()

        row = self.label_csv.iloc[idx]
        label_idx = int(row[self.label_col])
        label_name = self.idx_to_class[label_idx]

        filename = str(row[self.img_col]).strip()

        img_path = self.root_dir / label_name / filename
        if not img_path.exists():
            print(os.listdir(self.root_dir))
            print(os.listdir(self.root_dir / label_name))

            raise FileNotFoundError(
                f"Image not found\n"
                f"Expected path: {img_path}\n"
                f"root_dir: {self.root_dir}\n"
                f"label_name: {label_name}\n"
                f"filename: {repr(filename)}"
            )

        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, torch.tensor(label_idx, dtype=torch.long)

__len__()

Return number of samples.

Source code in app/vision_automl/ml_engine/dataset.py
53
54
55
def __len__(self):
    """Return number of samples."""
    return len(self.label_csv)

QuestionAnsweringFromCSVDataset

Bases: Dataset

Dataset for extractive QA tasks.

Expected CSV columns: question, context, answer_start (int), answer_text (str). Returns raw strings; the datamodule tokenises them.

Source code in app/vision_automl/ml_engine/dataset.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class QuestionAnsweringFromCSVDataset(Dataset):
    """Dataset for extractive QA tasks.

    Expected CSV columns: ``question``, ``context``, ``answer_start`` (int),
    ``answer_text`` (str).  Returns raw strings; the datamodule tokenises them.
    """

    def __init__(
        self,
        csv_file: Union[Path, pd.DataFrame],
        question_col: str = "question",
        context_col: str = "context",
        answer_start_col: str = "answer_start",
        answer_text_col: str = "answer_text",
    ):
        if isinstance(csv_file, Path):
            self.df = pd.read_csv(csv_file)
        elif isinstance(csv_file, pd.DataFrame):
            self.df = csv_file.reset_index(drop=True)
        else:
            raise ValueError("csv_file must be a path or DataFrame")

        self.question_col = question_col
        self.context_col = context_col
        self.answer_start_col = answer_start_col
        self.answer_text_col = answer_text_col

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict:
        if torch.is_tensor(idx):
            idx = idx.item()
        row = self.df.iloc[idx]
        return {
            "question": str(row[self.question_col]),
            "context": str(row[self.context_col]),
            "answer_start": int(row[self.answer_start_col]),
            "answer_text": str(row[self.answer_text_col]),
        }

Seq2SeqFromCSVDataset

Bases: Dataset

Dataset for sequence-to-sequence tasks.

Expected CSV columns: input_text and target_text.

Source code in app/vision_automl/ml_engine/dataset.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class Seq2SeqFromCSVDataset(Dataset):
    """Dataset for sequence-to-sequence tasks.

    Expected CSV columns: ``input_text`` and ``target_text``.
    """

    def __init__(
        self,
        csv_file: Union[Path, pd.DataFrame],
        input_col: str = "input_text",
        target_col: str = "target_text",
    ):
        if isinstance(csv_file, Path):
            self.df = pd.read_csv(csv_file)
        elif isinstance(csv_file, pd.DataFrame):
            self.df = csv_file.reset_index(drop=True)
        else:
            raise ValueError("csv_file must be a path or DataFrame")

        self.input_col = input_col
        self.target_col = target_col

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> tuple[str, str]:
        if torch.is_tensor(idx):
            idx = idx.item()
        row = self.df.iloc[idx]
        return str(row[self.input_col]), str(row[self.target_col])

TextClassificationFromCSVDataset

Bases: Dataset

Torch dataset that reads text and labels from a CSV/DataFrame.

Expected columns: text (str) and label (str or int). Returns (text, label_idx) tuples — the collate function in the datamodule applies the tokeniser.

Source code in app/vision_automl/ml_engine/dataset.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class TextClassificationFromCSVDataset(Dataset):
    """Torch dataset that reads text and labels from a CSV/DataFrame.

    Expected columns: ``text`` (str) and ``label`` (str or int).
    Returns ``(text, label_idx)`` tuples — the collate function in the
    datamodule applies the tokeniser.
    """

    def __init__(
        self,
        csv_file: Union[Path, pd.DataFrame],
        text_col: str = "text",
        label_col: str = "label",
    ):
        if isinstance(csv_file, Path):
            self.df = pd.read_csv(csv_file)
        elif isinstance(csv_file, pd.DataFrame):
            self.df = csv_file.reset_index(drop=True)
        else:
            raise ValueError("csv_file must be a path or DataFrame")

        self.text_col = text_col
        self.label_col = label_col

        if self.df[self.label_col].dtype == object:
            self.classes = sorted(self.df[self.label_col].unique().tolist())
            self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
            self.df = self.df.copy()
            self.df[self.label_col] = self.df[self.label_col].map(self.class_to_idx)
        else:
            self.classes = sorted(self.df[self.label_col].unique().tolist())
            self.class_to_idx = {c: c for c in self.classes}

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> tuple[str, int]:
        if torch.is_tensor(idx):
            idx = idx.item()
        row = self.df.iloc[idx]
        return str(row[self.text_col]), int(row[self.label_col])

AudioClassificationModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForAudioClassification.

Source code in app/vision_automl/ml_engine/model.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class AudioClassificationModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForAudioClassification."""

    def __init__(
        self,
        model_id: str,
        num_classes: int = 2,
        id2label: dict | None = None,
        label2id: dict | None = None,
    ):
        super().__init__()
        self.model = AutoModelForAudioClassification.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
            num_labels=num_classes,
            id2label=id2label or {i: str(i) for i in range(num_classes)},
            label2id=label2id or {str(i): i for i in range(num_classes)},
        )

    def forward(self, input_values: torch.Tensor) -> torch.Tensor:
        return self.model(input_values=input_values).logits

CausalLMModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForCausalLM.

Source code in app/vision_automl/ml_engine/model.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class CausalLMModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForCausalLM."""

    def __init__(self, model_id: str):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(model_id)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Always returns the scalar language modelling loss."""
        return self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        ).loss

forward(input_ids, attention_mask=None, labels=None)

Always returns the scalar language modelling loss.

Source code in app/vision_automl/ml_engine/model.py
219
220
221
222
223
224
225
226
227
228
def forward(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    labels: torch.Tensor | None = None,
) -> torch.Tensor:
    """Always returns the scalar language modelling loss."""
    return self.model(
        input_ids=input_ids, attention_mask=attention_mask, labels=labels
    ).loss

ImageClassificationModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForImageClassification.

Source code in app/vision_automl/ml_engine/model.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class ImageClassificationModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForImageClassification."""

    def __init__(
        self,
        model_id: str = "google/vit-base-patch16-224",
        num_classes: int = 2,
        freeze_backbone: bool = True,
        id2label: dict | None = None,
        label2id: dict | None = None,
    ):
        super().__init__()
        config_kwargs = {
            "num_labels": num_classes,
            "id2label": id2label or {i: str(i) for i in range(num_classes)},
            "label2id": label2id or {str(i): i for i in range(num_classes)},
        }
        self.model = AutoModelForImageClassification.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
            **config_kwargs,
        )
        if freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
            if hasattr(self.model, "classifier"):
                for param in self.model.classifier.parameters():
                    param.requires_grad = True

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.model(pixel_values).logits

ImageSegmentationModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForImageSegmentation.

Source code in app/vision_automl/ml_engine/model.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class ImageSegmentationModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForImageSegmentation."""

    def __init__(
        self,
        model_id: str,
        num_classes: int = 2,
        id2label: dict | None = None,
        label2id: dict | None = None,
    ):
        super().__init__()
        self.model = AutoModelForImageSegmentation.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
            num_labels=num_classes,
            id2label=id2label or {i: str(i) for i in range(num_classes)},
            label2id=label2id or {str(i): i for i in range(num_classes)},
        )

    def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor | None = None):
        """Returns loss (scalar) when labels provided, else logits."""
        output = self.model(pixel_values=pixel_values, labels=labels)
        return output.loss if labels is not None else output.logits

forward(pixel_values, labels=None)

Returns loss (scalar) when labels provided, else logits.

Source code in app/vision_automl/ml_engine/model.py
74
75
76
77
def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor | None = None):
    """Returns loss (scalar) when labels provided, else logits."""
    output = self.model(pixel_values=pixel_values, labels=labels)
    return output.loss if labels is not None else output.logits

KeypointDetectionModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForKeypointDetection.

Source code in app/vision_automl/ml_engine/model.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class KeypointDetectionModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForKeypointDetection."""

    def __init__(self, model_id: str):
        super().__init__()
        self.model = AutoModelForKeypointDetection.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
        )

    def forward(self, pixel_values: torch.Tensor, labels=None):
        """Returns loss when labels provided, else raw output."""
        output = self.model(pixel_values=pixel_values, labels=labels)
        return output.loss if labels is not None else output

forward(pixel_values, labels=None)

Returns loss when labels provided, else raw output.

Source code in app/vision_automl/ml_engine/model.py
130
131
132
133
def forward(self, pixel_values: torch.Tensor, labels=None):
    """Returns loss when labels provided, else raw output."""
    output = self.model(pixel_values=pixel_values, labels=labels)
    return output.loss if labels is not None else output

MaskedLMModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForMaskedLM.

Source code in app/vision_automl/ml_engine/model.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class MaskedLMModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForMaskedLM."""

    def __init__(self, model_id: str):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(model_id)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Returns the scalar masked language modelling loss."""
        return self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        ).loss

forward(input_ids, attention_mask=None, labels=None)

Returns the scalar masked language modelling loss.

Source code in app/vision_automl/ml_engine/model.py
261
262
263
264
265
266
267
268
269
270
def forward(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    labels: torch.Tensor | None = None,
) -> torch.Tensor:
    """Returns the scalar masked language modelling loss."""
    return self.model(
        input_ids=input_ids, attention_mask=attention_mask, labels=labels
    ).loss

ObjectDetectionModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForObjectDetection.

Source code in app/vision_automl/ml_engine/model.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class ObjectDetectionModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForObjectDetection."""

    def __init__(self, model_id: str, num_classes: int = 2):
        super().__init__()
        self.model = AutoModelForObjectDetection.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
            num_labels=num_classes,
        )

    def forward(self, pixel_values: torch.Tensor, labels=None):
        """Returns loss when labels provided (list of dicts), else raw output."""
        output = self.model(pixel_values=pixel_values, labels=labels)
        return output.loss if labels is not None else output

forward(pixel_values, labels=None)

Returns loss when labels provided (list of dicts), else raw output.

Source code in app/vision_automl/ml_engine/model.py
91
92
93
94
def forward(self, pixel_values: torch.Tensor, labels=None):
    """Returns loss when labels provided (list of dicts), else raw output."""
    output = self.model(pixel_values=pixel_values, labels=labels)
    return output.loss if labels is not None else output

QuestionAnsweringModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForQuestionAnswering.

Source code in app/vision_automl/ml_engine/model.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class QuestionAnsweringModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForQuestionAnswering."""

    def __init__(self, model_id: str):
        super().__init__()
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_id)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        start_positions: torch.Tensor | None = None,
        end_positions: torch.Tensor | None = None,
    ):
        """Returns loss scalar when start/end positions provided, else raw output."""
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            start_positions=start_positions,
            end_positions=end_positions,
        )
        if start_positions is not None and end_positions is not None:
            return output.loss
        return output

forward(input_ids, attention_mask=None, start_positions=None, end_positions=None)

Returns loss scalar when start/end positions provided, else raw output.

Source code in app/vision_automl/ml_engine/model.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def forward(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    start_positions: torch.Tensor | None = None,
    end_positions: torch.Tensor | None = None,
):
    """Returns loss scalar when start/end positions provided, else raw output."""
    output = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        start_positions=start_positions,
        end_positions=end_positions,
    )
    if start_positions is not None and end_positions is not None:
        return output.loss
    return output

Seq2SeqLMModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForSeq2SeqLM.

Source code in app/vision_automl/ml_engine/model.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class Seq2SeqLMModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForSeq2SeqLM."""

    def __init__(self, model_id: str):
        super().__init__()
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        decoder_input_ids: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Returns the scalar seq2seq loss."""
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            labels=labels,
        ).loss

forward(input_ids, attention_mask=None, decoder_input_ids=None, labels=None)

Returns the scalar seq2seq loss.

Source code in app/vision_automl/ml_engine/model.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def forward(
    self,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    decoder_input_ids: torch.Tensor | None = None,
    labels: torch.Tensor | None = None,
) -> torch.Tensor:
    """Returns the scalar seq2seq loss."""
    return self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        labels=labels,
    ).loss

SequenceClassificationModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForSequenceClassification.

Source code in app/vision_automl/ml_engine/model.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class SequenceClassificationModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForSequenceClassification."""

    def __init__(
        self,
        model_id: str,
        num_classes: int = 2,
        id2label: dict | None = None,
        label2id: dict | None = None,
    ):
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
            num_labels=num_classes,
            id2label=id2label or {i: str(i) for i in range(num_classes)},
            label2id=label2id or {str(i): i for i in range(num_classes)},
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.model(input_ids=input_ids, attention_mask=attention_mask).logits

VideoClassificationModel

Bases: Module

Thin nn.Module wrapping HF AutoModelForVideoClassification.

Source code in app/vision_automl/ml_engine/model.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class VideoClassificationModel(nn.Module):
    """Thin nn.Module wrapping HF AutoModelForVideoClassification."""

    def __init__(
        self,
        model_id: str,
        num_classes: int = 2,
        id2label: dict | None = None,
        label2id: dict | None = None,
    ):
        super().__init__()
        self.model = AutoModelForVideoClassification.from_pretrained(
            model_id,
            ignore_mismatched_sizes=True,
            num_labels=num_classes,
            id2label=id2label or {i: str(i) for i in range(num_classes)},
            label2id=label2id or {str(i): i for i in range(num_classes)},
        )

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.model(pixel_values=pixel_values).logits

EarlyStopping

Simple early stopping callback based on monitored metric.

Source code in app/vision_automl/ml_engine/trainer.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class EarlyStopping:
    """Simple early stopping callback based on monitored metric."""

    def __init__(
        self, monitor: str = "val_loss", patience: int = 3, min_delta: float = 0.0
    ) -> None:
        self.monitor: str = monitor
        self.patience: int = patience
        self.min_delta: float = min_delta
        self.best: float = float("inf")
        self.counter: int = 0

    def on_epoch_end(
        self, trainer: "FabricTrainer", epoch: int, logs: dict[str, float]
    ) -> None:
        """Update state after epoch; may signal stopping on trainer."""
        current: float | None = logs.get(self.monitor)
        if current is None:
            logger.warning(
                f"Metric '{self.monitor}' not found in logs. Skipping early stopping check."
            )
            return

        if current < self.best - self.min_delta:
            self.best = current
            self.counter = 0
            logger.info(f"New best {self.monitor}: {self.best:.4f}")
        else:
            self.counter += 1
            logger.info(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                logger.info("Early stopping triggered!")
                trainer.epochs = epoch + 1

on_epoch_end(trainer, epoch, logs)

Update state after epoch; may signal stopping on trainer.

Source code in app/vision_automl/ml_engine/trainer.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def on_epoch_end(
    self, trainer: "FabricTrainer", epoch: int, logs: dict[str, float]
) -> None:
    """Update state after epoch; may signal stopping on trainer."""
    current: float | None = logs.get(self.monitor)
    if current is None:
        logger.warning(
            f"Metric '{self.monitor}' not found in logs. Skipping early stopping check."
        )
        return

    if current < self.best - self.min_delta:
        self.best = current
        self.counter = 0
        logger.info(f"New best {self.monitor}: {self.best:.4f}")
    else:
        self.counter += 1
        logger.info(f"EarlyStopping counter: {self.counter}/{self.patience}")
        if self.counter >= self.patience:
            logger.info("Early stopping triggered!")
            trainer.epochs = epoch + 1

FabricTrainer

Minimal trainer using Lightning Fabric.

Supports both: - Classification tasks where the model returns logits and the trainer computes the loss via loss_fn (model_computes_loss=False). - Generative / structured-prediction tasks where the model computes its own loss internally and returns a scalar tensor (model_computes_loss=True).

Source code in app/vision_automl/ml_engine/trainer.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class FabricTrainer:
    """Minimal trainer using Lightning Fabric.

    Supports both:
    - Classification tasks where the model returns logits and the trainer
      computes the loss via ``loss_fn`` (``model_computes_loss=False``).
    - Generative / structured-prediction tasks where the model computes
      its own loss internally and returns a scalar tensor
      (``model_computes_loss=True``).
    """

    def __init__(
        self,
        datamodule: Any,
        model_class: type[nn.Module],
        model_kwargs: dict[str, Any] = {},
        optimizer_class: type[optim.Optimizer] = optim.AdamW,
        optimizer_kwargs: dict[str, Any] = {},
        loss_fn: nn.Module = nn.CrossEntropyLoss(),
        lr: float = 0.001,
        epochs: int = 1,
        time_limit: float | None = None,
        device: str = "auto",
        callbacks: list[Any] = [],
        input_dtype: torch.dtype = torch.float32,
        target_dtype: torch.dtype = torch.long,
        model_computes_loss: bool = False,
    ) -> None:
        self.datamodule: Any = datamodule
        self.model_class: type[nn.Module] = model_class
        self.model_kwargs: dict[str, Any] = model_kwargs
        self.optimizer_class: type[optim.Optimizer] = optimizer_class
        self.optimizer_kwargs: dict[str, Any] = optimizer_kwargs or {"lr": lr}
        self.loss_fn: nn.Module = loss_fn
        self.epochs: int = epochs
        self.time_limit: float | None = time_limit
        self.device: str = device
        self.callbacks: list[Any] = callbacks
        self.input_dtype: torch.dtype = input_dtype
        self.target_dtype: torch.dtype = target_dtype
        self.model_computes_loss: bool = model_computes_loss

        self.fabric: L.Fabric = L.Fabric(devices=self.device)
        self._setup_model_optimizer()

    def _setup_model_optimizer(self) -> None:
        """Instantiate model and optimizer and prepare loaders with Fabric."""
        logger.info("Setting up model and optimizer.")
        self.model: nn.Module = self.model_class(**self.model_kwargs)
        self.optimizer: optim.Optimizer = self.optimizer_class(
            self.model.parameters(), **self.optimizer_kwargs
        )

        train_loader: Any = self.datamodule.train_dataloader()
        val_loader: Any = self.datamodule.val_dataloader()
        (
            self.model,
            self.optimizer,
            self.train_loader,
            self.val_loader,
        ) = self.fabric.setup(self.model, self.optimizer, train_loader, val_loader)
        self.test_loader: Any = self.datamodule.test_dataloader()
        logger.info("Model and optimizer setup complete.")

    def _move_batch(self, batch: Any) -> dict[str, Any]:
        """Move batch to the Fabric device.

        Handles arbitrary dict batches (all modalities) and legacy
        ``(images, labels)`` tuple batches.  Non-tensor values (e.g. list
        of annotation dicts for object detection) are passed through as-is.
        Integer tensors (``input_ids``, etc.) are moved without dtype coercion.
        """
        if isinstance(batch, dict):
            moved: dict[str, Any] = {}
            for k, v in batch.items():
                if not isinstance(v, torch.Tensor):
                    moved[k] = v  # keep non-tensors (e.g. list of dicts)
                elif k in _TARGET_KEYS:
                    moved[k] = v.to(self.fabric.device, dtype=self.target_dtype)
                elif v.dtype.is_floating_point:
                    moved[k] = v.to(self.fabric.device, dtype=self.input_dtype)
                else:
                    # int/long tensors (input_ids, etc.) — preserve dtype
                    moved[k] = v.to(self.fabric.device)
            return moved
        else:
            imgs, batch_labels = batch
            return {
                "pixel_values": imgs.to(self.fabric.device, dtype=self.input_dtype),
                "labels": batch_labels.to(self.fabric.device, dtype=self.target_dtype),
            }

    def _check_time_limit(self, start_time: float) -> bool:
        """Return True if configured time limit has been exceeded."""
        elapsed: float = time.time() - start_time
        if self.time_limit and elapsed > self.time_limit:
            logger.warning(f"Time limit reached ({elapsed:.2f}s). Stopping training.")
            return True
        return False

    def _compute_loss_and_logits(
        self, moved: dict[str, Any]
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Run forward pass, return (loss, logits_or_None)."""
        if self.model_computes_loss:
            outputs = self.model(**moved)
            loss = outputs if isinstance(outputs, torch.Tensor) else outputs.loss
            return loss, None
        else:
            labels = moved.pop("labels")
            outputs = self.model(**moved)
            loss = self.loss_fn(outputs, labels)
            return loss, outputs

    def train_epoch(self, epoch: int, start_time: float) -> float:
        """Train for a single epoch and return average training loss."""
        self.model.train()
        running_loss: float = 0.0
        batch_count: int = len(self.train_loader)

        for batch in tqdm(
            self.train_loader, desc=f"Epoch {epoch+1} Training", leave=False
        ):
            if self._check_time_limit(start_time):
                return running_loss / max(1, batch_count)

            moved = self._move_batch(batch)
            self.optimizer.zero_grad()
            loss, _ = self._compute_loss_and_logits(moved)
            self.fabric.backward(loss)
            self.optimizer.step()
            running_loss += loss.item()

        avg_loss: float = running_loss / batch_count
        logger.info(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")
        return avg_loss

    def validate(self, start_time: float) -> tuple[float, float]:
        """Evaluate on validation set; return (avg_loss, accuracy)."""
        self.model.eval()
        val_loss: float = 0.0
        correct: int = 0
        total: int = 0

        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation", leave=False):
                if self._check_time_limit(start_time):
                    break

                moved = self._move_batch(batch)

                if self.model_computes_loss:
                    outputs = self.model(**moved)
                    loss = (
                        outputs if isinstance(outputs, torch.Tensor) else outputs.loss
                    )
                    val_loss += loss.item()
                else:
                    labels = moved.pop("labels")
                    outputs = self.model(**moved)
                    loss = self.loss_fn(outputs, labels)
                    val_loss += loss.item()
                    preds = outputs.argmax(dim=1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

        avg_loss: float = val_loss / max(1, len(self.val_loader))
        accuracy: float = correct / max(1, total)
        logger.info(f"Validation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        return avg_loss, accuracy

    def test(self) -> tuple[float, float]:
        """Evaluate on test set; return (avg_loss, accuracy)."""
        self.model.eval()
        test_loss: float = 0.0
        correct: int = 0
        total: int = 0

        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc="Testing"):
                moved = self._move_batch(batch)

                if self.model_computes_loss:
                    outputs = self.model(**moved)
                    loss = (
                        outputs if isinstance(outputs, torch.Tensor) else outputs.loss
                    )
                    test_loss += loss.item()
                else:
                    labels = moved.pop("labels")
                    outputs = self.model(**moved)
                    loss = self.loss_fn(outputs, labels)
                    test_loss += loss.item()
                    preds = outputs.argmax(dim=1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

        avg_loss: float = test_loss / len(self.test_loader)
        accuracy: float = correct / max(1, total)
        logger.info(f"Test Results - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        return avg_loss, accuracy

    def fit(self, trial: optuna.Trial | None = None) -> tuple[float, float]:
        logger.info("Starting training loop.")
        start_time: float = time.time()

        for epoch in range(self.epochs):
            train_loss = self.train_epoch(epoch, start_time)
            val_loss, val_acc = self.validate(start_time)

            if trial is not None:
                trial.report(val_loss, step=epoch)
                if trial.should_prune():
                    raise optuna.TrialPruned()

            logs = {
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_acc": val_acc,
            }

            for cb in self.callbacks:
                cb.on_epoch_end(self, epoch, logs)

            if self._check_time_limit(start_time):
                break

        return self.test()

test()

Evaluate on test set; return (avg_loss, accuracy).

Source code in app/vision_automl/ml_engine/trainer.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def test(self) -> tuple[float, float]:
    """Evaluate on test set; return (avg_loss, accuracy)."""
    self.model.eval()
    test_loss: float = 0.0
    correct: int = 0
    total: int = 0

    with torch.no_grad():
        for batch in tqdm(self.test_loader, desc="Testing"):
            moved = self._move_batch(batch)

            if self.model_computes_loss:
                outputs = self.model(**moved)
                loss = (
                    outputs if isinstance(outputs, torch.Tensor) else outputs.loss
                )
                test_loss += loss.item()
            else:
                labels = moved.pop("labels")
                outputs = self.model(**moved)
                loss = self.loss_fn(outputs, labels)
                test_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

    avg_loss: float = test_loss / len(self.test_loader)
    accuracy: float = correct / max(1, total)
    logger.info(f"Test Results - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy

train_epoch(epoch, start_time)

Train for a single epoch and return average training loss.

Source code in app/vision_automl/ml_engine/trainer.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def train_epoch(self, epoch: int, start_time: float) -> float:
    """Train for a single epoch and return average training loss."""
    self.model.train()
    running_loss: float = 0.0
    batch_count: int = len(self.train_loader)

    for batch in tqdm(
        self.train_loader, desc=f"Epoch {epoch+1} Training", leave=False
    ):
        if self._check_time_limit(start_time):
            return running_loss / max(1, batch_count)

        moved = self._move_batch(batch)
        self.optimizer.zero_grad()
        loss, _ = self._compute_loss_and_logits(moved)
        self.fabric.backward(loss)
        self.optimizer.step()
        running_loss += loss.item()

    avg_loss: float = running_loss / batch_count
    logger.info(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")
    return avg_loss

validate(start_time)

Evaluate on validation set; return (avg_loss, accuracy).

Source code in app/vision_automl/ml_engine/trainer.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def validate(self, start_time: float) -> tuple[float, float]:
    """Evaluate on validation set; return (avg_loss, accuracy)."""
    self.model.eval()
    val_loss: float = 0.0
    correct: int = 0
    total: int = 0

    with torch.no_grad():
        for batch in tqdm(self.val_loader, desc="Validation", leave=False):
            if self._check_time_limit(start_time):
                break

            moved = self._move_batch(batch)

            if self.model_computes_loss:
                outputs = self.model(**moved)
                loss = (
                    outputs if isinstance(outputs, torch.Tensor) else outputs.loss
                )
                val_loss += loss.item()
            else:
                labels = moved.pop("labels")
                outputs = self.model(**moved)
                loss = self.loss_fn(outputs, labels)
                val_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

    avg_loss: float = val_loss / max(1, len(self.val_loader))
    accuracy: float = correct / max(1, total)
    logger.info(f"Validation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy

Run an Optuna hyperparameter search for the given task type.

Dispatches to the appropriate per-task objective via OBJECTIVE_REGISTRY. extra_kwargs are forwarded to the objective (e.g. text_column for text tasks).

Raises:

Type Description
ValueError

If task_type is not in OBJECTIVE_REGISTRY.

Source code in app/vision_automl/ml_engine/trainer.py
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
def run_optuna_search(
    *,
    task_type: str = "image_classification",
    csv_path: Path,
    images_dir: Path | None = None,
    filename_column: str = "filename",
    label_column: str = "label",
    n_trials: int = 3,
    timeout: int | None = None,
    model_size: str = "small",
    workdir: Path,
    **extra_kwargs,
) -> dict:
    """Run an Optuna hyperparameter search for the given task type.

    Dispatches to the appropriate per-task objective via ``OBJECTIVE_REGISTRY``.
    ``extra_kwargs`` are forwarded to the objective (e.g. ``text_column`` for
    text tasks).

    Raises:
        ValueError: If ``task_type`` is not in ``OBJECTIVE_REGISTRY``.
    """
    if task_type not in OBJECTIVE_REGISTRY:
        raise ValueError(
            f"Unknown task type '{task_type}'. "
            f"Supported: {sorted(OBJECTIVE_REGISTRY)}"
        )

    config = load_task_config(task_type)
    objective_fn = OBJECTIVE_REGISTRY[task_type]

    run_dir = workdir / "optuna"
    run_dir.mkdir(exist_ok=True)

    pruner = optuna.pruners.SuccessiveHalvingPruner(
        min_resource=10,
        reduction_factor=3,
        min_early_stopping_rate=0,
    )
    sampler = optuna.samplers.TPESampler(seed=42)
    study = optuna.create_study(direction="minimize", sampler=sampler, pruner=pruner)
    timeout_per_trial = timeout / max(n_trials, 1) if timeout else None

    # Build keyword arguments for the objective
    objective_kwargs: dict = {
        "csv_path": csv_path,
        "images_dir": images_dir,
        "filename_column": filename_column,
        "label_column": label_column,
        "model_size": model_size,
        "timeout_per_trial": timeout_per_trial,
        "config": config,
        **extra_kwargs,
    }

    study.optimize(
        functools.partial(objective_fn, **objective_kwargs),
        n_trials=n_trials,
        timeout=timeout,
    )

    completed = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    if not completed:
        raise RuntimeError(
            f"All {len(study.trials)} Optuna trial(s) failed or were pruned. "
            "Check your dataset, model IDs, and time budget."
        )

    return {
        "best_value": study.best_value,
        "best_params": study.best_params,
        "n_trials": len(study.trials),
        "model_dir": run_dir / f"trial_{study.best_trial.number}",
    }