69 lines
2.3 KiB
JavaScript
69 lines
2.3 KiB
JavaScript
const XLSX = require('xlsx');
|
|
const tf = require("@tensorflow/tfjs");
|
|
//const tf = require("@tensorflow/tfjs-node");
|
|
|
|
function worksheet_to_csv_url(worksheet) {
|
|
/* generate CSV */
|
|
const csv = XLSX.utils.sheet_to_csv(worksheet);
|
|
|
|
/* CSV -> Uint8Array -> Blob */
|
|
const u8 = new TextEncoder().encode(csv);
|
|
const blob = new Blob([u8], { type: "text/csv" });
|
|
|
|
/* generate a blob URL */
|
|
return URL.createObjectURL(blob);
|
|
}
|
|
|
|
(async() => { try {
|
|
/* fetch file */
|
|
const f = await fetch("https://docs.sheetjs.com/cd.xls");
|
|
const ab = await f.arrayBuffer();
|
|
/* parse file and get first worksheet */
|
|
const wb = XLSX.read(ab);
|
|
const ws = wb.Sheets[wb.SheetNames[0]];
|
|
|
|
/* generate blob URL */
|
|
const url = worksheet_to_csv_url(ws);
|
|
|
|
/* feed to tf.js */
|
|
const dataset = tf.data.csv(url, {
|
|
hasHeader: true,
|
|
configuredColumnsOnly: true,
|
|
columnConfigs:{
|
|
"Horsepower": {required: false, default: 0},
|
|
"Miles_per_Gallon":{required: false, default: 0, isLabel:true}
|
|
}
|
|
});
|
|
|
|
/* pre-process data */
|
|
let flat = dataset
|
|
.map(({xs,ys}) =>({xs: Object.values(xs), ys: Object.values(ys)}))
|
|
.filter(({xs,ys}) => [...xs,...ys].every(v => v>0));
|
|
|
|
/* normalize manually :( */
|
|
let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity;
|
|
await flat.forEachAsync(({xs, ys}) => {
|
|
minX = Math.min(minX, xs[0]); maxX = Math.max(maxX, xs[0]);
|
|
minY = Math.min(minY, ys[0]); maxY = Math.max(maxY, ys[0]);
|
|
});
|
|
flat = flat.map(({xs, ys}) => ({xs:xs.map(v => (v-minX)/(maxX - minX)),ys:ys.map(v => (v-minY)/(maxY-minY))}));
|
|
flat = flat.batch(32);
|
|
|
|
/* build and train model */
|
|
const model = tf.sequential();
|
|
model.add(tf.layers.dense({inputShape: [1], units: 1}));
|
|
model.compile({ optimizer: tf.train.sgd(0.000001), loss: 'meanSquaredError' });
|
|
await model.fitDataset(flat, { epochs: 100, callbacks: { onEpochEnd: async (epoch, logs) => {
|
|
console.error(`${epoch}:${logs.loss}`);
|
|
}}});
|
|
|
|
/* predict values */
|
|
const inp = tf.linspace(0, 1, 9);
|
|
const pred = model.predict(inp);
|
|
const xs = await inp.dataSync(), ys = await pred.dataSync();
|
|
|
|
for (let i=0; i<xs.length; ++i) {
|
|
console.log([xs[i] * (maxX - minX) + minX, ys[i] * (maxY - minY) + minY].join(" "));
|
|
}
|
|
} catch(e) { console.error(`ERROR: ${String(e)}`); }})();
|