Skip to content

Commit

Permalink
Merge pull request #80 from COS301-SE-2024/feat/functions
Browse files Browse the repository at this point in the history
Feat/functions
  • Loading branch information
James-Fitzsimons authored Sep 6, 2024
2 parents 75ecd5a + f6b34e8 commit 25c4273
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 136 deletions.
199 changes: 117 additions & 82 deletions functions/index.js
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
/**
* Import function triggers from their respective submodules:
*
* const {onCall} = require("firebase-functions/v2/https");
* const {onDocumentWritten} = require("firebase-functions/v2/firestore");
*
* See a full list of supported triggers at https://firebase.google.com/docs/functions
*/
// logger.info('Hello logs!', { structuredData: true });

// Dependencies for callable functions.
const { onCall, HttpsError } = require('firebase-functions/v2/https');
const { logger } = require('firebase-functions/v2');
//dependencies
const { getDatabase } = require('firebase-admin/database');
const { initializeApp } = require('firebase-admin/app');
const { getFirestore } = require('firebase-admin/firestore');

initializeApp();
let pipeline;
let categoryEmbeddings;
let pipe;

async function initialize() {
const { pipeline } = await import('@xenova/transformers');
pipe = await pipeline('feature-extraction', 'Supabase/gte-small');
await initializeCategoryEmbeddings(pipe);
}

async function initializeCategoryEmbeddings(pipe) {
categoryEmbeddings = await Promise.all(
categories.map(async (category) => {
const output = await pipe(category, { pooling: 'mean', normalize: true });
return Array.from(output.data);
})
);
}

const Months = [
'january',
Expand Down Expand Up @@ -49,16 +53,6 @@ const categories = [

async function useModel(transactionDescription) {
try {
if (!pipeline) {
await initializePipeline();
}

const pipe = await pipeline('feature-extraction', 'Supabase/gte-small');

if (!categoryEmbeddings) {
await initializeCategoryEmbeddings(pipe);
}

// Generate embedding for the transaction description
const descriptionOutput = await pipe(transactionDescription, {
pooling: 'mean',
Expand All @@ -83,24 +77,10 @@ async function useModel(transactionDescription) {

return closestCategory;
} catch (error) {
logger.log(error);
logger.error('Error in useModel:', error);
}
}

async function initializePipeline() {
const transformers = await import('@xenova/transformers');
pipeline = transformers.pipeline;
}

async function initializeCategoryEmbeddings(pipe) {
categoryEmbeddings = await Promise.all(
categories.map(async (category) => {
const output = await pipe(category, { pooling: 'mean', normalize: true });
return Array.from(output.data);
})
);
}

function cosineSimilarity(embedding1, embedding2) {
const dotProduct = embedding1.reduce(
(sum, val, i) => sum + val * embedding2[i],
Expand Down Expand Up @@ -224,54 +204,109 @@ function useKnownList(description) {
return '';
}

exports.categoriseExpenses = onCall(async (request) => {
const year = request.data.year;
const uid = request.auth.uid;
// const doc = await getFirestore().doc(`transaction_data_${year}/${uid}`);
// const docSnap = await doc.get();
// .collection(db, `transaction_data_${year}`);
const db = getDatabase();
const accRef = getFirestore().collection(`transaction_data_${year}`);
const snapshot = await accRef.where('uid', '==', uid).get();
// exports.categoriseExpenses = onCall(
// { timeoutSeconds: 300, memory: '1GiB', cpu: 2 },
// async (request) => {
// await initialize();

// const year = request.data.year;
// const uid = request.auth.uid;
// const accRef = getFirestore().collection(`transaction_data_${year}`);
// const snapshot = await accRef.where('uid', '==', uid).get();

// snapshot.forEach(async (doc) => {
// let updateFlag = false;
// //can categorize and set
// for (month of Months) {
// if (doc.data()[month]) {
// const IncomingMonthData = JSON.parse(doc.data()[month]);
// for (transaction of IncomingMonthData) {
// //do this all async
// logger.info(month);
// if (transaction.category == '') {
// updateFlag = true;
// let newCategory = '';
// if (transaction.amount > 0) {
// newCategory = 'Income';
// }
// if (newCategory == '') {
// newCategory = await useKnownList(transaction.description);
// }
// if (newCategory == '') {
// newCategory = await useModel(transaction.description);
// }
// if (newCategory == 'Fuel') {
// if (Math.abs(parseFloat(transaction.amount)) < 100) {
// newCategory = 'Eating Out';
// } else {
// newCategory = 'Transport';
// }
// }
// transaction.category = newCategory;
// }
// }
// if (updateFlag) {
// await getFirestore()
// .doc(`transaction_data_${year}/${doc.id}`)
// .update({ [month]: JSON.stringify(IncomingMonthData) });
// }
// }
// }
// });
// }
// );

// const q = getFirestore().query(accRef, where('uid', '==', user.uid));
// const querySnapshot = await getDocs(q);
exports.categoriseExpenses = onCall(
{ timeoutSeconds: 300, memory: '1GiB', cpu: 2 },
async (request) => {
await initialize();

snapshot.forEach(async (doc) => {
let updateFlag = false;
//can categorize and set
for (month of Months) {
if (doc.data()[month]) {
const IncomingMonthData = JSON.parse(doc.data()[month]);
for (transaction of IncomingMonthData) {
if (transaction.category == '') {
updateFlag = true;
let newCategory = '';
if (transaction.amount > 0) {
newCategory = 'Income';
}
if (newCategory == '') {
newCategory = await useKnownList(transaction.description);
}
if (newCategory == '') {
newCategory = await useModel(transaction.description);
}
if (newCategory == 'Fuel') {
if (Math.abs(parseFloat(transaction.amount)) < 100) {
newCategory = 'Eating Out';
} else {
newCategory = 'Transport';
const year = request.data.year;
const uid = request.auth.uid;
const accRef = getFirestore().collection(`transaction_data_${year}`);
const snapshot = await accRef.where('uid', '==', uid).get();

const updatePromises = snapshot.docs.map(async (doc) => {
for (const month of Months) {
let updateFlag = false;
if (doc.data()[month]) {
const IncomingMonthData = JSON.parse(doc.data()[month]);
const updatedTransactions = await Promise.all(
IncomingMonthData.map(async (transaction) => {
if (transaction.category === '') {
updateFlag = true;
let newCategory = '';
if (transaction.amount > 0) {
newCategory = 'Income';
}
if (newCategory === '') {
newCategory = useKnownList(transaction.description);
}
if (newCategory === '') {
newCategory = await useModel(transaction.description);
}
if (newCategory === 'Fuel') {
if (Math.abs(parseFloat(transaction.amount)) < 100) {
newCategory = 'Eating Out';
} else {
newCategory = 'Transport';
}
}
return { ...transaction, category: newCategory };
}
}
transaction.category = newCategory;
return transaction;
})
);
if (updateFlag) {
await getFirestore()
.doc(`transaction_data_${year}/${doc.id}`)
.update({ [month]: JSON.stringify(updatedTransactions) });
}
}
if (updateFlag) {
await getFirestore()
.doc(`transaction_data_${year}/${doc.id}`)
.update({ [month]: JSON.stringify(IncomingMonthData) });
}
}
}
});
});
});

// Wait for all updates to complete
await Promise.all(updatePromises);
}
);
Loading

0 comments on commit 25c4273

Please sign in to comment.