August 31, 2025 in Machine Learning8 minutes
In my last blog, I set up a basic 2-2-1 neural network by hand to solve the XOR table. I used mean squared error (MSE) for the network’s loss function:
$$ \mathcal{L} = \tfrac{1}{2}(\hat{y} - y)^2 $$As a reminder, the loss function is how we calculate how far from the correct answer a network was during training.
However, as I discovered later, MSE is actually not well suited to this particular problem. The main reason is that we can get into a vanishing gradients problem very quickly, because MSE is more well-suited to regressions, instead of simple classifications like this. It doesn’t provide a strong enough gradient in response to “confidently wrong” answers. So the network’s ability to learn slows or even stalls.
With binary cross-entropy loss (BCE) on the other hand, if the network is very wrong, the gradient will be very large in response. This allows the network to be trained much more reliably and quickly.
I wanted to revisit the last post’s implementation, at a minimum updating the network to use the updated loss function, which has implications on how we determine the error term for each layer.
Let’s start with the loss function. Binary cross-entropy loss is defined as:
$$ \mathcal{L}(y, \hat{y}) = - \big[ y \cdot \log(\hat{y}) + (1 - y) \cdot \log(1 - \hat{y}) \big] $$This is important to call out as we’re about to factor this into our updated error term implementation, but also as we learned in the last post, we’ll want this to be calculated explicitly so we can track convergence during training.
As a reminder, the error term is the loss wrt the pre-activation. You can pretty easily find the answer via other online resources:
$$ \frac{\partial \mathcal{L}}{\partial z} = \hat y - y $$However, I still want to see how the sausage is made, especially since I don’t live and breathe calculus every day. This did end up being a little more complicated than the last post, so I still wanted to go step-by-step even if only for future me’s benefit. Feel free to expand the section below if you’re curious.
By now we know we’ll need the chain rule to expand this out a little bit.
$$ \frac{\partial \mathcal{L}}{\partial z} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial z} $$Let’s look at the left side first. As a reminder, here’s our loss function (BCE):
$$ \mathcal{L}(y,\hat y)=-\Big[y\ln(\hat y)+(1-y)\ln(1-\hat y)\Big] $$We can use the linearity of derivatives to break this down into terms:
$$ \frac{\partial}{\partial \hat y}\big[-(A+B)\big] = -\frac{\partial}{\partial \hat y}(A+B) = -\left(\frac{\partial A}{\partial \hat y} + \frac{\partial B}{\partial \hat y}\right), $$where left term $A=-y\ln\hat y$ and right term $B = (1-y)\ln(1-\hat y)$. We can treat each term individually, and then add them together at the end.
For the first term:
$$ \begin{aligned} \frac{\partial}{\partial \hat y} \Big(-y \ln \hat y\Big) &= -y \cdot \frac{\partial}{\partial \hat y} \ln \hat y \\ &= -y \cdot \frac{1}{\hat y} \\ &= \frac{-y \cdot 1}{\hat y} \\ &= -\frac{y}{\hat y} \end{aligned} $$At a high-level this is:
Second term is a little more complicated, but uses some of the same steps. The main complication is that we’re passing $1-\hat y$ into the natural log, not just $\hat y$ so we need an extra chain rule step:
$$ \begin{aligned} \frac{\partial}{\partial \hat y} \Big(-(1-y)\ln(1-\hat y)\Big) &= -(1-y) \cdot \frac{\partial}{\partial \hat y} \ln(1-\hat y) \\ &= -(1-y) \cdot \left( \frac{1}{1-\hat y} \cdot \frac{\partial}{\partial \hat y}(1-\hat y) \right) \\ &= -(1-y) \cdot \left( \frac{1}{1-\hat y} \cdot (-1) \right) \\ &= -(1-y) \cdot \left( -\frac{1}{1-\hat y} \right) \\ &= (1-y) \cdot \left( \frac{1}{1-\hat y} \right) \\ &= \frac{1-y}{1-\hat y} \end{aligned} $$Step-by-step:
Now we can re-join our two separated terms to get the loss w.r.t. $\hat y(1-\hat y)$, and simplify:
$$ \begin{aligned} \frac{\partial \mathcal{L}}{\partial \hat y} &= \frac{y}{\hat y} + \frac{1-y}{1-\hat y} \\ &= \frac{-y(1-\hat y)}{\hat y(1-\hat y)} + \frac{\hat y(1-y)}{\hat y(1-\hat y)} \\ &= \frac{-y(1-\hat y) + \hat y(1-y)}{\hat y(1-\hat y)} \\ &= \frac{-y + y\hat y + \hat y - y\hat y}{\hat y(1-\hat y)} \\ &= \frac{\hat y - y}{\hat y(1-\hat y)} \\ \end{aligned} $$Steps:
So our loss derivative w.r.t. $\hat y$ is:
$$ \displaystyle \frac{\partial \mathcal{L}}{\partial \hat y} = \frac{\hat y - y}{\hat y(1-\hat y)} $$Finally, we can go all the way back to our initial chain rule and substitute in the two coefficients:
$$ \begin{aligned} \frac{\partial \mathcal{L}}{\partial z} &= \frac{\partial \mathcal{L}}{\partial \hat y}\cdot\frac{\partial \hat y}{\partial z} \\ &= \frac{\hat y - y}{\hat y(1-\hat y)}\cdot \hat y(1-\hat y) \\ &= \boxed{\hat y - y} \end{aligned} $$I won’t always break things down in this much detail, but given I’m still getting my feet wet w.r.t. deep learning and anything to do with calculus was a long time ago, it’s a good exercise for me right now.
Below is a notebook containing the updated implementation. Try it out! But before we get there, here is a summary of the changes from the previous version. I am pretty sure that all of these changes contributed to the improvements
In my research I encountered a lot of ideas for improvement but settled on these. I believe all of them meaningfully contributed to an improvement in convergence reliability, but tbh they still depend on each other in complex ways, so it’s possible these can be whittled down to achieve the same or even greater benefit. Nevertheless, here they are:
Here’s the updated notebook:
View this interactive marimo notebook below (best on desktop), or in a separate tab.
I am definitely observing an improvement in convergence reliabilty. Where the previous approach would often require 20 or more retries (though sometimes very few), the new approach VERY rarely requires more than 1 or 2. It’s just an estimate running both by hand many times, but that’s what I’m observing.
Can retries be avoided?
We’re able to reduce the number of retries by increasing convergence reliability, but can we eliminate the need for them entirely? Guarantee convergence each time? In general yes, but from here, we’d have to either:
While this does result in an improvement, I did not choose a different activation function, or add more neurons. Such a change may make this network much more reliable. However, I’m finding that the time spent on this is having diminishing returns, as while a 2-2-1 network is great for learning the general concepts, optimizing it is not that worthwhile, at least for me. I’ll be moving on to larger networks and more real-world problems in future blog posts.