Hi everyone,
I’m very new to neural networks and just trying to figure out the right approach for my project and why my current code is not working as planned. I apologize in advance for my stupid questions.
I am trying to use tensorflow.js in a game, where I have 18 inputs, such as the x- and y-positions of the enemies, and 3 outputs: the cosine and sine of an angle between -180 and 180 (as suggested here: circular statistics - Encoding Angle Data for Neural Network - Cross Validated) used for directions and an output that controls moving or not moving (-1: do not move, 1: move).
Because the cos and sin results are between -1 and 1, I am using -1 and 1 for the third output as well and tanh as the activation function. I pass in 36 inputs (previous and current frame) to predict the direction to go in in the current frame and whether to move or not. The net then returns 6 outputs (3 for each frame), so I take the last 3 and put them into the player controller. I use returnSequences: true for the last lstm layer because I cannot seem to figure out how to make it output only the last 3 outputs without an error.
When I train the model using game data and then try to use it in-game, the outputs all center around certain values and barely vary, which leads to the player never moving at all or always moving in one direction. EDIT: This appears to be an issue of the amount of training/data. I’m beginning to see results with more training.
However, I still want to ask: Is my current method valid? What should I try to improve my approach? Training on my laptop takes very long, so I would like to improve training performance.
I also cannot find the memory leak - one tensor remains at the end. Can you find it? I think it is in “compile,” but using tf.tidy() or startScope/endScope did not help.
See my code below (just a test environment, not the entire game):
</div>
<div>
<span>Output after training: </span>
</div>
<div>
<span id = "Output">0</span>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
<script>
let net = tf.sequential();
console.log("number of tensors in memory before adding layers:", tf.memory().numTensors)
net.add(tf.layers.lstm({units: 18, inputShape: [null, 18], activation: "tanh", returnSequences: true}));
net.add(tf.layers.lstm({units: 18, activation: "tanh", returnSequences: true}));
net.add(tf.layers.lstm({units: 18, activation: "tanh", returnSequences: true}));
net.add(tf.layers.dense({units: 3, activation: "tanh"}));
console.log("number of tensors in memory after adding layers:", tf.memory().numTensors)
net.compile({loss: "meanSquaredError", optimizer: tf.train.sgd(0.3)});
console.log("number of tensors in memory after compiling model:", tf.memory().numTensors)
//Array of 360 items; one timestep is supposed to consist of 18 items, and the total length of the array should be variable
let totalInputData = [0.5,0.5,1.0265625847740325,0.4968331195707686,-0.027321307508046434,0.8432027740193291,-0.026595301826361762,0.43737337534322807,0.86542210185598,-0.02739700386481555,-0.027123958327189335,0.21580715437649517,1.026609147819722,0.5748529301190116,1.026614258178386,0.5789125974595326,1.02669943068863,0.3698218427429599,0.5,0.5,1.0218751695480652,0.49686131087028745,-0.02339261501609287,0.8406458165394973,-0.021940603652723524,0.43792694662348824,0.8627524432575345,-0.0235440077296311,-0.02299791665437867,0.21803166249386177,1.0219682956394436,0.5741932731696758,1.0219785163567718,0.5782179363814965,1.02214886137726,0.3709465539025108,0.5,0.5,1.0171877543220977,0.4968895021698064,-0.019463922524139303,0.8380888590596655,-0.01728590547908529,0.4384805179037484,0.860082784659089,-0.01969101159444665,-0.018871874981568004,0.22025617061122835,1.0173274434591655,0.5735336162203402,1.0173427745351575,0.5775232753034605,1.01759829206589,0.3720712650620616,0.5,0.5,1.0125003390961305,0.49691769346932524,-0.01553523003218574,0.8355319015798337,-0.012631207305447053,0.43903408918400866,0.8574131260606436,-0.0158380154592622,-0.014745833308757339,0.22248067872859495,1.0126865912788872,0.5728739592710044,1.0127070327135435,0.5768286142254244,1.01304772275452,0.3731959762216125,0.5,0.5,1.007812923870163,0.49694588476884416,-0.011606537540232175,0.8329749441000018,-0.007976509131808817,0.43958766046426884,0.8547434674621981,-0.01198501932407775,-0.010619791635946672,0.22470518684596152,1.008045739098609,0.5722143023216686,1.0080712908919294,0.5761339531473884,1.00849715344315,0.3743206873811633,0.5,0.5,1.0031255086441955,0.4969740760683631,-0.007677845048278611,0.83041798662017,-0.0033218109581705816,0.440141231744529,0.8520738088637525,-0.0081320231888933,-0.006493749963136006,0.2269296949633281,1.003404886918331,0.5715546453723329,1.0034355490703153,0.5754392920693523,1.0039465841317798,0.3754453985407141,0.5,0.5,0.9984380934182282,0.49700226736788194,-0.003749152556325047,0.8278610291403382,0.0013328872154676542,0.4406948030247892,0.849404150265307,-0.00427902705370885,-0.002367708290325339,0.2291542030806947,0.9987640347380526,0.5708949884229971,0.9987998072487011,0.5747446309913162,0.99939601482041,0.376570109700265,0.5,0.5,0.9937506781922607,0.49703045866740087,0.00017953993562851712,0.8253040716605063,0.00598758538910589,0.4412483743050494,0.8467344916668615,-0.00042603091852440046,0.0017583333824853276,0.23137871119806128,0.9941231825577745,0.5702353314736615,0.994164065427087,0.5740499699132802,0.9948454455090399,0.3776948208598158,0.5,0.5,0.9890632629662933,0.49705864996691973,0.004108232427582082,0.8227471141806746,0.010642283562744126,0.4418019455853096,0.8440648330684161,0.0034269652166600494,0.005884375055295994,0.23360321931542788,0.9894823303774963,0.5695756745243257,0.9895283236054728,0.5733553088352441,0.9902948761976699,0.3788195320193667,0.5,0.5,0.9843758477403259,0.49708684126643865,0.008036924919535646,0.8201901567008427,0.015296981736382362,0.4423555168655698,0.8413951744699706,0.0072799613518444994,0.010010416728106661,0.23582772743279445,0.9848414781972181,0.5689160175749899,0.9848925817838587,0.572660647757208,0.9857443068862999,0.3799442431789175,0.5,0.5,0.9796884325143586,0.49711503256595757,0.01196561741148921,0.8176331992210109,0.019951679910020597,0.44290908814582997,0.838725515871525,0.011132957487028949,0.014136458400917328,0.23805223555016103,0.9802006260169399,0.5682563606256542,0.9802568399622447,0.571965986679172,0.9811937375749299,0.3810689543384683,0.5,0.5,0.9750010172883912,0.49714322386547644,0.015894309903442774,0.8150762417411791,0.024606378083658835,0.44346265942609014,0.8360558572730795,0.014985953622213399,0.018262500073727993,0.24027674366752763,0.9755597738366617,0.5675967036763184,0.9756210981406305,0.5712713256011359,0.9766431682635599,0.3821936654980192,0.5,0.5,0.9703136020624237,0.49717141516499536,0.01982300239539634,0.8125192842613472,0.02926107625729707,0.4440162307063503,0.833386198674634,0.01883894975739785,0.02238854174653866,0.2425012517848942,0.9709189216563836,0.5669370467269828,0.9709853563190164,0.5705766645230999,0.9720925989521898,0.38331837665757,0.5,0.5,0.9656261868364563,0.4971996064645142,0.023751694887349902,0.8099623267815155,0.03391577443093531,0.44456980198661056,0.8307165400761886,0.022691945892582298,0.026514583419349324,0.2447257599022608,0.9662780694761054,0.566277389777647,0.9663496144974022,0.5698820034450638,0.9675420296408198,0.3844430878171209,0.5,0.5,0.9609387716104889,0.49722779776403314,0.027680387379303468,0.8074053693016836,0.038570472604573545,0.44512337326687074,0.8280468814777431,0.02654494202776675,0.03064062509215999,0.24695026801962738,0.9616372172958272,0.5656177328283112,0.9617138726757881,0.5691873423670277,0.9629914603294498,0.3855677989766717,0.5,0.5,0.9562513563845215,0.49725598906355206,0.03160907987125703,0.8048484118218517,0.04322517077821179,0.4456769445471309,0.8253772228792975,0.0303979381629512,0.03476666676497066,0.24917477613699396,0.956996365115549,0.5649580758789755,0.957078130854174,0.5684926812889917,0.9584408910180798,0.3866925101362225,0.5,0.5,0.9515639411585541,0.4972841803630709,0.035537772363210596,0.80229145434202,0.04787986895185002,0.4462305158273911,0.822707564280852,0.03425093429813565,0.03889270843778132,0.25139928425436053,0.9523555129352708,0.5642984189296397,0.9524423890325598,0.5677980202109556,0.9538903217067098,0.3878172212957734,0.5,0.5,0.9468765259325866,0.49731237166258985,0.03946646485516416,0.7997344968621881,0.05253456712548826,0.4467840871076513,0.8200379056824065,0.038103930433320096,0.04301875011059199,0.25362379237172716,0.9477146607549927,0.563638761980304,0.9478066472109458,0.5671033591329195,0.9493397523953397,0.3889419324553242,0.5,0.5,0.9421891107066193,0.4973405629621087,0.04339515734711773,0.7971775393823564,0.057189265299126504,0.4473376583879115,0.817368247083961,0.04195692656850455,0.047144791783402654,0.25584830048909374,0.9430738085747145,0.5629791050309683,0.9431709053893316,0.5664086980548835,0.9447891830839698,0.3900666436148751,0.5,0.5,0.9375016954806519,0.49736875426162763,0.0473238498390713,0.7946205819025245,0.06184396347276474,0.4478912296681717,0.8146985884855156,0.045809922703689,0.05127083345621332,0.2580728086064603,0.9384329563944362,0.5623194480816325,0.9385351635677175,0.5657140369768474,0.9402386137725998,0.3911913547744259];
console.log("length of input data", totalInputData.length)
//Array of 60 items; one timestep is supposed to consist of 3 items, and the total length of the array should be variable
let totalOutputData = [1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1];
console.log("expected output:", totalOutputData.slice(3, 6));
console.log("length of output data:", totalOutputData.length)
async function trainNet(inputs, outputs) {
console.log("training input data:", inputs, "training output data:", outputs);
console.time("training completed in");
console.log("number of tensors in memory before training:", tf.memory().numTensors)
let xs = tf.tensor3d(inputs, [1, inputs.length / 18, 18]);
let ys = tf.tensor3d(outputs, [1, outputs.length / 3, 3]);
const res = await net.fit(xs, ys, {
batchSize: 180,
epochs: 10
});
console.log(res);
console.log("loss after training:", res.history.loss);
xs.dispose();
ys.dispose();
console.timeEnd("training completed in");
console.log("number of tensors in memory after training:", tf.memory().numTensors)
}
function predictOutput() {
tf.engine().startScope();
let predictionData = [0.5,0.5,1.0265625847740325,0.4968331195707686,-0.027321307508046434,0.8432027740193291,-0.026595301826361762,0.43737337534322807,0.86542210185598,-0.02739700386481555,-0.027123958327189335,0.21580715437649517,1.026609147819722,0.5748529301190116,1.026614258178386,0.5789125974595326,1.02669943068863,0.3698218427429599,0.5,0.5,1.0218751695480652,0.49686131087028745,-0.02339261501609287,0.8406458165394973,-0.021940603652723524,0.43792694662348824,0.8627524432575345,-0.0235440077296311,-0.02299791665437867,0.21803166249386177,1.0219682956394436,0.5741932731696758,1.0219785163567718,0.5782179363814965,1.02214886137726,0.3709465539025108];
console.log("length of prediction data:", predictionData.length);
//Predicting the output based on a 3d tensor with 2 timesteps (predictionData corresponds to first 36 items in totalInputData)
let predictionTensor = tf.tensor3d(predictionData, [1, 2, 18]);
const predictions = net.predict(predictionTensor);
let netOutput = predictions.dataSync();
predictions.dispose();
predictionTensor.dispose();
//Displaying the last 3 outputs (outputs for the current frame) in the output sequence
document.getElementById("Output").innerHTML = [netOutput[3], netOutput[4], netOutput[5]];
tf.engine().endScope();
net.dispose();
//1 tensor remains in memory for some reason
console.log("number of tensors in memory after predicting and disposing net:", tf.memory().numTensors)
}
trainNet(totalInputData, totalOutputData).then(predictOutput);
//Intended output: [1,0,-1]
//Example output at 1000 epochs: [0.9696730375289917,-0.0013935886090621352,-0.9711311459541321]
</script>