forked from sheetjs/docs.sheetjs.com
93 lines
3.4 KiB
TypeScript
93 lines
3.4 KiB
TypeScript
|
import { useState, useCallback } from "kaioken";
|
||
|
import { TensorContainerObject, data, layers, linspace, train, sequential } from "@tensorflow/tfjs";
|
||
|
import { read, utils } from "xlsx";
|
||
|
|
||
|
import type { Tensor, Rank } from "@tensorflow/tfjs";
|
||
|
import type { WorkSheet } from "xlsx";
|
||
|
|
||
|
interface Data extends TensorContainerObject {
|
||
|
xs: Tensor;
|
||
|
ys: Tensor;
|
||
|
}
|
||
|
type DSet = data.Dataset<Data>;
|
||
|
|
||
|
export default function SheetJSToTFJSCSV() {
|
||
|
const [output, setOutput] = useState("");
|
||
|
const [results, setResults] = useState<[number, number][]>([]);
|
||
|
const [disabled, setDisabled] = useState(false);
|
||
|
|
||
|
function worksheet_to_csv_url(worksheet: WorkSheet) {
|
||
|
/* generate CSV */
|
||
|
const csv = utils.sheet_to_csv(worksheet);
|
||
|
|
||
|
/* CSV -> Uint8Array -> Blob */
|
||
|
const u8 = new TextEncoder().encode(csv);
|
||
|
const blob = new Blob([u8], { type: "text/csv" });
|
||
|
|
||
|
/* generate a blob URL */
|
||
|
return URL.createObjectURL(blob);
|
||
|
}
|
||
|
|
||
|
const doit = useCallback(async () => {
|
||
|
setResults([]); setOutput(""); setDisabled(true);
|
||
|
try {
|
||
|
/* fetch file */
|
||
|
const f = await fetch("https://docs.sheetjs.com/cd.xls");
|
||
|
const ab = await f.arrayBuffer();
|
||
|
/* parse file and get first worksheet */
|
||
|
const wb = read(ab);
|
||
|
const ws = wb.Sheets[wb.SheetNames[0]];
|
||
|
|
||
|
/* generate blob URL */
|
||
|
const url = worksheet_to_csv_url(ws);
|
||
|
|
||
|
/* feed to tf.js */
|
||
|
const dataset = data.csv(url, {
|
||
|
hasHeader: true,
|
||
|
configuredColumnsOnly: true,
|
||
|
columnConfigs:{
|
||
|
"Horsepower": {required: false, default: 0},
|
||
|
"Miles_per_Gallon":{required: false, default: 0, isLabel:true}
|
||
|
}
|
||
|
});
|
||
|
|
||
|
/* pre-process data */
|
||
|
let flat = (dataset as unknown as DSet)
|
||
|
.map(({xs,ys}) =>({xs: Object.values(xs), ys: Object.values(ys)}))
|
||
|
.filter(({xs,ys}) => [...xs,...ys].every(v => v>0));
|
||
|
|
||
|
/* normalize manually :( */
|
||
|
let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity;
|
||
|
await flat.forEachAsync(({xs, ys}) => {
|
||
|
minX = Math.min(minX, xs[0]); maxX = Math.max(maxX, xs[0]);
|
||
|
minY = Math.min(minY, ys[0]); maxY = Math.max(maxY, ys[0]);
|
||
|
});
|
||
|
flat = flat.map(({xs, ys}) => ({xs:xs.map(v => (v-minX)/(maxX - minX)),ys:ys.map(v => (v-minY)/(maxY-minY))}));
|
||
|
let batch = flat.batch(32);
|
||
|
|
||
|
/* build and train model */
|
||
|
const model = sequential();
|
||
|
model.add(layers.dense({inputShape: [1], units: 1}));
|
||
|
model.compile({ optimizer: train.sgd(0.000001), loss: 'meanSquaredError' });
|
||
|
await model.fitDataset(batch, { epochs: 100, callbacks: { onEpochEnd: async (epoch, logs) => {
|
||
|
setOutput(`${epoch}:${logs?.loss}`);
|
||
|
}}});
|
||
|
|
||
|
/* predict values */
|
||
|
const inp = linspace(0, 1, 9);
|
||
|
const pred = model.predict(inp) as Tensor<Rank>;
|
||
|
const xs = await inp.dataSync(), ys = await pred.dataSync();
|
||
|
setResults(Array.from(xs).map((x, i) => [ x * (maxX - minX) + minX, ys[i] * (maxY - minY) + minY ]));
|
||
|
setOutput("");
|
||
|
} catch(e) { setOutput(`ERROR: ${String(e)}`); } finally { setDisabled(false);}
|
||
|
}, []);
|
||
|
|
||
|
return ( <>
|
||
|
<button onclick={doit} disabled={disabled}>Click to run</button><br/>
|
||
|
{output && <pre>{output}</pre> || <></>}
|
||
|
{results.length && <table><thead><tr><th>Horsepower</th><th>MPG</th></tr></thead><tbody>
|
||
|
{results.map((r,i) => <tr key={i}><td>{r[0]}</td><td>{r[1].toFixed(2)}</td></tr>)}
|
||
|
</tbody></table> || <></>}
|
||
|
</> );
|
||
|
}
|