A comprehensive question-answering system leveraging state space models for Indic languages (Hindi, Telugu, Marathi).
MAMBA_QA is a research project focused on exploring and implementing various state space models (SSMs) for question-answering tasks in Indic languages. The system supports multiple cutting-edge SSM architectures including Mamba, Mamba2, Jamba, Falcon, and FalconV2.
- Question-answering capabilities for Hindi, Telugu, and Marathi languages
- Support for multiple state space model architectures:
- Mamba (original)
- Mamba2 (with improved performance)
- Jamba (Mamba variant)
- Falcon
- FalconV2
- Extensible architecture for adding new models and languages
- Training and evaluation pipelines for QA datasets
- Python 3.8+
- PyTorch
- Mamba SSM library
- Hugging Face Transformers
- CUDA support for GPU acceleration
torch>=2.0.0
transformers>=4.30.0
datasets>=2.10.0
mamba-ssm>=1.0.0
evaluate>=0.4.0
numpy>=1.22.0
pandas>=1.5.0
tqdm>=4.65.0
matplotlib>=3.7.0
scikit-learn>=1.2.0
# Clone the repository
git clone git@github.com:mrinal18/MAMBA_QA.git
cd MAMBA_QA
# Create and activate virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Install additional model dependencies
pip install mamba-ssm
This implementation uses the IndicQA dataset from the Hugging Face Hub, which contains question-answer pairs in various Indic languages. The system focuses on Hindi, Telugu, and Marathi, but you can modify the configuration to include other languages available in the dataset.
from models.mamba<
8000
/span> import MambaForQuestionAnswering
from utils.trainer import train_model
from utils.data_loader import load_indic_qa_dataset
# Load dataset
train_data, val_data = load_indic_qa_dataset("hindi")
# Initialize model
model = MambaForQuestionAnswering("state-spaces/mamba-370m")
# Train the model
train_model(model, train_data, val_data,
epochs=3,
batch_size=8,
learning_rate=5e-5)
from models.mamba import MambaForQuestionAnswering
from utils.inference import get_answer
model = MambaForQuestionAnswering("path/to/saved/model")
model.eval()
context = "भारत के प्रधानमंत्री नरेंद्र मोदी हैं। वे गुजरात के रहने वाले हैं।"
question = "भारत के प्रधानमंत्री कौन हैं?"
answer = get_answer(model, question, context)
print(f"Answer: {answer}") # Output: नरेंद्र मोदी
The system uses the Mamba2 state space model as the backbone, with a question-answering head on top for predicting answer spans. State space models like Mamba2 offer efficient sequence processing with linear scaling to sequence length, making them well-suited for tasks involving long contexts. The QA head outputs two sets of logits:
- Start logits - predicting the start position of the answer in the context
- End logits - predicting the end position of the answer in the context
The system is evaluated using standard QA metrics:
- Exact Match (EM): Percentage of predictions that exactly match the ground truth
- F1 Score: Harmonic mean of precision and recall at the token level
After training, the system generates plots showing:
- Training and validation loss curves
- Evaluation metrics (EM and F1) across epochs
These plots are saved in the output directory specified in the configuration.
To extend the system to other Indic languages supported by IndicQA:
- Modify the languages list in config.py to include additional language codes
- Ensure that SSM tokenizer supports the script used by the target languages
MAMBA_QA/
├── data/ # Dataset storage and preprocessing
├── models/ # Model implementations
│ ├── mamba.py # Mamba, Mamba2, Jamba, Falcon, FalconV2 implementations
│ └── ...
├── utils/ # Utility functions
│ ├── config.py # Configuration settings
│ ├── data_loader.py # Data loading utilities
│ └── ...
├── experiments/ # Experiment scripts and results
├── notebooks/ # Jupyter notebooks for analysis
├── requirements.txt # Dependencies
└── README.md # Project documentation